@@ -468,8 +468,12 @@ def block_scale_mxfp_matmul( #
468468 reason = "Requires compute capability >= 10" )
469469def test_blocked_scale_mxfp (M , N , K , BLOCK_M , BLOCK_N , BLOCK_K , NUM_STAGES , USE_2D_SCALE_LOAD , device , monkeypatch ):
470470 if is_xpu ():
471- pytest .skip ("FIXME: Fail RuntimeError on XPU" )
472-
471+ if not torch .xpu .get_device_capability ()["has_subgroup_matrix_multiply_accumulate" ]:
472+ pytest .skip ("The device does not support MMA" )
473+ elif (BLOCK_M , BLOCK_N , BLOCK_K ) == (128 , 256 , 256 ) and \
474+ triton .runtime .driver .active .utils .get_device_properties (
475+ triton .runtime .driver .active .get_current_device ())["max_shared_mem" ] < 196608 :
476+ pytest .xfail ("Not enough shared memory" )
473477 if BLOCK_N == 256 and BLOCK_K == 256 :
474478 NUM_STAGES = min (NUM_STAGES , 2 )
475479 elif BLOCK_K == 256 :
@@ -494,7 +498,6 @@ def test_blocked_scale_mxfp(M, N, K, BLOCK_M, BLOCK_N, BLOCK_K, NUM_STAGES, USE_
494498 b .stride (1 ), output .stride (0 ), output .stride (1 ), BLOCK_M , BLOCK_N , BLOCK_K ,
495499 NUM_STAGES = NUM_STAGES , USE_2D_SCALE_LOAD = USE_2D_SCALE_LOAD )
496500 ttgir = out .asm ["ttgir" ]
497- ptx = out .asm ["ptx" ]
498501
499502 def flatten_scale (scale ):
500503 num_chunk_m , num_chunk_k , _ , _ , _ = scale .shape
@@ -519,7 +522,7 @@ def flatten_scale(scale):
519522 if USE_2D_SCALE_LOAD :
520523 # Due to an issue in the coalescing pass, tmem_copy can not be generated for the 5D load.
521524 # The issue is fixed using the patch from https://github.com/triton-lang/triton/pull/4914
522- assert "tcgen05.cp" in ptx
525+ assert is_xpu () or "tcgen05.cp" in out . asm [ " ptx" ]
523526 if NUM_STAGES > 1 :
524527 if BLOCK_M == BLOCK_K and BLOCK_N == BLOCK_K :
525528 load_pipelined = ttgir .count (f"ttg.local_alloc : () -> !ttg.memdesc<{ NUM_STAGES } x{ BLOCK_M } x{ BLOCK_K } " ) == 2
@@ -613,8 +616,8 @@ def lhs_in_tmem_kernel_mxfp( #
613616@pytest .mark .skipif (is_cuda () and torch .cuda .get_device_capability ()[0 ] < 10 ,
614617 reason = "Requires compute capability >= 10" )
615618def test_lhs_in_tmem_mxfp (device , monkeypatch ):
616- if is_xpu ():
617- pytest .skip ("FIXME: failed to legalize operation 'tt.dot_scaled' on XPU " )
619+ if is_xpu () and not torch . xpu . get_device_capability ()[ "has_subgroup_matrix_multiply_accumulate" ] :
620+ pytest .skip ("The device does not support MMA " )
618621 _knob_promote_lhs_to_tmem (monkeypatch )
619622 M , N , K = 128 , 64 , 32
620623 torch .manual_seed (42 )
0 commit comments