Skip to content

Commit 6e2e4d2

Browse files
committed
only enable 64 block size support on architectures with 32 warp size
1 parent f7b4430 commit 6e2e4d2

File tree

2 files changed

+28
-10
lines changed

2 files changed

+28
-10
lines changed

csrc/kernels.hip

Lines changed: 27 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -3044,23 +3044,29 @@ MAKE_kQuantizeBlockwise(half, 1024, 4, 0, General8bit)
30443044
MAKE_kQuantizeBlockwise(half, 512, 2, 0, General8bit)
30453045
MAKE_kQuantizeBlockwise(half, 256, 2, 0, General8bit)
30463046
MAKE_kQuantizeBlockwise(half, 128, 2, 0, General8bit)
3047-
MAKE_kQuantizeBlockwise(half, 64, 2, 0, General8bit)
3047+
#if WARP_SIZE == 32
3048+
MAKE_kQuantizeBlockwise(half, 64, 2, 0, General8bit)
3049+
#endif
30483050

30493051
MAKE_kQuantizeBlockwise(half, 4096, 4, 0, FP4)
30503052
MAKE_kQuantizeBlockwise(half, 2048, 4, 0, FP4)
30513053
MAKE_kQuantizeBlockwise(half, 1024, 4, 0, FP4)
30523054
MAKE_kQuantizeBlockwise(half, 512, 2, 0, FP4)
30533055
MAKE_kQuantizeBlockwise(half, 256, 2, 0, FP4)
30543056
MAKE_kQuantizeBlockwise(half, 128, 2, 0, FP4)
3055-
MAKE_kQuantizeBlockwise(half, 64, 2, 0, FP4)
3057+
#if WARP_SIZE == 32
3058+
MAKE_kQuantizeBlockwise(half, 64, 2, 0, FP4)
3059+
#endif
30563060

30573061
MAKE_kQuantizeBlockwise(half, 4096, 4, 0, NF4)
30583062
MAKE_kQuantizeBlockwise(half, 2048, 4, 0, NF4)
30593063
MAKE_kQuantizeBlockwise(half, 1024, 4, 0, NF4)
30603064
MAKE_kQuantizeBlockwise(half, 512, 2, 0, NF4)
30613065
MAKE_kQuantizeBlockwise(half, 256, 2, 0, NF4)
30623066
MAKE_kQuantizeBlockwise(half, 128, 2, 0, NF4)
3063-
MAKE_kQuantizeBlockwise(half, 64, 2, 0, NF4)
3067+
#if WARP_SIZE == 32
3068+
MAKE_kQuantizeBlockwise(half, 64, 2, 0, NF4)
3069+
#endif
30643070

30653071
MAKE_kQuantizeBlockwise(float, 4096, 4, 0, General8bit)
30663072
MAKE_kQuantizeBlockwise(float, 4096, 4, 1, General8bit)
@@ -3069,23 +3075,29 @@ MAKE_kQuantizeBlockwise(float, 1024, 4, 0, General8bit)
30693075
MAKE_kQuantizeBlockwise(float, 512, 2, 0, General8bit)
30703076
MAKE_kQuantizeBlockwise(float, 256, 2, 0, General8bit)
30713077
MAKE_kQuantizeBlockwise(float, 128, 2, 0, General8bit)
3072-
MAKE_kQuantizeBlockwise(float, 64, 2, 0, General8bit)
3078+
#if WARP_SIZE == 32
3079+
MAKE_kQuantizeBlockwise(float, 64, 2, 0, General8bit)
3080+
#endif
30733081

30743082
MAKE_kQuantizeBlockwise(float, 4096, 4, 0, FP4)
30753083
MAKE_kQuantizeBlockwise(float, 2048, 4, 0, FP4)
30763084
MAKE_kQuantizeBlockwise(float, 1024, 4, 0, FP4)
30773085
MAKE_kQuantizeBlockwise(float, 512, 2, 0, FP4)
30783086
MAKE_kQuantizeBlockwise(float, 256, 2, 0, FP4)
30793087
MAKE_kQuantizeBlockwise(float, 128, 2, 0, FP4)
3080-
MAKE_kQuantizeBlockwise(float, 64, 2, 0, FP4)
3088+
#if WARP_SIZE == 32
3089+
MAKE_kQuantizeBlockwise(float, 64, 2, 0, FP4)
3090+
#endif
30813091

30823092
MAKE_kQuantizeBlockwise(float, 4096, 4, 0, NF4)
30833093
MAKE_kQuantizeBlockwise(float, 2048, 4, 0, NF4)
30843094
MAKE_kQuantizeBlockwise(float, 1024, 4, 0, NF4)
30853095
MAKE_kQuantizeBlockwise(float, 512, 2, 0, NF4)
30863096
MAKE_kQuantizeBlockwise(float, 256, 2, 0, NF4)
30873097
MAKE_kQuantizeBlockwise(float, 128, 2, 0, NF4)
3088-
MAKE_kQuantizeBlockwise(float, 64, 2, 0, NF4)
3098+
#if WARP_SIZE == 32
3099+
MAKE_kQuantizeBlockwise(float, 64, 2, 0, NF4)
3100+
#endif
30893101

30903102
MAKE_kQuantizeBlockwise(hip_bfloat16, 4096, 4, 0, General8bit)
30913103
MAKE_kQuantizeBlockwise(hip_bfloat16, 4096, 4, 1, General8bit)
@@ -3094,23 +3106,29 @@ MAKE_kQuantizeBlockwise(hip_bfloat16, 1024, 4, 0, General8bit)
30943106
MAKE_kQuantizeBlockwise(hip_bfloat16, 512, 2, 0, General8bit)
30953107
MAKE_kQuantizeBlockwise(hip_bfloat16, 256, 2, 0, General8bit)
30963108
MAKE_kQuantizeBlockwise(hip_bfloat16, 128, 2, 0, General8bit)
3097-
MAKE_kQuantizeBlockwise(hip_bfloat16, 64, 2, 0, General8bit)
3109+
#if WARP_SIZE == 32
3110+
MAKE_kQuantizeBlockwise(hip_bfloat16, 64, 2, 0, General8bit)
3111+
#endif
30983112

30993113
MAKE_kQuantizeBlockwise(hip_bfloat16, 4096, 4, 0, FP4)
31003114
MAKE_kQuantizeBlockwise(hip_bfloat16, 2048, 4, 0, FP4)
31013115
MAKE_kQuantizeBlockwise(hip_bfloat16, 1024, 4, 0, FP4)
31023116
MAKE_kQuantizeBlockwise(hip_bfloat16, 512, 2, 0, FP4)
31033117
MAKE_kQuantizeBlockwise(hip_bfloat16, 256, 2, 0, FP4)
31043118
MAKE_kQuantizeBlockwise(hip_bfloat16, 128, 2, 0, FP4)
3105-
MAKE_kQuantizeBlockwise(hip_bfloat16, 64, 2, 0, FP4)
3119+
#if WARP_SIZE == 32
3120+
MAKE_kQuantizeBlockwise(hip_bfloat16, 64, 2, 0, FP4)
3121+
#endif
31063122

31073123
MAKE_kQuantizeBlockwise(hip_bfloat16, 4096, 4, 0, NF4)
31083124
MAKE_kQuantizeBlockwise(hip_bfloat16, 2048, 4, 0, NF4)
31093125
MAKE_kQuantizeBlockwise(hip_bfloat16, 1024, 4, 0, NF4)
31103126
MAKE_kQuantizeBlockwise(hip_bfloat16, 512, 2, 0, NF4)
31113127
MAKE_kQuantizeBlockwise(hip_bfloat16, 256, 2, 0, NF4)
31123128
MAKE_kQuantizeBlockwise(hip_bfloat16, 128, 2, 0, NF4)
3113-
MAKE_kQuantizeBlockwise(hip_bfloat16, 64, 2, 0, NF4)
3129+
#if WARP_SIZE == 32
3130+
MAKE_kQuantizeBlockwise(hip_bfloat16, 64, 2, 0, NF4)
3131+
#endif
31143132

31153133
template __global__ void kDequantizeBlockwise<half, 512, 64, 8, FP4>(float *code, unsigned char * A, float * absmax, half *out, const int blocksize, const int n);
31163134
template __global__ void kDequantizeBlockwise<half, 512, 64, 8, General8bit>(float *code, unsigned char * A, float * absmax, half *out, const int blocksize, const int n);

csrc/ops.hip

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -57,7 +57,7 @@ template <typename T, int STOCHASTIC, int DATA_TYPE> void quantizeBlockwise(floa
5757
hipLaunchKernelGGL(( kQuantizeBlockwise<T, 256, 2, 0, DATA_TYPE>), dim3(num_blocks), dim3(128), 0, 0, code, A, absmax, out, rand, rand_offset, n);
5858
else if(blocksize == 128)
5959
hipLaunchKernelGGL(( kQuantizeBlockwise<T, 128, 2, 0, DATA_TYPE>), dim3(num_blocks), dim3(64), 0, 0, code, A, absmax, out, rand, rand_offset, n);
60-
else if(blocksize == 64)
60+
else if(blocksize == 64 && warpSize == 32)
6161
hipLaunchKernelGGL(( kQuantizeBlockwise<T, 64, 2, 0, DATA_TYPE>), dim3(num_blocks), dim3(32), 0, 0, code, A, absmax, out, rand, rand_offset, n);
6262

6363

0 commit comments

Comments
 (0)