diff --git a/src/compressed_tensors/compressors/quantized_compressors/nvfp4_quantized.py b/src/compressed_tensors/compressors/quantized_compressors/nvfp4_quantized.py index 5f348e91..419d47c6 100644 --- a/src/compressed_tensors/compressors/quantized_compressors/nvfp4_quantized.py +++ b/src/compressed_tensors/compressors/quantized_compressors/nvfp4_quantized.py @@ -71,7 +71,6 @@ def compress_weight( zero_point: Optional[torch.Tensor] = None, g_idx: Optional[torch.Tensor] = None, ) -> Dict[str, torch.Tensor]: - quantized_weight = quantize( x=weight, scale=scale, @@ -91,7 +90,6 @@ def decompress_weight( compressed_data: Dict[str, Tensor], quantization_args: Optional[QuantizationArgs] = None, ) -> torch.Tensor: - weight = compressed_data["weight_packed"] scale = compressed_data["weight_scale"] global_scale = compressed_data["weight_global_scale"] @@ -105,6 +103,7 @@ def decompress_weight( return decompressed_weight +@torch.compile(fullgraph=True, dynamic=True) def pack_fp4_to_uint8(x: torch.Tensor) -> torch.Tensor: """ Packs a tensor with values in the fp4 range into uint8. @@ -127,12 +126,11 @@ def pack_fp4_to_uint8(x: torch.Tensor) -> torch.Tensor: # Find closest valid FP4 value index for each element abs_x = torch.abs(x) - abs_indices = torch.zeros_like(abs_x, dtype=torch.long) - for i, val in enumerate(kE2M1): - abs_indices = torch.where(torch.isclose(abs_x, val), i, abs_indices) + abs_diff_x = torch.abs(abs_x.unsqueeze(-1) - kE2M1) # [m, n, 8] + abs_indices = torch.argmin(abs_diff_x, dim=-1) # [m, n] # Apply sign bit (bit 3) to get final 4-bit representation - indices = abs_indices + (torch.signbit(x) << 3).to(torch.long) + indices = abs_indices + (torch.signbit(x).to(torch.long) << 3) # Reshape to prepare for packing pairs of values indices = indices.reshape(-1) @@ -154,7 +152,9 @@ def pack_fp4_to_uint8(x: torch.Tensor) -> torch.Tensor: [0.0, 0.5, 1.0, 1.5, 2.0, 3.0, 4.0, 6.0], dtype=torch.float32 ) + # reference: : https://github.com/vllm-project/vllm/pull/16362 +@torch.compile(fullgraph=True, dynamic=True) def unpack_fp4_from_uint8( a: torch.Tensor, m: int, n: int, dtype: Optional[torch.dtype] = torch.bfloat16 ) -> torch.Tensor: