Skip to content

Commit c2ffbac

Browse files
committed
Speed up nvfp4 pack/unpack w/ torch.compile
Signed-off-by: Fynn Schmitt-Ulms <[email protected]>
1 parent 09b7ed4 commit c2ffbac

File tree

1 file changed

+5
-4
lines changed

1 file changed

+5
-4
lines changed

src/compressed_tensors/compressors/quantized_compressors/nvfp4_quantized.py

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -105,6 +105,7 @@ def decompress_weight(
105105
return decompressed_weight
106106

107107

108+
@torch.compile(fullgraph=True)
108109
def pack_fp4_to_uint8(x: torch.Tensor) -> torch.Tensor:
109110
"""
110111
Packs a tensor with values in the fp4 range into uint8.
@@ -127,12 +128,11 @@ def pack_fp4_to_uint8(x: torch.Tensor) -> torch.Tensor:
127128

128129
# Find closest valid FP4 value index for each element
129130
abs_x = torch.abs(x)
130-
abs_indices = torch.zeros_like(abs_x, dtype=torch.long)
131-
for i, val in enumerate(kE2M1):
132-
abs_indices = torch.where(torch.isclose(abs_x, val), i, abs_indices)
131+
abs_diff_x = torch.abs(abs_x.unsqueeze(-1) - kE2M1) # [m, n, 8]
132+
abs_indices = torch.argmin(abs_diff_x, dim=-1) # [m, n]
133133

134134
# Apply sign bit (bit 3) to get final 4-bit representation
135-
indices = abs_indices + (torch.signbit(x) << 3).to(torch.long)
135+
indices = abs_indices + (torch.signbit(x).to(torch.long) << 3)
136136

137137
# Reshape to prepare for packing pairs of values
138138
indices = indices.reshape(-1)
@@ -155,6 +155,7 @@ def pack_fp4_to_uint8(x: torch.Tensor) -> torch.Tensor:
155155
)
156156

157157
# reference: : https://github.com/vllm-project/vllm/pull/16362
158+
@torch.compile(fullgraph=True)
158159
def unpack_fp4_from_uint8(
159160
a: torch.Tensor, m: int, n: int, dtype: Optional[torch.dtype] = torch.bfloat16
160161
) -> torch.Tensor:

0 commit comments

Comments
 (0)