@@ -6490,8 +6490,6 @@ def triton_gather(src: torch.Tensor, axis: int, indices: torch.Tensor):
64906490])
64916491def test_gather_warp_shuffle (src_shape , indices_shape , axis , src_layout , indices_layout , tmp_path : pathlib .Path ,
64926492 device ):
6493- if is_xpu ():
6494- pytest .skip ("warp-local gather has issues on XPU" )
64956493 if is_hip ():
64966494 pytest .skip ("warp-local gather has issues on HIP" )
64976495
@@ -6517,13 +6515,13 @@ def inject_layout(ir, src: torch.Tensor, axis, indices: torch.Tensor, src_layout
65176515
65186516 pat = r"(%[0-9]+) = tt.gather (%[0-9]+)\[(%[0-9]+)\] {axis = "
65196517 pat += str (axis )
6520- pat += r" : i32} : \(tensor\<"
6518+ pat += r" : i32[, efficient_layout]* } : \(tensor\<"
65216519 pat += src_spec
6522- pat += r", (#[a-z]+[0-9]+ )\>, tensor\<"
6520+ pat += r", (#[a-z]+[0-9]* )\>, tensor\<"
65236521 pat += indices_spec
6524- pat += r", (#[a-z]+[0-9]+ )\>\) -> tensor\<"
6522+ pat += r", (#[a-z]+[0-9]* )\>\) -> tensor\<"
65256523 pat += output_spec
6526- pat += r", (#[a-z]+[0-9]+ )\>"
6524+ pat += r", (#[a-z]+[0-9]* )\>"
65276525
65286526 repl = r"""
65296527 %src = ttg.convert_layout \2 : tensor<""" + src_spec + r""", \4> -> tensor<""" + src_spec + r""", #src_layout>
@@ -6546,7 +6544,9 @@ def inject_layout(ir, src: torch.Tensor, axis, indices: torch.Tensor, src_layout
65466544 temp_file .write_text (ir )
65476545
65486546 kernel = triton .compile (str (temp_file ))
6549- assert ("nvvm.shfl.sync.idx" in kernel .asm ["llir" ]) or ("llvm.amdgcn.ds.bpermute" in kernel .asm ["llir" ])
6547+ assert ("nvvm.shfl.sync.idx" in kernel .asm ["llir" ]) or ("llvm.amdgcn.ds.bpermute"
6548+ in kernel .asm ["llir" ]) or ("_Z17sub_group_shufflefj"
6549+ in kernel .asm ["llir" ])
65506550
65516551 kernel [(1 , 1 , 1 )](src , indices , output )
65526552
0 commit comments