Skip to content

Commit f7b4430

Browse files
committed
uncomment 64 block size support in csrc
1 parent d607127 commit f7b4430

File tree

2 files changed

+11
-11
lines changed

2 files changed

+11
-11
lines changed

csrc/kernels.hip

Lines changed: 9 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -3044,23 +3044,23 @@ MAKE_kQuantizeBlockwise(half, 1024, 4, 0, General8bit)
30443044
MAKE_kQuantizeBlockwise(half, 512, 2, 0, General8bit)
30453045
MAKE_kQuantizeBlockwise(half, 256, 2, 0, General8bit)
30463046
MAKE_kQuantizeBlockwise(half, 128, 2, 0, General8bit)
3047-
//MAKE_kQuantizeBlockwise(half, 64, 2, 0, General8bit)
3047+
MAKE_kQuantizeBlockwise(half, 64, 2, 0, General8bit)
30483048

30493049
MAKE_kQuantizeBlockwise(half, 4096, 4, 0, FP4)
30503050
MAKE_kQuantizeBlockwise(half, 2048, 4, 0, FP4)
30513051
MAKE_kQuantizeBlockwise(half, 1024, 4, 0, FP4)
30523052
MAKE_kQuantizeBlockwise(half, 512, 2, 0, FP4)
30533053
MAKE_kQuantizeBlockwise(half, 256, 2, 0, FP4)
30543054
MAKE_kQuantizeBlockwise(half, 128, 2, 0, FP4)
3055-
//MAKE_kQuantizeBlockwise(half, 64, 2, 0, FP4)
3055+
MAKE_kQuantizeBlockwise(half, 64, 2, 0, FP4)
30563056

30573057
MAKE_kQuantizeBlockwise(half, 4096, 4, 0, NF4)
30583058
MAKE_kQuantizeBlockwise(half, 2048, 4, 0, NF4)
30593059
MAKE_kQuantizeBlockwise(half, 1024, 4, 0, NF4)
30603060
MAKE_kQuantizeBlockwise(half, 512, 2, 0, NF4)
30613061
MAKE_kQuantizeBlockwise(half, 256, 2, 0, NF4)
30623062
MAKE_kQuantizeBlockwise(half, 128, 2, 0, NF4)
3063-
//MAKE_kQuantizeBlockwise(half, 64, 2, 0, NF4)
3063+
MAKE_kQuantizeBlockwise(half, 64, 2, 0, NF4)
30643064

30653065
MAKE_kQuantizeBlockwise(float, 4096, 4, 0, General8bit)
30663066
MAKE_kQuantizeBlockwise(float, 4096, 4, 1, General8bit)
@@ -3069,23 +3069,23 @@ MAKE_kQuantizeBlockwise(float, 1024, 4, 0, General8bit)
30693069
MAKE_kQuantizeBlockwise(float, 512, 2, 0, General8bit)
30703070
MAKE_kQuantizeBlockwise(float, 256, 2, 0, General8bit)
30713071
MAKE_kQuantizeBlockwise(float, 128, 2, 0, General8bit)
3072-
//MAKE_kQuantizeBlockwise(float, 64, 2, 0, General8bit)
3072+
MAKE_kQuantizeBlockwise(float, 64, 2, 0, General8bit)
30733073

30743074
MAKE_kQuantizeBlockwise(float, 4096, 4, 0, FP4)
30753075
MAKE_kQuantizeBlockwise(float, 2048, 4, 0, FP4)
30763076
MAKE_kQuantizeBlockwise(float, 1024, 4, 0, FP4)
30773077
MAKE_kQuantizeBlockwise(float, 512, 2, 0, FP4)
30783078
MAKE_kQuantizeBlockwise(float, 256, 2, 0, FP4)
30793079
MAKE_kQuantizeBlockwise(float, 128, 2, 0, FP4)
3080-
//MAKE_kQuantizeBlockwise(float, 64, 2, 0, FP4)
3080+
MAKE_kQuantizeBlockwise(float, 64, 2, 0, FP4)
30813081

30823082
MAKE_kQuantizeBlockwise(float, 4096, 4, 0, NF4)
30833083
MAKE_kQuantizeBlockwise(float, 2048, 4, 0, NF4)
30843084
MAKE_kQuantizeBlockwise(float, 1024, 4, 0, NF4)
30853085
MAKE_kQuantizeBlockwise(float, 512, 2, 0, NF4)
30863086
MAKE_kQuantizeBlockwise(float, 256, 2, 0, NF4)
30873087
MAKE_kQuantizeBlockwise(float, 128, 2, 0, NF4)
3088-
//MAKE_kQuantizeBlockwise(float, 64, 2, 0, NF4)
3088+
MAKE_kQuantizeBlockwise(float, 64, 2, 0, NF4)
30893089

30903090
MAKE_kQuantizeBlockwise(hip_bfloat16, 4096, 4, 0, General8bit)
30913091
MAKE_kQuantizeBlockwise(hip_bfloat16, 4096, 4, 1, General8bit)
@@ -3094,23 +3094,23 @@ MAKE_kQuantizeBlockwise(hip_bfloat16, 1024, 4, 0, General8bit)
30943094
MAKE_kQuantizeBlockwise(hip_bfloat16, 512, 2, 0, General8bit)
30953095
MAKE_kQuantizeBlockwise(hip_bfloat16, 256, 2, 0, General8bit)
30963096
MAKE_kQuantizeBlockwise(hip_bfloat16, 128, 2, 0, General8bit)
3097-
//MAKE_kQuantizeBlockwise(hip_bfloat16, 64, 2, 0, General8bit)
3097+
MAKE_kQuantizeBlockwise(hip_bfloat16, 64, 2, 0, General8bit)
30983098

30993099
MAKE_kQuantizeBlockwise(hip_bfloat16, 4096, 4, 0, FP4)
31003100
MAKE_kQuantizeBlockwise(hip_bfloat16, 2048, 4, 0, FP4)
31013101
MAKE_kQuantizeBlockwise(hip_bfloat16, 1024, 4, 0, FP4)
31023102
MAKE_kQuantizeBlockwise(hip_bfloat16, 512, 2, 0, FP4)
31033103
MAKE_kQuantizeBlockwise(hip_bfloat16, 256, 2, 0, FP4)
31043104
MAKE_kQuantizeBlockwise(hip_bfloat16, 128, 2, 0, FP4)
3105-
//MAKE_kQuantizeBlockwise(hip_bfloat16, 64, 2, 0, FP4)
3105+
MAKE_kQuantizeBlockwise(hip_bfloat16, 64, 2, 0, FP4)
31063106

31073107
MAKE_kQuantizeBlockwise(hip_bfloat16, 4096, 4, 0, NF4)
31083108
MAKE_kQuantizeBlockwise(hip_bfloat16, 2048, 4, 0, NF4)
31093109
MAKE_kQuantizeBlockwise(hip_bfloat16, 1024, 4, 0, NF4)
31103110
MAKE_kQuantizeBlockwise(hip_bfloat16, 512, 2, 0, NF4)
31113111
MAKE_kQuantizeBlockwise(hip_bfloat16, 256, 2, 0, NF4)
31123112
MAKE_kQuantizeBlockwise(hip_bfloat16, 128, 2, 0, NF4)
3113-
//MAKE_kQuantizeBlockwise(hip_bfloat16, 64, 2, 0, NF4)
3113+
MAKE_kQuantizeBlockwise(hip_bfloat16, 64, 2, 0, NF4)
31143114

31153115
template __global__ void kDequantizeBlockwise<half, 512, 64, 8, FP4>(float *code, unsigned char * A, float * absmax, half *out, const int blocksize, const int n);
31163116
template __global__ void kDequantizeBlockwise<half, 512, 64, 8, General8bit>(float *code, unsigned char * A, float * absmax, half *out, const int blocksize, const int n);

csrc/ops.hip

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -57,8 +57,8 @@ template <typename T, int STOCHASTIC, int DATA_TYPE> void quantizeBlockwise(floa
5757
hipLaunchKernelGGL(( kQuantizeBlockwise<T, 256, 2, 0, DATA_TYPE>), dim3(num_blocks), dim3(128), 0, 0, code, A, absmax, out, rand, rand_offset, n);
5858
else if(blocksize == 128)
5959
hipLaunchKernelGGL(( kQuantizeBlockwise<T, 128, 2, 0, DATA_TYPE>), dim3(num_blocks), dim3(64), 0, 0, code, A, absmax, out, rand, rand_offset, n);
60-
//else if(blocksize == 64)
61-
// hipLaunchKernelGGL(( kQuantizeBlockwise<T, 64, 2, 0, DATA_TYPE>), dim3(num_blocks), dim3(32), 0, 0, code, A, absmax, out, rand, rand_offset, n);
60+
else if(blocksize == 64)
61+
hipLaunchKernelGGL(( kQuantizeBlockwise<T, 64, 2, 0, DATA_TYPE>), dim3(num_blocks), dim3(32), 0, 0, code, A, absmax, out, rand, rand_offset, n);
6262

6363

6464
CUDA_CHECK_RETURN(hipPeekAtLastError());

0 commit comments

Comments
 (0)