diff --git a/benchmarks/mx_formats/cast_bench.py b/benchmarks/mx_formats/cast_bench.py index f4f635af1b..09a1fd0b1e 100644 --- a/benchmarks/mx_formats/cast_bench.py +++ b/benchmarks/mx_formats/cast_bench.py @@ -257,12 +257,12 @@ def run( elif mode == "dim0_nvfp4": to_nvfp4_reference_c = torch.compile(to_nvfp4_reference) - y_d0, s_d0 = to_nvfp4_reference_c(x, use_triton_kernel=False) + y_d0, s_d0 = to_nvfp4_reference_c(x) for _ in range(2): - __ = to_nvfp4_reference_c(x, use_triton_kernel=False) + __ = to_nvfp4_reference_c(x) time_us = benchmark_cuda_function_in_microseconds( - lambda x: to_nvfp4_reference_c(x, use_triton_kernel=False), + lambda x: to_nvfp4_reference_c(x), x, ) assert y_d0.dtype == torch.uint8 diff --git a/torchao/prototype/mx_formats/kernels.py b/torchao/prototype/mx_formats/kernels.py index 173d99f746..b733d7fee6 100644 --- a/torchao/prototype/mx_formats/kernels.py +++ b/torchao/prototype/mx_formats/kernels.py @@ -1441,6 +1441,8 @@ def quantize_nvfp4_triton_kernel( N, USE_TENSOR_SCALE: tl.constexpr, MASK_SCALES: tl.constexpr, + ROW_TILE_SIZE: tl.constexpr, + COL_TILE_SIZE: tl.constexpr, ): F4_E2M1_MAX = 6.0 F8E4M3_MAX = 448.0 @@ -1449,8 +1451,8 @@ def quantize_nvfp4_triton_kernel( pid_m = tl.program_id(1) pid_n = tl.program_id(0) - offs_m = pid_m * 128 + tl.arange(0, 128)[:, None] - offs_n = pid_n * 64 + tl.arange(0, 64)[None, :] + offs_m = pid_m * ROW_TILE_SIZE + tl.arange(0, ROW_TILE_SIZE)[:, None] + offs_n = pid_n * COL_TILE_SIZE + tl.arange(0, COL_TILE_SIZE)[None, :] if MASK_SCALES: mask = (offs_m < M) & (offs_n < N) other = 0.0 @@ -1460,10 +1462,10 @@ def quantize_nvfp4_triton_kernel( x = tl.load( x_ptr + offs_m * stride_xm + offs_n * stride_xn, mask=mask, other=other ) # [128, 64] - x_blocks = x.to(tl.float32).reshape(128, 4, 16) # [128, 4, 16] + x_blocks = x.to(tl.float32).reshape(ROW_TILE_SIZE, 4, 16) # [-1, 4, 16] # Compute block-wise scales - block_amax = tl.max(x_blocks.abs(), axis=2) # [128, 4] + block_amax = tl.max(x_blocks.abs(), axis=2) # [-1, 4] if USE_TENSOR_SCALE: # Two-level scaling: quantize block scales with per-tensor scale @@ -1513,9 +1515,13 @@ def quantize_nvfp4_triton_kernel( ) # Convert to FP4 - x_fp4x2 = convert_fp32_to_fp4_packed(x_blocks.reshape(128, 32, 2).split()) - offs_m = pid_m * 128 + tl.arange(0, 128)[:, None] - offs_n = pid_n * 32 + tl.arange(0, 32)[None, :] + x_fp4x2 = convert_fp32_to_fp4_packed( + x_blocks.reshape(ROW_TILE_SIZE, 32, 2).split() + ) + offs_m = pid_m * ROW_TILE_SIZE + tl.arange(0, ROW_TILE_SIZE)[:, None] + offs_n = ( + pid_n * (COL_TILE_SIZE // 2) + tl.arange(0, COL_TILE_SIZE // 2)[None, :] + ) if MASK_SCALES: mask = (offs_m < M) & (offs_n < N // 2) else: @@ -1537,7 +1543,7 @@ def triton_quantize_nvfp4( Tuple[torch.Tensor, torch.Tensor]: Quantized tensor and scales tensor in swizzled layout. Note: - Since VLLM does not use dyanmo guards we need to make this a custom op + Since VLLM does not use dynamo guards we need to make this a custom op to avoid the triton kernel being invoked w/ the wrong use of `MASK_SCALES` """ # reshape to 2d @@ -1571,6 +1577,8 @@ def triton_quantize_nvfp4( tensor_scale_ptr = per_tensor_scale use_tensor_scale = True + ROW_TILE_SIZE = 128 + COL_TILE_SIZE = 64 quantize_nvfp4_triton_kernel[grid]( x, tensor_scale_ptr, @@ -1582,6 +1590,8 @@ def triton_quantize_nvfp4( N, USE_TENSOR_SCALE=use_tensor_scale, MASK_SCALES=MASK_SCALES, + ROW_TILE_SIZE=ROW_TILE_SIZE, + COL_TILE_SIZE=COL_TILE_SIZE, ) # reshape back to original shape