1+ #include " ops_common.h"
2+ #include " reduce/sm70.cuh"
3+
4+
5+ namespace lightllm {
6+ namespace ops {
7+
8+ using namespace lightllm ;
9+
10+ // CUDA kernel for per token quantization from BF16 to INT8
11+ template <int32_t TPB>
12+ __global__ void device_per_token_quant_bf16_to_int8_general (
13+ const bf16_t * __restrict__ input, // Input tensor in BF16 format
14+ int8_t * __restrict__ output, // Output tensor in INT8 format
15+ fp32_t * __restrict__ scales, // Output scales for each token
16+ const int64_t M, // Number of rows in the input tensor
17+ const int64_t N
18+ ) {
19+ const int32_t bid = blockIdx .x ;
20+ const int32_t tid = threadIdx .x ;
21+ constexpr fp32_t kINT8Max = 127 .0f ; // Maximum value representable in INT8 format
22+
23+ const bf16_t * _input = input + bid * N; // Input pointer for the token
24+ int8_t * _output = output + bid * N; // Output pointer for the token
25+
26+ fp32_t * _scales;
27+ _scales = scales + bid;
28+
29+ // Local arrays for intermediate storage
30+ int8_t local_int8;
31+ bf16_t local_bf16;
32+
33+ extern __shared__ bf16_t workspace1[];
34+
35+ fp32_t local_max = -FLT_MAX;
36+ for (int32_t i = tid; i < N; i += TPB) {
37+ local_bf16 = _input[i];
38+ workspace1[i] = local_bf16;
39+
40+ fp32_t tmp = cvt_bf16_f32 (local_bf16);
41+ local_max = fmaxf (local_max, tmp);
42+ }
43+
44+ // Reduce the maximum value across the block
45+ const fp32_t reduced_max = lightllm::reduce::sm70::sync_block_reduce_max_f32<TPB>(local_max);
46+
47+ // Compute the scale factor with epsilon to avoid division by zero
48+ constexpr fp32_t epsilon = 1e-7f ;
49+ const fp32_t scale = reduced_max / kINT8Max ;
50+ const fp32_t inv_scale = 1 .0f / (scale + epsilon);
51+
52+ for (int32_t i = tid; i < N; i += TPB) {
53+ local_bf16 = workspace1[i];
54+
55+ fp32_t tmp = cvt_bf16_f32 (local_bf16);
56+ fp32_t x = tmp * inv_scale;
57+ local_int8 = float_to_int8_rn (x);
58+
59+ _output[i] = local_int8;
60+ }
61+
62+ if (tid == 0 ){
63+ *_scales = scale;
64+ }
65+
66+ }
67+
68+ // CUDA kernel for per token quantization from BF16 to INT8
69+ template <int32_t TPB>
70+ __global__ void device_per_token_quant_bf16_to_int8_vpt (
71+ const bf16_t * __restrict__ input, // Input tensor in BF16 format
72+ int8_t * __restrict__ output, // Output tensor in INT8 format
73+ fp32_t * __restrict__ scales, // Output scales for each token
74+ const int64_t M, // Number of rows in the input tensor
75+ const int32_t N
76+ ) {
77+ constexpr int32_t VPT = 8 ;
78+
79+ const int32_t bid = blockIdx .x ;
80+ const int32_t tid = threadIdx .x ;
81+ constexpr fp32_t kINT8Max = 127 .0f ; // Maximum value representable in INT8 format
82+
83+ const bf16_t * _input = input + bid * N; // Input pointer for the token
84+ int8_t * _output = output + bid * N; // Output pointer for the token
85+
86+ fp32_t * _scales;
87+ _scales = scales + bid;
88+
89+ // Local arrays for intermediate storage
90+ int8_t local_int8[VPT];
91+ bf16x2_t local_bf16[VPT / 2 ];
92+
93+ extern __shared__ bf16x2_t workspace2[];
94+
95+ fp32_t local_max = -FLT_MAX;
96+ for (int32_t i = tid * VPT; i < N; i += TPB * VPT) {
97+ // Load VPT FP16 elements from global memory (_X) into local vector (local_x).
98+ vec_copy<sizeof (bf16_t ) * VPT>(_input + i, local_bf16);
99+
100+ vec_copy<sizeof (bf16_t ) * VPT>(local_bf16, workspace2 + (i >> 1 ));
101+
102+ // Compute the max for the VPT elements.
103+ #pragma unroll
104+ for (int32_t j = 0 ; j< VPT/2 ; j++){
105+ fp32x2_t tmp = bf16x2_to_fp32x2 (local_bf16[j]);
106+ fp32_t max = fmaxf (fabsf (tmp.x ), fabsf (tmp.y ));
107+ local_max = fmaxf (local_max, max);
108+ }
109+ }
110+
111+ // Reduce the maximum value across the block
112+ const fp32_t reduced_max = lightllm::reduce::sm70::sync_block_reduce_max_f32<TPB>(local_max);
113+
114+ // Compute the scale factor with epsilon to avoid division by zero
115+ constexpr fp32_t epsilon = 1e-7f ;
116+ const fp32_t scale = reduced_max / kINT8Max ;
117+ const fp32_t inv_scale = 1 .0f / (scale + epsilon);
118+
119+ for (int32_t i = tid * VPT; i < N; i += TPB * VPT) {
120+ vec_copy<sizeof (bf16_t ) * VPT>(workspace2 + (i >> 1 ), local_bf16);
121+
122+ #pragma unroll
123+ for (int32_t j = 0 ; j < VPT/2 ; j++) {
124+ fp32x2_t x = bf16x2_to_fp32x2 (local_bf16[j]);
125+
126+ int8_t a = float_to_int8_rn (x.x * inv_scale);
127+ int8_t b = float_to_int8_rn (x.y * inv_scale);
128+
129+ local_int8[2 * j] = a;
130+ local_int8[2 * j + 1 ] = b;
131+ }
132+
133+ vec_copy<sizeof (int8_t ) * VPT>(local_int8, _output + i);
134+ }
135+
136+ if (tid == 0 ){
137+ *_scales = scale;
138+ }
139+ }
140+
141+
142+
143+ // CUDA kernel for per token quantization from BF16 to INT8
144+ template <int32_t TPB, int32_t N>
145+ __global__ void device_per_token_quant_bf16_to_int8 (
146+ const bf16_t * __restrict__ input, // Input tensor in BF16 format
147+ int8_t * __restrict__ output, // Output tensor in INT8 format
148+ fp32_t * __restrict__ scales, // Output scales for each token
149+ const int64_t M // Number of rows in the input tensor
150+ ) {
151+ constexpr int32_t VPT = 8 ;
152+
153+ static_assert (N % 2 == 0 , " N must be even." );
154+ static_assert (N % VPT == 0 , " N must be a multiple of VPT." );
155+
156+ const int32_t bid = blockIdx .x ;
157+ const int32_t tid = threadIdx .x ;
158+ constexpr fp32_t kINT8Max = 127 .0f ; // Maximum value representable in INT8 format
159+
160+ const bf16_t * _input = input + bid * N; // Input pointer for the token
161+ int8_t * _output = output + bid * N; // Output pointer for the token
162+
163+ fp32_t * _scales;
164+ _scales = scales + bid;
165+
166+ // Local arrays for intermediate storage
167+ int8_t local_int8[VPT];
168+ bf16x2_t local_bf16[VPT / 2 ];
169+
170+ __shared__ bf16x2_t workspace[N / 2 ];
171+
172+ fp32_t local_max = -FLT_MAX;
173+ for (int32_t i = tid * VPT; i < N; i += TPB * VPT) {
174+ // Load VPT FP16 elements from global memory (_X) into local vector (local_x).
175+ vec_copy<sizeof (bf16_t ) * VPT>(_input + i, local_bf16);
176+
177+ vec_copy<sizeof (bf16_t ) * VPT>(local_bf16, workspace + (i >> 1 ));
178+
179+ // Compute the max for the VPT elements.
180+ #pragma unroll
181+ for (int32_t j = 0 ; j< VPT/2 ; j++){
182+ fp32x2_t tmp = bf16x2_to_fp32x2 (local_bf16[j]);
183+ fp32_t max = fmaxf (fabsf (tmp.x ), fabsf (tmp.y ));
184+ local_max = fmaxf (local_max, max);
185+ }
186+ }
187+
188+ // Reduce the maximum value across the block
189+ const fp32_t reduced_max = lightllm::reduce::sm70::sync_block_reduce_max_f32<TPB>(local_max);
190+
191+ // Compute the scale factor with epsilon to avoid division by zero
192+ constexpr fp32_t epsilon = 1e-7f ;
193+ const fp32_t scale = reduced_max / kINT8Max ;
194+ const fp32_t inv_scale = 1 .0f / (scale + epsilon);
195+
196+ for (int32_t i = tid * VPT; i < N; i += TPB * VPT) {
197+ vec_copy<sizeof (bf16_t ) * VPT>(workspace + (i >> 1 ), local_bf16);
198+
199+ #pragma unroll
200+ for (int32_t j = 0 ; j < VPT/2 ; j++) {
201+ fp32x2_t x = bf16x2_to_fp32x2 (local_bf16[j]);
202+
203+ int8_t a = float_to_int8_rn (x.x * inv_scale);
204+ int8_t b = float_to_int8_rn (x.y * inv_scale);
205+
206+ local_int8[2 * j] = a;
207+ local_int8[2 * j + 1 ] = b;
208+ }
209+
210+ vec_copy<sizeof (int8_t ) * VPT>(local_int8, _output + i);
211+ }
212+
213+ if (tid == 0 ){
214+ *_scales = scale;
215+ }
216+ }
217+
218+
219+ void per_token_quant_bf16_int8 (
220+ Tensor& output,
221+ const Tensor& input,
222+ Tensor& scales
223+ ) {
224+ TORCH_CHECK (input.is_cuda (), " Input must be a CUDA tensor" );
225+ TORCH_CHECK (input.dim () == 2 , " Input must be 2-dimensional" );
226+ TORCH_CHECK (input.scalar_type () == c10::kBFloat16 , " Input must be BF16 type" );
227+
228+ Tensor contiguous_input = input.is_contiguous () ? input : input.contiguous ();
229+ Tensor contiguous_scales = scales.is_contiguous () ? scales : scales.contiguous ();
230+
231+ const int64_t M = input.size (0 );
232+ const int64_t N = input.size (1 );
233+
234+ const int32_t blocks = M;
235+
236+ switch (N) {
237+ case 16 :
238+ device_per_token_quant_bf16_to_int8<128 , 16 >
239+ <<<blocks, 128 , 0 , at::cuda::getCurrentCUDAStream()>>> (
240+ PTR<bf16_t >(contiguous_input),
241+ PTR<int8_t >(output),
242+ PTR<fp32_t >(contiguous_scales),
243+ M
244+ );
245+ break ;
246+ case 32 :
247+ device_per_token_quant_bf16_to_int8<128 , 32 >
248+ <<<blocks, 128 , 0 , at::cuda::getCurrentCUDAStream()>>> (
249+ PTR<bf16_t >(contiguous_input),
250+ PTR<int8_t >(output),
251+ PTR<fp32_t >(contiguous_scales),
252+ M
253+ );
254+ break ;
255+ case 64 :
256+ device_per_token_quant_bf16_to_int8<128 , 64 >
257+ <<<blocks, 128 , 0 , at::cuda::getCurrentCUDAStream()>>> (
258+ PTR<bf16_t >(contiguous_input),
259+ PTR<int8_t >(output),
260+ PTR<fp32_t >(contiguous_scales),
261+ M
262+ );
263+ break ;
264+ case 512 :
265+ device_per_token_quant_bf16_to_int8<128 , 512 >
266+ <<<blocks, 128 , 0 , at::cuda::getCurrentCUDAStream()>>> (
267+ PTR<bf16_t >(contiguous_input),
268+ PTR<int8_t >(output),
269+ PTR<fp32_t >(contiguous_scales),
270+ M
271+ );
272+ break ;
273+ case 1024 :
274+ device_per_token_quant_bf16_to_int8<128 , 1024 >
275+ <<<blocks, 128 , 0 , at::cuda::getCurrentCUDAStream()>>> (
276+ PTR<bf16_t >(contiguous_input),
277+ PTR<int8_t >(output),
278+ PTR<fp32_t >(contiguous_scales),
279+ M
280+ );
281+ break ;
282+ case 3200 :
283+ device_per_token_quant_bf16_to_int8<128 , 3200 >
284+ <<<blocks, 128 , 0 , at::cuda::getCurrentCUDAStream()>>> (
285+ PTR<bf16_t >(contiguous_input),
286+ PTR<int8_t >(output),
287+ PTR<fp32_t >(contiguous_scales),
288+ M
289+ );
290+ break ;
291+ case 4096 :
292+ device_per_token_quant_bf16_to_int8<128 , 4096 >
293+ <<<blocks, 128 , 0 , at::cuda::getCurrentCUDAStream()>>> (
294+ PTR<bf16_t >(contiguous_input),
295+ PTR<int8_t >(output),
296+ PTR<fp32_t >(contiguous_scales),
297+ M
298+ );
299+ break ;
300+ case 12800 :
301+ device_per_token_quant_bf16_to_int8<256 , 12800 >
302+ <<<blocks, 256 , 0 , at::cuda::getCurrentCUDAStream()>>> (
303+ PTR<bf16_t >(contiguous_input),
304+ PTR<int8_t >(output),
305+ PTR<fp32_t >(contiguous_scales),
306+ M
307+ );
308+ break ;
309+ default : {
310+ static constexpr int TPB = 128 ;
311+ const int64_t shared_mem_size = N * sizeof (bf16_t );
312+ if (N % 8 == 0 ) {
313+ device_per_token_quant_bf16_to_int8_vpt<TPB>
314+ <<<blocks, TPB, shared_mem_size, at::cuda::getCurrentCUDAStream()>>> (
315+ PTR<bf16_t >(contiguous_input),
316+ PTR<int8_t >(output),
317+ PTR<fp32_t >(contiguous_scales),
318+ M,
319+ N
320+ );
321+ } else {
322+ device_per_token_quant_bf16_to_int8_general<TPB>
323+ <<<blocks, TPB, shared_mem_size, at::cuda::getCurrentCUDAStream()>>> (
324+ PTR<bf16_t >(contiguous_input),
325+ PTR<int8_t >(output),
326+ PTR<fp32_t >(contiguous_scales),
327+ M,
328+ N
329+ );
330+ }
331+ }
332+ }
333+
334+ return ;
335+ }
336+
337+ } // namespace ops
338+ } // namespace lightllm
0 commit comments