Skip to content

Commit 8e3bfb5

Browse files
Fix NF4 scale padding (#183)
Co-authored-by: ishan-modi <[email protected]>
1 parent cafa7f6 commit 8e3bfb5

File tree

5 files changed

+18
-8
lines changed

5 files changed

+18
-8
lines changed

modelopt/torch/quantization/nn/modules/tensor_quantizer.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -575,7 +575,9 @@ def _real_quantize(self, inputs):
575575
):
576576
# NF4 double quantization
577577
# Return real quantized tensor class and store scales inside the TensorQuantizer
578-
outputs, scales = NF4QTensor.quantize(inputs, self._block_sizes[-1])
578+
outputs, scales = NF4QTensor.quantize(
579+
inputs, self._block_sizes[-1], self._block_sizes["scale_block_sizes"][-1]
580+
)
579581
_scale, _double_scale, _scale_zeros = NF4QTensor.double_quantization(
580582
scales,
581583
self._block_sizes["scale_block_sizes"][-1],

modelopt/torch/quantization/qtensor/int4_tensor.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -37,7 +37,7 @@ def _get_quant_maxbound(num_bits):
3737
return 2 ** (num_bits - 1) - 1
3838

3939
@classmethod
40-
def quantize(cls, input: torch.Tensor, block_size: int) -> torch.Tensor:
40+
def quantize(cls, input: torch.Tensor, block_size: int) -> tuple:
4141
"""Converting a tensor to a quantized format based on INT4 (AWQ) quantization.
4242
4343
Args:

modelopt/torch/quantization/qtensor/nf4_tensor.py

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -74,7 +74,9 @@ class NF4QTensor(BaseQuantizedTensor):
7474
"""
7575

7676
@classmethod
77-
def quantize(cls, input: torch.Tensor, block_size: int) -> torch.Tensor:
77+
def quantize(
78+
cls, input: torch.Tensor, block_size: int, scale_block_size: int | None = None
79+
) -> tuple:
7880
"""Converting a tensor to a quantized format based on NF4 double quantization.
7981
8082
Args:
@@ -116,6 +118,8 @@ def quantize(cls, input: torch.Tensor, block_size: int) -> torch.Tensor:
116118
# | byte | byte | byte |
117119
packed_output_uint8 = quantized_output_uint8[::2] << 4 | quantized_output_uint8[1::2]
118120

121+
# pad the scales if needed
122+
scales = reduce_block_padding(scales.view(-1), block_sizes={-1: scale_block_size})
119123
return cls(original_input.shape, original_input.dtype, packed_output_uint8), scales
120124

121125
@classmethod
@@ -159,6 +163,9 @@ def dequantize(self, dtype: torch.dtype = None, **kwarg):
159163
double_scale = kwarg["double_scale"]
160164
scale_zeros = kwarg["scale_zeros"]
161165

166+
# unpadd the scales if needed
167+
scales = scales.view(-1)[: (self._quantized_data.numel() * 2) // block_sizes[-1]]
168+
162169
if cuda_ext and self._quantized_data.is_cuda:
163170
# with a custom cuda kernel
164171
scales = _dequantize_scalers(scales, double_scale, scale_zeros, dtype).flatten()

tests/gpu/torch/quantization/test_qtensor_cuda.py

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,10 @@ class TestQTensor:
3030
)
3131
@pytest.mark.parametrize("device", ["cpu", "cuda"])
3232
@pytest.mark.parametrize("input_dtype", [torch.float32, torch.float16, torch.bfloat16])
33-
def test_qtensor(self, num_bits, block_sizes, device, input_dtype):
33+
@pytest.mark.parametrize(
34+
("input_shape", "check_memory"), [((256, 64), True), ((256, 32), False)]
35+
) # test
36+
def test_qtensor(self, num_bits, block_sizes, device, input_dtype, input_shape, check_memory):
3437
nf4_attr_cfg = QuantizerAttributeConfig(
3538
num_bits=num_bits,
3639
block_sizes=block_sizes,
@@ -40,7 +43,7 @@ def test_qtensor(self, num_bits, block_sizes, device, input_dtype):
4043

4144
# Original tensor
4245
base_mem = torch.cuda.memory_allocated("cuda")
43-
x = torch.rand(256, 64).to(device).to(dtype=input_dtype)
46+
x = torch.rand(input_shape).to(device).to(dtype=input_dtype)
4447
x_allocated = torch.cuda.memory_allocated("cuda")
4548
bf16_mem_usage = x_allocated - base_mem
4649

@@ -51,7 +54,7 @@ def test_qtensor(self, num_bits, block_sizes, device, input_dtype):
5154
nf4_mem_usage = nf4_x_allocated - base_mem
5255

5356
# Check the memory saving
54-
if bf16_mem_usage > 0:
57+
if bf16_mem_usage > 0 and check_memory:
5558
assert (nf4_mem_usage) / bf16_mem_usage < 0.3
5659

5760
# De-quantize to origin dtype

tests/gpu/torch/quantization/test_real_quantize_cuda.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -47,7 +47,6 @@ def test_real_quantize(model_cls, config):
4747
config["quant_cfg"]["*weight_quantizer"]["block_sizes"] = {
4848
-1: 16,
4949
"scale_bits": 8,
50-
"scale_block_sizes": {-1: 16},
5150
}
5251
if model_cls is SimpleConv or model_cls is SimpleConvLinear:
5352
pytest.skip(
@@ -102,7 +101,6 @@ def test_save_restore(model_cls, config):
102101
config["quant_cfg"]["*weight_quantizer"]["block_sizes"] = {
103102
-1: 16,
104103
"scale_bits": 8,
105-
"scale_block_sizes": {-1: 16},
106104
}
107105
if model_cls is SimpleConv or model_cls is SimpleConvLinear:
108106
pytest.skip(

0 commit comments

Comments
 (0)