@@ -105,6 +105,7 @@ def decompress_weight(
105
105
return decompressed_weight
106
106
107
107
108
+ @torch .compile (fullgraph = True )
108
109
def pack_fp4_to_uint8 (x : torch .Tensor ) -> torch .Tensor :
109
110
"""
110
111
Packs a tensor with values in the fp4 range into uint8.
@@ -127,12 +128,11 @@ def pack_fp4_to_uint8(x: torch.Tensor) -> torch.Tensor:
127
128
128
129
# Find closest valid FP4 value index for each element
129
130
abs_x = torch .abs (x )
130
- abs_indices = torch .zeros_like (abs_x , dtype = torch .long )
131
- for i , val in enumerate (kE2M1 ):
132
- abs_indices = torch .where (torch .isclose (abs_x , val ), i , abs_indices )
131
+ abs_diff_x = torch .abs (abs_x .unsqueeze (- 1 ) - kE2M1 ) # [m, n, 8]
132
+ abs_indices = torch .argmin (abs_diff_x , dim = - 1 ) # [m, n]
133
133
134
134
# Apply sign bit (bit 3) to get final 4-bit representation
135
- indices = abs_indices + (torch .signbit (x ) << 3 ) .to (torch .long )
135
+ indices = abs_indices + (torch .signbit (x ).to (torch .long ) << 3 )
136
136
137
137
# Reshape to prepare for packing pairs of values
138
138
indices = indices .reshape (- 1 )
@@ -155,6 +155,7 @@ def pack_fp4_to_uint8(x: torch.Tensor) -> torch.Tensor:
155
155
)
156
156
157
157
# reference: : https://github.com/vllm-project/vllm/pull/16362
158
+ @torch .compile (fullgraph = True )
158
159
def unpack_fp4_from_uint8 (
159
160
a : torch .Tensor , m : int , n : int , dtype : Optional [torch .dtype ] = torch .bfloat16
160
161
) -> torch .Tensor :
0 commit comments