From f786bcdf77f729bf854abda453d78df1823e4559 Mon Sep 17 00:00:00 2001 From: Fynn Schmitt-Ulms Date: Wed, 23 Jul 2025 15:04:05 +0000 Subject: [PATCH] Raise ValueError when nvfp4 pack tensor has odd number of columns This raises an explicit error for a case that previously threw an implicit error during the reshape op at the end of the function. Signed-off-by: Fynn Schmitt-Ulms --- .../quantized_compressors/nvfp4_quantized.py | 9 +++++---- .../quantized_compressors/test_nvfp4_quant.py | 14 ++++++++++++++ 2 files changed, 19 insertions(+), 4 deletions(-) diff --git a/src/compressed_tensors/compressors/quantized_compressors/nvfp4_quantized.py b/src/compressed_tensors/compressors/quantized_compressors/nvfp4_quantized.py index 5f348e91..81b81578 100644 --- a/src/compressed_tensors/compressors/quantized_compressors/nvfp4_quantized.py +++ b/src/compressed_tensors/compressors/quantized_compressors/nvfp4_quantized.py @@ -121,6 +121,11 @@ def pack_fp4_to_uint8(x: torch.Tensor) -> torch.Tensor: m, n = x.shape device = x.device + if n % 2 != 0: + raise ValueError( + "tensor must have an even number of columns for nvfp4 compression" + ) + # Create lookup table for FP4 values to indices # Map the absolute values to 0-7 indices kE2M1 = torch.tensor(FLOAT_TO_E2M1, device=device, dtype=x.dtype) @@ -137,10 +142,6 @@ def pack_fp4_to_uint8(x: torch.Tensor) -> torch.Tensor: # Reshape to prepare for packing pairs of values indices = indices.reshape(-1) - # Handle odd length by padding if necessary - if indices.numel() % 2 != 0: - indices = torch.cat([indices, torch.zeros(1, dtype=torch.long, device=device)]) - # Reshape to pair consecutive elements indices = indices.reshape(-1, 2) diff --git a/tests/test_compressors/quantized_compressors/test_nvfp4_quant.py b/tests/test_compressors/quantized_compressors/test_nvfp4_quant.py index b7e2c82d..496374a8 100644 --- a/tests/test_compressors/quantized_compressors/test_nvfp4_quant.py +++ b/tests/test_compressors/quantized_compressors/test_nvfp4_quant.py @@ -12,6 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. +import pytest import torch from compressed_tensors.compressors.quantized_compressors.nvfp4_quantized import ( pack_fp4_to_uint8, @@ -41,3 +42,16 @@ def test_pack_unpack(): sign_bitx = torch.signbit(x) sign_bitout = torch.signbit(unpacked) assert torch.equal(sign_bitout, sign_bitx) + + +def test_pack_unpack_odd_dims(): + x = torch.Tensor( + [ + [-0.5000, -6.0000, -0.5000, -1.5000, -1.0000, 6.0000, 0.0000], + [-1.0000, -6.0000, -0.5000, -0.0000, 0.5000, 0.5000, -0.0000], + [1.5000, 6.0000, -0.0000, -0.5000, 1.0000, 1.0000, -0.0000], + ] + ) + + with pytest.raises((ValueError, torch._dynamo.exc.Unsupported)): + _packed = pack_fp4_to_uint8(x)