Skip to content

Commit 1b074b5

Browse files
[TEST] Enable test_blocked_scale_mxfp and test_blocked_scale_mxfp (#3682)
Related to #3307 --------- Co-authored-by: Whitney Tsang <[email protected]>
1 parent 95eb619 commit 1b074b5

File tree

1 file changed

+9
-6
lines changed

1 file changed

+9
-6
lines changed

python/test/unit/language/test_matmul.py

Lines changed: 9 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -468,8 +468,12 @@ def block_scale_mxfp_matmul( #
468468
reason="Requires compute capability >= 10")
469469
def test_blocked_scale_mxfp(M, N, K, BLOCK_M, BLOCK_N, BLOCK_K, NUM_STAGES, USE_2D_SCALE_LOAD, device, monkeypatch):
470470
if is_xpu():
471-
pytest.skip("FIXME: Fail RuntimeError on XPU")
472-
471+
if not torch.xpu.get_device_capability()["has_subgroup_matrix_multiply_accumulate"]:
472+
pytest.skip("The device does not support MMA")
473+
elif (BLOCK_M, BLOCK_N, BLOCK_K) == (128, 256, 256) and \
474+
triton.runtime.driver.active.utils.get_device_properties(
475+
triton.runtime.driver.active.get_current_device())["max_shared_mem"] < 196608:
476+
pytest.xfail("Not enough shared memory")
473477
if BLOCK_N == 256 and BLOCK_K == 256:
474478
NUM_STAGES = min(NUM_STAGES, 2)
475479
elif BLOCK_K == 256:
@@ -494,7 +498,6 @@ def test_blocked_scale_mxfp(M, N, K, BLOCK_M, BLOCK_N, BLOCK_K, NUM_STAGES, USE_
494498
b.stride(1), output.stride(0), output.stride(1), BLOCK_M, BLOCK_N, BLOCK_K,
495499
NUM_STAGES=NUM_STAGES, USE_2D_SCALE_LOAD=USE_2D_SCALE_LOAD)
496500
ttgir = out.asm["ttgir"]
497-
ptx = out.asm["ptx"]
498501

499502
def flatten_scale(scale):
500503
num_chunk_m, num_chunk_k, _, _, _ = scale.shape
@@ -519,7 +522,7 @@ def flatten_scale(scale):
519522
if USE_2D_SCALE_LOAD:
520523
# Due to an issue in the coalescing pass, tmem_copy can not be generated for the 5D load.
521524
# The issue is fixed using the patch from https://github.com/triton-lang/triton/pull/4914
522-
assert "tcgen05.cp" in ptx
525+
assert is_xpu() or "tcgen05.cp" in out.asm["ptx"]
523526
if NUM_STAGES > 1:
524527
if BLOCK_M == BLOCK_K and BLOCK_N == BLOCK_K:
525528
load_pipelined = ttgir.count(f"ttg.local_alloc : () -> !ttg.memdesc<{NUM_STAGES}x{BLOCK_M}x{BLOCK_K}") == 2
@@ -613,8 +616,8 @@ def lhs_in_tmem_kernel_mxfp( #
613616
@pytest.mark.skipif(is_cuda() and torch.cuda.get_device_capability()[0] < 10,
614617
reason="Requires compute capability >= 10")
615618
def test_lhs_in_tmem_mxfp(device, monkeypatch):
616-
if is_xpu():
617-
pytest.skip("FIXME: failed to legalize operation 'tt.dot_scaled' on XPU")
619+
if is_xpu() and not torch.xpu.get_device_capability()["has_subgroup_matrix_multiply_accumulate"]:
620+
pytest.skip("The device does not support MMA")
618621
_knob_promote_lhs_to_tmem(monkeypatch)
619622
M, N, K = 128, 64, 32
620623
torch.manual_seed(42)

0 commit comments

Comments
 (0)