@@ -3540,8 +3540,8 @@ template <typename T, int THREADS, int BITS> __global__ void kgemm_4bit_inferenc
35403540 float local_C = 0 .0f ;
35413541
35423542 unsigned char local_B_4bit[num_values_8bit];
3543- T local_B[num_values_4bit];
3544- T local_A[num_values_4bit];
3543+ T local_B[num_values_4bit/ 4 ];
3544+ T local_A[num_values_4bit/ 4 ];
35453545 __shared__ T quant_map[16 ];
35463546 T local_absmax = T (0 .0f );
35473547
@@ -3582,61 +3582,55 @@ template <typename T, int THREADS, int BITS> __global__ void kgemm_4bit_inferenc
35823582 local_B_4bit[j] = 0b01110111 ;
35833583 }
35843584
3585- #pragma unroll
3586- for (int k = 0 ; k < num_values_8bit; k++)
3587- {
3588- #if __CUDA_ARCH__ >= 800
3589- local_B[k*2 ] = quant_map[local_B_4bit[k] >> 4 ]*local_absmax;
3590- local_B[k*2 + 1 ] = quant_map[local_B_4bit[k] & 0x0F ]*local_absmax;
3591- #else
3592- // bf16 multipliation not supported
3593- local_B[k*2 ] = T ((float )quant_map[local_B_4bit[k] >> 4 ]*(float )local_absmax);
3594- local_B[k*2 + 1 ] = T ((float )quant_map[local_B_4bit[k] & 0x0F ]*(float )local_absmax);
3595- #endif
3596- }
3597-
3598- if (inner_idx+num_values_4bit < K)
3585+ for (int i = 0 ; i < 4 ; i++)
35993586 {
3600- // this is also relatively important for performance
3601- if (BITS==16 )
3602- {
3603- reinterpret_cast <int4 (&)[num_values_4bit]>(local_A)[0 ] = reinterpret_cast <int4 *>(A)[inner_idx/(num_values_4bit/4 ) + 0 ];
3604- reinterpret_cast <int4 (&)[num_values_4bit]>(local_A)[1 ] = reinterpret_cast <int4 *>(A)[inner_idx/(num_values_4bit/4 ) + 1 ];
3605- reinterpret_cast <int4 (&)[num_values_4bit]>(local_A)[2 ] = reinterpret_cast <int4 *>(A)[inner_idx/(num_values_4bit/4 ) + 2 ];
3606- reinterpret_cast <int4 (&)[num_values_4bit]>(local_A)[3 ] = reinterpret_cast <int4 *>(A)[inner_idx/(num_values_4bit/4 ) + 3 ];
3607- }
3608- else
3587+ #pragma unroll
3588+ for (int k = 0 ; k < num_values_8bit/4 ; k++)
36093589 {
3610- reinterpret_cast < int4 (&)[num_values_4bit]>(local_A)[ 0 ] = reinterpret_cast < int4 *>(A)[inner_idx/(num_values_4bit/ 8 ) + 0 ];
3611- reinterpret_cast < int4 (&)[num_values_4bit]>(local_A)[ 1 ] = reinterpret_cast < int4 *>(A)[inner_idx/(num_values_4bit/ 8 ) + 1 ] ;
3612- reinterpret_cast < int4 (&)[num_values_4bit]>(local_A)[ 2 ] = reinterpret_cast < int4 *>(A)[inner_idx/(num_values_4bit/ 8 ) + 2 ] ;
3613- reinterpret_cast < int4 (&)[num_values_4bit]>(local_A)[ 3 ] = reinterpret_cast < int4 *>(A)[inner_idx/(num_values_4bit/ 8 ) + 3 ];
3614- reinterpret_cast < int4 (&)[num_values_4bit]>(local_A)[ 4 ] = reinterpret_cast < int4 *>(A)[inner_idx/(num_values_4bit/ 8 ) + 4 ];
3615- reinterpret_cast < int4 (&)[num_values_4bit]>(local_A)[ 5 ] = reinterpret_cast < int4 *>(A)[inner_idx/(num_values_4bit/ 8 ) + 5 ] ;
3616- reinterpret_cast < int4 (&)[num_values_4bit]>(local_A)[ 6 ] = reinterpret_cast < int4 *>(A)[inner_idx/(num_values_4bit/ 8 ) + 6 ] ;
3617- reinterpret_cast < int4 (&)[num_values_4bit]>(local_A)[ 7 ] = reinterpret_cast < int4 *>(A)[inner_idx/(num_values_4bit/ 8 ) + 7 ];
3590+ # if __CUDA_ARCH__ >= 800
3591+ local_B[k* 2 ] = quant_map[local_B_4bit[(i*num_values_8bit/ 4 ) + k] >> 4 ]*local_absmax ;
3592+ local_B[k* 2 + 1 ] = quant_map[local_B_4bit[(i*num_values_8bit/ 4 ) + k] & 0x0F ]*local_absmax ;
3593+ # else
3594+ // bf16 multipliation not supported
3595+ local_B[k* 2 ] = T (( float )quant_map[local_B_4bit[(i*num_values_8bit/ 4 ) + k] >> 4 ]*( float )local_absmax) ;
3596+ local_B[k* 2 + 1 ] = T (( float )quant_map[local_B_4bit[(i*num_values_8bit/ 4 ) + k] & 0x0F ]*( float )local_absmax) ;
3597+ # endif
36183598 }
36193599
3620- }
3621- else
3622- #pragma unroll
3623- for (int k = 0 ; k < num_values_4bit; k++)
3624- if (inner_idx + k < K)
3625- local_A[k] = A[inner_idx + k];
3600+ if (inner_idx+(num_values_4bit/4 ) + (i*num_values_4bit/4 ) < K)
3601+ {
3602+ // this is also relatively important for performance
3603+ if (BITS==16 )
3604+ {
3605+ reinterpret_cast <int4 (&)[num_values_4bit]>(local_A)[0 ] = reinterpret_cast <int4 *>(A)[inner_idx/(num_values_4bit/4 ) + i];
3606+ }
36263607 else
3627- local_A[k] = T (0 .0f );
3608+ {
3609+ reinterpret_cast <int4 (&)[num_values_4bit]>(local_A)[0 ] = reinterpret_cast <int4 *>(A)[inner_idx/(num_values_4bit/8 ) + (2 *i) + 0 ];
3610+ reinterpret_cast <int4 (&)[num_values_4bit]>(local_A)[1 ] = reinterpret_cast <int4 *>(A)[inner_idx/(num_values_4bit/8 ) + (2 *i) + 1 ];
3611+ }
36283612
3613+ }
3614+ else
3615+ #pragma unroll
3616+ for (int k = 0 ; k < num_values_4bit/4 ; k++)
3617+ if (inner_idx + (i*num_values_4bit/4 ) + k < K)
3618+ local_A[k] = A[inner_idx + k + (i*num_values_4bit/4 )];
3619+ else
3620+ local_A[k] = T (0 .0f );
36293621
3630- // accumulate in float; small performance hit for Ampere, but lower error for outputs
3631- #pragma unroll
3632- for (int k = 0 ; k < num_values_4bit; k++)
3633- {
3634- #if __CUDA_ARCH__ >= 800
3635- local_C += (float )(local_A[k]*local_B[k]);
3636- #else
3637- // bf16 multipliation not supported
3638- local_C += ((float )local_A[k]*(float )local_B[k]);
3639- #endif
3622+
3623+ // accumulate in float; small performance hit for Ampere, but lower error for outputs
3624+ #pragma unroll
3625+ for (int k = 0 ; k < num_values_4bit/4 ; k++)
3626+ {
3627+ #if __CUDA_ARCH__ >= 800
3628+ local_C += (float )(local_A[k]*local_B[k]);
3629+ #else
3630+ // bf16 multipliation not supported
3631+ local_C += ((float )local_A[k]*(float )local_B[k]);
3632+ #endif
3633+ }
36403634 }
36413635 }
36423636
0 commit comments