Skip to content

Commit b954474

Browse files
Add comment to explain division optimization
1 parent 32a60c5 commit b954474

File tree

1 file changed

+10
-3
lines changed

1 file changed

+10
-3
lines changed

csrc/kernels.cu

Lines changed: 10 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -729,8 +729,11 @@ __global__ void kDequantizeBlockwise(float *code, unsigned char * A, float * abs
729729
valid_items_load = min(TILE_SIZE, n - i);
730730
valid_items_store = valid_items_load;
731731
}
732+
733+
// Since blocksize will always be a power-of-2, we avoid more expensive
734+
// division by the blocksize and instead use a shift operation.
735+
// This is equivalent to (i+threadId.x*NUM_PER_TH)/blocksize.
732736
local_abs_max = __ldg(&absmax[(i+threadIdx.x*NUM_PER_TH) >> (31 - __clz(blocksize))]);
733-
//local_abs_max = __ldg(&absmax[(i+threadIdx.x*NUM_PER_TH)/(blocksize)]);
734737

735738
__syncthreads();
736739
LoadChar(loadchar).Load(&(A[i]), qvals, valid_items_load, 128);
@@ -3579,9 +3582,13 @@ template <typename T, int THREADS, int BITS> __global__ void kgemm_4bit_inferenc
35793582
for(int inner_idx = warp_lane*num_values_4bit; inner_idx < K; inner_idx += 32*num_values_4bit)
35803583
{
35813584
const int inner_idx_halved = inner_idx/2;
3585+
3586+
// Since blocksize will always be a power-of-2, we avoid more expensive
3587+
// division by the blocksize and instead use a shift operation.
3588+
// This is equivalent to (i+threadId.x*NUM_PER_TH)/blocksize.
35823589
const int absidx = ((2*offset_B)+inner_idx) >> (31 - __clz(blocksize));
3583-
//int absidx = ((2*offset_B)+inner_idx)/blocksize;
3584-
local_absmax = __ldg(&(absmax[absidx]));
3590+
3591+
local_absmax = __ldg(&(absmax[absidx]));
35853592

35863593
if(row_B < M)
35873594
{

0 commit comments

Comments
 (0)