Skip to content

Commit f765a2a

Browse files
authored
[Quantization] Add per-expert global scaling factor for fp4 batched quantize (#1835)
<!-- .github/pull_request_template.md --> ## 📌 Description Add per-expert global scaling factor for fp4 batched quantize ## 🔍 Related Issues <!-- Link any related issues here --> ## 🚀 Pull Request Checklist Thank you for contributing to FlashInfer! Before we review your pull request, please make sure the following items are complete. ### ✅ Pre-commit Checks - [x] I have installed `pre-commit` by running `pip install pre-commit` (or used your preferred method). - [x] I have installed the hooks with `pre-commit install`. - [x] I have run the hooks manually with `pre-commit run --all-files` and fixed any reported issues. > If you are unsure about how to set up `pre-commit`, see [the pre-commit documentation](https://pre-commit.com/). ## 🧪 Tests - [ ] Tests have been added or updated as needed. - [ ] All tests are passing (`unittest`, etc.). ## Reviewer Notes <!-- Optional: anything you'd like reviewers to focus on, concerns, etc. -->
1 parent 40df947 commit f765a2a

File tree

2 files changed

+5
-5
lines changed

2 files changed

+5
-5
lines changed

csrc/nv_internal/tensorrt_llm/kernels/quantization.cuh

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -795,7 +795,6 @@ __device__ inline void quantize_with_block_size_impl(int32_t numbatches, int32_t
795795
static constexpr int CVT_NUM_THREADS_PER_SF = SF_VEC_SIZE / ELTS_PER_THREAD;
796796
static_assert(sizeof(PackedVec) == sizeof(Type) * ELTS_PER_THREAD, "Vec size is not matched.");
797797

798-
float const SFScaleVal = SFScale == nullptr ? 1.0f : SFScale[0];
799798
bool isSfSwizzledLayout = layout == QuantizationSFLayout::SWIZZLED_128x4 ||
800799
layout == QuantizationSFLayout::SWIZZLED_8x4;
801800
int rowTile = (layout == QuantizationSFLayout::SWIZZLED_128x4) ? 128 : 8;
@@ -810,6 +809,7 @@ __device__ inline void quantize_with_block_size_impl(int32_t numbatches, int32_t
810809
asm volatile("griddepcontrol.wait;");
811810
for (int rowIdx = blockIdx.x; rowIdx < numPaddedRowsForSf; rowIdx += gridDim.x) {
812811
for (int batchIdx = 0; batchIdx < numbatches; batchIdx++) {
812+
float const SFScaleVal = SFScale == nullptr ? 1.0f : SFScale[batchIdx];
813813
for (int colIdx = threadIdx.x; colIdx < numColThreadsForSf; colIdx += blockDim.x) {
814814
std::optional<int> optionalBatchIdx = batchIdx;
815815
std::optional<int> optionalNumRows = numRows;

tests/utils/test_fp4_quantize.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@
1818
DTYPES = [torch.float16, torch.bfloat16]
1919
# The batch dimension doesn't need to be multiple of 128
2020
SHAPES = [(128, 64), (256, 128), (120, 64), (200, 256)]
21-
BATCH_SHAPES = [(2, 128, 64), (3, 256, 128), (1, 120, 64)]
21+
BATCH_SHAPES = [(1, 256, 128), (2, 128, 64), (3, 256, 128), (1, 120, 64)]
2222
SEEDS = [42]
2323
CUDA_DEVICES = ["cuda:0"]
2424

@@ -334,7 +334,7 @@ def test_nvfp4_batched_quantize(
334334

335335
b, m, n = batch_shape
336336
x = torch.randn(batch_shape, dtype=dtype)
337-
tensor_amax = torch.abs(x).max().to(torch.float32)
337+
tensor_amax = torch.abs(x).amax(dim=(1, 2)).to(torch.float32)
338338
global_scale = FLOAT8_E4M3_MAX * FLOAT4_E2M1_MAX / tensor_amax
339339
mask = None
340340
# Test the batched quantization
@@ -357,7 +357,7 @@ def test_nvfp4_batched_quantize(
357357

358358
# Compare with single tensor quantization for each batch
359359
for i in range(b):
360-
single_out, single_scale = fp4_quantize(x[i], global_scale, 16, False, True)
360+
single_out, single_scale = fp4_quantize(x[i], global_scale[i], 16, False, True)
361361
if use_mask:
362362
torch.testing.assert_close(
363363
out[i][: mask[i]], single_out[: mask[i]], rtol=1e-5, atol=1e-5
@@ -414,7 +414,7 @@ def test_silu_and_mul_nvfp4_batched_quantize(
414414
for i in range(b):
415415
x_silu_mul = silu_and_mul(x[i])
416416
single_out, single_scale = fp4_quantize(
417-
x_silu_mul, global_scale, 16, False, True
417+
x_silu_mul, global_scale[i], 16, False, True
418418
)
419419
torch.testing.assert_close(
420420
out[i][: mask[i]], single_out[: mask[i]], rtol=1e-5, atol=1e-5

0 commit comments

Comments
 (0)