Skip to content

Commit b59bb9a

Browse files
Fix test_gather_warp_shuffle (#3088)
Co-authored-by: Lu,Chengjun <[email protected]>
1 parent 36b6dd2 commit b59bb9a

File tree

2 files changed

+14
-12
lines changed

2 files changed

+14
-12
lines changed

python/test/unit/language/test_core.py

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -6490,8 +6490,6 @@ def triton_gather(src: torch.Tensor, axis: int, indices: torch.Tensor):
64906490
])
64916491
def test_gather_warp_shuffle(src_shape, indices_shape, axis, src_layout, indices_layout, tmp_path: pathlib.Path,
64926492
device):
6493-
if is_xpu():
6494-
pytest.skip("warp-local gather has issues on XPU")
64956493
if is_hip():
64966494
pytest.skip("warp-local gather has issues on HIP")
64976495

@@ -6517,13 +6515,13 @@ def inject_layout(ir, src: torch.Tensor, axis, indices: torch.Tensor, src_layout
65176515

65186516
pat = r"(%[0-9]+) = tt.gather (%[0-9]+)\[(%[0-9]+)\] {axis = "
65196517
pat += str(axis)
6520-
pat += r" : i32} : \(tensor\<"
6518+
pat += r" : i32[, efficient_layout]*} : \(tensor\<"
65216519
pat += src_spec
6522-
pat += r", (#[a-z]+[0-9]+)\>, tensor\<"
6520+
pat += r", (#[a-z]+[0-9]*)\>, tensor\<"
65236521
pat += indices_spec
6524-
pat += r", (#[a-z]+[0-9]+)\>\) -> tensor\<"
6522+
pat += r", (#[a-z]+[0-9]*)\>\) -> tensor\<"
65256523
pat += output_spec
6526-
pat += r", (#[a-z]+[0-9]+)\>"
6524+
pat += r", (#[a-z]+[0-9]*)\>"
65276525

65286526
repl = r"""
65296527
%src = ttg.convert_layout \2 : tensor<""" + src_spec + r""", \4> -> tensor<""" + src_spec + r""", #src_layout>
@@ -6546,7 +6544,9 @@ def inject_layout(ir, src: torch.Tensor, axis, indices: torch.Tensor, src_layout
65466544
temp_file.write_text(ir)
65476545

65486546
kernel = triton.compile(str(temp_file))
6549-
assert ("nvvm.shfl.sync.idx" in kernel.asm["llir"]) or ("llvm.amdgcn.ds.bpermute" in kernel.asm["llir"])
6547+
assert ("nvvm.shfl.sync.idx" in kernel.asm["llir"]) or ("llvm.amdgcn.ds.bpermute"
6548+
in kernel.asm["llir"]) or ("_Z17sub_group_shufflefj"
6549+
in kernel.asm["llir"])
65506550

65516551
kernel[(1, 1, 1)](src, indices, output)
65526552

third_party/intel/lib/TritonIntelGPUToLLVM/ConvertLayoutOpToLLVM.cpp

Lines changed: 7 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -896,11 +896,13 @@ struct ConvertLayoutOpUsingLinearLayoutsConversion
896896
void mlir::triton::intel::populateConvertLayoutOpToLLVMPatterns(
897897
LLVMTypeConverter &typeConverter, const TargetInfo &targetInfo,
898898
RewritePatternSet &patterns, PatternBenefit benefit) {
899-
// We prefer using the linear layout conversion, so it gets a higher benefit.
900-
// Eventually the LL conversion will subsume all of the others and be the only
901-
// one left.
899+
// We prefer using the Intel specific linear layout conversion, so it gets a
900+
// higher benefit. Eventually the LL conversion will subsume all of the others
901+
// and be the only one left.
902902
patterns.add<gpu::ConvertLayoutOpUsingLinearLayoutsConversion>(
903-
typeConverter, targetInfo, benefit.getBenefit() + 1);
903+
typeConverter, targetInfo, benefit.getBenefit() + 2);
904904
patterns.add<gpu::ConvertLayoutOpConversion>(typeConverter, targetInfo,
905-
benefit);
905+
benefit.getBenefit() + 1);
906+
mlir::triton::populateConvertLayoutOpToLLVMPatterns(typeConverter, targetInfo,
907+
patterns, benefit);
906908
}

0 commit comments

Comments
 (0)