diff --git a/test/prototype/mx_formats/test_mx_tensor.py b/test/prototype/mx_formats/test_mx_tensor.py index 577112b16a..1a3631cc53 100644 --- a/test/prototype/mx_formats/test_mx_tensor.py +++ b/test/prototype/mx_formats/test_mx_tensor.py @@ -662,3 +662,57 @@ def test_to_blocked_from_blocked_roundtrip(shape, use_triton_kernel: bool): rtol=0.0, msg=f"Roundtrip failed for shape {shape} with use_triton_kernel={use_triton_kernel}", ) + + +@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available") +@pytest.mark.skipif(not torch_version_at_least("2.8.0"), reason="requires PyTorch 2.8+") +@pytest.mark.parametrize("transpose", [False, True]) +@pytest.mark.parametrize( + "shape", + ( + (128, 64), + (1, 128, 64), + ), +) +def test_scale_shape_matches_qdata(transpose, shape): + if len(shape) == 3 and transpose: + pytest.skip("transpose not yet implemented for 3D MXTensor") + + block_size = 32 + + x_hp = torch.randn(*shape, device="cuda") + x = MXTensor.to_mx( + x_hp, + torch.float8_e4m3fn, + block_size, + ScaleCalculationMode.FLOOR, + ) + + if len(shape) == 2: + m_dim, k_dim = 0, 1 + if transpose: + x_hp = x_hp.t() + x = x.t() + m_dim, k_dim = 1, 0 + else: + assert len(shape) == 3, "unsupported" + m_dim, k_dim = 1, 2 + if transpose: + x_hp = x_hp.transpose(-2, -1) + x = x.transpose(-2, -1) + m_dim, k_dim = 2, 1 + + orig_m = x_hp.shape[m_dim] + expected_padded_m = orig_m + actual_padded_m = x.scale.shape[m_dim] + assert expected_padded_m == actual_padded_m, ( + f"incompatible padded shape for dim {m_dim}: {expected_padded_m=}, {actual_padded_m=}, {x.shape}, {x.scale.shape}" + ) + + orig_k = x_hp.shape[k_dim] + expected_padded_k = orig_k // block_size + actual_padded_k = x.scale.shape[k_dim] + + assert expected_padded_k == actual_padded_k, ( + f"incompatible padded shape for dim {k_dim}: {expected_padded_k}, {actual_padded_k=}, {x.shape}, {x.scale.shape}" + ) diff --git a/torchao/prototype/mx_formats/kernels.py b/torchao/prototype/mx_formats/kernels.py index 7053225521..173d99f746 100644 --- a/torchao/prototype/mx_formats/kernels.py +++ b/torchao/prototype/mx_formats/kernels.py @@ -1245,7 +1245,7 @@ def triton_to_mxfp8_dim1( return ( output_col_major.t(), - col_scale.view(torch.float8_e8m0fnu), + col_scale.view(torch.float8_e8m0fnu).squeeze(-1), ) @register_sharding(torch.ops.torchao.triton_to_mxfp8_dim1.default) @@ -1274,7 +1274,7 @@ def triton_to_mxfp8_dim1_reference( scale_e8m0_dim1 = scale_e8m0_dim1.view(torch.float8_e8m0fnu) return ( x_hp_d1_normalized.t(), - scale_e8m0_dim1.unsqueeze(-1), + scale_e8m0_dim1, ) @triton.jit diff --git a/torchao/prototype/mx_formats/mx_tensor.py b/torchao/prototype/mx_formats/mx_tensor.py index 05c8fdc8e4..a5e50b2468 100644 --- a/torchao/prototype/mx_formats/mx_tensor.py +++ b/torchao/prototype/mx_formats/mx_tensor.py @@ -362,6 +362,7 @@ def to_dtype( # unpacking and unscaling if is_transposed: data_lp = data_lp.t() + scale_e8m0 = scale_e8m0.t() assert data_lp.is_contiguous() orig_shape = (orig_shape[1], orig_shape[0]) @@ -688,7 +689,7 @@ def _addmm_mx_dispatch( assert b._block_size == 32, f"Invalid block size {b._block_size}" a_scale = a.scale.view(M, K // a._block_size) - b_scale = b.scale.view(N, K // b._block_size) + b_scale = b.scale.t().view(N, K // b._block_size) a_scale_block = to_blocked(a_scale) b_scale_block = to_blocked(b_scale) @@ -759,7 +760,7 @@ def mx_t(func, types, args, kwargs): old = args[0] new = MXTensor( old.qdata.t(), - old.scale, + old.scale.t(), old._elem_dtype, old._block_size, old._orig_dtype,