Skip to content

Commit 13a1c3d

Browse files
authored
add per_token_quant_bf16_int8 kernel (#939)
1 parent e540520 commit 13a1c3d

File tree

9 files changed

+411
-2
lines changed

9 files changed

+411
-2
lines changed

lightllm-kernel/csrc/ops_bindings.cpp

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,8 @@ PYBIND11_MODULE(_C, m) {
1010
m.def("rmsnorm_align16_bf16", &rmsnorm_align16_bf16, "RMSNORM (CUDA)");
1111
m.def("pre_tp_norm_bf16", &pre_tp_norm_bf16, "PRE TP NORM (CUDA)");
1212
m.def("post_tp_norm_bf16", &post_tp_norm_bf16, "POST TP NORM (CUDA)");
13-
m.def("per_token_quant_bf16_fp8", &per_token_quant_bf16_fp8, "PER TOKEN QUANT (CUDA)");
13+
m.def("per_token_quant_bf16_fp8", &per_token_quant_bf16_fp8, "PER TOKEN QUANT FP8 (CUDA)");
14+
m.def("per_token_quant_bf16_int8", &per_token_quant_bf16_int8, "PER TOKEN QUANT INT8 (CUDA)");
1415
m.def("add_norm_quant_bf16_fp8", &add_norm_quant_bf16_fp8, "ADD NORM QUANT FUSED (CUDA)");
1516
m.def("gelu_per_token_quant_bf16_fp8", &gelu_per_token_quant_bf16_fp8, "GELU QUANT FUSED (CUDA)");
1617
m.def("cutlass_scaled_mm", &cutlass_scaled_mm, "CUTLASS SCALED MM (CUDA)");

lightllm-kernel/csrc/quant/per_token_quantize_bf16.cu renamed to lightllm-kernel/csrc/quant/per_token_quantize_bf16_fp8.cu

File renamed without changes.
Lines changed: 338 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,338 @@
1+
#include "ops_common.h"
2+
#include "reduce/sm70.cuh"
3+
4+
5+
namespace lightllm {
6+
namespace ops {
7+
8+
using namespace lightllm;
9+
10+
// CUDA kernel for per token quantization from BF16 to INT8
11+
template<int32_t TPB>
12+
__global__ void device_per_token_quant_bf16_to_int8_general(
13+
const bf16_t* __restrict__ input, // Input tensor in BF16 format
14+
int8_t* __restrict__ output, // Output tensor in INT8 format
15+
fp32_t* __restrict__ scales, // Output scales for each token
16+
const int64_t M, // Number of rows in the input tensor
17+
const int64_t N
18+
) {
19+
const int32_t bid = blockIdx.x;
20+
const int32_t tid = threadIdx.x;
21+
constexpr fp32_t kINT8Max = 127.0f; // Maximum value representable in INT8 format
22+
23+
const bf16_t* _input = input + bid * N; // Input pointer for the token
24+
int8_t* _output = output + bid * N; // Output pointer for the token
25+
26+
fp32_t* _scales;
27+
_scales = scales + bid;
28+
29+
// Local arrays for intermediate storage
30+
int8_t local_int8;
31+
bf16_t local_bf16;
32+
33+
extern __shared__ bf16_t workspace1[];
34+
35+
fp32_t local_max = -FLT_MAX;
36+
for (int32_t i = tid; i < N; i += TPB) {
37+
local_bf16 = _input[i];
38+
workspace1[i] = local_bf16;
39+
40+
fp32_t tmp = cvt_bf16_f32(local_bf16);
41+
local_max = fmaxf(local_max, tmp);
42+
}
43+
44+
// Reduce the maximum value across the block
45+
const fp32_t reduced_max = lightllm::reduce::sm70::sync_block_reduce_max_f32<TPB>(local_max);
46+
47+
// Compute the scale factor with epsilon to avoid division by zero
48+
constexpr fp32_t epsilon = 1e-7f;
49+
const fp32_t scale = reduced_max / kINT8Max;
50+
const fp32_t inv_scale = 1.0f / (scale + epsilon);
51+
52+
for (int32_t i = tid; i < N; i += TPB) {
53+
local_bf16 = workspace1[i];
54+
55+
fp32_t tmp = cvt_bf16_f32(local_bf16);
56+
fp32_t x = tmp * inv_scale;
57+
local_int8 = float_to_int8_rn(x);
58+
59+
_output[i] = local_int8;
60+
}
61+
62+
if(tid == 0){
63+
*_scales = scale;
64+
}
65+
66+
}
67+
68+
// CUDA kernel for per token quantization from BF16 to INT8
69+
template<int32_t TPB>
70+
__global__ void device_per_token_quant_bf16_to_int8_vpt(
71+
const bf16_t* __restrict__ input, // Input tensor in BF16 format
72+
int8_t* __restrict__ output, // Output tensor in INT8 format
73+
fp32_t* __restrict__ scales, // Output scales for each token
74+
const int64_t M, // Number of rows in the input tensor
75+
const int32_t N
76+
) {
77+
constexpr int32_t VPT = 8;
78+
79+
const int32_t bid = blockIdx.x;
80+
const int32_t tid = threadIdx.x;
81+
constexpr fp32_t kINT8Max = 127.0f; // Maximum value representable in INT8 format
82+
83+
const bf16_t* _input = input + bid * N; // Input pointer for the token
84+
int8_t* _output = output + bid * N; // Output pointer for the token
85+
86+
fp32_t* _scales;
87+
_scales = scales + bid;
88+
89+
// Local arrays for intermediate storage
90+
int8_t local_int8[VPT];
91+
bf16x2_t local_bf16[VPT / 2];
92+
93+
extern __shared__ bf16x2_t workspace2[];
94+
95+
fp32_t local_max = -FLT_MAX;
96+
for (int32_t i = tid * VPT; i < N; i += TPB * VPT) {
97+
// Load VPT FP16 elements from global memory (_X) into local vector (local_x).
98+
vec_copy<sizeof(bf16_t) * VPT>(_input + i, local_bf16);
99+
100+
vec_copy<sizeof(bf16_t) * VPT>(local_bf16, workspace2 + (i >> 1));
101+
102+
// Compute the max for the VPT elements.
103+
#pragma unroll
104+
for(int32_t j = 0; j< VPT/2; j++){
105+
fp32x2_t tmp = bf16x2_to_fp32x2(local_bf16[j]);
106+
fp32_t max = fmaxf(fabsf(tmp.x), fabsf(tmp.y));
107+
local_max = fmaxf(local_max, max);
108+
}
109+
}
110+
111+
// Reduce the maximum value across the block
112+
const fp32_t reduced_max = lightllm::reduce::sm70::sync_block_reduce_max_f32<TPB>(local_max);
113+
114+
// Compute the scale factor with epsilon to avoid division by zero
115+
constexpr fp32_t epsilon = 1e-7f;
116+
const fp32_t scale = reduced_max / kINT8Max;
117+
const fp32_t inv_scale = 1.0f / (scale + epsilon);
118+
119+
for (int32_t i = tid * VPT; i < N; i += TPB * VPT) {
120+
vec_copy<sizeof(bf16_t) * VPT>(workspace2 + (i >> 1), local_bf16);
121+
122+
#pragma unroll
123+
for (int32_t j = 0; j < VPT/2; j++) {
124+
fp32x2_t x = bf16x2_to_fp32x2(local_bf16[j]);
125+
126+
int8_t a = float_to_int8_rn(x.x * inv_scale);
127+
int8_t b = float_to_int8_rn(x.y * inv_scale);
128+
129+
local_int8[2 * j] = a;
130+
local_int8[2 * j + 1] = b;
131+
}
132+
133+
vec_copy<sizeof(int8_t) * VPT>(local_int8, _output + i);
134+
}
135+
136+
if(tid == 0){
137+
*_scales = scale;
138+
}
139+
}
140+
141+
142+
143+
// CUDA kernel for per token quantization from BF16 to INT8
144+
template<int32_t TPB, int32_t N>
145+
__global__ void device_per_token_quant_bf16_to_int8(
146+
const bf16_t* __restrict__ input, // Input tensor in BF16 format
147+
int8_t* __restrict__ output, // Output tensor in INT8 format
148+
fp32_t* __restrict__ scales, // Output scales for each token
149+
const int64_t M // Number of rows in the input tensor
150+
) {
151+
constexpr int32_t VPT = 8;
152+
153+
static_assert(N % 2 == 0, "N must be even.");
154+
static_assert(N % VPT == 0, "N must be a multiple of VPT.");
155+
156+
const int32_t bid = blockIdx.x;
157+
const int32_t tid = threadIdx.x;
158+
constexpr fp32_t kINT8Max = 127.0f; // Maximum value representable in INT8 format
159+
160+
const bf16_t* _input = input + bid * N; // Input pointer for the token
161+
int8_t* _output = output + bid * N; // Output pointer for the token
162+
163+
fp32_t* _scales;
164+
_scales = scales + bid;
165+
166+
// Local arrays for intermediate storage
167+
int8_t local_int8[VPT];
168+
bf16x2_t local_bf16[VPT / 2];
169+
170+
__shared__ bf16x2_t workspace[N / 2];
171+
172+
fp32_t local_max = -FLT_MAX;
173+
for (int32_t i = tid * VPT; i < N; i += TPB * VPT) {
174+
// Load VPT FP16 elements from global memory (_X) into local vector (local_x).
175+
vec_copy<sizeof(bf16_t) * VPT>(_input + i, local_bf16);
176+
177+
vec_copy<sizeof(bf16_t) * VPT>(local_bf16, workspace + (i >> 1));
178+
179+
// Compute the max for the VPT elements.
180+
#pragma unroll
181+
for(int32_t j = 0; j< VPT/2; j++){
182+
fp32x2_t tmp = bf16x2_to_fp32x2(local_bf16[j]);
183+
fp32_t max = fmaxf(fabsf(tmp.x), fabsf(tmp.y));
184+
local_max = fmaxf(local_max, max);
185+
}
186+
}
187+
188+
// Reduce the maximum value across the block
189+
const fp32_t reduced_max = lightllm::reduce::sm70::sync_block_reduce_max_f32<TPB>(local_max);
190+
191+
// Compute the scale factor with epsilon to avoid division by zero
192+
constexpr fp32_t epsilon = 1e-7f;
193+
const fp32_t scale = reduced_max / kINT8Max;
194+
const fp32_t inv_scale = 1.0f / (scale + epsilon);
195+
196+
for (int32_t i = tid * VPT; i < N; i += TPB * VPT) {
197+
vec_copy<sizeof(bf16_t) * VPT>(workspace + (i >> 1), local_bf16);
198+
199+
#pragma unroll
200+
for (int32_t j = 0; j < VPT/2; j++) {
201+
fp32x2_t x = bf16x2_to_fp32x2(local_bf16[j]);
202+
203+
int8_t a = float_to_int8_rn(x.x * inv_scale);
204+
int8_t b = float_to_int8_rn(x.y * inv_scale);
205+
206+
local_int8[2 * j] = a;
207+
local_int8[2 * j + 1] = b;
208+
}
209+
210+
vec_copy<sizeof(int8_t) * VPT>(local_int8, _output + i);
211+
}
212+
213+
if(tid == 0){
214+
*_scales = scale;
215+
}
216+
}
217+
218+
219+
void per_token_quant_bf16_int8 (
220+
Tensor& output,
221+
const Tensor& input,
222+
Tensor& scales
223+
) {
224+
TORCH_CHECK(input.is_cuda(), "Input must be a CUDA tensor");
225+
TORCH_CHECK(input.dim() == 2, "Input must be 2-dimensional");
226+
TORCH_CHECK(input.scalar_type() == c10::kBFloat16, "Input must be BF16 type");
227+
228+
Tensor contiguous_input = input.is_contiguous() ? input : input.contiguous();
229+
Tensor contiguous_scales = scales.is_contiguous() ? scales : scales.contiguous();
230+
231+
const int64_t M = input.size(0);
232+
const int64_t N = input.size(1);
233+
234+
const int32_t blocks = M;
235+
236+
switch (N) {
237+
case 16:
238+
device_per_token_quant_bf16_to_int8<128, 16>
239+
<<<blocks, 128, 0, at::cuda::getCurrentCUDAStream()>>>(
240+
PTR<bf16_t>(contiguous_input),
241+
PTR<int8_t>(output),
242+
PTR<fp32_t>(contiguous_scales),
243+
M
244+
);
245+
break;
246+
case 32:
247+
device_per_token_quant_bf16_to_int8<128, 32>
248+
<<<blocks, 128, 0, at::cuda::getCurrentCUDAStream()>>>(
249+
PTR<bf16_t>(contiguous_input),
250+
PTR<int8_t>(output),
251+
PTR<fp32_t>(contiguous_scales),
252+
M
253+
);
254+
break;
255+
case 64:
256+
device_per_token_quant_bf16_to_int8<128, 64>
257+
<<<blocks, 128, 0, at::cuda::getCurrentCUDAStream()>>>(
258+
PTR<bf16_t>(contiguous_input),
259+
PTR<int8_t>(output),
260+
PTR<fp32_t>(contiguous_scales),
261+
M
262+
);
263+
break;
264+
case 512:
265+
device_per_token_quant_bf16_to_int8<128, 512>
266+
<<<blocks, 128, 0, at::cuda::getCurrentCUDAStream()>>>(
267+
PTR<bf16_t>(contiguous_input),
268+
PTR<int8_t>(output),
269+
PTR<fp32_t>(contiguous_scales),
270+
M
271+
);
272+
break;
273+
case 1024:
274+
device_per_token_quant_bf16_to_int8<128, 1024>
275+
<<<blocks, 128, 0, at::cuda::getCurrentCUDAStream()>>>(
276+
PTR<bf16_t>(contiguous_input),
277+
PTR<int8_t>(output),
278+
PTR<fp32_t>(contiguous_scales),
279+
M
280+
);
281+
break;
282+
case 3200:
283+
device_per_token_quant_bf16_to_int8<128, 3200>
284+
<<<blocks, 128, 0, at::cuda::getCurrentCUDAStream()>>>(
285+
PTR<bf16_t>(contiguous_input),
286+
PTR<int8_t>(output),
287+
PTR<fp32_t>(contiguous_scales),
288+
M
289+
);
290+
break;
291+
case 4096:
292+
device_per_token_quant_bf16_to_int8<128, 4096>
293+
<<<blocks, 128, 0, at::cuda::getCurrentCUDAStream()>>>(
294+
PTR<bf16_t>(contiguous_input),
295+
PTR<int8_t>(output),
296+
PTR<fp32_t>(contiguous_scales),
297+
M
298+
);
299+
break;
300+
case 12800:
301+
device_per_token_quant_bf16_to_int8<256, 12800>
302+
<<<blocks, 256, 0, at::cuda::getCurrentCUDAStream()>>>(
303+
PTR<bf16_t>(contiguous_input),
304+
PTR<int8_t>(output),
305+
PTR<fp32_t>(contiguous_scales),
306+
M
307+
);
308+
break;
309+
default: {
310+
static constexpr int TPB = 128;
311+
const int64_t shared_mem_size = N * sizeof(bf16_t);
312+
if (N % 8 == 0) {
313+
device_per_token_quant_bf16_to_int8_vpt<TPB>
314+
<<<blocks, TPB, shared_mem_size, at::cuda::getCurrentCUDAStream()>>>(
315+
PTR<bf16_t>(contiguous_input),
316+
PTR<int8_t>(output),
317+
PTR<fp32_t>(contiguous_scales),
318+
M,
319+
N
320+
);
321+
} else {
322+
device_per_token_quant_bf16_to_int8_general<TPB>
323+
<<<blocks, TPB, shared_mem_size, at::cuda::getCurrentCUDAStream()>>>(
324+
PTR<bf16_t>(contiguous_input),
325+
PTR<int8_t>(output),
326+
PTR<fp32_t>(contiguous_scales),
327+
M,
328+
N
329+
);
330+
}
331+
}
332+
}
333+
334+
return;
335+
}
336+
337+
} // namespace ops
338+
} // namespace lightllm

lightllm-kernel/include/ops_common.h

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,12 @@ void per_token_quant_bf16_fp8(
3232
Tensor& scales
3333
);
3434

35+
void per_token_quant_bf16_int8(
36+
Tensor& output,
37+
const Tensor& input,
38+
Tensor& scales
39+
);
40+
3541
std::tuple<Tensor, Tensor> add_norm_quant_bf16_fp8(
3642
Tensor& X, const Tensor &R, const Tensor &W,
3743
const fp32_t eps

lightllm-kernel/include/utils.h

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -68,6 +68,12 @@ __device__ inline bf16x2_t _float22bf162_rn(fp32x2_t val) {
6868
return bf16x2_t(low, high);
6969
}
7070

71+
__device__ inline int8_t float_to_int8_rn(fp32_t x) {
72+
uint32_t dst;
73+
asm volatile("cvt.rni.sat.s8.f32 %0, %1;" : "=r"(dst) : "f"(x));
74+
return reinterpret_cast<const int8_t&>(dst);
75+
}
76+
7177
template <typename T>
7278
__host__ __device__ T Cdiv(T numerator, T denominator) {
7379
return (numerator + denominator - 1) / denominator;

lightllm-kernel/lightllm_kernel/ops/__init__.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -77,14 +77,15 @@
7777
allgather_register_graph_buffers,
7878
allgather_get_graph_buffer_ipc_meta,
7979
)
80-
from .quant import per_token_quant_bf16_fp8
80+
from .quant import per_token_quant_bf16_fp8, per_token_quant_bf16_int8
8181
from .gemm import cutlass_scaled_mm_bias_ls
8282
from .moe import grouped_topk
8383
from .attention import group8_int8kv_flashdecoding_stage1, group_int8kv_decode_attention
8484

8585
__all__ = [
8686
"rmsnorm_bf16",
8787
"per_token_quant_bf16_fp8",
88+
"per_token_quant_bf16_int8",
8889
"pre_tp_norm_bf16",
8990
"post_tp_norm_bf16",
9091
"add_norm_quant_bf16_fp8",

0 commit comments

Comments
 (0)