You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
Copy file name to clipboardExpand all lines: csrc/kernels.hip
+14-23Lines changed: 14 additions & 23 deletions
Original file line number
Diff line number
Diff line change
@@ -3543,20 +3543,22 @@ template <typename T, int THREADS> __global__ void kgemm_4bit_inference(int M, i
3543
3543
#endif
3544
3544
}
3545
3545
3546
+
#definewarp_size __AMDGCN_WAVEFRONT_SIZE
3547
+
// No of 4bit values processed by each thread
3546
3548
#definenum_values_4bit32
3547
3549
template <typename T, int THREADS, int BITS> __global__ voidkgemm_4bit_inference_naive(int M, int N, int K, T * __restrict__ const A, unsignedchar *B, float *absmax, constfloat *datatype, T * out, int lda, int ldb, int ldc, int blocksize)
3548
3550
{
3549
3551
3550
3552
// per threadblock:
3551
-
// load step-by-step in chunks of [32,warps]: 1x32 * [32,warps] -> [1,warps]
3553
+
// load step-by-step in chunks of [warp_size,warps]: 1xwarp_size * [warp_size,warps] -> [1,warps]
Copy file name to clipboardExpand all lines: csrc/ops.hip
+6-1Lines changed: 6 additions & 1 deletion
Original file line number
Diff line number
Diff line change
@@ -901,7 +901,12 @@ template <typename T> void gemm_4bit_inference(int m, int n, int k, T * A, unsi
901
901
template <typename T, int BITS> voidgemm_4bit_inference_naive(int m, int n, int k, T * A, unsignedchar* B, float *absmax, float *datatype, T * out, int lda, int ldb, int ldc, int blocksize)
902
902
{
903
903
904
-
int num_blocks = (m+3)/4;
904
+
//warpsize - 32
905
+
int num_blocks = (m+3)/4;
906
+
//warpsize - 64
907
+
#if __AMDGCN_WAVEFRONT_SIZE == 64
908
+
num_blocks = (m+1)/2;
909
+
#endif
905
910
906
911
hipLaunchKernelGGL(( kgemm_4bit_inference_naive<T, 128, BITS>), dim3(num_blocks), dim3(128), 0, 0 , m, n, k, A, B, absmax, datatype, out, lda, ldb, ldc, blocksize);
0 commit comments