Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions benchmarks/mx_formats/cast_bench.py
Original file line number Diff line number Diff line change
Expand Up @@ -257,12 +257,12 @@ def run(

elif mode == "dim0_nvfp4":
to_nvfp4_reference_c = torch.compile(to_nvfp4_reference)
y_d0, s_d0 = to_nvfp4_reference_c(x, use_triton_kernel=False)
y_d0, s_d0 = to_nvfp4_reference_c(x)

for _ in range(2):
__ = to_nvfp4_reference_c(x, use_triton_kernel=False)
__ = to_nvfp4_reference_c(x)
time_us = benchmark_cuda_function_in_microseconds(
lambda x: to_nvfp4_reference_c(x, use_triton_kernel=False),
lambda x: to_nvfp4_reference_c(x),
x,
)
assert y_d0.dtype == torch.uint8
Expand Down
26 changes: 18 additions & 8 deletions torchao/prototype/mx_formats/kernels.py
Original file line number Diff line number Diff line change
Expand Up @@ -1441,6 +1441,8 @@ def quantize_nvfp4_triton_kernel(
N,
USE_TENSOR_SCALE: tl.constexpr,
MASK_SCALES: tl.constexpr,
ROW_TILE_SIZE: tl.constexpr,
COL_TILE_SIZE: tl.constexpr,
):
F4_E2M1_MAX = 6.0
F8E4M3_MAX = 448.0
Expand All @@ -1449,8 +1451,8 @@ def quantize_nvfp4_triton_kernel(
pid_m = tl.program_id(1)
pid_n = tl.program_id(0)

offs_m = pid_m * 128 + tl.arange(0, 128)[:, None]
offs_n = pid_n * 64 + tl.arange(0, 64)[None, :]
offs_m = pid_m * ROW_TILE_SIZE + tl.arange(0, ROW_TILE_SIZE)[:, None]
offs_n = pid_n * COL_TILE_SIZE + tl.arange(0, COL_TILE_SIZE)[None, :]
if MASK_SCALES:
mask = (offs_m < M) & (offs_n < N)
other = 0.0
Expand All @@ -1460,10 +1462,10 @@ def quantize_nvfp4_triton_kernel(
x = tl.load(
x_ptr + offs_m * stride_xm + offs_n * stride_xn, mask=mask, other=other
) # [128, 64]
x_blocks = x.to(tl.float32).reshape(128, 4, 16) # [128, 4, 16]
x_blocks = x.to(tl.float32).reshape(ROW_TILE_SIZE, 4, 16) # [-1, 4, 16]

# Compute block-wise scales
block_amax = tl.max(x_blocks.abs(), axis=2) # [128, 4]
block_amax = tl.max(x_blocks.abs(), axis=2) # [-1, 4]

if USE_TENSOR_SCALE:
# Two-level scaling: quantize block scales with per-tensor scale
Expand Down Expand Up @@ -1513,9 +1515,13 @@ def quantize_nvfp4_triton_kernel(
)

# Convert to FP4
x_fp4x2 = convert_fp32_to_fp4_packed(x_blocks.reshape(128, 32, 2).split())
offs_m = pid_m * 128 + tl.arange(0, 128)[:, None]
offs_n = pid_n * 32 + tl.arange(0, 32)[None, :]
x_fp4x2 = convert_fp32_to_fp4_packed(
x_blocks.reshape(ROW_TILE_SIZE, 32, 2).split()
)
offs_m = pid_m * ROW_TILE_SIZE + tl.arange(0, ROW_TILE_SIZE)[:, None]
offs_n = (
pid_n * (COL_TILE_SIZE // 2) + tl.arange(0, COL_TILE_SIZE // 2)[None, :]
)
if MASK_SCALES:
mask = (offs_m < M) & (offs_n < N // 2)
else:
Expand All @@ -1537,7 +1543,7 @@ def triton_quantize_nvfp4(
Tuple[torch.Tensor, torch.Tensor]: Quantized tensor and scales tensor in swizzled layout.

Note:
Since VLLM does not use dyanmo guards we need to make this a custom op
Since VLLM does not use dynamo guards we need to make this a custom op
to avoid the triton kernel being invoked w/ the wrong use of `MASK_SCALES`
"""
# reshape to 2d
Expand Down Expand Up @@ -1571,6 +1577,8 @@ def triton_quantize_nvfp4(
tensor_scale_ptr = per_tensor_scale
use_tensor_scale = True

ROW_TILE_SIZE = 128
COL_TILE_SIZE = 64
quantize_nvfp4_triton_kernel[grid](
x,
tensor_scale_ptr,
Expand All @@ -1582,6 +1590,8 @@ def triton_quantize_nvfp4(
N,
USE_TENSOR_SCALE=use_tensor_scale,
MASK_SCALES=MASK_SCALES,
ROW_TILE_SIZE=ROW_TILE_SIZE,
COL_TILE_SIZE=COL_TILE_SIZE,
)

# reshape back to original shape
Expand Down
Loading