diff --git a/test/prototype/mx_formats/test_inference_workflow.py b/test/prototype/mx_formats/test_inference_workflow.py index 556a3f8aff..6aea6e69a5 100644 --- a/test/prototype/mx_formats/test_inference_workflow.py +++ b/test/prototype/mx_formats/test_inference_workflow.py @@ -67,10 +67,14 @@ def cuda_kernel_profiler(kernel_pattern): @pytest.mark.skipif( not torch_version_at_least("2.8.0"), reason="torch.compile requires PyTorch 2.8+" ) -@pytest.mark.parametrize("elem_dtype", [torch.float8_e4m3fn, torch.float4_e2m1fn_x2]) -@pytest.mark.parametrize("bias", [True, False]) -@pytest.mark.parametrize("compile", [True, False]) -@pytest.mark.parametrize("emulate", [True, False]) +# @pytest.mark.parametrize("elem_dtype", [torch.float8_e4m3fn, torch.float4_e2m1fn_x2]) +@pytest.mark.parametrize("elem_dtype", [torch.float8_e4m3fn]) +# @pytest.mark.parametrize("bias", [True, False]) +# @pytest.mark.parametrize("compile", [True, False]) +# @pytest.mark.parametrize("emulate", [True, False]) +@pytest.mark.parametrize("bias", [False]) +@pytest.mark.parametrize("compile", [False]) +@pytest.mark.parametrize("emulate", [False]) @torch.no_grad() @skip_if_rocm( "ROCm float4 gemm require gfx950" @@ -93,7 +97,11 @@ def test_inference_workflow_mx(elem_dtype, bias: bool, compile: bool, emulate: b # TODO(future PR): investigate and fix this pytest.skip("mxfp4 + compile currently does not work, low SQNR") - m = nn.Linear(32, 128, bias=bias, dtype=torch.bfloat16, device="cuda") + # M, N, K = 16, 3072, 4096 + # M, N, K = 1920, 3072, 256 + M, N, K = 1920, 18432, 3072 + # m = nn.Linear(32, 128, bias=bias, dtype=torch.bfloat16, device="cuda") + m = nn.Linear(K, N, bias=bias, dtype=torch.bfloat16, device="cuda") m_mx = copy.deepcopy(m) if emulate: @@ -108,18 +116,22 @@ def test_inference_workflow_mx(elem_dtype, bias: bool, compile: bool, emulate: b gemm_kernel_choice=kernel_choice, ) quantize_(m_mx, config=config) + print("m_mx:", m_mx) + if compile: m_mx = torch.compile(m_mx, fullgraph=True) - x = torch.randn(128, 32, device="cuda", dtype=torch.bfloat16) - y_ref = m(x) - y_mx = m_mx(x) + with torch.inference_mode(): + x = torch.randn(1, M, K, device="cuda", dtype=torch.bfloat16) + y_ref = m(x) + y_mx = m_mx(x) sqnr = compute_error(y_ref, y_mx) SQNR_THRESHOLD = 25.0 if elem_dtype == torch.float8_e4m3fn else 15.0 assert sqnr >= SQNR_THRESHOLD, ( f"Got a sqnr of {sqnr} for {elem_dtype} and bias={bias}" ) + raise Exception("stop") # serialization with tempfile.NamedTemporaryFile() as f: torch.save(m_mx.state_dict(), f) diff --git a/torchao/prototype/mx_formats/mx_tensor.py b/torchao/prototype/mx_formats/mx_tensor.py index 05c8fdc8e4..98dbe4efde 100644 --- a/torchao/prototype/mx_formats/mx_tensor.py +++ b/torchao/prototype/mx_formats/mx_tensor.py @@ -665,7 +665,10 @@ def _addmm_mx_dispatch( The only difference is whether bias is None or not. """ + out_starting_shape = a.shape[:-1] + if not isinstance(a, MXTensor): + a = a.reshape(-1, a.shape[-1]) assert b.act_quant_kwargs is not None, "weight-only quant not yet supported" k = b.act_quant_kwargs a = MXTensor.to_mx( @@ -698,6 +701,7 @@ def _addmm_mx_dispatch( "CUBLAS is the only supported kernel choice for MX FP8 operations" ) + print(f"scaled_mm info: {a.qdata.shape=}, {b.qdata.shape=}, {a_scale_block.shape=}, {b_scale_block.shape=}, a.qdata contig: {a.qdata.is_contiguous()} b.qdata contig: {b.qdata.is_contiguous()} a_scale_block contig: {a_scale_block.is_contiguous()}, b_scale_block contig: {b_scale_block.is_contiguous()}") res = torch._scaled_mm( a.qdata, b.qdata, @@ -706,6 +710,7 @@ def _addmm_mx_dispatch( bias=bias, out_dtype=torch.bfloat16, ) + print(f"after scaled mm {a.qdata.shape=}, {b.qdata.shape=}") else: assert a._elem_dtype == torch.float4_e2m1fn_x2 assert b._elem_dtype == torch.float4_e2m1fn_x2 @@ -717,7 +722,6 @@ def _addmm_mx_dispatch( # TODO add optional bias to kernel if bias is not None: res = res + bias - else: # emulated MX gemm a_hp = a.dequantize(a._orig_dtype) @@ -726,12 +730,17 @@ def _addmm_mx_dispatch( assert a_hp.is_contiguous() assert b_hp.t().is_contiguous() - # Call appropriate aten_op based on whether bias is provided - if bias is not None: - res = aten_op(bias, a_hp, b_hp) # addmm + if aten_op == aten.linear.default: + res = aten_op(a_hp, b_hp.t(), bias) else: - res = aten_op(a_hp, b_hp) # mm + # Call appropriate aten_op based on whether bias is provided + if bias is not None: + res = aten_op(bias, a_hp, b_hp) # addmm + else: + res = aten_op(a_hp, b_hp) # mm + res = res.reshape(*out_starting_shape, res.shape[-1]) + return res @@ -752,6 +761,14 @@ def mx_addmm(func, types, args, kwargs): b = args[2] return _addmm_mx_dispatch(a, b, func, bias=bias) +@implements([aten.linear.default]) +def mx_linear(func, types, args, kwargs): + assert isinstance(args[0], torch.Tensor) and isinstance(args[1], MXTensor) + a = args[0] + b = args[1] + bias = args[2] if len(args) > 2 else None + return _addmm_mx_dispatch(a, b.t(), func, bias=bias) + @implements([aten.t.default]) def mx_t(func, types, args, kwargs):