Skip to content

Commit ebb6797

Browse files
improve register usage of kInt8VectorQuant - especially for A100/H100
1 parent 73f02e8 commit ebb6797

File tree

1 file changed

+11
-9
lines changed

1 file changed

+11
-9
lines changed

csrc/kernels.cu

Lines changed: 11 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55

66
#include "kernels.cuh"
77
#include "common.cuh"
8+
#include <cuda_fp16.h>
89
#include <cub/block/block_radix_sort.cuh>
910
#include <cub/warp/warp_reduce.cuh>
1011
#include <cub/block/block_load.cuh>
@@ -2141,7 +2142,7 @@ kOptimizerStatic8bit1StateBlockwise(T* p, T* __restrict__ const g, unsigned char
21412142
template<typename T, int THREADS, int SPARSE_DECOMP>
21422143
__launch_bounds__(1024, BNB_MAX_THREADS_PER_SM / 1024)
21432144
__global__ void kInt8VectorQuant(T * __restrict__ A, int8_t* out, float* rowStats, float threshold, int rows, int cols) {
2144-
using BlockReduceT = cub::BlockReduce<float, THREADS>;
2145+
using BlockReduceT = cub::BlockReduce<T, THREADS>;
21452146

21462147
// One block per row.
21472148
// Threads load column values in a striped arrangement.
@@ -2151,27 +2152,27 @@ __global__ void kInt8VectorQuant(T * __restrict__ A, int8_t* out, float* rowStat
21512152
// We then do a blockwise reduction to determine the row's absmax.
21522153

21532154
__shared__ typename BlockReduceT::TempStorage temp_storage;
2154-
__shared__ float smem_row_absmax;
2155+
__shared__ T smem_row_absmax;
21552156

21562157
const int row_id = blockIdx.x;
2157-
const T* __restrict__ row_data = A + (row_id * cols);
2158+
const T* row_data = A + (row_id * cols);
21582159

21592160
// Threads will read the row values in a striped access pattern and find a local absmax.
2160-
float row_local_absmax = -FLT_MIN;
2161+
T row_local_absmax = -FLT_MIN;
21612162
for (int i = threadIdx.x; i < cols; i += THREADS) {
2162-
const float absval = fabsf(__ldcs(&(row_data[i])));
2163+
const T absval = fabsf(__ldcs(&(row_data[i])));
21632164

21642165
// For sparse decomposition, values outside of the threshold are not to be
21652166
// included when calculating the row's absmax.
21662167
if constexpr (SPARSE_DECOMP) {
2167-
row_local_absmax = fmaxf(row_local_absmax, absval < threshold ? absval : row_local_absmax);
2168+
row_local_absmax = fmaxf(row_local_absmax, absval < T(threshold) ? absval : row_local_absmax);
21682169
} else {
21692170
row_local_absmax = fmaxf(row_local_absmax, absval);
21702171
}
21712172
}
21722173

21732174
// Reduce thread-local absmax across the block.
2174-
const float row_absmax = BlockReduceT(temp_storage).Reduce(row_local_absmax, cub::Max(), cols);
2175+
const T row_absmax = BlockReduceT(temp_storage).Reduce(row_local_absmax, cub::Max(), cols);
21752176
if (threadIdx.x == 0) {
21762177
// Save our block's absmax to shared memory for the quantization step.
21772178
rowStats[row_id] = smem_row_absmax = row_absmax;
@@ -2181,13 +2182,14 @@ __global__ void kInt8VectorQuant(T * __restrict__ A, int8_t* out, float* rowStat
21812182
// Quantize row-wise.
21822183
const float scale = __fdividef(127.0f, smem_row_absmax);
21832184
for (int i = threadIdx.x; i < cols; i += THREADS) {
2185+
float val = row_data[i];
2186+
21842187
if constexpr (SPARSE_DECOMP) {
21852188
// For sparse decomposition, we do not want to quantize the outliers.
21862189
// Instead they're zeroed out.
2187-
float val = row_data[i];
21882190
out[row_id * cols + i] = fabs(val) < threshold ? __float2int_rn(val * scale) : 0;
21892191
} else {
2190-
out[row_id * cols + i] = __float2int_rn(float(row_data[i]) * scale);
2192+
out[row_id * cols + i] = __float2int_rn(val * scale);
21912193
}
21922194
}
21932195
}

0 commit comments

Comments
 (0)