1414 is_hip_cdna4 ,
1515 is_hopper_or_newer ,
1616 is_hopper ,
17+ is_xpu ,
1718)
1819from triton .experimental import gluon
1920from triton .experimental .gluon import language as ttgl
@@ -55,8 +56,8 @@ def copy_kernel(Out, In, numel, XBLOCK: ttgl.constexpr, layout: ttgl.constexpr):
5556 ttgl .BlockedLayout (size_per_thread = [8 ], threads_per_warp = [THREADS_PER_WARP ], warps_per_cta = [8 ], order = [0 ]),
5657])
5758@pytest .mark .parametrize ("XBLOCK" , [128 , 256 , 512 , 1024 , 2048 ])
58- def test_copy_kernel (layout , XBLOCK ):
59- inp = torch .randn (XBLOCK * 4 - 7 , device = "cuda" )
59+ def test_copy_kernel (layout , XBLOCK , device ):
60+ inp = torch .randn (XBLOCK * 4 - 7 , device = device )
6061 out = torch .empty_like (inp )
6162
6263 copy_kernel [(4 , )](out , inp , inp .numel (), XBLOCK , layout , num_warps = layout .warps_per_cta [0 ])
@@ -73,7 +74,7 @@ def tma_kernel(desc):
7374 alloc ._keep_alive ()
7475
7576
76- @pytest .mark .skipif (not is_hopper_or_newer (), reason = "Requires Hopper" )
77+ @pytest .mark .xfail (not is_hopper_or_newer (), reason = "Requires Hopper" )
7778def test_tma ():
7879 out = torch .ones ((16 , 16 ), dtype = torch .float16 , device = "cuda" )
7980 layout = ttgl .NVMMASharedLayout (
@@ -112,9 +113,9 @@ def async_copy_mbarrier_kernel(out, inp, xnumel, XBLOCK: ttgl.constexpr, YBLOCK:
112113 ttgl .store (out + xindex * YBLOCK + yindex , val )
113114
114115
115- @pytest .mark .skipif (not is_ampere_or_newer (), reason = "Requires Ampere" )
116- def test_async_copy_mbarrier ():
117- tensor_opts = dict (dtype = torch .float , device = "cuda" )
116+ @pytest .mark .xfail (not is_ampere_or_newer (), reason = "Requires Ampere" )
117+ def test_async_copy_mbarrier (device ):
118+ tensor_opts = dict (dtype = torch .float , device = device )
118119 out = torch .empty ((32 , 32 ), ** tensor_opts )
119120 inp = torch .randn ((20 , 32 ), ** tensor_opts )
120121 async_copy_mbarrier_kernel [(1 , )](out , inp , inp .shape [0 ], XBLOCK = 32 , YBLOCK = 32 )
@@ -153,7 +154,7 @@ def warpgroup_mma_kernel(a, b, out, M: ttgl.constexpr, N: ttgl.constexpr, K: ttg
153154 ttgl .store (out + out_offs_m * N + out_offs_n , acc )
154155
155156
156- @pytest .mark .skipif (not is_hopper (), reason = "Requires Hopper" )
157+ @pytest .mark .xfail (not is_hopper (), reason = "Requires Hopper" )
157158@pytest .mark .parametrize ("ASYNC" , [True , False ])
158159def test_warpgroup_mma (ASYNC ):
159160 torch .manual_seed (0 )
@@ -168,7 +169,7 @@ def test_warpgroup_mma(ASYNC):
168169 torch .testing .assert_close (out , ref , atol = 1e-3 , rtol = 1e-1 )
169170
170171
171- @pytest .mark .skipif (not is_hip_cdna4 (), reason = "Requires CDNA4" )
172+ @pytest .mark .xfail (not is_hip_cdna4 (), reason = "Requires CDNA4" )
172173@pytest .mark .parametrize ("use_buffer_load" , [True , False ])
173174def test_amd_direct_load_to_shared (use_buffer_load ):
174175
@@ -204,7 +205,7 @@ def kernel(a_ptr, b_ptr, use_buffer_load: ttgl.constexpr):
204205 assert 'vmcnt(0)' in pgm .asm ['amdgcn' ]
205206
206207
207- @pytest .mark .skipif (not (is_hip_gfx11 () or is_hip_gfx12 ()), reason = "Requires RDNA3 or RDNA4" )
208+ @pytest .mark .xfail (not (is_hip_gfx11 () or is_hip_gfx12 ()), reason = "Requires RDNA3 or RDNA4" )
208209@pytest .mark .parametrize ("M, N, K" , [(64 , 64 , 64 )])
209210@pytest .mark .parametrize ("in_dtype" , ['float16' , 'bfloat16' ])
210211def test_amd_wmma (M , N , K , in_dtype ):
@@ -270,6 +271,8 @@ def kernel(a_ptr, b_ptr, c_ptr, #
270271@pytest .mark .parametrize ("num_warps" , [4 , 8 ])
271272@pytest .mark .parametrize ("cdna_version" , [3 , 4 ])
272273def test_amd_mfma (M , N , K , in_dtype , num_warps , cdna_version ):
274+ if is_xpu ():
275+ pytest .xfail ("XPU does not support AMD MFMA" )
273276
274277 @gluon .jit
275278 def kernel (a_ptr , b_ptr , c_ptr , stride_am , stride_ak , #
@@ -328,7 +331,7 @@ def kernel(a_ptr, b_ptr, c_ptr, stride_am, stride_ak, #
328331 torch .testing .assert_close (ref , triton_output )
329332
330333
331- @pytest .mark .skipif (not is_hip_cdna4 (), reason = "Requires CDNA4" )
334+ @pytest .mark .xfail (not is_hip_cdna4 (), reason = "Requires CDNA4" )
332335@pytest .mark .parametrize ("M, N, K, rhs_scale, mxfp_type, normal_type" , [(32 , 32 , 128 , rhs_scale , mxfp_type , normal_type )
333336 for rhs_scale in [True , False ]
334337 for mxfp_type in ["e2m1" ]
@@ -470,7 +473,7 @@ def make_finite(x, dtype):
470473 torch .testing .assert_close (z , z_ref , rtol = 1e-5 , atol = 1e-5 )
471474
472475
473- def test_math_fast_expf ():
476+ def test_math_fast_expf (device ):
474477
475478 @gluon .jit
476479 def fast_expf_kernel (x_ptr , y_ptr , warp_size : ttgl .constexpr , num_warps : ttgl .constexpr ):
@@ -484,13 +487,13 @@ def fast_expf_kernel(x_ptr, y_ptr, warp_size: ttgl.constexpr, num_warps: ttgl.co
484487 num_warps = 4
485488
486489 torch .manual_seed (0 )
487- x = torch .randn (THREADS_PER_WARP * num_warps , device = "cuda" , dtype = torch .float32 )
490+ x = torch .randn (THREADS_PER_WARP * num_warps , device = device , dtype = torch .float32 )
488491 y = torch .empty_like (x )
489492 fast_expf_kernel [(1 , )](x , y , THREADS_PER_WARP , num_warps )
490493 torch .testing .assert_close (y , torch .exp (x ), atol = 1e-5 , rtol = 1e-4 )
491494
492495
493- def test_math_fast_dividef ():
496+ def test_math_fast_dividef (device ):
494497
495498 @gluon .jit
496499 def fast_dividef_kernel (x_ptr , y_ptr , z_ptr , warp_size : ttgl .constexpr , num_warps : ttgl .constexpr ):
@@ -505,7 +508,7 @@ def fast_dividef_kernel(x_ptr, y_ptr, z_ptr, warp_size: ttgl.constexpr, num_warp
505508 num_warps = 4
506509
507510 torch .manual_seed (0 )
508- x = torch .randn (THREADS_PER_WARP * num_warps , device = "cuda" , dtype = torch .float32 )
511+ x = torch .randn (THREADS_PER_WARP * num_warps , device = device , dtype = torch .float32 )
509512 y = torch .randn_like (x )
510513 z = torch .empty_like (x )
511514 y [y == 0 ] = 1.0
@@ -514,7 +517,7 @@ def fast_dividef_kernel(x_ptr, y_ptr, z_ptr, warp_size: ttgl.constexpr, num_warp
514517
515518
516519@pytest .mark .xfail (reason = "copy to tmem with scale layout is currently broken in Gluon." )
517- @pytest .mark .skipif (not is_blackwell (), reason = "Requires Blackwell" )
520+ @pytest .mark .xfail (not is_blackwell (), reason = "Requires Blackwell" )
518521def test_tmem_copy_2d ():
519522 device = "cuda"
520523
@@ -563,7 +566,7 @@ def kernel(in_ptr, out_ptr, smem_h: ttgl.constexpr, smem_w: ttgl.constexpr, num_
563566 assert torch .equal (x [m * 32 :(m + 1 ) * 32 ], z_tri [32 * i :32 * (i + 1 ), col_offset :(col_offset + 4 )])
564567
565568
566- @pytest .mark .skipif (not is_blackwell (), reason = "Requires Blackwell" )
569+ @pytest .mark .xfail (not is_blackwell (), reason = "Requires Blackwell" )
567570def test_tmem_subslice_block_m_64 ():
568571
569572 @gluon .jit
@@ -643,7 +646,7 @@ def kernel(s_ptr, out_ptr):
643646 torch .testing .assert_close (out_ref , out_tri , atol = 0 , rtol = 0 )
644647
645648
646- @pytest .mark .skipif (not is_blackwell (), reason = "Requires Blackwell" )
649+ @pytest .mark .xfail (not is_blackwell (), reason = "Requires Blackwell" )
647650def test_block_m_64_mma ():
648651
649652 @gluon .jit
@@ -734,7 +737,7 @@ def kernel(a_ptr, b_ptr, c_ptr, d_ptr):
734737 torch .testing .assert_close (d_ref , d_tri , rtol = 0.08 , atol = 0 )
735738
736739
737- def test_slice_reinterpret ():
740+ def test_slice_reinterpret (device ):
738741 BLOCK = ttgl .constexpr (2048 )
739742 SPLIT_BLOCK = ttgl .constexpr (BLOCK // 2 )
740743 XBLOCK = ttgl .constexpr (32 )
@@ -759,13 +762,13 @@ def kernel(in_ptr, out_ptr):
759762 value = smem_slice1 .load (blocked )
760763 ttgl .store (ttgl .set_auto_layout (out_ptr + offs , blocked ), value )
761764
762- input = torch .randint (0 , 100 , (XBLOCK , YBLOCK ), dtype = torch .int32 , device = "cuda" )
765+ input = torch .randint (0 , 100 , (XBLOCK , YBLOCK ), dtype = torch .int32 , device = device )
763766 output = torch .empty_like (input )
764767 kernel [(1 , )](input , output )
765768 torch .testing .assert_close (input , output , atol = 0 , rtol = 0 )
766769
767770
768- @pytest .mark .skipif (not is_hopper_or_newer (), reason = "Requires Hopper" )
771+ @pytest .mark .xfail (not is_hopper_or_newer (), reason = "Requires Hopper" )
769772def test_tma_slice ():
770773 XBLOCK = YBLOCK = ttgl .constexpr (128 )
771774
@@ -802,7 +805,7 @@ def kernel(in_desc, out_desc):
802805@pytest .mark .parametrize ("swizzle" , [32 , 64 , 128 ])
803806@pytest .mark .parametrize ("num_warps" , [4 , 8 ])
804807@pytest .mark .parametrize ("M, N, BLOCK_N" , [(128 , 128 , 128 ), (256 , 128 , 64 ), (128 , 128 , 16 )])
805- @pytest .mark .skipif (not is_blackwell (), reason = "Requires Blackwell" )
808+ @pytest .mark .xfail (not is_blackwell (), reason = "Requires Blackwell" )
806809def test_tmem_copy_no_scales (M , N , BLOCK_N , num_warps , swizzle ):
807810
808811 @gluon .jit
@@ -856,7 +859,7 @@ def early_return_kernel(x):
856859 return x
857860
858861
859- def test_2d_tensor_early_return ():
862+ def test_2d_tensor_early_return (device ):
860863 warp_size = ttgl .constexpr (THREADS_PER_WARP )
861864
862865 @gluon .jit
@@ -871,12 +874,12 @@ def kernel(N, out):
871874 x += early_return_kernel (x )
872875 ttgl .store (out , x .sum (0 ).sum (0 ))
873876
874- out = torch .empty (1 , dtype = torch .int32 , device = "cuda" )
877+ out = torch .empty (1 , dtype = torch .int32 , device = device )
875878 compiled_kernel = kernel .warmup (N = 100 , out = out , grid = (1 , ))
876879 assert compiled_kernel .asm ["llir" ].count ("define" ) == 1
877880
878881
879- @pytest .mark .skipif (not is_hip_cdna3 () and not is_hip_cdna4 (), reason = "Requires CDNA3 or CDNA4" )
882+ @pytest .mark .xfail (not is_hip_cdna3 () and not is_hip_cdna4 (), reason = "Requires CDNA3 or CDNA4" )
880883def test_inline_with_amdgpu_dialect ():
881884
882885 @gluon .jit
@@ -906,7 +909,8 @@ def kernel(x, y):
906909 {"offsets" : [[0 , 1 ], [0 , 2 ], [0 , 8 ], [0 , 4 ], [0 , 16 ], [0 , 32 ], [2 , 0 ], [1 , 0 ], [4 , 0 ], [8 , 0 ], [16 , 0 ], [32 , 0 ]]}])
907910@pytest .mark .parametrize ("slice_m_offset, slice_n_offset, slice_m, slice_n" , [(48 , 16 , 16 , 16 ), (32 , 48 , 32 , 16 ),
908911 (48 , 32 , 16 , 32 )])
909- def test_padded_shared_layout_subslice (interval_pairs , shared_layout , slice_m_offset , slice_n_offset , slice_m , slice_n ):
912+ def test_padded_shared_layout_subslice (interval_pairs , shared_layout , slice_m_offset , slice_n_offset , slice_m , slice_n ,
913+ device ):
910914 m = 64
911915 n = 64
912916 num_warps = 1
@@ -945,8 +949,8 @@ def kernel(in_ptr, out_ptr, M: ttgl.constexpr, N: ttgl.constexpr, SLICE_M_OFFSET
945949 out_offs = offs_m_store [:, None ] * SLICE_N + offs_n_store [None , :]
946950 ttgl .store (out_ptr + out_offs , out_data )
947951
948- input = torch .arange (m * n , device = "cuda" ).reshape (m , n ).to (torch .int32 )
949- output = torch .zeros ((slice_m , slice_n ), dtype = torch .int32 , device = "cuda" )
952+ input = torch .arange (m * n , device = device ).reshape (m , n ).to (torch .int32 )
953+ output = torch .zeros ((slice_m , slice_n ), dtype = torch .int32 , device = device )
950954 ref_output = input [slice_m_offset :slice_m_offset + slice_m , slice_n_offset :slice_n_offset + slice_n ]
951955
952956 kernel [(1 , )](input , output , m , n , slice_m_offset , slice_n_offset , slice_m , slice_n , num_warps = num_warps )
0 commit comments