Skip to content

Commit 3fabd1a

Browse files
ROCm: Fix compilation.
1 parent 94d6027 commit 3fabd1a

File tree

5 files changed

+167
-231
lines changed

5 files changed

+167
-231
lines changed

csrc/kernels.hip

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -973,7 +973,7 @@ template<typename T, int OPTIMIZER>
973973
__launch_bounds__(TH, 1)
974974
__global__ void kOptimizer32bit2State(T* g, T* p,
975975
float* state1, float* state2, float *unorm, const float max_unorm, const float param_norm,
976-
const float beta1, const float beta2, const float eps, const float weight_decay,
976+
const float beta1, const float beta2, const float beta3, const float alpha, const float eps, const float weight_decay,
977977
const int step, const float lr, const float gnorm_scale, const bool skip_zeros, const int n)
978978
{
979979

@@ -1742,7 +1742,7 @@ template<typename T, int OPTIMIZER, int BLOCK_SIZE, int N_PER_TH>
17421742
__launch_bounds__(256, 3)
17431743
__global__ void
17441744
kOptimizerStatic8bit2StateBlockwise(T* p, T* __restrict__ const g, unsigned char* state1, unsigned char* state2,
1745-
const float beta1, const float beta2,
1745+
const float beta1, const float beta2, const float beta3, const float alpha,
17461746
const float eps, const int step, const float lr,
17471747
float* __restrict__ const quantiles1, float* __restrict__ const quantiles2,
17481748
float* absmax1, float* absmax2,
@@ -2268,7 +2268,7 @@ template __global__ void kgetColRowStats<half, 64, 4, 16, 64*4, 1>(half * __rest
22682268

22692269
#define MM_DEQUANT_CONST 6.200012e-05f //1.0f/(127.0f*127.0f)
22702270

2271-
template <int ITEMS_PER_THREAD, int SUBTILE_ROWS, int THREADS>__global__ void kdequant_mm_int32_fp16(int *__restrict__ const A, float *__restrict__ const rowStats, float *__restrict__ const colStats, half *out, float* newRowStats, float* newcolStats, half *__restrict__ const bias, const int numRows, const int numCols, const int tileCols, const int n)
2271+
template <int ITEMS_PER_THREAD, int SUBTILE_ROWS, int THREADS>__global__ void kdequant_mm_int32_fp16(int *__restrict__ const A, float *__restrict__ const rowStats, float *__restrict__ const colStats, half *out, half *__restrict__ const bias, const int numRows, const int numCols, const int n)
22722272
{
22732273

22742274
// Strategy: To dequantize we need to load col/row statistics. This can be very expensive
@@ -3851,7 +3851,7 @@ template __global__ void kTransformRowToFormat<256, 8, 32, 32*8, 1, COL_TURING>(
38513851
template __global__ void kTransformRowToFormat<256, 8, 32, 32*8, 0, COL_AMPERE>(char *__restrict__ const A, char *out, int rows, int cols, int tiledCols, int outRows, int outCols);
38523852
template __global__ void kTransformRowToFormat<256, 8, 32, 32*8, 1, COL_AMPERE>(char *__restrict__ const A, char *out, int rows, int cols, int tiledCols, int outRows, int outCols);
38533853

3854-
template __global__ void kdequant_mm_int32_fp16<4, 128, 512>(int *__restrict__ const A, float *__restrict__ const rowStats, float *__restrict__ const colStats, half *out, float* newRowStats, float* newcolStats, half * __restrict__ const bias, const int numRows, const int numCols, const int tileCols, const int n);
3854+
template __global__ void kdequant_mm_int32_fp16<4, 128, 512>(int *__restrict__ const A, float *__restrict__ const rowStats, float *__restrict__ const colStats, half *out, half * __restrict__ const bias, const int numRows, const int numCols, const int n);
38553855

38563856
template __global__ void kDoubleRowColQuant<64, 4, 16, 64*4, 0>(half *__restrict__ const A, float *__restrict__ const rowStats, float * __restrict__ const colStats, char *out_col_normed, char *out_row_normed, int *rowidx, int *colidx, half *val, int * __restrict__ nnz_block_ptr, float threshold, int rows, int cols, int tiledCols);
38573857
template __global__ void kDoubleRowColQuant<64, 4, 16, 64*4, 1>(half *__restrict__ const A, float *__restrict__ const rowStats, float * __restrict__ const colStats, char *out_col_normed, char *out_row_normed, int *rowidx, int *colidx, half *val, int * __restrict__ nnz_block_ptr, float threshold, int rows, int cols, int tiledCols);
@@ -3903,11 +3903,11 @@ MAKE_PreconditionOptimizer32bit2State(ADAM, half)
39033903
MAKE_PreconditionOptimizer32bit2State(ADAM, hip_bfloat16)
39043904

39053905
template __global__ void kOptimizer32bit2State<float, ADAM>(float* g, float* p, float* state1, float* state2, float *unorm, const float max_unorm, const float param_norm,
3906-
const float beta1, const float beta2, const float eps, const float weight_decay,const int step, const float lr, const float gnorm_scale, const bool skip_zeros, const int n);
3906+
const float beta1, const float beta2, const float beta3, const float alpha, const float eps, const float weight_decay,const int step, const float lr, const float gnorm_scale, const bool skip_zeros, const int n);
39073907
template __global__ void kOptimizer32bit2State<half, ADAM>(half* g, half* p, float* state1, float* state2, float *unorm, const float max_unorm, const float param_norm,
3908-
const float beta1, const float beta2, const float eps, const float weight_decay,const int step, const float lr, const float gnorm_scale, const bool skip_zeros, const int n);
3908+
const float beta1, const float beta2, const float beta3, const float alpha, const float eps, const float weight_decay,const int step, const float lr, const float gnorm_scale, const bool skip_zeros, const int n);
39093909
template __global__ void kOptimizer32bit2State<hip_bfloat16, ADAM>(hip_bfloat16* g, hip_bfloat16* p, float* state1, float* state2, float *unorm, const float max_unorm, const float param_norm,
3910-
const float beta1, const float beta2, const float eps, const float weight_decay,const int step, const float lr, const float gnorm_scale, const bool skip_zeros, const int n);
3910+
const float beta1, const float beta2, const float beta3, const float alpha, const float eps, const float weight_decay,const int step, const float lr, const float gnorm_scale, const bool skip_zeros, const int n);
39113911

39123912

39133913
#define MAKE_PreconditionStatic8bit1State(oname, gtype) \
@@ -4068,7 +4068,7 @@ template __global__ void kDequantizeBlockwise<hip_bfloat16, 512, 64, 8, NF4>(flo
40684068

40694069
#define MAKE_OptimizerStatic8bit2StateBlockwise(oname, gtype, block_size, num_per_thread) \
40704070
template __global__ void kOptimizerStatic8bit2StateBlockwise<gtype, oname, block_size, num_per_thread>(gtype* p, gtype* __restrict__ const g, unsigned char* state1, unsigned char* state2, \
4071-
const float beta1, const float beta2, \
4071+
const float beta1, const float beta2, const float beta3, const float alpha, \
40724072
const float eps, const int step, const float lr, \
40734073
float* __restrict__ const quantiles1, float* __restrict__ const quantiles2, \
40744074
float* absmax1, float* absmax2, \

csrc/kernels_hip.cuh

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,7 @@ __global__ void kPreconditionOptimizer32bit2State(T* g, T* p,
3030
template<typename T, int OPTIMIZER>
3131
__global__ void kOptimizer32bit2State(T* g, T* p,
3232
float* state1, float* state2, float *unorm, const float max_unorm, const float param_norm,
33-
const float beta1, const float beta2, const float eps, const float weight_decay,
33+
const float beta1, const float beta2, const float beta3, const float alpha, const float eps, const float weight_decay,
3434
const int step, const float lr, const float gnorm_scale, const bool skip_zeros, const int n);
3535

3636
template<typename T, int OPTIMIZER, int BLOCK_SIZE, int NUM_VALS>
@@ -92,7 +92,7 @@ kOptimizerStatic8bit2State(T* p, T* const g, unsigned char* state1, unsigned cha
9292

9393
template<typename T, int OPTIMIZER, int BLOCK_SIZE, int N_PER_TH> __global__ void kOptimizerStatic8bit2StateBlockwise(
9494
T* p, T* __restrict__ const g, unsigned char* state1, unsigned char* state2,
95-
const float beta1, const float beta2, const float eps, const int step, const float lr,
95+
const float beta1, const float beta2, const float beta3, const float alpha, const float eps, const int step, const float lr,
9696
float* __restrict__ const quantiles1, float* __restrict__ const quantiles2,
9797
float* absmax1, float* absmax2, float weight_decay, const float gnorm_scale, const bool skip_zeros, const int n);
9898

@@ -116,7 +116,7 @@ template <typename T, int SPMM_ITEMS, int BITS> __global__ void kspmm_coo_very_s
116116

117117
template <int ITEMS_PER_THREAD, int SUBTILE_ROWS, int THREADS>__global__ void kdequant_mm_int32_fp16(
118118
int *__restrict__ const A, float *__restrict__ const rowStats, float *__restrict__ const colStats,
119-
half *out, float* newRowStats, float* newcolStats, half * __restrict__ const bias, const int numRows, const int numCols, const int tileCols, const int n);
119+
half *out, half * __restrict__ const bias, const int numRows, const int numCols, const int n);
120120

121121
template<typename T, int THREADS, int ITEMS_PER_THREAD, int TILE_ROWS, int TILE_COLS, int SPARSE_DECOMP> __global__ void kgetColRowStats(T * __restrict__ A, float *rowStats, float *colStats, int * nnz_count_row, float nnz_threshold, int rows, int cols, int tiledRows, int tiledCols);
122122
template <int THREADS, int ITEMS_PER_THREAD, int TILE_ROWS, int TILE_COLS, int SPARSE_DECOMP> __global__ void kDoubleRowColQuant(half *__restrict__ const A, float *__restrict__ const rowStats, float * __restrict__ const colStats, char *out_col_normed, char *out_row_normed, int *rowidx, int *colidx, half *val, int * __restrict__ nnz_block_ptr, float threshold, int rows, int cols, int tiledCols);

0 commit comments

Comments
 (0)