Skip to content

Commit b8ddb63

Browse files
authored
Raise ValueError when nvfp4 pack tensor has odd number of columns (#402)
This raises an explicit error for a case that previously threw an implicit error during the reshape op at the end of the function. Signed-off-by: Fynn Schmitt-Ulms <[email protected]>
1 parent 4370618 commit b8ddb63

File tree

2 files changed

+19
-4
lines changed

2 files changed

+19
-4
lines changed

src/compressed_tensors/compressors/quantized_compressors/nvfp4_quantized.py

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -140,6 +140,11 @@ def pack_fp4_to_uint8(x: torch.Tensor) -> torch.Tensor:
140140
m, n = x.shape
141141
device = x.device
142142

143+
if n % 2 != 0:
144+
raise ValueError(
145+
"tensor must have an even number of columns for nvfp4 compression"
146+
)
147+
143148
# Create lookup table for FP4 values to indices
144149
# Map the absolute values to 0-7 indices
145150
kE2M1 = torch.tensor(FLOAT_TO_E2M1, device=device, dtype=x.dtype)
@@ -155,10 +160,6 @@ def pack_fp4_to_uint8(x: torch.Tensor) -> torch.Tensor:
155160
# Reshape to prepare for packing pairs of values
156161
indices = indices.reshape(-1)
157162

158-
# Handle odd length by padding if necessary
159-
if indices.numel() % 2 != 0:
160-
indices = torch.cat([indices, torch.zeros(1, dtype=torch.long, device=device)])
161-
162163
# Reshape to pair consecutive elements
163164
indices = indices.reshape(-1, 2)
164165

tests/test_compressors/quantized_compressors/test_nvfp4_quant.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414

15+
import pytest
1516
import torch
1617
from compressed_tensors.compressors.quantized_compressors.nvfp4_quantized import (
1718
pack_fp4_to_uint8,
@@ -41,3 +42,16 @@ def test_pack_unpack():
4142
sign_bitx = torch.signbit(x)
4243
sign_bitout = torch.signbit(unpacked)
4344
assert torch.equal(sign_bitout, sign_bitx)
45+
46+
47+
def test_pack_unpack_odd_dims():
48+
x = torch.Tensor(
49+
[
50+
[-0.5000, -6.0000, -0.5000, -1.5000, -1.0000, 6.0000, 0.0000],
51+
[-1.0000, -6.0000, -0.5000, -0.0000, 0.5000, 0.5000, -0.0000],
52+
[1.5000, 6.0000, -0.0000, -0.5000, 1.0000, 1.0000, -0.0000],
53+
]
54+
)
55+
56+
with pytest.raises((ValueError, torch._dynamo.exc.Unsupported)):
57+
_packed = pack_fp4_to_uint8(x)

0 commit comments

Comments
 (0)