@@ -123,6 +123,7 @@ def decompress_weight(
123
123
return decompressed_weight
124
124
125
125
126
+ @torch .compile (fullgraph = True , dynamic = True )
126
127
def pack_fp4_to_uint8 (x : torch .Tensor ) -> torch .Tensor :
127
128
"""
128
129
Packs a tensor with values in the fp4 range into uint8.
@@ -145,12 +146,11 @@ def pack_fp4_to_uint8(x: torch.Tensor) -> torch.Tensor:
145
146
146
147
# Find closest valid FP4 value index for each element
147
148
abs_x = torch .abs (x )
148
- abs_indices = torch .zeros_like (abs_x , dtype = torch .long )
149
- for i , val in enumerate (kE2M1 ):
150
- abs_indices = torch .where (torch .isclose (abs_x , val ), i , abs_indices )
149
+ abs_diff_x = torch .abs (abs_x .unsqueeze (- 1 ) - kE2M1 ) # [m, n, 8]
150
+ abs_indices = torch .argmin (abs_diff_x , dim = - 1 ) # [m, n]
151
151
152
152
# Apply sign bit (bit 3) to get final 4-bit representation
153
- indices = abs_indices + (torch .signbit (x ) << 3 ) .to (torch .long )
153
+ indices = abs_indices + (torch .signbit (x ).to (torch .long ) << 3 )
154
154
155
155
# Reshape to prepare for packing pairs of values
156
156
indices = indices .reshape (- 1 )
@@ -174,6 +174,7 @@ def pack_fp4_to_uint8(x: torch.Tensor) -> torch.Tensor:
174
174
175
175
176
176
# reference: : https://github.com/vllm-project/vllm/pull/16362
177
+ @torch .compile (fullgraph = True , dynamic = True )
177
178
def unpack_fp4_from_uint8 (
178
179
a : torch .Tensor , m : int , n : int , dtype : Optional [torch .dtype ] = torch .bfloat16
179
180
) -> torch .Tensor :
0 commit comments