Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
54 changes: 54 additions & 0 deletions test/prototype/mx_formats/test_mx_tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -662,3 +662,57 @@ def test_to_blocked_from_blocked_roundtrip(shape, use_triton_kernel: bool):
rtol=0.0,
msg=f"Roundtrip failed for shape {shape} with use_triton_kernel={use_triton_kernel}",
)


@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available")
@pytest.mark.skipif(not torch_version_at_least("2.8.0"), reason="requires PyTorch 2.8+")
@pytest.mark.parametrize("transpose", [False, True])
@pytest.mark.parametrize(
"shape",
(
(128, 64),
(1, 128, 64),
),
)
def test_scale_shape_matches_qdata(transpose, shape):
if len(shape) == 3 and transpose:
pytest.skip("transpose not yet implemented for 3D MXTensor")

block_size = 32

x_hp = torch.randn(*shape, device="cuda")
x = MXTensor.to_mx(
x_hp,
torch.float8_e4m3fn,
block_size,
ScaleCalculationMode.FLOOR,
)

if len(shape) == 2:
m_dim, k_dim = 0, 1
if transpose:
x_hp = x_hp.t()
x = x.t()
m_dim, k_dim = 1, 0
else:
assert len(shape) == 3, "unsupported"
m_dim, k_dim = 1, 2
if transpose:
x_hp = x_hp.transpose(-2, -1)
x = x.transpose(-2, -1)
m_dim, k_dim = 2, 1

orig_m = x_hp.shape[m_dim]
expected_padded_m = orig_m
actual_padded_m = x.scale.shape[m_dim]
assert expected_padded_m == actual_padded_m, (
f"incompatible padded shape for dim {m_dim}: {expected_padded_m=}, {actual_padded_m=}, {x.shape}, {x.scale.shape}"
)

orig_k = x_hp.shape[k_dim]
expected_padded_k = orig_k // block_size
actual_padded_k = x.scale.shape[k_dim]

assert expected_padded_k == actual_padded_k, (
f"incompatible padded shape for dim {k_dim}: {expected_padded_k}, {actual_padded_k=}, {x.shape}, {x.scale.shape}"
)
4 changes: 2 additions & 2 deletions torchao/prototype/mx_formats/kernels.py
Original file line number Diff line number Diff line change
Expand Up @@ -1245,7 +1245,7 @@ def triton_to_mxfp8_dim1(

return (
output_col_major.t(),
col_scale.view(torch.float8_e8m0fnu),
col_scale.view(torch.float8_e8m0fnu).squeeze(-1),
)

@register_sharding(torch.ops.torchao.triton_to_mxfp8_dim1.default)
Expand Down Expand Up @@ -1274,7 +1274,7 @@ def triton_to_mxfp8_dim1_reference(
scale_e8m0_dim1 = scale_e8m0_dim1.view(torch.float8_e8m0fnu)
return (
x_hp_d1_normalized.t(),
scale_e8m0_dim1.unsqueeze(-1),
scale_e8m0_dim1,
)

@triton.jit
Expand Down
5 changes: 3 additions & 2 deletions torchao/prototype/mx_formats/mx_tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -362,6 +362,7 @@ def to_dtype(
# unpacking and unscaling
if is_transposed:
data_lp = data_lp.t()
scale_e8m0 = scale_e8m0.t()
assert data_lp.is_contiguous()
orig_shape = (orig_shape[1], orig_shape[0])

Expand Down Expand Up @@ -688,7 +689,7 @@ def _addmm_mx_dispatch(
assert b._block_size == 32, f"Invalid block size {b._block_size}"

a_scale = a.scale.view(M, K // a._block_size)
b_scale = b.scale.view(N, K // b._block_size)
b_scale = b.scale.t().view(N, K // b._block_size)
a_scale_block = to_blocked(a_scale)
b_scale_block = to_blocked(b_scale)

Expand Down Expand Up @@ -759,7 +760,7 @@ def mx_t(func, types, args, kwargs):
old = args[0]
new = MXTensor(
old.qdata.t(),
old.scale,
old.scale.t(),
old._elem_dtype,
old._block_size,
old._orig_dtype,
Expand Down
Loading