@@ -973,7 +973,7 @@ template<typename T, int OPTIMIZER>
973973__launch_bounds__ (TH, 1 )
974974__global__ void kOptimizer32bit2State(T* g, T* p,
975975 float * state1, float * state2, float *unorm, const float max_unorm, const float param_norm,
976- const float beta1, const float beta2, const float eps, const float weight_decay,
976+ const float beta1, const float beta2, const float beta3, const float alpha, const float eps, const float weight_decay,
977977 const int step, const float lr, const float gnorm_scale, const bool skip_zeros, const int n)
978978{
979979
@@ -1742,7 +1742,7 @@ template<typename T, int OPTIMIZER, int BLOCK_SIZE, int N_PER_TH>
17421742__launch_bounds__ (256 , 3 )
17431743__global__ void
17441744kOptimizerStatic8bit2StateBlockwise(T* p, T* __restrict__ const g, unsigned char * state1, unsigned char * state2,
1745- const float beta1, const float beta2,
1745+ const float beta1, const float beta2, const float beta3, const float alpha,
17461746 const float eps, const int step, const float lr,
17471747 float * __restrict__ const quantiles1, float * __restrict__ const quantiles2,
17481748 float * absmax1, float * absmax2,
@@ -2268,7 +2268,7 @@ template __global__ void kgetColRowStats<half, 64, 4, 16, 64*4, 1>(half * __rest
22682268
22692269#define MM_DEQUANT_CONST 6 .200012e-05f // 1.0f/(127.0f*127.0f)
22702270
2271- template <int ITEMS_PER_THREAD, int SUBTILE_ROWS, int THREADS>__global__ void kdequant_mm_int32_fp16 (int *__restrict__ const A, float *__restrict__ const rowStats, float *__restrict__ const colStats, half *out, float * newRowStats, float * newcolStats, half *__restrict__ const bias, const int numRows, const int numCols, const int tileCols , const int n)
2271+ template <int ITEMS_PER_THREAD, int SUBTILE_ROWS, int THREADS>__global__ void kdequant_mm_int32_fp16 (int *__restrict__ const A, float *__restrict__ const rowStats, float *__restrict__ const colStats, half *out, half *__restrict__ const bias, const int numRows, const int numCols, const int n)
22722272{
22732273
22742274 // Strategy: To dequantize we need to load col/row statistics. This can be very expensive
@@ -3851,7 +3851,7 @@ template __global__ void kTransformRowToFormat<256, 8, 32, 32*8, 1, COL_TURING>(
38513851template __global__ void kTransformRowToFormat <256 , 8 , 32 , 32 *8 , 0 , COL_AMPERE>(char *__restrict__ const A, char *out, int rows, int cols, int tiledCols, int outRows, int outCols);
38523852template __global__ void kTransformRowToFormat <256 , 8 , 32 , 32 *8 , 1 , COL_AMPERE>(char *__restrict__ const A, char *out, int rows, int cols, int tiledCols, int outRows, int outCols);
38533853
3854- template __global__ void kdequant_mm_int32_fp16<4 , 128 , 512 >(int *__restrict__ const A, float *__restrict__ const rowStats, float *__restrict__ const colStats, half *out, float * newRowStats, float * newcolStats, half * __restrict__ const bias, const int numRows, const int numCols, const int tileCols , const int n);
3854+ template __global__ void kdequant_mm_int32_fp16<4 , 128 , 512 >(int *__restrict__ const A, float *__restrict__ const rowStats, float *__restrict__ const colStats, half *out, half * __restrict__ const bias, const int numRows, const int numCols, const int n);
38553855
38563856template __global__ void kDoubleRowColQuant <64 , 4 , 16 , 64 *4 , 0 >(half *__restrict__ const A, float *__restrict__ const rowStats, float * __restrict__ const colStats, char *out_col_normed, char *out_row_normed, int *rowidx, int *colidx, half *val, int * __restrict__ nnz_block_ptr, float threshold, int rows, int cols, int tiledCols);
38573857template __global__ void kDoubleRowColQuant <64 , 4 , 16 , 64 *4 , 1 >(half *__restrict__ const A, float *__restrict__ const rowStats, float * __restrict__ const colStats, char *out_col_normed, char *out_row_normed, int *rowidx, int *colidx, half *val, int * __restrict__ nnz_block_ptr, float threshold, int rows, int cols, int tiledCols);
@@ -3903,11 +3903,11 @@ MAKE_PreconditionOptimizer32bit2State(ADAM, half)
39033903MAKE_PreconditionOptimizer32bit2State(ADAM, hip_bfloat16)
39043904
39053905template __global__ void kOptimizer32bit2State<float, ADAM>(float * g, float * p, float * state1, float * state2, float *unorm, const float max_unorm, const float param_norm,
3906- const float beta1, const float beta2, const float eps, const float weight_decay,const int step, const float lr, const float gnorm_scale, const bool skip_zeros, const int n);
3906+ const float beta1, const float beta2, const float beta3, const float alpha, const float eps, const float weight_decay,const int step, const float lr, const float gnorm_scale, const bool skip_zeros, const int n);
39073907template __global__ void kOptimizer32bit2State <half, ADAM>(half* g, half* p, float * state1, float * state2, float *unorm, const float max_unorm, const float param_norm,
3908- const float beta1, const float beta2, const float eps, const float weight_decay,const int step, const float lr, const float gnorm_scale, const bool skip_zeros, const int n);
3908+ const float beta1, const float beta2, const float beta3, const float alpha, const float eps, const float weight_decay,const int step, const float lr, const float gnorm_scale, const bool skip_zeros, const int n);
39093909template __global__ void kOptimizer32bit2State <hip_bfloat16, ADAM>(hip_bfloat16* g, hip_bfloat16* p, float * state1, float * state2, float *unorm, const float max_unorm, const float param_norm,
3910- const float beta1, const float beta2, const float eps, const float weight_decay,const int step, const float lr, const float gnorm_scale, const bool skip_zeros, const int n);
3910+ const float beta1, const float beta2, const float beta3, const float alpha, const float eps, const float weight_decay,const int step, const float lr, const float gnorm_scale, const bool skip_zeros, const int n);
39113911
39123912
39133913#define MAKE_PreconditionStatic8bit1State (oname, gtype ) \
@@ -4068,7 +4068,7 @@ template __global__ void kDequantizeBlockwise<hip_bfloat16, 512, 64, 8, NF4>(flo
40684068
40694069#define MAKE_OptimizerStatic8bit2StateBlockwise (oname, gtype, block_size, num_per_thread ) \
40704070template __global__ void kOptimizerStatic8bit2StateBlockwise <gtype, oname, block_size, num_per_thread>(gtype* p, gtype* __restrict__ const g, unsigned char * state1, unsigned char * state2, \
4071- const float beta1, const float beta2, \
4071+ const float beta1, const float beta2, const float beta3, const float alpha, \
40724072 const float eps, const int step, const float lr, \
40734073 float * __restrict__ const quantiles1, float * __restrict__ const quantiles2, \
40744074 float * absmax1, float * absmax2, \
0 commit comments