Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
28 changes: 20 additions & 8 deletions test/prototype/mx_formats/test_inference_workflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,10 +67,14 @@ def cuda_kernel_profiler(kernel_pattern):
@pytest.mark.skipif(
not torch_version_at_least("2.8.0"), reason="torch.compile requires PyTorch 2.8+"
)
@pytest.mark.parametrize("elem_dtype", [torch.float8_e4m3fn, torch.float4_e2m1fn_x2])
@pytest.mark.parametrize("bias", [True, False])
@pytest.mark.parametrize("compile", [True, False])
@pytest.mark.parametrize("emulate", [True, False])
# @pytest.mark.parametrize("elem_dtype", [torch.float8_e4m3fn, torch.float4_e2m1fn_x2])
@pytest.mark.parametrize("elem_dtype", [torch.float8_e4m3fn])
# @pytest.mark.parametrize("bias", [True, False])
# @pytest.mark.parametrize("compile", [True, False])
# @pytest.mark.parametrize("emulate", [True, False])
@pytest.mark.parametrize("bias", [False])
@pytest.mark.parametrize("compile", [False])
@pytest.mark.parametrize("emulate", [False])
@torch.no_grad()
@skip_if_rocm(
"ROCm float4 gemm require gfx950"
Expand All @@ -93,7 +97,11 @@ def test_inference_workflow_mx(elem_dtype, bias: bool, compile: bool, emulate: b
# TODO(future PR): investigate and fix this
pytest.skip("mxfp4 + compile currently does not work, low SQNR")

m = nn.Linear(32, 128, bias=bias, dtype=torch.bfloat16, device="cuda")
# M, N, K = 16, 3072, 4096
# M, N, K = 1920, 3072, 256
M, N, K = 1920, 18432, 3072
# m = nn.Linear(32, 128, bias=bias, dtype=torch.bfloat16, device="cuda")
m = nn.Linear(K, N, bias=bias, dtype=torch.bfloat16, device="cuda")
m_mx = copy.deepcopy(m)

if emulate:
Expand All @@ -108,18 +116,22 @@ def test_inference_workflow_mx(elem_dtype, bias: bool, compile: bool, emulate: b
gemm_kernel_choice=kernel_choice,
)
quantize_(m_mx, config=config)
print("m_mx:", m_mx)

if compile:
m_mx = torch.compile(m_mx, fullgraph=True)

x = torch.randn(128, 32, device="cuda", dtype=torch.bfloat16)
y_ref = m(x)
y_mx = m_mx(x)
with torch.inference_mode():
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

is this what makes the test fail, without changes to the product code?

x = torch.randn(1, M, K, device="cuda", dtype=torch.bfloat16)
y_ref = m(x)
y_mx = m_mx(x)
sqnr = compute_error(y_ref, y_mx)
SQNR_THRESHOLD = 25.0 if elem_dtype == torch.float8_e4m3fn else 15.0
assert sqnr >= SQNR_THRESHOLD, (
f"Got a sqnr of {sqnr} for {elem_dtype} and bias={bias}"
)

raise Exception("stop")
# serialization
with tempfile.NamedTemporaryFile() as f:
torch.save(m_mx.state_dict(), f)
Expand Down
27 changes: 22 additions & 5 deletions torchao/prototype/mx_formats/mx_tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -665,7 +665,10 @@ def _addmm_mx_dispatch(
The only difference is whether bias is None or not.
"""

out_starting_shape = a.shape[:-1]

if not isinstance(a, MXTensor):
a = a.reshape(-1, a.shape[-1])
assert b.act_quant_kwargs is not None, "weight-only quant not yet supported"
k = b.act_quant_kwargs
a = MXTensor.to_mx(
Expand Down Expand Up @@ -698,6 +701,7 @@ def _addmm_mx_dispatch(
"CUBLAS is the only supported kernel choice for MX FP8 operations"
)

print(f"scaled_mm info: {a.qdata.shape=}, {b.qdata.shape=}, {a_scale_block.shape=}, {b_scale_block.shape=}, a.qdata contig: {a.qdata.is_contiguous()} b.qdata contig: {b.qdata.is_contiguous()} a_scale_block contig: {a_scale_block.is_contiguous()}, b_scale_block contig: {b_scale_block.is_contiguous()}")
res = torch._scaled_mm(
a.qdata,
b.qdata,
Expand All @@ -706,6 +710,7 @@ def _addmm_mx_dispatch(
bias=bias,
out_dtype=torch.bfloat16,
)
print(f"after scaled mm {a.qdata.shape=}, {b.qdata.shape=}")
else:
assert a._elem_dtype == torch.float4_e2m1fn_x2
assert b._elem_dtype == torch.float4_e2m1fn_x2
Expand All @@ -717,7 +722,6 @@ def _addmm_mx_dispatch(
# TODO add optional bias to kernel
if bias is not None:
res = res + bias

else:
# emulated MX gemm
a_hp = a.dequantize(a._orig_dtype)
Expand All @@ -726,12 +730,17 @@ def _addmm_mx_dispatch(
assert a_hp.is_contiguous()
assert b_hp.t().is_contiguous()

# Call appropriate aten_op based on whether bias is provided
if bias is not None:
res = aten_op(bias, a_hp, b_hp) # addmm
if aten_op == aten.linear.default:
res = aten_op(a_hp, b_hp.t(), bias)
else:
res = aten_op(a_hp, b_hp) # mm
# Call appropriate aten_op based on whether bias is provided
if bias is not None:
res = aten_op(bias, a_hp, b_hp) # addmm
else:
res = aten_op(a_hp, b_hp) # mm

res = res.reshape(*out_starting_shape, res.shape[-1])

return res


Expand All @@ -752,6 +761,14 @@ def mx_addmm(func, types, args, kwargs):
b = args[2]
return _addmm_mx_dispatch(a, b, func, bias=bias)

@implements([aten.linear.default])
def mx_linear(func, types, args, kwargs):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think the reshape should live here, the addmm op is supposed to have 2d inputs

assert isinstance(args[0], torch.Tensor) and isinstance(args[1], MXTensor)
a = args[0]
b = args[1]
bias = args[2] if len(args) > 2 else None
return _addmm_mx_dispatch(a, b.t(), func, bias=bias)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

IMO either don't pass func into addmm_mx_dispatch, or we should rename it to something like addmm_or_linear_mx_dispatch



@implements([aten.t.default])
def mx_t(func, types, args, kwargs):
Expand Down
Loading