Skip to content

Commit f75b360

Browse files
cthimeta-codesync[bot]
authored andcommitted
FP4 grouped API for torch (#4958)
Summary: Pull Request resolved: #4958 X-link: https://github.com/facebookresearch/FBGEMM/pull/1979 We upgrade the FP4 grouped kernel with new API `f4f4bf16_grouped_mm` that could be used in torch and vLLM. The kernel will support both MX and NV FP4, determined based on the scale dtypes (E4M3 vs E8M0). The API largely matches existing one we added for MXFP8. We also add unit tests for these new APIs. Next steps: - Full re-tune of the kernel - Add other layouts to better support backwards Reviewed By: q10 Differential Revision: D83171662 fbshipit-source-id: ba0abe5e1adf151e1b98b19b0abb14c9325d7966
1 parent 5beb3e6 commit f75b360

20 files changed

+964
-90
lines changed

fbgemm_gpu/experimental/gemm/triton_gemm/fp4_quantize.py

Lines changed: 50 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5540,3 +5540,53 @@ def calculate_group_max(
55405540
USE_INT64=use_int64,
55415541
)
55425542
return out, tensor_idx
5543+
5544+
5545+
def get_nvfp4_global_scales_naive(
5546+
xs: list[torch.Tensor], ws: list[torch.Tensor]
5547+
) -> tuple[list[torch.Tensor], list[torch.Tensor], list[torch.Tensor]]:
5548+
"""
5549+
Get global scales for each tensor in xs and ws.
5550+
This is done "naively" (not efficiently with a kernel). This function is used in unit tests or debugging.
5551+
"""
5552+
global_scales = []
5553+
x_global_scales = []
5554+
w_global_scales = []
5555+
5556+
for x, w in zip(xs, ws):
5557+
# pyre-ignore
5558+
x_global_scale: torch.Tensor = (448.0 * 6.0) / torch.amax(
5559+
torch.abs(x.flatten()), dim=-1
5560+
).to(torch.float32)
5561+
# pyre-ignore
5562+
w_global_scale: torch.Tensor = (448.0 * 6.0) / torch.amax(
5563+
torch.abs(w.flatten()), dim=-1
5564+
).to(torch.float32)
5565+
# pyre-ignore
5566+
global_scale: torch.Tensor = 1 / (x_global_scale * w_global_scale)
5567+
5568+
global_scales.append(global_scale)
5569+
x_global_scales.append(x_global_scale)
5570+
w_global_scales.append(w_global_scale)
5571+
5572+
return global_scales, x_global_scales, w_global_scales
5573+
5574+
5575+
def quantize_nvfp4_naive(
5576+
xs: list[torch.Tensor], global_scales: list[torch.Tensor]
5577+
) -> tuple[
5578+
list[torch.Tensor],
5579+
list[torch.Tensor],
5580+
]:
5581+
"""
5582+
Quantize A to NVFP4 format.
5583+
This is done "naively" using a kernel for each group. This function is largely used in unit tests or debugging.
5584+
"""
5585+
xqs, x_scales = zip(
5586+
*(
5587+
triton_scale_nvfp4_quant(x, global_scale)
5588+
for x, global_scale in zip(xs, global_scales)
5589+
)
5590+
)
5591+
5592+
return xqs, x_scales

0 commit comments

Comments
 (0)