|
37 | 37 | from torchao.testing.utils import skip_if_rocm
|
38 | 38 |
|
39 | 39 |
|
40 |
| -@skip_if_rocm("ROCm enablement in progress") |
| 40 | +@skip_if_rocm("ROCm not supported") |
41 | 41 | def test_valid_scaled_grouped_mm_2d_3d():
|
42 | 42 | out_dtype = torch.bfloat16
|
43 | 43 | device = "cuda"
|
@@ -91,6 +91,7 @@ def test_valid_scaled_grouped_mm_2d_3d():
|
91 | 91 | assert torch.equal(b_t.grad, ref_b_t.grad)
|
92 | 92 |
|
93 | 93 |
|
| 94 | +@skip_if_rocm("ROCm not supported") |
94 | 95 | @pytest.mark.parametrize("m", [16, 17])
|
95 | 96 | @pytest.mark.parametrize("k", [16, 18])
|
96 | 97 | @pytest.mark.parametrize("n", [32, 33])
|
@@ -219,6 +220,7 @@ def compute_reference_forward(
|
219 | 220 | return output_ref
|
220 | 221 |
|
221 | 222 |
|
| 223 | +@skip_if_rocm("ROCm not supported") |
222 | 224 | @pytest.mark.parametrize("M,K,N", [(1024, 1024, 1024), (1024, 2048, 4096)])
|
223 | 225 | @pytest.mark.parametrize("num_experts", (1, 8, 16))
|
224 | 226 | def test_emulate_mxfp8_grouped_gemm(M, K, N, num_experts):
|
@@ -249,6 +251,7 @@ def test_emulate_mxfp8_grouped_gemm(M, K, N, num_experts):
|
249 | 251 | assert sqnr >= min_sqnr, f"sqnr {sqnr} is too low, must be >= {min_sqnr}"
|
250 | 252 |
|
251 | 253 |
|
| 254 | +@skip_if_rocm("ROCm not supported") |
252 | 255 | @pytest.mark.parametrize("M,K,N", [(1024, 1024, 1024), (1024, 2048, 4096)])
|
253 | 256 | @pytest.mark.parametrize("num_experts", (1, 8, 16))
|
254 | 257 | def test_mxfp8_grouped_gemm_with_dq_fwd(M, K, N, num_experts):
|
|
0 commit comments