From 821bd2b7985f26743ef7644a60e7380cb16e8c26 Mon Sep 17 00:00:00 2001 From: vasiliy Date: Thu, 16 Oct 2025 07:41:27 -0700 Subject: [PATCH 1/7] Update [ghstack-poisoned] --- benchmarks/float8/float8_roofline.py | 22 ++++++++++-- torchao/testing/training/roofline_utils.py | 41 +++++++++++++++++++++- 2 files changed, 60 insertions(+), 3 deletions(-) diff --git a/benchmarks/float8/float8_roofline.py b/benchmarks/float8/float8_roofline.py index 4bf54538df..547b0a40e4 100644 --- a/benchmarks/float8/float8_roofline.py +++ b/benchmarks/float8/float8_roofline.py @@ -180,7 +180,7 @@ def get_gemm_times( scale_a = torch.ones(M, K // 32, device=device, dtype=torch.float8_e8m0fnu) scale_b = torch.ones(N, K // 32, device=device, dtype=torch.float8_e8m0fnu) else: - assert False, "TODO add cutlass mx gemm here" + assert False, f"unsupported {float8_recipe_name=} {mx_recipe_name=}" def do_matmul(A, B): return torch._scaled_mm( @@ -233,6 +233,20 @@ def run( print(f"mx_recipe_name: {mx_recipe_name}") print(f"enable_fusion_modeling: {enable_fusion_modeling}") + assert mx_recipe_name in ( + # real mxfp8_cublas recipe + "mxfp8_cublas", + # real mxfp8_cublas_rceil recipe + "mxfp8_cublas_rceil", + # modeling of what mxfp8 with 32x32 block size and without gemm + # operand layout restrictions would look like + "mxfp8_32x32_flexible_gemm_layout", + # modeling of what mxfp8 with 32x32 block size for weight + "mxfp8_32x32_weight", + # real mxfp4_cutlass recipe + "mxfp4_cutlass", + ), f"unsupported {mx_recipe_name=}" + M, K, N = sympy.symbols("M K N") fp8_ovhd_time_sympy = get_float8_mem_sympy( @@ -309,7 +323,11 @@ def run( rb_fp8_gemm_ratio = -1 if do_benchmarks: - assert mx_recipe_name != "mxfp4_cutlass", "unsupported" + assert mx_recipe_name not in ( + "mxfp4_cutlass", + "mxfp8_32x32_flexible_gemm_layout", + "mxfp8_32x32_weight", + ), f"do_benchmarks unsupported with {mx_recipe_name=}" # TODO(future): make the bf16 gemm times exactly match the e2e # benchmarks, there is a slight deviation, probably related to gemm diff --git a/torchao/testing/training/roofline_utils.py b/torchao/testing/training/roofline_utils.py index f57705333a..6610654bf1 100644 --- a/torchao/testing/training/roofline_utils.py +++ b/torchao/testing/training/roofline_utils.py @@ -187,13 +187,52 @@ def get_tensor_memory_traffic_ovhd_s( else: assert False, "unsupported" + elif mx_recipe_name == "mxfp8_32x32_flexible_gemm_layout": + # modeling the following: + # 1. mxfp8 scaling with 32x32 everywhere, so the format makes sense + # across dim0 and dim1 + # 2. mxfp8 gemm with TN, NT, TT, NN formats supported (not in + # PyTorch right now) + # x_bf16 = ... + # kernel 1: x_bf16 -> x_mxfp8_dim0 + if fuse_with_prev: + kernel_1_rw = 0 + BYTES_PER_EL_FLOAT8 * numel + else: + kernel_1_rw = BYTES_PER_EL_BF16 * numel + BYTES_PER_EL_FLOAT8 * numel + res_bytes = [kernel_1_rw] + + elif mx_recipe_name == "mxfp8_32x32_weight": + # modeling the following: + # 1. mxfp8 scaling with 32x32 weights, so the format makes sense + # across dim0 and dim1. input and grad_output still 1x32. + + if tensor_role in ("input", "grad_output"): + # kernel 1: x_bf16 -> x_mxfp8_dim0 + # kernel 2: x_bf16 -> x_mxfp8_dim1 + if fuse_with_prev: + kernel_1_rw = 0 + BYTES_PER_EL_FLOAT8 * numel + else: + kernel_1_rw = BYTES_PER_EL_BF16 * numel + BYTES_PER_EL_FLOAT8 * numel + kernel_2_rw = BYTES_PER_EL_BF16 * numel + BYTES_PER_EL_FLOAT8 * numel + + elif tensor_role == "weight": + # kernel 1: x_bf16 -> x_mxfp8_dim0 + # kernel 2: x_mxfp8_dim0 -> x_mxfp8_dim1 + kernel_1_rw = BYTES_PER_EL_BF16 * numel + BYTES_PER_EL_FLOAT8 * numel + kernel_2_rw = BYTES_PER_EL_FLOAT8 * numel * 2 + + else: + assert False, "unsupported" + + res_bytes = [kernel_1_rw, kernel_2_rw] + else: assert mx_recipe_name in ( "mxfp8_emulated", "mxfp8_cublas", "mxfp8_cublas_rceil", "mxfp4_cutlass", - ), "unsupported" + ), f"unsupported {mx_recipe_name=}" # For now, assume that we can't profitably fuse kernel 1 and kernel 2 # x_bf16 = ... # kernel 1: x_bf16 -> x_mxfp8_dim0 From 5bd4e3b4ff6617d6bb7eec8b13f6be99b1aeb40d Mon Sep 17 00:00:00 2001 From: vasiliy Date: Thu, 16 Oct 2025 13:32:59 -0700 Subject: [PATCH 2/7] Update [ghstack-poisoned] --- torchao/testing/training/roofline_utils.py | 1 + 1 file changed, 1 insertion(+) diff --git a/torchao/testing/training/roofline_utils.py b/torchao/testing/training/roofline_utils.py index 6610654bf1..e391a4d44b 100644 --- a/torchao/testing/training/roofline_utils.py +++ b/torchao/testing/training/roofline_utils.py @@ -207,6 +207,7 @@ def get_tensor_memory_traffic_ovhd_s( # across dim0 and dim1. input and grad_output still 1x32. if tensor_role in ("input", "grad_output"): + # TODO(future): update all of the mx rooflines to just read once # kernel 1: x_bf16 -> x_mxfp8_dim0 # kernel 2: x_bf16 -> x_mxfp8_dim1 if fuse_with_prev: From ea2d54f578ef0fb39d0556699429598419ce8927 Mon Sep 17 00:00:00 2001 From: vasiliy Date: Thu, 16 Oct 2025 14:09:19 -0700 Subject: [PATCH 3/7] Update [ghstack-poisoned] --- .../float8/float8_inference_roofline.py | 106 +++++++++++++----- 1 file changed, 81 insertions(+), 25 deletions(-) diff --git a/benchmarks/float8/float8_inference_roofline.py b/benchmarks/float8/float8_inference_roofline.py index fbfead161a..6c8113e8cb 100644 --- a/benchmarks/float8/float8_inference_roofline.py +++ b/benchmarks/float8/float8_inference_roofline.py @@ -38,6 +38,14 @@ ) import torchao +from torchao.prototype.mx_formats.config import ( + MXGemmKernelChoice, +) +from torchao.prototype.mx_formats.inference_workflow import ( + MXFPInferenceConfig, + NVFP4InferenceConfig, + NVFP4MMConfig, +) from torchao.quantization.quant_api import ( Float8DynamicActivationFloat8WeightConfig, PerRow, @@ -80,40 +88,67 @@ def get_gemm_times( fast_accum: bool, recipe_name: Optional[str], ): - assert recipe_name in {"rowwise"}, ( - "Only support real benchmarks for 'rowwise' recipe for now" - ) device = torch.device("cuda") # bf16 time x_bf16 = torch.randn(M, K, dtype=torch.bfloat16, device=device) - # w_bf16 = torch.randn(K, N, dtype=torch.bfloat16, device=device).t().contiguous().t() w_bf16 = torch.randn(K, N, dtype=torch.bfloat16, device=device) bf16_time_s = get_gpu_kernel_gemm_time_s(torch.mm, x_bf16, w_bf16) - e4m3_dtype = torch.float8_e4m3fn - if torch.version.hip and torch.cuda.is_available() and is_MI300(): - e4m3_dtype = torch.float8_e4m3fnuz - d1, d2, d3 = e4m3_dtype, e4m3_dtype, torch.bfloat16 - A = torch.randint(0, 255, (M, K), device=device, dtype=torch.uint8).view(d1) - B = ( - torch.randint(0, 255, (K, N), device=device, dtype=torch.uint8) - .view(d2) - .t() - .contiguous() - .t() - ) + if recipe_name in ("mxfp4_cutlass", "nvfp4"): + d1, d2, d3 = torch.float4_e2m1fn_x2, torch.float4_e2m1fn_x2, torch.bfloat16 + A = torch.randint(0, 255, (M, K // 2), device=device, dtype=torch.uint8).view( + d1 + ) + B = ( + torch.randint(0, 255, (K // 2, N), device=device, dtype=torch.uint8) + .t() + .contiguous() + .t() + .view(d2) + ) + else: + e4m3_dtype = torch.float8_e4m3fn + if torch.version.hip and torch.cuda.is_available() and is_MI300(): + e4m3_dtype = torch.float8_e4m3fnuz + d1, d2, d3 = e4m3_dtype, e4m3_dtype, torch.bfloat16 + A = torch.randint(0, 255, (M, K), device=device, dtype=torch.uint8).view(d1) + B = ( + torch.randint(0, 255, (K, N), device=device, dtype=torch.uint8) + .view(d2) + .t() + .contiguous() + .t() + ) + if recipe_name == "rowwise": scale_a = torch.ones(M, 1, device=device) scale_b = torch.ones(1, N, device=device) + elif recipe_name == "mxfp8_cublas": + scale_a = torch.ones(M, K // 32, device=device, dtype=torch.float8_e8m0fnu) + scale_b = torch.ones(N, K // 32, device=device, dtype=torch.float8_e8m0fnu) + elif recipe_name == "mxfp4_cutlass": + scale_a = torch.ones(M, K // 32, device=device, dtype=torch.float8_e8m0fnu) + scale_b = torch.ones(N, K // 32, device=device, dtype=torch.float8_e8m0fnu) + elif recipe_name == "nvfp4": + scale_a = torch.ones(M, K // 16, device=device, dtype=torch.float8_e4m3fn) + scale_b = torch.ones(N, K // 16, device=device, dtype=torch.float8_e4m3fn) + else: assert False, "unsupported" def do_matmul(A, B): - return torch._scaled_mm( - A, B, scale_a, scale_b, out_dtype=d3, use_fast_accum=fast_accum - ) + if recipe_name == "mxfp4_cutlass": + return torchao.ops.mx_fp4_bf16(A, B, scale_a, scale_b) + if recipe_name == "nvfp4": + return torch._scaled_mm( + A, B, scale_a, scale_b, out_dtype=d3, use_fast_accum=False + ) + else: + return torch._scaled_mm( + A, B, scale_a, scale_b, out_dtype=d3, use_fast_accum=fast_accum + ) f8_time_s = get_gpu_kernel_gemm_time_s(do_matmul, A, B) @@ -259,12 +294,33 @@ def run( # get the float8 dynamic scaling gpu kernel time torch._dynamo.reset() - config = Float8DynamicActivationFloat8WeightConfig( - granularity=PerRow(), - # for now, use TORCH. In the future might be interesting - # to benchmark AUTO and FBGEMM. - kernel_preference=KernelPreference.TORCH, - ) + if recipe_name == "rowwise": + config = Float8DynamicActivationFloat8WeightConfig( + granularity=PerRow(), + # for now, use TORCH. In the future might be interesting + # to benchmark AUTO and FBGEMM. + kernel_preference=KernelPreference.TORCH, + ) + elif recipe_name == "mxfp8_cublas": + config = MXFPInferenceConfig( + activation_dtype=torch.float8_e4m3fn, + weight_dtype=torch.float8_e4m3fn, + gemm_kernel_choice=MXGemmKernelChoice.CUBLAS, + ) + elif recipe_name == "mxfp4_cutlass": + config = MXFPInferenceConfig( + activation_dtype=torch.float4_e2m1fn_x2, + weight_dtype=torch.float4_e2m1fn_x2, + gemm_kernel_choice=MXGemmKernelChoice.CUTLASS, + ) + elif recipe_name == "nvfp4": + config = NVFP4InferenceConfig( + mm_config=NVFP4MMConfig.DYNAMIC, + use_dynamic_per_tensor_scale=False, + ) + else: + assert False, "unsupported" + m_fp8_dyn = copy.deepcopy(m_orig) quantize_(m_fp8_dyn, config) From b88850f0d83a7cac38b83868da00ddfaf2f9ab26 Mon Sep 17 00:00:00 2001 From: vasiliy Date: Thu, 16 Oct 2025 17:44:34 -0700 Subject: [PATCH 4/7] Update [ghstack-poisoned] --- .../float8/float8_inference_roofline.py | 22 ++++++++++++++++--- 1 file changed, 19 insertions(+), 3 deletions(-) diff --git a/benchmarks/float8/float8_inference_roofline.py b/benchmarks/float8/float8_inference_roofline.py index 6c8113e8cb..3365fba923 100644 --- a/benchmarks/float8/float8_inference_roofline.py +++ b/benchmarks/float8/float8_inference_roofline.py @@ -60,7 +60,7 @@ @torch.no_grad() -def get_gpu_kernel_time(m, x): +def get_gpu_kernel_time(m, x, trace_filename=None): # warm up for _ in range(2): __ = m(x) @@ -72,6 +72,12 @@ def get_gpu_kernel_time(m, x): for _ in range(n_iter): __ = m(x) torch.cuda.synchronize() + + # save a trace, if requested + if trace_filename is not None: + print(f"exporting trace to {trace_filename}") + prof.export_chrome_trace(trace_filename) + # get the gpu kernel time and aggregate it num_leaf_tensors = 1 + len(list(m.parameters())) ref_times = profiler_output_to_filtered_time_by_kernel_name( @@ -161,6 +167,7 @@ def run( do_benchmarks: bool = True, shape_gen_name: str = "pow2", n_limit: Optional[int] = None, + save_profile_traces: bool = False, ): """ Args: @@ -168,6 +175,7 @@ def run( * `do_benchmarks`: if True, gemm and e2e fwd+bwd of LNLinearSigmoid are benchmarked * `shape_gen_name`: `llama`, `pow2`, `pow2_extended`, or `sweep` * `n_limit (optional)`: if specified, only runs `n_limit` iterations + # `save_profile_traces (optional)`: if True, saves profiling traces """ config_table = [ ["GPU", torch.cuda.get_device_name(0)], @@ -289,7 +297,11 @@ def run( # get the bf16 gpu kernel time torch._dynamo.reset() m_bf16 = torch.compile(copy.deepcopy(m_orig)) - b_bf16_e2e_time_s = get_gpu_kernel_time(m_bf16, x) + + bf16_trace_filename = None + if save_profile_traces: + bf16_trace_filename = f"{outfile}_{M_val}_{K_val}_{N_val}_bf16.json" + b_bf16_e2e_time_s = get_gpu_kernel_time(m_bf16, x, bf16_trace_filename) # get the float8 dynamic scaling gpu kernel time torch._dynamo.reset() @@ -325,7 +337,11 @@ def run( quantize_(m_fp8_dyn, config) m_fp8_dyn = torch.compile(m_fp8_dyn) - b_fp8_e2e_time_s = get_gpu_kernel_time(m_fp8_dyn, x) + + fp8_trace_filename = None + if save_profile_traces: + fp8_trace_filename = f"{outfile}_{M_val}_{K_val}_{N_val}_fp8.json" + b_fp8_e2e_time_s = get_gpu_kernel_time(m_fp8_dyn, x, fp8_trace_filename) r_speedup = r_bf16_gemm_time_s / (r_fp8_gemm_time_s + r_fp8_ovhd_time_s) From 07749746ad0700c590d6b2f491b343e79218bcb5 Mon Sep 17 00:00:00 2001 From: vasiliy Date: Thu, 16 Oct 2025 18:51:11 -0700 Subject: [PATCH 5/7] Update [ghstack-poisoned] --- test/prototype/mx_formats/test_mx_tensor.py | 54 +++++++++++++++++++++ torchao/prototype/mx_formats/kernels.py | 4 +- torchao/prototype/mx_formats/mx_tensor.py | 5 +- 3 files changed, 59 insertions(+), 4 deletions(-) diff --git a/test/prototype/mx_formats/test_mx_tensor.py b/test/prototype/mx_formats/test_mx_tensor.py index 577112b16a..1a3631cc53 100644 --- a/test/prototype/mx_formats/test_mx_tensor.py +++ b/test/prototype/mx_formats/test_mx_tensor.py @@ -662,3 +662,57 @@ def test_to_blocked_from_blocked_roundtrip(shape, use_triton_kernel: bool): rtol=0.0, msg=f"Roundtrip failed for shape {shape} with use_triton_kernel={use_triton_kernel}", ) + + +@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available") +@pytest.mark.skipif(not torch_version_at_least("2.8.0"), reason="requires PyTorch 2.8+") +@pytest.mark.parametrize("transpose", [False, True]) +@pytest.mark.parametrize( + "shape", + ( + (128, 64), + (1, 128, 64), + ), +) +def test_scale_shape_matches_qdata(transpose, shape): + if len(shape) == 3 and transpose: + pytest.skip("transpose not yet implemented for 3D MXTensor") + + block_size = 32 + + x_hp = torch.randn(*shape, device="cuda") + x = MXTensor.to_mx( + x_hp, + torch.float8_e4m3fn, + block_size, + ScaleCalculationMode.FLOOR, + ) + + if len(shape) == 2: + m_dim, k_dim = 0, 1 + if transpose: + x_hp = x_hp.t() + x = x.t() + m_dim, k_dim = 1, 0 + else: + assert len(shape) == 3, "unsupported" + m_dim, k_dim = 1, 2 + if transpose: + x_hp = x_hp.transpose(-2, -1) + x = x.transpose(-2, -1) + m_dim, k_dim = 2, 1 + + orig_m = x_hp.shape[m_dim] + expected_padded_m = orig_m + actual_padded_m = x.scale.shape[m_dim] + assert expected_padded_m == actual_padded_m, ( + f"incompatible padded shape for dim {m_dim}: {expected_padded_m=}, {actual_padded_m=}, {x.shape}, {x.scale.shape}" + ) + + orig_k = x_hp.shape[k_dim] + expected_padded_k = orig_k // block_size + actual_padded_k = x.scale.shape[k_dim] + + assert expected_padded_k == actual_padded_k, ( + f"incompatible padded shape for dim {k_dim}: {expected_padded_k}, {actual_padded_k=}, {x.shape}, {x.scale.shape}" + ) diff --git a/torchao/prototype/mx_formats/kernels.py b/torchao/prototype/mx_formats/kernels.py index 69bb076b40..c69da4d076 100644 --- a/torchao/prototype/mx_formats/kernels.py +++ b/torchao/prototype/mx_formats/kernels.py @@ -1264,7 +1264,7 @@ def triton_to_mxfp8_dim1( return ( output_col_major.t(), - col_scale.view(torch.float8_e8m0fnu), + col_scale.view(torch.float8_e8m0fnu).squeeze(-1), ) @register_sharding(torch.ops.torchao.triton_to_mxfp8_dim1.default) @@ -1293,7 +1293,7 @@ def triton_to_mxfp8_dim1_reference( scale_e8m0_dim1 = scale_e8m0_dim1.view(torch.float8_e8m0fnu) return ( x_hp_d1_normalized.t(), - scale_e8m0_dim1.unsqueeze(-1), + scale_e8m0_dim1, ) @triton.jit diff --git a/torchao/prototype/mx_formats/mx_tensor.py b/torchao/prototype/mx_formats/mx_tensor.py index 05c8fdc8e4..a5e50b2468 100644 --- a/torchao/prototype/mx_formats/mx_tensor.py +++ b/torchao/prototype/mx_formats/mx_tensor.py @@ -362,6 +362,7 @@ def to_dtype( # unpacking and unscaling if is_transposed: data_lp = data_lp.t() + scale_e8m0 = scale_e8m0.t() assert data_lp.is_contiguous() orig_shape = (orig_shape[1], orig_shape[0]) @@ -688,7 +689,7 @@ def _addmm_mx_dispatch( assert b._block_size == 32, f"Invalid block size {b._block_size}" a_scale = a.scale.view(M, K // a._block_size) - b_scale = b.scale.view(N, K // b._block_size) + b_scale = b.scale.t().view(N, K // b._block_size) a_scale_block = to_blocked(a_scale) b_scale_block = to_blocked(b_scale) @@ -759,7 +760,7 @@ def mx_t(func, types, args, kwargs): old = args[0] new = MXTensor( old.qdata.t(), - old.scale, + old.scale.t(), old._elem_dtype, old._block_size, old._orig_dtype, From 00c007cac41486c85d1dde1394c0a84672dbcfa4 Mon Sep 17 00:00:00 2001 From: vasiliy Date: Fri, 17 Oct 2025 08:06:29 -0700 Subject: [PATCH 6/7] Update [ghstack-poisoned] --- test/prototype/mx_formats/test_mx_tensor.py | 71 +++++- .../mx_formats/inference_workflow.py | 2 + torchao/prototype/mx_formats/mx_tensor.py | 130 +++++----- torchao/prototype/mx_formats/nvfp4_tensor.py | 172 +------------ torchao/prototype/mx_formats/utils.py | 226 ++++++++++++++++++ 5 files changed, 361 insertions(+), 240 deletions(-) diff --git a/test/prototype/mx_formats/test_mx_tensor.py b/test/prototype/mx_formats/test_mx_tensor.py index 1a3631cc53..0f22f2f8ae 100644 --- a/test/prototype/mx_formats/test_mx_tensor.py +++ b/test/prototype/mx_formats/test_mx_tensor.py @@ -5,6 +5,8 @@ # This source code is licensed under the license found in the # LICENSE file in the root directory of this source tree. +import math + import pytest import torch from torch._inductor.utils import run_and_get_code @@ -22,6 +24,7 @@ ScaleCalculationMode, to_dtype, ) +from torchao.prototype.mx_formats.utils import from_blocked, to_blocked from torchao.quantization.utils import compute_error from torchao.utils import ( is_sm_at_least_89, @@ -388,6 +391,7 @@ def test_exponent_nan_out(elem_dtype, pack_fp6): MXGemmKernelChoice.EMULATED, pack_fp6, None, + False, ) tensor_hp = tensor_mx.dequantize(torch.float) assert torch.all(torch.isnan(tensor_hp.flatten()[0:4])) @@ -645,8 +649,6 @@ def to_f8(x): not torch_version_at_least("2.8.0"), reason="torch.compile requires PyTorch 2.8+" ) def test_to_blocked_from_blocked_roundtrip(shape, use_triton_kernel: bool): - from torchao.prototype.mx_formats.utils import from_blocked, to_blocked - rows, cols = shape device = "cuda" if torch.cuda.is_available() else "cpu" @@ -716,3 +718,68 @@ def test_scale_shape_matches_qdata(transpose, shape): assert expected_padded_k == actual_padded_k, ( f"incompatible padded shape for dim {k_dim}: {expected_padded_k}, {actual_padded_k=}, {x.shape}, {x.scale.shape}" ) + + +@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available") +@pytest.mark.skipif(not torch_version_at_least("2.8.0"), reason="requires PyTorch 2.8+") +@pytest.mark.parametrize("elem_dtype", (torch.float8_e4m3fn, torch.float4_e2m1fn_x2)) +@pytest.mark.parametrize("transpose", [False, True]) +@pytest.mark.parametrize( + "shape", + ( + (128, 64), + (1, 128, 64), + ), +) +def test_swizzle(elem_dtype, transpose, shape): + if len(shape) == 3 and transpose: + pytest.skip("transpose not yet implemented for 3D MXTensor") + + block_size = 32 + + x_hp = torch.randn(*shape, device="cuda") + x = MXTensor.to_mx( + x_hp, + elem_dtype, + block_size, + ScaleCalculationMode.FLOOR, + ) + + xs = MXTensor.to_mx( + x_hp, + elem_dtype, + block_size, + ScaleCalculationMode.FLOOR, + is_swizzled_scales=True, + ) + + if transpose: + x = x.t() + xs = xs.t() + + torch.testing.assert_close(x.qdata, xs.qdata, atol=0, rtol=0) + + if transpose: + leading_dims, M, K = x.shape[:-2], x.shape[-1], x.shape[-2] + xs_scale_unblocked = from_blocked( + xs.scale.t(), math.prod(leading_dims) * M, K // block_size + ) + xs_scale_unblocked = xs_scale_unblocked.view(*leading_dims, M, K // block_size) + xs_scale_unblocked = xs_scale_unblocked.t() + else: + leading_dims, M, K = x.shape[:-2], x.shape[-2], x.shape[-1] + xs_scale_unblocked = from_blocked( + xs.scale, math.prod(leading_dims) * M, K // block_size + ) + xs_scale_unblocked = xs_scale_unblocked.view(*leading_dims, M, K // block_size) + + torch.testing.assert_close( + x.scale, + xs_scale_unblocked, + atol=0, + rtol=0, + ) + + x_dq = x.dequantize(x.dtype) + xs_dq = xs.dequantize(xs.dtype) + torch.testing.assert_close(x_dq, xs_dq, atol=0, rtol=0) diff --git a/torchao/prototype/mx_formats/inference_workflow.py b/torchao/prototype/mx_formats/inference_workflow.py index 0cbe6e995f..8725c33b44 100644 --- a/torchao/prototype/mx_formats/inference_workflow.py +++ b/torchao/prototype/mx_formats/inference_workflow.py @@ -111,6 +111,7 @@ def _mx_inference_linear_transform( block_size=config.block_size, gemm_kernel_choice=config.gemm_kernel_choice, pack_fp6=False, + is_swizzled_scales=True, ) # Convert weight to MX Tensor @@ -121,6 +122,7 @@ def _mx_inference_linear_transform( gemm_kernel_choice=config.gemm_kernel_choice, pack_fp6=False, # TODO act_quant_kwargs=act_quant_kwargs, + is_swizzled_scales=True, ) module.weight = torch.nn.Parameter(quantized_weight, requires_grad=False) diff --git a/torchao/prototype/mx_formats/mx_tensor.py b/torchao/prototype/mx_formats/mx_tensor.py index a5e50b2468..3ad7d5f268 100644 --- a/torchao/prototype/mx_formats/mx_tensor.py +++ b/torchao/prototype/mx_formats/mx_tensor.py @@ -17,6 +17,7 @@ * Zeros: N/A """ +import math from dataclasses import dataclass from typing import Optional, Union @@ -62,7 +63,12 @@ triton_f6_e3m2_to_scaled_bf16, unpack_uint4, ) -from torchao.prototype.mx_formats.utils import to_blocked +from torchao.prototype.mx_formats.utils import ( + _swizzle_aware_slice, + from_blocked, + hp_data_dims_to_swizzled_scale_dims_mx, + to_blocked, +) from torchao.quantization.quantize_.common import ( QuantizeTensorKwargs, ) @@ -86,6 +92,7 @@ class QuantizeTensorToMXKwargs(QuantizeTensorKwargs): scaling_mode: ScaleCalculationMode = ScaleCalculationMode.FLOOR gemm_kernel_choice: MXGemmKernelChoice = MXGemmKernelChoice.EMULATED pack_fp6: bool = False + is_swizzled_scales: bool = False def _to_mx_rceil( @@ -142,6 +149,7 @@ def to_mx( block_size: int, scaling_mode: ScaleCalculationMode = ScaleCalculationMode.FLOOR, pack_fp6: bool = False, + is_swizzled_scales: bool = False, ): """ Takes a high precision tensor and converts to MX scale and raw data, in @@ -321,13 +329,21 @@ def to_mx( # approach for fp4x2 in any case data_lp = data_lp.reshape(orig_shape) data_lp = f32_to_f4_unpacked(data_lp) - orig_shape = [*orig_shape[:-1], orig_shape[-1] // 2] data_lp = pack_uint4(data_lp) else: raise AssertionError("unsupported") scale_e8m0_biased = scale_e8m0_biased.view(torch.float8_e8m0fnu) scale_e8m0_biased = scale_e8m0_biased.squeeze(-1) + + # if user requested scale swizzling, do it here + if is_swizzled_scales: + leading_dims, M, K = orig_shape[:-2], orig_shape[-2], orig_shape[-1] + scale_shape = (math.prod(leading_dims) * M, K // block_size) + scale = to_blocked(scale_e8m0_biased.view(scale_shape)).flatten() + scale_M, scale_K = hp_data_dims_to_swizzled_scale_dims_mx(M, K) + scale_e8m0_biased = scale.view(*leading_dims, scale_M, scale_K) + return scale_e8m0_biased, data_lp @@ -479,6 +495,7 @@ class MXTensor(TorchAOBaseTensor): "_gemm_kernel_choice", "_pack_fp6", "act_quant_kwargs", + "_is_swizzled_scales", ] def __new__( @@ -491,6 +508,7 @@ def __new__( gemm_kernel_choice, pack_fp6, act_quant_kwargs, + is_swizzled_scales, ): new_size = qdata.size() if elem_dtype == torch.float4_e2m1fn_x2: @@ -526,31 +544,6 @@ def __new__( torch.float8_e5m2, torch.uint8, ), "unsupported" - if elem_dtype in ( - torch.float8_e4m3fn, - torch.float8_e5m2, - ): - target_numel = scale_e8m0_bits.numel() * block_size - elif elem_dtype == torch.float4_e2m1fn_x2: - assert qdata.dtype is torch.uint8 # fp4 - target_numel = scale_e8m0_bits.numel() * block_size / 2 - elif elem_dtype in [DTYPE_FP6_E2M3, DTYPE_FP6_E3M2]: - assert qdata.dtype is torch.uint8 # fp4 - target_numel = scale_e8m0_bits.numel() * block_size - if pack_fp6: - target_numel = 3 * target_numel // 4 - else: - raise AssertionError("unsupported") - if not issubclass( - torch._subclasses.fake_tensor.FakeTensor, - type(qdata), - ): - # this check is sometimes broken for FakeTensor - # TODO investigate - assert target_numel == qdata.numel(), f"{target_numel} != {qdata.numel()}" - - # `scale` has rank 1 and applies to a row-major memory layout of - # `qdata` self.qdata = qdata self.scale = scale_e8m0_bits self._elem_dtype = elem_dtype @@ -559,11 +552,12 @@ def __new__( self._gemm_kernel_choice = gemm_kernel_choice self._pack_fp6 = pack_fp6 self.act_quant_kwargs = act_quant_kwargs + self._is_swizzled_scales = is_swizzled_scales return self def __repr__(self): # TODO better elem dtype print for fp4 - return f"MXTensor: elem_dtype: {self._elem_dtype}, s_e8m0: {self.scale}, d: {self.qdata}, act_quant_kwargs: {self.act_quant_kwargs}" # noqa: E501 + return f"MXTensor: elem_dtype: {self._elem_dtype}, s_e8m0: {self.scale}, d: {self.qdata}, act_quant_kwargs: {self.act_quant_kwargs}, _is_swizzled_scales={self._is_swizzled_scales}" # noqa: E501 def _quantization_type(self): return f"{self._elem_dtype=}, {self._block_size=}, {self._orig_dtype=}, {self._gemm_kernel_choice=}, {self.act_quant_kwargs=}" @@ -571,9 +565,25 @@ def _quantization_type(self): def dequantize(self, output_dtype: Optional[torch.dtype] = None) -> torch.Tensor: if output_dtype is None: output_dtype = self.dtype + + scale = self.scale + if self._is_swizzled_scales: + is_transposed = self.qdata.stride(-2) < self.qdata.stride(-1) + if is_transposed: + leading_dims, M, K = self.shape[:-2], self.shape[-1], self.shape[-2] + scale = scale.transpose(-2, -1) + else: + leading_dims, M, K = self.shape[:-2], self.shape[-2], self.shape[-1] + scale = from_blocked( + scale, math.prod(leading_dims) * M, K // self._block_size + ) + scale = scale.view(*leading_dims, M, K // self._block_size) + if is_transposed: + scale = scale.transpose(-2, -1) + return to_dtype( self.qdata, - self.scale, + scale, self._elem_dtype, self._block_size, output_dtype, @@ -591,9 +601,10 @@ def to_mx( gemm_kernel_choice: MXGemmKernelChoice = MXGemmKernelChoice.EMULATED, pack_fp6: bool = False, act_quant_kwargs: Optional[QuantizeTensorToMXKwargs] = None, + is_swizzled_scales: bool = False, ): scale_e8m0_biased, data_lp = to_mx( - data_hp, elem_dtype, block_size, scaling_mode, pack_fp6 + data_hp, elem_dtype, block_size, scaling_mode, pack_fp6, is_swizzled_scales ) if isinstance(scale_e8m0_biased, DTensor): assert isinstance(data_lp, DTensor), "unsupported" @@ -608,6 +619,7 @@ def to_mx( gemm_kernel_choice, pack_fp6, act_quant_kwargs, + is_swizzled_scales, ) return DTensor.from_local( inner_mx_tensor, @@ -626,6 +638,7 @@ def to_mx( gemm_kernel_choice, pack_fp6, act_quant_kwargs, + is_swizzled_scales, ) # Do not force the MXTensor type on the returned tensor @@ -676,6 +689,7 @@ def _addmm_mx_dispatch( k.scaling_mode, k.gemm_kernel_choice, k.pack_fp6, + k.is_swizzled_scales, ) gemm_choice = _get_gemm_choice(a._gemm_kernel_choice, b._gemm_kernel_choice) @@ -688,10 +702,17 @@ def _addmm_mx_dispatch( assert a._block_size == 32, f"Invalid block size {a._block_size}" assert b._block_size == 32, f"Invalid block size {b._block_size}" - a_scale = a.scale.view(M, K // a._block_size) - b_scale = b.scale.t().view(N, K // b._block_size) - a_scale_block = to_blocked(a_scale) - b_scale_block = to_blocked(b_scale) + if a._is_swizzled_scales: + a_scale_block = a.scale + else: + a_scale = a.scale.view(M, K // a._block_size) + a_scale_block = to_blocked(a_scale) + + if b._is_swizzled_scales: + b_scale_block = b.scale.t() + else: + b_scale = b.scale.t().view(N, K // b._block_size) + b_scale_block = to_blocked(b_scale) if a._elem_dtype == torch.float8_e4m3fn: assert b._elem_dtype == torch.float8_e4m3fn @@ -767,6 +788,7 @@ def mx_t(func, types, args, kwargs): old._gemm_kernel_choice, old._pack_fp6, old.act_quant_kwargs, + old._is_swizzled_scales, ) return new @@ -811,6 +833,7 @@ def mx_view_op(func, types, args, kwargs): args[0]._gemm_kernel_choice, args[0]._pack_fp6, args[0].act_quant_kwargs, + args[0]._is_swizzled_scales, ) @@ -821,41 +844,7 @@ def mx_slice(func, types, args, kwargs): if step != 1: raise ValueError("Only support aten.slice with step=1") - M, K = x.shape[0], x.shape[1] - - # TODO why doesn't scale have shape? - scale_shaped = x.scale.view(M, K // x._block_size) - - if dim == 0: - # Slicing along the first dimension (rows) TODO assuming that dim 1 is reduciton dim for now - sliced_scale = aten.slice.Tensor(scale_shaped, dim, start, end, step) - sliced_data = aten.slice.Tensor(x.qdata, dim, start, end, step).unsqueeze(-1) - elif dim == 1: - # Slicing along reduciton dim - if start is not None: - # Assert start is a multiple of block_size - assert start % x._block_size == 0, ( - f"Start index {start} must be a multiple of block_size {x._block_size}" - ) - - if end is not None: - # Assert end is a multiple of block_size - assert end % x._block_size == 0, ( - f"End index {end} must be a multiple of block_size {x._block_size}" - ) - - sliced_data = aten.slice.Tensor(x.qdata, dim, start, end, step) - - # Calculate which scale elements to keep - start_block = 0 if start is None else start // x._block_size - end_block = -1 if end is None else end // x._block_size - - # Slice the scale tensor accordingly - sliced_scale = aten.slice.Tensor(scale_shaped, 1, start_block, end_block, step) - else: - raise ValueError( - f"MXTensor only supports slicing along dimensions 0 and 1, got dim={dim}" - ) + sliced_data, sliced_scale = _swizzle_aware_slice(x, dim, start, end, step) return return_and_correct_aliasing( func, @@ -870,6 +859,7 @@ def mx_slice(func, types, args, kwargs): x._gemm_kernel_choice, x._pack_fp6, x.act_quant_kwargs, + x._is_swizzled_scales, ), ) @@ -894,6 +884,7 @@ def mx_select(func, types, args, kwargs): assert len(old_mx_tensor.qdata.shape) == len(old_mx_tensor.scale.shape), ( "unsupported" ) + assert not old_mx_tensor._is_swizzled_scales, "unsupported" new_mx_tensor = old_mx_tensor.__class__( old_mx_tensor.qdata[index], old_mx_tensor.scale[index], @@ -903,5 +894,6 @@ def mx_select(func, types, args, kwargs): old_mx_tensor._gemm_kernel_choice, old_mx_tensor._pack_fp6, old_mx_tensor.act_quant_kwargs, + old_mx_tensor._is_swizzled_scales, ) return return_and_correct_aliasing(func, args, kwargs, new_mx_tensor) diff --git a/torchao/prototype/mx_formats/nvfp4_tensor.py b/torchao/prototype/mx_formats/nvfp4_tensor.py index cf1b971744..2397270d5e 100644 --- a/torchao/prototype/mx_formats/nvfp4_tensor.py +++ b/torchao/prototype/mx_formats/nvfp4_tensor.py @@ -5,7 +5,6 @@ # LICENSE file in the root directory of this source tree. import math -import sys from dataclasses import dataclass from enum import Enum from typing import Any, Dict, Optional @@ -26,6 +25,7 @@ tensor_size_hp_to_fp4x2, ) from torchao.prototype.mx_formats.utils import ( + _swizzle_aware_slice, from_blocked, hp_data_dims_to_swizzled_scale_dims_nvfp4, to_blocked, @@ -33,7 +33,7 @@ from torchao.quantization.quantize_.common import ( QuantizeTensorKwargs, ) -from torchao.utils import TorchAOBaseTensor, ceil_div, fill_defaults +from torchao.utils import TorchAOBaseTensor, fill_defaults E4M3_EPS = torch.finfo(torch.float8_e4m3fn).tiny @@ -393,173 +393,7 @@ def nvfp4_slice(func, types, args, kwargs): f"only rank 2 is supported for slice, got rank {len(x.shape)}" ) - M, K = x.shape[0], x.shape[1] - - # The scale manipulations below assume a flattened scale. For now, we - # flatten the scale, go through the calculations below, and then reshape - # it back to the format which matches the shape of `qdata`. - # TODO(future PR): update this - - if x._is_swizzled_scales: - scale_rows = M - scale_cols = K // x._block_size - n_row_blocks = ceil_div(scale_rows, 128) - n_col_blocks = ceil_div(scale_cols, 4) - elements_per_block = 32 * 16 # 512 elements - - if dim == 0: - # Row slicing - # Handle sys.maxsize (default slice end) - if end == sys.maxsize: - end = M - - # Check if start/end align with 128-row boundaries - if start is not None and start % 128 != 0: - raise RuntimeError( - f"Row slicing of NVFP4Tensor with swizzled scales requires " - f"start index to be a multiple of 128, got {start}" - ) - if end is not None and end != M and end % 128 != 0: - raise RuntimeError( - f"Row slicing of NVFP4Tensor with swizzled scales requires " - f"end index to be a multiple of 128 or equal to tensor size {M}, got {end}" - ) - - # Calculate which row blocks to keep - start_block = 0 if start is None else start // 128 - end_block = n_row_blocks if end is None or end >= M else end // 128 - - # The swizzled tensor has shape (n_row_blocks * n_col_blocks * 32 * 16,) - blocks_per_row = n_col_blocks - start_idx = start_block * blocks_per_row * elements_per_block - end_idx = ( - end_block * blocks_per_row * elements_per_block - if end_block < n_row_blocks - else None - ) - - sliced_scale = aten.slice.Tensor( - x.scale.flatten(), 0, start_idx, end_idx, 1 - ) - sliced_data = aten.slice.Tensor(x.qdata, 0, start, end, step) - - elif dim == 1: - # Column slicing - # Handle sys.maxsize (default slice end) - if end == sys.maxsize: - end = K - - # Check if start/end align with 64-column boundaries (4 scale columns * 16 block_size) - if start is not None and start % 64 != 0: - raise RuntimeError( - f"Column slicing of NVFP4Tensor with swizzled scales requires " - f"start index to be a multiple of 64, got {start}" - ) - if end is not None and end != K and end % 64 != 0: - raise RuntimeError( - f"Column slicing of NVFP4Tensor with swizzled scales requires " - f"end index to be a multiple of 64 or equal to tensor size {K}, got {end}" - ) - - # Also check FP4 packing alignment - if start is not None and start % 2 != 0: - raise RuntimeError(f"Start index {start} must be even for FP4 packing") - if end is not None and end != K and end % 2 != 0: - raise RuntimeError(f"End index {end} must be even for FP4 packing") - - # Calculate which column blocks to keep - start_scale_col = 0 if start is None else start // 16 - end_scale_col = scale_cols if end is None or end >= K else end // 16 - - start_col_block = start_scale_col // 4 - end_col_block = end_scale_col // 4 - - # Verify the end aligns with block boundary - if end_scale_col % 4 != 0: - raise RuntimeError( - f"Column slicing end index {end} does not align with scale block boundaries. " - f"End must result in a multiple of 4 scale columns (64 data columns)." - ) - - if start_col_block == 0 and end_col_block == n_col_blocks: - # Full width - no slicing needed - sliced_scale = x.scale - else: - # Extract specific column blocks from each row block - # Each row block in swizzled format contains n_col_blocks chunks of (32, 16) - elements_per_row_block = n_col_blocks * elements_per_block - - # Build list of slices to extract - slices_to_extract = [] - for row_block in range(n_row_blocks): - row_start = row_block * elements_per_row_block - col_start = row_start + start_col_block * elements_per_block - col_end = row_start + end_col_block * elements_per_block - slices_to_extract.append(x.scale.flatten()[col_start:col_end]) - - # Concatenate all the slices - sliced_scale = torch.cat(slices_to_extract, dim=0) - - # Slice the data tensor - packed_start = None if start is None else start // 2 - packed_end = None if end is None else end // 2 - sliced_data = aten.slice.Tensor( - x.qdata, dim, packed_start, packed_end, step - ) - - else: - raise ValueError( - f"NVFP4Tensor only supports slicing along dimensions 0 and 1, got dim={dim}" - ) - - else: - scale_shaped = x.scale.view(M, K // x._block_size) - - if dim == 0: - sliced_scale = aten.slice.Tensor(scale_shaped, dim, start, end, step) - sliced_data = aten.slice.Tensor(x.qdata, dim, start, end, step) - - elif dim == 1: - if start is not None: - assert start % x._block_size == 0, ( - f"Start index {start} must be a multiple of block_size {x._block_size}" - ) - assert start % 2 == 0, ( - f"Start index {start} must be even for FP4 packing" - ) - - if end is not None and end != sys.maxsize: - assert end % x._block_size == 0, ( - f"End index {end} must be a multiple of block_size {x._block_size}" - ) - assert end % 2 == 0, f"End index {end} must be even for FP4 packing" - - packed_start = None if start is None else start // 2 - packed_end = None if end is None else end // 2 - sliced_data = aten.slice.Tensor( - x.qdata, dim, packed_start, packed_end, step - ) - - start_block = 0 if start is None else start // x._block_size - end_block = None if end is None else end // x._block_size - sliced_scale = aten.slice.Tensor( - scale_shaped, 1, start_block, end_block, step - ) - - sliced_scale = sliced_scale.flatten() - - # reshape at the end - sliced_M = sliced_data.shape[0] - # multiply by 2 to convert from bytes to num_elements - sliced_K = sliced_data.shape[1] * 2 - if x._is_swizzled_scales: - scale_M, scale_K = hp_data_dims_to_swizzled_scale_dims_nvfp4(sliced_M, sliced_K) - else: - # a 1x16 unpacked or 1x8 packed qdata tile corresponds to 1 - # scale element - scale_M = sliced_M - scale_K = sliced_K // x._block_size - sliced_scale = sliced_scale.view(scale_M, scale_K) + sliced_data, sliced_scale = _swizzle_aware_slice(x, dim, start, end, step) # Create result tensor result = NVFP4Tensor( diff --git a/torchao/prototype/mx_formats/utils.py b/torchao/prototype/mx_formats/utils.py index 28a8526709..d96a8b48af 100644 --- a/torchao/prototype/mx_formats/utils.py +++ b/torchao/prototype/mx_formats/utils.py @@ -4,6 +4,7 @@ # This source code is licensed under the license found in the # LICENSE file in the root directory of this source tree. +import sys from typing import Tuple import torch @@ -21,6 +22,8 @@ Tensor = torch.Tensor +aten = torch.ops.aten + def ceil_div(a, b): return (a + b - 1) // b @@ -117,6 +120,22 @@ def hp_data_dims_to_swizzled_scale_dims_nvfp4( return scale_M, scale_K +def hp_data_dims_to_swizzled_scale_dims_mx( + hp_data_M, + hp_data_K, +) -> Tuple[int, int]: + """ + Given the `M` and `K` dimensions of a high precision contiguous tensor, + returns a 2d tuple of the dims of the swizzled mx scale corresponding to + that tensor. + """ + # a 128x128 unpacked or 128x64 packed qdata tile corresponds + # to a swizzled 32x16 scale tile + scale_M = ceil_div(hp_data_M, 128) * 32 + scale_K = ceil_div(hp_data_K, 128) * 16 + return scale_M, scale_K + + def _to_blocked_single(scales: Tensor) -> Tensor: """Assume that we have a 128x4 block of scales in K Major order @@ -158,6 +177,7 @@ def _to_mxfp8_dim1_kernel_wrapper( else: raise ValueError(f"must be one of [CUDA, TRITON], got {cast_kernel_choice}") + is_swizzled_scales = False if isinstance(a_data, DTensor): assert isinstance(a_scale, DTensor) a_data_local = a_data.to_local() @@ -171,6 +191,7 @@ def _to_mxfp8_dim1_kernel_wrapper( gemm_kernel_choice, False, None, + is_swizzled_scales, ) mx_tensor = DTensor.from_local( inner, @@ -190,5 +211,210 @@ def _to_mxfp8_dim1_kernel_wrapper( gemm_kernel_choice, False, None, + is_swizzled_scales, ) return mx_tensor + + +def _swizzle_aware_slice( + x: torch.Tensor, dim, start, end, step +) -> Tuple[torch.Tensor, torch.Tensor]: + """ + Input: NVFP4Tensor or MXTensor + Output: sliced qdata and scale, does the right thing for unswizzled and swizzled scales + """ + + M, K = x.shape[0], x.shape[1] + + # The scale manipulations below assume a flattened scale. For now, we + # flatten the scale, go through the calculations below, and then reshape + # it back to the format which matches the shape of `qdata`. + # TODO(future PR): update this + + if x._is_swizzled_scales: + scale_rows = M + scale_cols = K // x._block_size + n_row_blocks = ceil_div(scale_rows, 128) + n_col_blocks = ceil_div(scale_cols, 4) + elements_per_block = 32 * 16 # 512 elements + + if dim == 0: + # Row slicing + # Handle sys.maxsize (default slice end) + if end == sys.maxsize: + end = M + + # Check if start/end align with 128-row boundaries + if start is not None and start % 128 != 0: + raise RuntimeError( + f"Row slicing of NVFP4Tensor with swizzled scales requires " + f"start index to be a multiple of 128, got {start}" + ) + if end is not None and end != M and end % 128 != 0: + raise RuntimeError( + f"Row slicing of NVFP4Tensor with swizzled scales requires " + f"end index to be a multiple of 128 or equal to tensor size {M}, got {end}" + ) + + # Calculate which row blocks to keep + start_block = 0 if start is None else start // 128 + end_block = n_row_blocks if end is None or end >= M else end // 128 + + # The swizzled tensor has shape (n_row_blocks * n_col_blocks * 32 * 16,) + blocks_per_row = n_col_blocks + start_idx = start_block * blocks_per_row * elements_per_block + end_idx = ( + end_block * blocks_per_row * elements_per_block + if end_block < n_row_blocks + else None + ) + + sliced_scale = aten.slice.Tensor( + x.scale.flatten(), 0, start_idx, end_idx, 1 + ) + sliced_data = aten.slice.Tensor(x.qdata, 0, start, end, step) + + elif dim == 1: + # Column slicing + # Handle sys.maxsize (default slice end) + if end == sys.maxsize: + end = K + + # Check if start/end align with 64-column boundaries (4 scale columns * 16 block_size) + if start is not None and start % 64 != 0: + raise RuntimeError( + f"Column slicing of NVFP4Tensor with swizzled scales requires " + f"start index to be a multiple of 64, got {start}" + ) + if end is not None and end != K and end % 64 != 0: + raise RuntimeError( + f"Column slicing of NVFP4Tensor with swizzled scales requires " + f"end index to be a multiple of 64 or equal to tensor size {K}, got {end}" + ) + + # TODO(future PR): use torch.float4_e2m1fn_x2 for nvfp4 and mxfp4 + if x.qdata.dtype != torch.float8_e4m3fn: + # Also check FP4 packing alignment + if start is not None and start % 2 != 0: + raise RuntimeError( + f"Start index {start} must be even for FP4 packing" + ) + if end is not None and end != K and end % 2 != 0: + raise RuntimeError(f"End index {end} must be even for FP4 packing") + + # Calculate which column blocks to keep + start_scale_col = 0 if start is None else start // 16 + end_scale_col = scale_cols if end is None or end >= K else end // 16 + + start_col_block = start_scale_col // 4 + end_col_block = end_scale_col // 4 + + # Verify the end aligns with block boundary + if end_scale_col % 4 != 0: + raise RuntimeError( + f"Column slicing end index {end} does not align with scale block boundaries. " + f"End must result in a multiple of 4 scale columns (64 data columns)." + ) + + if start_col_block == 0 and end_col_block == n_col_blocks: + # Full width - no slicing needed + sliced_scale = x.scale + else: + # Extract specific column blocks from each row block + # Each row block in swizzled format contains n_col_blocks chunks of (32, 16) + elements_per_row_block = n_col_blocks * elements_per_block + + # Build list of slices to extract + slices_to_extract = [] + for row_block in range(n_row_blocks): + row_start = row_block * elements_per_row_block + col_start = row_start + start_col_block * elements_per_block + col_end = row_start + end_col_block * elements_per_block + slices_to_extract.append(x.scale.flatten()[col_start:col_end]) + + # Concatenate all the slices + sliced_scale = torch.cat(slices_to_extract, dim=0) + + # Slice the data tensor + if x.qdata.dtype != torch.float8_e4m3fn: + packed_start = None if start is None else start // 2 + packed_end = None if end is None else end // 2 + else: + packed_start = start + packed_end = end + sliced_data = aten.slice.Tensor( + x.qdata, dim, packed_start, packed_end, step + ) + + else: + raise ValueError( + f"NVFP4Tensor only supports slicing along dimensions 0 and 1, got dim={dim}" + ) + + else: + scale_shaped = x.scale.view(M, K // x._block_size) + + if dim == 0: + sliced_scale = aten.slice.Tensor(scale_shaped, dim, start, end, step) + sliced_data = aten.slice.Tensor(x.qdata, dim, start, end, step) + + elif dim == 1: + if start is not None: + assert start % x._block_size == 0, ( + f"Start index {start} must be a multiple of block_size {x._block_size}" + ) + assert start % 2 == 0, ( + f"Start index {start} must be even for FP4 packing" + ) + + if end is not None and end != sys.maxsize: + assert end % x._block_size == 0, ( + f"End index {end} must be a multiple of block_size {x._block_size}" + ) + assert end % 2 == 0, f"End index {end} must be even for FP4 packing" + + if x.qdata.dtype != torch.float8_e4m3fn: + packed_start = None if start is None else start // 2 + packed_end = None if end is None else end // 2 + else: + packed_start = start + packed_end = end + sliced_data = aten.slice.Tensor( + x.qdata, dim, packed_start, packed_end, step + ) + + start_block = 0 if start is None else start // x._block_size + end_block = None if end is None else end // x._block_size + sliced_scale = aten.slice.Tensor( + scale_shaped, 1, start_block, end_block, step + ) + + sliced_scale = sliced_scale.flatten() + + # reshape at the end + sliced_M = sliced_data.shape[0] + if x.qdata.dtype == torch.float8_e4m3fn: + sliced_K = sliced_data.shape[1] + else: + # multiply by 2 to convert from bytes to num_elements + sliced_K = sliced_data.shape[1] * 2 + if x._is_swizzled_scales: + if x._block_size == 16: + scale_M, scale_K = hp_data_dims_to_swizzled_scale_dims_nvfp4( + sliced_M, sliced_K + ) + else: + assert x._block_size == 32, f"unexpected {x._block_size=}" + scale_M, scale_K = hp_data_dims_to_swizzled_scale_dims_mx( + sliced_M, sliced_K + ) + else: + # nvfp4: a 1x16 unpacked or 1x8 packed qdata tile corresponds to 1 + # scale element + # mx: a 1x32 unpacked or 1x16 packed qdata tile corresponds to 1 + # scale element + scale_M = sliced_M + scale_K = sliced_K // x._block_size + sliced_scale = sliced_scale.view(scale_M, scale_K) + + return sliced_data, sliced_scale From cc61907c92aa9f53deecd9d0f50d20e303f3d7d1 Mon Sep 17 00:00:00 2001 From: vasiliy Date: Fri, 17 Oct 2025 10:04:38 -0700 Subject: [PATCH 7/7] Update [ghstack-poisoned] --- benchmarks/mx_formats/cast_bench.py | 6 +++--- torchao/prototype/mx_formats/kernels.py | 26 +++++++++++++++++-------- 2 files changed, 21 insertions(+), 11 deletions(-) diff --git a/benchmarks/mx_formats/cast_bench.py b/benchmarks/mx_formats/cast_bench.py index f4f635af1b..09a1fd0b1e 100644 --- a/benchmarks/mx_formats/cast_bench.py +++ b/benchmarks/mx_formats/cast_bench.py @@ -257,12 +257,12 @@ def run( elif mode == "dim0_nvfp4": to_nvfp4_reference_c = torch.compile(to_nvfp4_reference) - y_d0, s_d0 = to_nvfp4_reference_c(x, use_triton_kernel=False) + y_d0, s_d0 = to_nvfp4_reference_c(x) for _ in range(2): - __ = to_nvfp4_reference_c(x, use_triton_kernel=False) + __ = to_nvfp4_reference_c(x) time_us = benchmark_cuda_function_in_microseconds( - lambda x: to_nvfp4_reference_c(x, use_triton_kernel=False), + lambda x: to_nvfp4_reference_c(x), x, ) assert y_d0.dtype == torch.uint8 diff --git a/torchao/prototype/mx_formats/kernels.py b/torchao/prototype/mx_formats/kernels.py index 173d99f746..b733d7fee6 100644 --- a/torchao/prototype/mx_formats/kernels.py +++ b/torchao/prototype/mx_formats/kernels.py @@ -1441,6 +1441,8 @@ def quantize_nvfp4_triton_kernel( N, USE_TENSOR_SCALE: tl.constexpr, MASK_SCALES: tl.constexpr, + ROW_TILE_SIZE: tl.constexpr, + COL_TILE_SIZE: tl.constexpr, ): F4_E2M1_MAX = 6.0 F8E4M3_MAX = 448.0 @@ -1449,8 +1451,8 @@ def quantize_nvfp4_triton_kernel( pid_m = tl.program_id(1) pid_n = tl.program_id(0) - offs_m = pid_m * 128 + tl.arange(0, 128)[:, None] - offs_n = pid_n * 64 + tl.arange(0, 64)[None, :] + offs_m = pid_m * ROW_TILE_SIZE + tl.arange(0, ROW_TILE_SIZE)[:, None] + offs_n = pid_n * COL_TILE_SIZE + tl.arange(0, COL_TILE_SIZE)[None, :] if MASK_SCALES: mask = (offs_m < M) & (offs_n < N) other = 0.0 @@ -1460,10 +1462,10 @@ def quantize_nvfp4_triton_kernel( x = tl.load( x_ptr + offs_m * stride_xm + offs_n * stride_xn, mask=mask, other=other ) # [128, 64] - x_blocks = x.to(tl.float32).reshape(128, 4, 16) # [128, 4, 16] + x_blocks = x.to(tl.float32).reshape(ROW_TILE_SIZE, 4, 16) # [-1, 4, 16] # Compute block-wise scales - block_amax = tl.max(x_blocks.abs(), axis=2) # [128, 4] + block_amax = tl.max(x_blocks.abs(), axis=2) # [-1, 4] if USE_TENSOR_SCALE: # Two-level scaling: quantize block scales with per-tensor scale @@ -1513,9 +1515,13 @@ def quantize_nvfp4_triton_kernel( ) # Convert to FP4 - x_fp4x2 = convert_fp32_to_fp4_packed(x_blocks.reshape(128, 32, 2).split()) - offs_m = pid_m * 128 + tl.arange(0, 128)[:, None] - offs_n = pid_n * 32 + tl.arange(0, 32)[None, :] + x_fp4x2 = convert_fp32_to_fp4_packed( + x_blocks.reshape(ROW_TILE_SIZE, 32, 2).split() + ) + offs_m = pid_m * ROW_TILE_SIZE + tl.arange(0, ROW_TILE_SIZE)[:, None] + offs_n = ( + pid_n * (COL_TILE_SIZE // 2) + tl.arange(0, COL_TILE_SIZE // 2)[None, :] + ) if MASK_SCALES: mask = (offs_m < M) & (offs_n < N // 2) else: @@ -1537,7 +1543,7 @@ def triton_quantize_nvfp4( Tuple[torch.Tensor, torch.Tensor]: Quantized tensor and scales tensor in swizzled layout. Note: - Since VLLM does not use dyanmo guards we need to make this a custom op + Since VLLM does not use dynamo guards we need to make this a custom op to avoid the triton kernel being invoked w/ the wrong use of `MASK_SCALES` """ # reshape to 2d @@ -1571,6 +1577,8 @@ def triton_quantize_nvfp4( tensor_scale_ptr = per_tensor_scale use_tensor_scale = True + ROW_TILE_SIZE = 128 + COL_TILE_SIZE = 64 quantize_nvfp4_triton_kernel[grid]( x, tensor_scale_ptr, @@ -1582,6 +1590,8 @@ def triton_quantize_nvfp4( N, USE_TENSOR_SCALE=use_tensor_scale, MASK_SCALES=MASK_SCALES, + ROW_TILE_SIZE=ROW_TILE_SIZE, + COL_TILE_SIZE=COL_TILE_SIZE, ) # reshape back to original shape