Skip to content

Commit 422c58b

Browse files
authored
[Bug NVIDIA#193] fix fp8 blockwise real quantization (NVIDIA#522)
## What does this PR do? **Type of change:** bug fix <!-- Use one of the following: Bug fix, new feature, new example, new tests, documentation. --> **Overview:** instead of using amax, we should compute block_scale on the fly ## Usage <!-- You can potentially add a usage example below. --> ```python # Add a code snippet demonstrating how to use this ``` ## Testing <!-- Mention how have you tested your change if applicable. --> ## Before your PR is "*Ready for review*" <!-- If you haven't finished some of the above items you can still open `Draft` PR. --> - **Make sure you read and follow [Contributor guidelines](https://github.com/NVIDIA/TensorRT-Model-Optimizer/blob/main/CONTRIBUTING.md)** and your commits are signed. - **Is this change backward compatible?**: Yes/No <!--- If No, explain why. --> - **Did you write any new necessary tests?**: Yes/No - **Did you add or update any necessary documentation?**: Yes/No - **Did you update [Changelog](https://github.com/NVIDIA/TensorRT-Model-Optimizer/blob/main/CHANGELOG.rst)?**: Yes/No <!--- Only for new features, API changes, critical bug fixes or bw breaking changes. --> ## Additional Information <!-- E.g. related issue. --> --------- Signed-off-by: weimingc <[email protected]>
1 parent e4c5a68 commit 422c58b

File tree

2 files changed

+36
-1
lines changed

2 files changed

+36
-1
lines changed

modelopt/torch/quantization/nn/modules/tensor_quantizer.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -563,11 +563,13 @@ def _real_quantize(self, inputs):
563563
if self._num_bits == (4, 3):
564564
# FP8 quantization
565565
# For per-tensor/per-channel quantization, we might need amax which is synced across all ranks
566+
# For blockwise quantization, amax will be recomputed in the kernel
567+
use_amax = self.amax is not None and not (self._block_sizes and self.amax.numel() == 1)
566568
outputs, _scale = FP8QTensor.quantize(
567569
inputs,
568570
axis=self._axis,
569571
block_sizes=self._block_sizes,
570-
scales=self.amax / 448.0 if self.amax is not None else None,
572+
scales=self.amax / 448.0 if use_amax else None,
571573
)
572574
buffer_to_register["_scale"] = _scale
573575
elif self._num_bits == 8:

tests/gpu/torch/quantization/test_qtensor_cuda.py

Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -569,3 +569,36 @@ def test_nvfp4_dequantize_fast(self, shape, input_dtype):
569569
f"Fast and standard dequantization differ: "
570570
f"max diff = {(dequant_fast - dequant_standard).abs().max()}"
571571
)
572+
573+
@pytest.mark.parametrize("device", ["cuda"])
574+
@pytest.mark.parametrize("input_dtype", [torch.float32, torch.float16, torch.bfloat16])
575+
@pytest.mark.parametrize(
576+
("input_shape", "block_sizes"),
577+
[
578+
((128, 1152), {-1: 128}),
579+
((256, 256), {-1: 64, -2: 64}), # 2D block sizes
580+
],
581+
)
582+
def test_fp8_with_amax_and_block_sizes(self, device, input_dtype, input_shape, block_sizes):
583+
"""Test FP8 quantization with both amax and block_sizes specified."""
584+
quant_cfg = QuantizerAttributeConfig(
585+
num_bits=(4, 3),
586+
block_sizes=block_sizes,
587+
fake_quant=False,
588+
)
589+
quantizer = TensorQuantizer(quant_cfg).to(device)
590+
591+
# Set a mock amax (scalar) - this was causing the bug
592+
mock_amax = torch.tensor(1.5, device=device)
593+
quantizer.amax = mock_amax
594+
595+
# Create input tensor
596+
x = torch.randn(input_shape, dtype=input_dtype, device=device)
597+
598+
# QDQ
599+
q_x = quantizer(x)
600+
deq_x = quantizer(q_x)
601+
602+
assert torch.allclose(deq_x, x, rtol=1e-1, atol=1e-1)
603+
assert hasattr(quantizer, "_scale")
604+
assert quantizer._scale.numel() > 1

0 commit comments

Comments
 (0)