@@ -137,9 +137,10 @@ def check_type_supported(dtype, device):
137137 pytest .xfail ("float64 not supported on current xpu hardware" )
138138
139139
140- def check_threads_supported (num_warps , threads_per_warp ):
141- device = triton .runtime .driver .active .get_current_device ()
142- props = triton .runtime .driver .active .utils .get_device_properties (device )
140+ def check_threads_supported (num_warps , threads_per_warp , device ):
141+ if device != "xpu" :
142+ return
143+ props = triton .runtime .driver .active .utils .get_device_properties (triton .runtime .driver .active .get_current_device ())
143144 if threads_per_warp not in props ['sub_group_sizes' ]:
144145 pytest .xfail ('unsupported warp size' )
145146 if threads_per_warp * num_warps > props ['max_work_group_size' ]:
@@ -2366,7 +2367,7 @@ def get_reduced_dtype(dtype_str, op):
23662367 [(64 , 16 ), (4 , THREADS_PER_WARP )] if is_xpu () else [(4 , THREADS_PER_WARP )])
23672368def test_reduce1d (op , dtype_str , shape , num_ctas , num_warps , threads_per_warp , device ):
23682369 check_type_supported (dtype_str , device ) # bfloat16 on cc < 80 will not be tested
2369- check_threads_supported (num_warps , threads_per_warp )
2370+ check_threads_supported (num_warps , threads_per_warp , device )
23702371
23712372 # triton kernel
23722373 @triton .jit
@@ -2475,7 +2476,7 @@ def kernel(X, Z, BLOCK: tl.constexpr):
24752476 [(64 , 16 ), (4 , THREADS_PER_WARP )] if is_xpu () else [(4 , THREADS_PER_WARP )])
24762477def test_reduce (op , dtype_str , shape , axis , keep_dims , num_ctas , num_warps , threads_per_warp , device ):
24772478 check_type_supported (dtype_str , device ) # bfloat16 on cc < 80 will not be tested
2478- check_threads_supported (num_warps , threads_per_warp )
2479+ check_threads_supported (num_warps , threads_per_warp , device )
24792480
24802481 @triton .jit
24812482 def kernel (X , Z , BLOCK_M : tl .constexpr , BLOCK_N : tl .constexpr , BLOCK_K : tl .constexpr , IS_3D : tl .constexpr ,
0 commit comments