Skip to content

Commit a8e5788

Browse files
authored
[TEST] Check ttng.tc_gen5_mma only for CUDA (triton-lang#6543)
This fixes test failures on AMD gfx950 architecture.
1 parent bceed1f commit a8e5788

File tree

1 file changed

+4
-3
lines changed

1 file changed

+4
-3
lines changed

python/test/unit/language/test_matmul.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -380,9 +380,10 @@ def test_mxfp(M, N, K, BLOCK_M, BLOCK_N, BLOCK_K, NUM_STAGES, nonKDim, NUM_WARPS
380380
rtol = 0.0001
381381
torch.testing.assert_close(ref_out, output, atol=atol, rtol=rtol)
382382

383-
# Pipelining of dot_scaled requires tmem_copy to be used, which in turn
384-
# requires the scales to be in the blocked layout in global memory.
385-
assert out.asm["ttgir"].count("ttng.tc_gen5_mma") == 1
383+
if is_cuda():
384+
# Pipelining of dot_scaled requires tmem_copy to be used, which in turn
385+
# requires the scales to be in the blocked layout in global memory.
386+
assert out.asm["ttgir"].count("ttng.tc_gen5_mma") == 1
386387

387388

388389
def _knob_promote_lhs_to_tmem(monkeypatch):

0 commit comments

Comments
 (0)