Skip to content

Commit ffaf572

Browse files
skip rocm for moe training tests (#2646)
1 parent 6e941c8 commit ffaf572

File tree

1 file changed

+4
-1
lines changed

1 file changed

+4
-1
lines changed

test/prototype/moe_training/test_scaled_grouped_mm.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -37,7 +37,7 @@
3737
from torchao.testing.utils import skip_if_rocm
3838

3939

40-
@skip_if_rocm("ROCm enablement in progress")
40+
@skip_if_rocm("ROCm not supported")
4141
def test_valid_scaled_grouped_mm_2d_3d():
4242
out_dtype = torch.bfloat16
4343
device = "cuda"
@@ -91,6 +91,7 @@ def test_valid_scaled_grouped_mm_2d_3d():
9191
assert torch.equal(b_t.grad, ref_b_t.grad)
9292

9393

94+
@skip_if_rocm("ROCm not supported")
9495
@pytest.mark.parametrize("m", [16, 17])
9596
@pytest.mark.parametrize("k", [16, 18])
9697
@pytest.mark.parametrize("n", [32, 33])
@@ -219,6 +220,7 @@ def compute_reference_forward(
219220
return output_ref
220221

221222

223+
@skip_if_rocm("ROCm not supported")
222224
@pytest.mark.parametrize("M,K,N", [(1024, 1024, 1024), (1024, 2048, 4096)])
223225
@pytest.mark.parametrize("num_experts", (1, 8, 16))
224226
def test_emulate_mxfp8_grouped_gemm(M, K, N, num_experts):
@@ -249,6 +251,7 @@ def test_emulate_mxfp8_grouped_gemm(M, K, N, num_experts):
249251
assert sqnr >= min_sqnr, f"sqnr {sqnr} is too low, must be >= {min_sqnr}"
250252

251253

254+
@skip_if_rocm("ROCm not supported")
252255
@pytest.mark.parametrize("M,K,N", [(1024, 1024, 1024), (1024, 2048, 4096)])
253256
@pytest.mark.parametrize("num_experts", (1, 8, 16))
254257
def test_mxfp8_grouped_gemm_with_dq_fwd(M, K, N, num_experts):

0 commit comments

Comments
 (0)