Skip to content

Commit bdf4598

Browse files
authored
NVFP4 -> Use more of e4m3 range for block_scales (#2604)
stack-info: PR: #2604, branch: drisspg/stack/85
1 parent 6501fb8 commit bdf4598

File tree

1 file changed

+8
-7
lines changed

1 file changed

+8
-7
lines changed

torchao/prototype/mx_formats/nvfp4_tensor.py

Lines changed: 8 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -723,18 +723,19 @@ def nvfp4_addmm(func, types, args, kwargs):
723723

724724

725725
def per_tensor_amax_to_scale(amax: torch.Tensor) -> torch.Tensor:
726-
"""Convert per-tensor amax to per-tensor scale.
727-
Used to scale fp32 scales down to fp8 scales
726+
"""Convert per-tensor amax to per-tensor scale for NVFP4 quantization.
727+
728+
Divides by both F8E4M3_MAX and F4_E2M1_MAX to ensure block scales can utilize
729+
the full FP8 E4M3 range (up to 448) when block_max equals tensor_max.
730+
Without F4_E2M1_MAX, the maximum scale would only reach FP8_MAX / FP4_MAX.
728731
729732
Args:
730-
amax: Per-tensor amax tensor
733+
amax: Per-tensor absolute maximum value from calibration
731734
732735
Returns:
733-
torch.Tensor: Per-tensor scale tensor
736+
torch.Tensor: Per-tensor scale for two-level NVFP4 scaling
734737
"""
735-
return torch.clamp(amax / F8E4M3_MAX, min=E4M3_EPS, max=F8E4M3_MAX).to(
736-
torch.float32
737-
)
738+
return amax.to(torch.float32) / (F8E4M3_MAX * F4_E2M1_MAX)
738739

739740

740741
def nvfp4_quantize(

0 commit comments

Comments
 (0)