Skip to content

Commit 4aad810

Browse files
authored
Merge pull request ROCm#49 from ROCm/gemv_4bit_warpsize_64
Update 4bit gemm kernel for warpsize 64
2 parents 64bc947 + 44f6602 commit 4aad810

File tree

2 files changed

+20
-24
lines changed

2 files changed

+20
-24
lines changed

csrc/kernels.hip

Lines changed: 14 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -3543,20 +3543,22 @@ template <typename T, int THREADS> __global__ void kgemm_4bit_inference(int M, i
35433543
#endif
35443544
}
35453545

3546+
#define warp_size __AMDGCN_WAVEFRONT_SIZE
3547+
// No of 4bit values processed by each thread
35463548
#define num_values_4bit 32
35473549
template <typename T, int THREADS, int BITS> __global__ void kgemm_4bit_inference_naive(int M, int N, int K, T * __restrict__ const A, unsigned char *B, float *absmax, const float *datatype, T * out, int lda, int ldb, int ldc, int blocksize)
35483550
{
35493551

35503552
// per threadblock:
3551-
// load step-by-step in chunks of [32,warps]: 1x32 * [32,warps] -> [1,warps]
3553+
// load step-by-step in chunks of [warp_size,warps]: 1xwarp_size * [warp_size,warps] -> [1,warps]
35523554
// 4 warps -> 4 loads per iter
3553-
// 1x32 * 32x4 -> 1x4 outputs per thread block
3554-
typedef hipcub::WarpReduce<float, 32> WarpReduce;
3555-
__shared__ typename WarpReduce::TempStorage temp_storage[THREADS/32];
3555+
// 1xwarp_size * warp_sizex4 -> 1x4 outputs per thread block
3556+
typedef hipcub::WarpReduce<float, warp_size> WarpReduce;
3557+
__shared__ typename WarpReduce::TempStorage temp_storage[THREADS/warp_size];
35563558

3557-
const int warp_idx = threadIdx.x / 32;
3558-
const int warp_lane = threadIdx.x % 32;
3559-
const int row_B = (THREADS/32)*blockIdx.x + warp_idx;
3559+
const int warp_idx = threadIdx.x / warp_size;
3560+
const int warp_lane = threadIdx.x % warp_size;
3561+
const int row_B = (THREADS/warp_size)*blockIdx.x + warp_idx;
35603562
const int num_values_8bit = num_values_4bit/2;
35613563
float local_C = 0.0f;
35623564

@@ -3571,8 +3573,8 @@ template <typename T, int THREADS, int BITS> __global__ void kgemm_4bit_inferenc
35713573
__syncthreads();
35723574

35733575
// A: [1, K]
3574-
// B: [N, K]
3575-
for(int inner_idx = warp_lane*num_values_4bit; inner_idx < K; inner_idx += 32*num_values_4bit)
3576+
// B: [M, K]
3577+
for(int inner_idx = warp_lane*num_values_4bit; inner_idx < K; inner_idx += warp_size*num_values_4bit)
35763578
{
35773579
int inner_idx_halved = inner_idx/2;
35783580
int offset_B = ldb*row_B;
@@ -3608,14 +3610,8 @@ template <typename T, int THREADS, int BITS> __global__ void kgemm_4bit_inferenc
36083610
#pragma unroll
36093611
for(int k = 0; k < num_values_8bit/4; k++)
36103612
{
3611-
#if __CUDA_ARCH__ >= 800
3612-
local_B[k*2] = quant_map[local_B_4bit[(i*num_values_8bit/4) + k] >> 4]*local_absmax;
3613-
local_B[k*2 + 1] = quant_map[local_B_4bit[(i*num_values_8bit/4) + k] & 0x0F]*local_absmax;
3614-
#else
3615-
// bf16 multipliation not supported
3616-
local_B[k*2] = T((float)quant_map[local_B_4bit[(i*num_values_8bit/4) + k] >> 4]*(float)local_absmax);
3617-
local_B[k*2 + 1] = T((float)quant_map[local_B_4bit[(i*num_values_8bit/4) + k] & 0x0F]*(float)local_absmax);
3618-
#endif
3613+
local_B[k*2] = quant_map[local_B_4bit[(i*num_values_8bit/4) + k] >> 4]*local_absmax;
3614+
local_B[k*2 + 1] = quant_map[local_B_4bit[(i*num_values_8bit/4) + k] & 0x0F]*local_absmax;
36193615
}
36203616

36213617
if(inner_idx+(num_values_4bit/4) + (i*num_values_4bit/4) < K)
@@ -3645,12 +3641,7 @@ template <typename T, int THREADS, int BITS> __global__ void kgemm_4bit_inferenc
36453641
#pragma unroll
36463642
for(int k = 0; k < num_values_4bit/4; k++)
36473643
{
3648-
#if __CUDA_ARCH__ >= 800
3649-
local_C += (float)(local_A[k]*local_B[k]);
3650-
#else
3651-
// bf16 multipliation not supported
3652-
local_C += ((float)local_A[k]*(float)local_B[k]);
3653-
#endif
3644+
local_C += (float)(local_A[k]*local_B[k]);
36543645
}
36553646
}
36563647
}

csrc/ops.hip

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -901,7 +901,12 @@ template <typename T> void gemm_4bit_inference(int m, int n, int k, T * A, unsi
901901
template <typename T, int BITS> void gemm_4bit_inference_naive(int m, int n, int k, T * A, unsigned char* B, float *absmax, float *datatype, T * out, int lda, int ldb, int ldc, int blocksize)
902902
{
903903

904-
int num_blocks = (m+3)/4;
904+
//warpsize - 32
905+
int num_blocks = (m+3)/4;
906+
//warpsize - 64
907+
#if __AMDGCN_WAVEFRONT_SIZE == 64
908+
num_blocks = (m+1)/2;
909+
#endif
905910

906911
hipLaunchKernelGGL(( kgemm_4bit_inference_naive<T, 128, BITS>), dim3(num_blocks), dim3(128), 0, 0 , m, n, k, A, B, absmax, datatype, out, lda, ldb, ldc, blocksize);
907912
CUDA_CHECK_RETURN(hipPeekAtLastError());

0 commit comments

Comments
 (0)