Skip to content

Commit 15177be

Browse files
authored
Speed up nvfp4 pack/unpack w/ torch.compile (#400)
* Speed up nvfp4 pack/unpack w/ torch.compile Signed-off-by: Fynn Schmitt-Ulms <[email protected]> * Add `dynamic=True` to torch.compile call in nvfp4 packing Signed-off-by: Fynn Schmitt-Ulms <[email protected]> --------- Signed-off-by: Fynn Schmitt-Ulms <[email protected]>
1 parent 0679683 commit 15177be

File tree

1 file changed

+5
-4
lines changed

1 file changed

+5
-4
lines changed

src/compressed_tensors/compressors/quantized_compressors/nvfp4_quantized.py

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -123,6 +123,7 @@ def decompress_weight(
123123
return decompressed_weight
124124

125125

126+
@torch.compile(fullgraph=True, dynamic=True)
126127
def pack_fp4_to_uint8(x: torch.Tensor) -> torch.Tensor:
127128
"""
128129
Packs a tensor with values in the fp4 range into uint8.
@@ -145,12 +146,11 @@ def pack_fp4_to_uint8(x: torch.Tensor) -> torch.Tensor:
145146

146147
# Find closest valid FP4 value index for each element
147148
abs_x = torch.abs(x)
148-
abs_indices = torch.zeros_like(abs_x, dtype=torch.long)
149-
for i, val in enumerate(kE2M1):
150-
abs_indices = torch.where(torch.isclose(abs_x, val), i, abs_indices)
149+
abs_diff_x = torch.abs(abs_x.unsqueeze(-1) - kE2M1) # [m, n, 8]
150+
abs_indices = torch.argmin(abs_diff_x, dim=-1) # [m, n]
151151

152152
# Apply sign bit (bit 3) to get final 4-bit representation
153-
indices = abs_indices + (torch.signbit(x) << 3).to(torch.long)
153+
indices = abs_indices + (torch.signbit(x).to(torch.long) << 3)
154154

155155
# Reshape to prepare for packing pairs of values
156156
indices = indices.reshape(-1)
@@ -174,6 +174,7 @@ def pack_fp4_to_uint8(x: torch.Tensor) -> torch.Tensor:
174174

175175

176176
# reference: : https://github.com/vllm-project/vllm/pull/16362
177+
@torch.compile(fullgraph=True, dynamic=True)
177178
def unpack_fp4_from_uint8(
178179
a: torch.Tensor, m: int, n: int, dtype: Optional[torch.dtype] = torch.bfloat16
179180
) -> torch.Tensor:

0 commit comments

Comments
 (0)