@@ -217,8 +217,7 @@ def __str__(self):
217217
218218
219219def is_layout_applicable (layout ) -> bool :
220- common_layouts = [BlockedLayout , SharedLayout ]
221- if layout in common_layouts :
220+ if isinstance (layout , (BlockedLayout , SharedLayout )):
222221 return True
223222 elif isinstance (layout , SliceLayout ):
224223 return is_layout_applicable (layout .parent )
@@ -1447,6 +1446,7 @@ def kernel(X, Y, Z):
14471446 for mode in ['all_neg' , 'all_pos' , 'min_neg' , 'max_pos' ]
14481447 for sem in [None , 'acquire' , 'release' , 'acq_rel' , 'relaxed' ]]))
14491448def test_atomic_rmw (op , dtype_x_str , mode , sem , device ):
1449+ check_type_supported (dtype_x_str , device )
14501450 if is_interpreter ():
14511451 if dtype_x_str == 'float16' :
14521452 pytest .skip ("Only test atomic float16 ops on GPU" )
@@ -1523,6 +1523,7 @@ def kernel(X):
15231523 for num_ctas in num_ctas_list
15241524 for dtype_x_str in ['float16' , 'float32' , 'uint64' , 'int64' , 'float64' ]])
15251525def test_tensor_atomic_rmw (shape , axis , num_ctas , dtype_x_str , device ):
1526+ check_type_supported (dtype_x_str , device )
15261527 shape0 , shape1 = shape
15271528 # triton kernel
15281529
@@ -2874,7 +2875,7 @@ def test_reduce_layouts(M, N, src_layout, axis, epilogue_kind, dtype_str, add_ov
28742875
28752876
28762877@pytest .mark .parametrize ("M" , [32 , 64 , 128 , 256 ])
2877- @pytest .mark .parametrize ("src_layout" , layouts )
2878+ @pytest .mark .parametrize ("src_layout" , filter_layouts ( layouts ) )
28782879def test_store_op (M , src_layout , device , tmp_path : pathlib .Path ):
28792880
28802881 ir = f"""
@@ -3807,7 +3808,7 @@ def test_dot3d(B, num_warps, M, N, K, BLOCK_M, BLOCK_N, in_dtype_str, out_dtype_
38073808
38083809 if B == 8 and M == 64 and in_dtype_str == "float32" and out_dtype_str == "float32" :
38093810 if not is_interpreter () and triton .runtime .driver .active .utils .get_device_properties (
3810- torch . cuda . current_device ())["max_shared_mem" ] < 131072 :
3811+ triton . runtime . driver . active . get_current_device ())["max_shared_mem" ] < 131072 :
38113812 pytest .skip (
38123813 "Skipping tests with B = 8, M = 64, in_type = float32, out_type = float32 due to insufficient shared memory (less than 128 KB per SM) on this GPU."
38133814 )
@@ -6550,7 +6551,7 @@ def gather_test_kernel(src_ptr, idx_ptr, out_ptr, axis: tl.constexpr, src_dim0:
65506551 ([128 , 64 ], [256 , 64 ], 0 ),
65516552 ([128 , 64 ], [128 , 128 ], 1 ),
65526553])
6553- def test_gather (src_shape , indices_shape , axis ):
6554+ def test_gather (src_shape , indices_shape , axis , device ):
65546555
65556556 def triton_gather (src : torch .Tensor , axis : int , indices : torch .Tensor ):
65566557 output = torch .empty (indices .shape , dtype = src .dtype , device = src .device )
@@ -6562,8 +6563,8 @@ def triton_gather(src: torch.Tensor, axis: int, indices: torch.Tensor):
65626563
65636564 return output
65646565
6565- src = torch .randn (src_shape , device = 'cuda' )
6566- indices = torch .randint (0 , src .shape [axis ], indices_shape , device = 'cuda' )
6566+ src = torch .randn (src_shape , device = device )
6567+ indices = torch .randint (0 , src .shape [axis ], indices_shape , device = device )
65676568 ref = torch .gather (src , axis , indices )
65686569 result = triton_gather (src , axis , indices )
65696570 torch .testing .assert_close (result , ref , rtol = 0 , atol = 0 )
@@ -6580,7 +6581,8 @@ def triton_gather(src: torch.Tensor, axis: int, indices: torch.Tensor):
65806581 "linear<{register = [[0, 2], [32, 0], [0, 32], [2, 0], [0, 16], [64, 0], [128, 0]], lane = [[0, 8], [8, 0], [1, 0], [4, 0], [16, 0]], warp = [[0, 1], [0, 4]], block = []}>"
65816582 ),
65826583])
6583- def test_gather_warp_shuffle (src_shape , indices_shape , axis , src_layout , indices_layout , tmp_path : pathlib .Path ):
6584+ def test_gather_warp_shuffle (src_shape , indices_shape , axis , src_layout , indices_layout , tmp_path : pathlib .Path ,
6585+ device ):
65846586 if is_hip ():
65856587 pytest .skip ("warp-local gather has issues on HIP" )
65866588
@@ -6623,8 +6625,8 @@ def inject_layout(ir, src: torch.Tensor, axis, indices: torch.Tensor, src_layout
66236625 \1 = ttg.convert_layout %out : tensor<""" + output_spec + r""", #idx_layout> -> tensor<""" + output_spec + r""", \6>"""
66246626 return re .sub (pat , repl , ir )
66256627
6626- src = torch .randn (src_shape , device = 'cuda' )
6627- indices = torch .randint (0 , src .shape [axis ], indices_shape , device = 'cuda' )
6628+ src = torch .randn (src_shape , device = device )
6629+ indices = torch .randint (0 , src .shape [axis ], indices_shape , device = device )
66286630 ref = torch .gather (src , axis , indices )
66296631
66306632 output , compiled = prepare_kernel (src , axis , indices )
0 commit comments