Skip to content

Commit a71b1b2

Browse files
authored
fix:add fabsf() to general kernel when compare max values (#943)
1 parent 13a1c3d commit a71b1b2

File tree

4 files changed

+26
-52
lines changed

4 files changed

+26
-52
lines changed

lightllm-kernel/csrc/quant/per_token_quantize_bf16_fp8.cu

Lines changed: 11 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,6 @@ __global__ void device_per_token_quant_bf16_to_fp8_general(
1313
const bf16_t* __restrict__ input, // Input tensor in BF16 format
1414
fp8_e4m3_t* __restrict__ output, // Output tensor in FP8 format
1515
fp32_t* __restrict__ scales, // Output scales for each token
16-
const int64_t M, // Number of rows in the input tensor
1716
const int64_t N
1817
) {
1918
const int32_t bid = blockIdx.x;
@@ -38,7 +37,7 @@ __global__ void device_per_token_quant_bf16_to_fp8_general(
3837
workspace1[i] = local_bf16;
3938

4039
fp32_t tmp = cvt_bf16_f32(local_bf16);
41-
local_max = fmaxf(local_max, tmp);
40+
local_max = fmaxf(local_max, fabsf(tmp));
4241
}
4342

4443
// Reduce the maximum value across the block
@@ -71,7 +70,6 @@ __global__ void device_per_token_quant_bf16_to_fp8_vpt(
7170
const bf16_t* __restrict__ input, // Input tensor in BF16 format
7271
fp8_e4m3_t* __restrict__ output, // Output tensor in FP8 format
7372
fp32_t* __restrict__ scales, // Output scales for each token
74-
const int64_t M, // Number of rows in the input tensor
7573
const int32_t N
7674
) {
7775
constexpr int32_t VPT = 8;
@@ -147,8 +145,7 @@ template<int32_t TPB, int32_t N>
147145
__global__ void device_per_token_quant_bf16_to_fp8(
148146
const bf16_t* __restrict__ input, // Input tensor in BF16 format
149147
fp8_e4m3_t* __restrict__ output, // Output tensor in FP8 format
150-
fp32_t* __restrict__ scales, // Output scales for each token
151-
const int64_t M // Number of rows in the input tensor
148+
fp32_t* __restrict__ scales // Output scales for each token
152149
) {
153150
constexpr int32_t VPT = 8;
154151

@@ -243,71 +240,63 @@ void per_token_quant_bf16_fp8 (
243240
<<<blocks, 128, 0, at::cuda::getCurrentCUDAStream()>>>(
244241
PTR<bf16_t>(contiguous_input),
245242
PTR<fp8_e4m3_t>(output),
246-
PTR<fp32_t>(contiguous_scales),
247-
M
243+
PTR<fp32_t>(contiguous_scales)
248244
);
249245
break;
250246
case 32:
251247
device_per_token_quant_bf16_to_fp8<128, 32>
252248
<<<blocks, 128, 0, at::cuda::getCurrentCUDAStream()>>>(
253249
PTR<bf16_t>(contiguous_input),
254250
PTR<fp8_e4m3_t>(output),
255-
PTR<fp32_t>(contiguous_scales),
256-
M
251+
PTR<fp32_t>(contiguous_scales)
257252
);
258253
break;
259254
case 64:
260255
device_per_token_quant_bf16_to_fp8<128, 64>
261256
<<<blocks, 128, 0, at::cuda::getCurrentCUDAStream()>>>(
262257
PTR<bf16_t>(contiguous_input),
263258
PTR<fp8_e4m3_t>(output),
264-
PTR<fp32_t>(contiguous_scales),
265-
M
259+
PTR<fp32_t>(contiguous_scales)
266260
);
267261
break;
268262
case 512:
269263
device_per_token_quant_bf16_to_fp8<128, 512>
270264
<<<blocks, 128, 0, at::cuda::getCurrentCUDAStream()>>>(
271265
PTR<bf16_t>(contiguous_input),
272266
PTR<fp8_e4m3_t>(output),
273-
PTR<fp32_t>(contiguous_scales),
274-
M
267+
PTR<fp32_t>(contiguous_scales)
275268
);
276269
break;
277270
case 1024:
278271
device_per_token_quant_bf16_to_fp8<128, 1024>
279272
<<<blocks, 128, 0, at::cuda::getCurrentCUDAStream()>>>(
280273
PTR<bf16_t>(contiguous_input),
281274
PTR<fp8_e4m3_t>(output),
282-
PTR<fp32_t>(contiguous_scales),
283-
M
275+
PTR<fp32_t>(contiguous_scales)
284276
);
285277
break;
286278
case 3200:
287279
device_per_token_quant_bf16_to_fp8<128, 3200>
288280
<<<blocks, 128, 0, at::cuda::getCurrentCUDAStream()>>>(
289281
PTR<bf16_t>(contiguous_input),
290282
PTR<fp8_e4m3_t>(output),
291-
PTR<fp32_t>(contiguous_scales),
292-
M
283+
PTR<fp32_t>(contiguous_scales)
293284
);
294285
break;
295286
case 4096:
296287
device_per_token_quant_bf16_to_fp8<128, 4096>
297288
<<<blocks, 128, 0, at::cuda::getCurrentCUDAStream()>>>(
298289
PTR<bf16_t>(contiguous_input),
299290
PTR<fp8_e4m3_t>(output),
300-
PTR<fp32_t>(contiguous_scales),
301-
M
291+
PTR<fp32_t>(contiguous_scales)
302292
);
303293
break;
304294
case 12800:
305295
device_per_token_quant_bf16_to_fp8<256, 12800>
306296
<<<blocks, 256, 0, at::cuda::getCurrentCUDAStream()>>>(
307297
PTR<bf16_t>(contiguous_input),
308298
PTR<fp8_e4m3_t>(output),
309-
PTR<fp32_t>(contiguous_scales),
310-
M
299+
PTR<fp32_t>(contiguous_scales)
311300
);
312301
break;
313302
default: {
@@ -319,7 +308,6 @@ void per_token_quant_bf16_fp8 (
319308
PTR<bf16_t>(contiguous_input),
320309
PTR<fp8_e4m3_t>(output),
321310
PTR<fp32_t>(contiguous_scales),
322-
M,
323311
N
324312
);
325313
} else {
@@ -328,7 +316,6 @@ void per_token_quant_bf16_fp8 (
328316
PTR<bf16_t>(contiguous_input),
329317
PTR<fp8_e4m3_t>(output),
330318
PTR<fp32_t>(contiguous_scales),
331-
M,
332319
N
333320
);
334321
}
@@ -339,4 +326,4 @@ void per_token_quant_bf16_fp8 (
339326
}
340327

341328
} // namespace ops
342-
} // namespace lightllm
329+
} // namespace lightllm

lightllm-kernel/csrc/quant/per_token_quantize_bf16_int8.cu

Lines changed: 11 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,6 @@ __global__ void device_per_token_quant_bf16_to_int8_general(
1313
const bf16_t* __restrict__ input, // Input tensor in BF16 format
1414
int8_t* __restrict__ output, // Output tensor in INT8 format
1515
fp32_t* __restrict__ scales, // Output scales for each token
16-
const int64_t M, // Number of rows in the input tensor
1716
const int64_t N
1817
) {
1918
const int32_t bid = blockIdx.x;
@@ -38,7 +37,7 @@ __global__ void device_per_token_quant_bf16_to_int8_general(
3837
workspace1[i] = local_bf16;
3938

4039
fp32_t tmp = cvt_bf16_f32(local_bf16);
41-
local_max = fmaxf(local_max, tmp);
40+
local_max = fmaxf(local_max, fabsf(tmp));
4241
}
4342

4443
// Reduce the maximum value across the block
@@ -71,7 +70,6 @@ __global__ void device_per_token_quant_bf16_to_int8_vpt(
7170
const bf16_t* __restrict__ input, // Input tensor in BF16 format
7271
int8_t* __restrict__ output, // Output tensor in INT8 format
7372
fp32_t* __restrict__ scales, // Output scales for each token
74-
const int64_t M, // Number of rows in the input tensor
7573
const int32_t N
7674
) {
7775
constexpr int32_t VPT = 8;
@@ -145,8 +143,7 @@ template<int32_t TPB, int32_t N>
145143
__global__ void device_per_token_quant_bf16_to_int8(
146144
const bf16_t* __restrict__ input, // Input tensor in BF16 format
147145
int8_t* __restrict__ output, // Output tensor in INT8 format
148-
fp32_t* __restrict__ scales, // Output scales for each token
149-
const int64_t M // Number of rows in the input tensor
146+
fp32_t* __restrict__ scales // Output scales for each token
150147
) {
151148
constexpr int32_t VPT = 8;
152149

@@ -239,71 +236,63 @@ void per_token_quant_bf16_int8 (
239236
<<<blocks, 128, 0, at::cuda::getCurrentCUDAStream()>>>(
240237
PTR<bf16_t>(contiguous_input),
241238
PTR<int8_t>(output),
242-
PTR<fp32_t>(contiguous_scales),
243-
M
239+
PTR<fp32_t>(contiguous_scales)
244240
);
245241
break;
246242
case 32:
247243
device_per_token_quant_bf16_to_int8<128, 32>
248244
<<<blocks, 128, 0, at::cuda::getCurrentCUDAStream()>>>(
249245
PTR<bf16_t>(contiguous_input),
250246
PTR<int8_t>(output),
251-
PTR<fp32_t>(contiguous_scales),
252-
M
247+
PTR<fp32_t>(contiguous_scales)
253248
);
254249
break;
255250
case 64:
256251
device_per_token_quant_bf16_to_int8<128, 64>
257252
<<<blocks, 128, 0, at::cuda::getCurrentCUDAStream()>>>(
258253
PTR<bf16_t>(contiguous_input),
259254
PTR<int8_t>(output),
260-
PTR<fp32_t>(contiguous_scales),
261-
M
255+
PTR<fp32_t>(contiguous_scales)
262256
);
263257
break;
264258
case 512:
265259
device_per_token_quant_bf16_to_int8<128, 512>
266260
<<<blocks, 128, 0, at::cuda::getCurrentCUDAStream()>>>(
267261
PTR<bf16_t>(contiguous_input),
268262
PTR<int8_t>(output),
269-
PTR<fp32_t>(contiguous_scales),
270-
M
263+
PTR<fp32_t>(contiguous_scales)
271264
);
272265
break;
273266
case 1024:
274267
device_per_token_quant_bf16_to_int8<128, 1024>
275268
<<<blocks, 128, 0, at::cuda::getCurrentCUDAStream()>>>(
276269
PTR<bf16_t>(contiguous_input),
277270
PTR<int8_t>(output),
278-
PTR<fp32_t>(contiguous_scales),
279-
M
271+
PTR<fp32_t>(contiguous_scales)
280272
);
281273
break;
282274
case 3200:
283275
device_per_token_quant_bf16_to_int8<128, 3200>
284276
<<<blocks, 128, 0, at::cuda::getCurrentCUDAStream()>>>(
285277
PTR<bf16_t>(contiguous_input),
286278
PTR<int8_t>(output),
287-
PTR<fp32_t>(contiguous_scales),
288-
M
279+
PTR<fp32_t>(contiguous_scales)
289280
);
290281
break;
291282
case 4096:
292283
device_per_token_quant_bf16_to_int8<128, 4096>
293284
<<<blocks, 128, 0, at::cuda::getCurrentCUDAStream()>>>(
294285
PTR<bf16_t>(contiguous_input),
295286
PTR<int8_t>(output),
296-
PTR<fp32_t>(contiguous_scales),
297-
M
287+
PTR<fp32_t>(contiguous_scales)
298288
);
299289
break;
300290
case 12800:
301291
device_per_token_quant_bf16_to_int8<256, 12800>
302292
<<<blocks, 256, 0, at::cuda::getCurrentCUDAStream()>>>(
303293
PTR<bf16_t>(contiguous_input),
304294
PTR<int8_t>(output),
305-
PTR<fp32_t>(contiguous_scales),
306-
M
295+
PTR<fp32_t>(contiguous_scales)
307296
);
308297
break;
309298
default: {
@@ -315,7 +304,6 @@ void per_token_quant_bf16_int8 (
315304
PTR<bf16_t>(contiguous_input),
316305
PTR<int8_t>(output),
317306
PTR<fp32_t>(contiguous_scales),
318-
M,
319307
N
320308
);
321309
} else {
@@ -324,7 +312,6 @@ void per_token_quant_bf16_int8 (
324312
PTR<bf16_t>(contiguous_input),
325313
PTR<int8_t>(output),
326314
PTR<fp32_t>(contiguous_scales),
327-
M,
328315
N
329316
);
330317
}
@@ -335,4 +322,4 @@ void per_token_quant_bf16_int8 (
335322
}
336323

337324
} // namespace ops
338-
} // namespace lightllm
325+
} // namespace lightllm

lightllm-kernel/test/quant/fp8_quant_test.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@ class TestQuantBF16(unittest.TestCase):
99
def setUp(self):
1010
"""Set up common test parameters."""
1111
self.tokens = [1024, 13325]
12-
self.hiddenDims = [256, 511, 1023, 1024, 1025, 1032, 3200, 3201, 3208, 12800]
12+
self.hiddenDims = [3, 256, 511, 1023, 1024, 1025, 1032, 3200, 3201, 3208, 12800]
1313
self.device = "cuda"
1414
self.dtype = torch.bfloat16
1515

@@ -20,7 +20,7 @@ def test_accuracy(self):
2020
with self.subTest(shape=[token, hiddenDim]):
2121
input = torch.rand(size=[token, hiddenDim], device=self.device, dtype=self.dtype) - 0.5
2222
y_real, scales_real = ops.scaled_fp8_quant(
23-
input.contiguous().cuda(self.device), scale=None, use_per_token_if_dynamic=True
23+
input.contiguous(), scale=None, use_per_token_if_dynamic=True
2424
)
2525
y_pred, scales_pred = per_token_quant_bf16_fp8(input)
2626
self.assertTrue(

lightllm-kernel/test/quant/int8_quant_test.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@ class TestQuantBF16(unittest.TestCase):
99
def setUp(self):
1010
"""Set up common test parameters."""
1111
self.tokens = [1024, 13325]
12-
self.hiddenDims = [256, 257, 511, 1023, 1024, 1025, 1032, 3200, 3201, 3208, 12800]
12+
self.hiddenDims = [3, 256, 257, 511, 1023, 1024, 1025, 1032, 3200, 3201, 3208, 12800]
1313
self.device = "cuda:2"
1414
self.dtype = torch.bfloat16
1515
torch.cuda.set_device(self.device)
@@ -21,7 +21,7 @@ def test_accuracy(self):
2121
with self.subTest(shape=[token, hiddenDim]):
2222
input = torch.rand(size=[token, hiddenDim], device=self.device, dtype=self.dtype) - 0.5
2323
y_real, scales_real, _ = ops.scaled_int8_quant(
24-
input.contiguous().cuda(self.device)
24+
input.contiguous()
2525
)
2626
y_pred, scales_pred = per_token_quant_bf16_int8(input)
2727
self.assertTrue(

0 commit comments

Comments
 (0)