diff --git a/src/compressed_tensors/compressors/quantized_compressors/nvfp4_quantized.py b/src/compressed_tensors/compressors/quantized_compressors/nvfp4_quantized.py index 5f348e91..81b81578 100644 --- a/src/compressed_tensors/compressors/quantized_compressors/nvfp4_quantized.py +++ b/src/compressed_tensors/compressors/quantized_compressors/nvfp4_quantized.py @@ -121,6 +121,11 @@ def pack_fp4_to_uint8(x: torch.Tensor) -> torch.Tensor: m, n = x.shape device = x.device + if n % 2 != 0: + raise ValueError( + "tensor must have an even number of columns for nvfp4 compression" + ) + # Create lookup table for FP4 values to indices # Map the absolute values to 0-7 indices kE2M1 = torch.tensor(FLOAT_TO_E2M1, device=device, dtype=x.dtype) @@ -137,10 +142,6 @@ def pack_fp4_to_uint8(x: torch.Tensor) -> torch.Tensor: # Reshape to prepare for packing pairs of values indices = indices.reshape(-1) - # Handle odd length by padding if necessary - if indices.numel() % 2 != 0: - indices = torch.cat([indices, torch.zeros(1, dtype=torch.long, device=device)]) - # Reshape to pair consecutive elements indices = indices.reshape(-1, 2) diff --git a/tests/test_compressors/quantized_compressors/test_nvfp4_quant.py b/tests/test_compressors/quantized_compressors/test_nvfp4_quant.py index b7e2c82d..496374a8 100644 --- a/tests/test_compressors/quantized_compressors/test_nvfp4_quant.py +++ b/tests/test_compressors/quantized_compressors/test_nvfp4_quant.py @@ -12,6 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. +import pytest import torch from compressed_tensors.compressors.quantized_compressors.nvfp4_quantized import ( pack_fp4_to_uint8, @@ -41,3 +42,16 @@ def test_pack_unpack(): sign_bitx = torch.signbit(x) sign_bitout = torch.signbit(unpacked) assert torch.equal(sign_bitout, sign_bitx) + + +def test_pack_unpack_odd_dims(): + x = torch.Tensor( + [ + [-0.5000, -6.0000, -0.5000, -1.5000, -1.0000, 6.0000, 0.0000], + [-1.0000, -6.0000, -0.5000, -0.0000, 0.5000, 0.5000, -0.0000], + [1.5000, 6.0000, -0.0000, -0.5000, 1.0000, 1.0000, -0.0000], + ] + ) + + with pytest.raises((ValueError, torch._dynamo.exc.Unsupported)): + _packed = pack_fp4_to_uint8(x)