Skip to content

Commit 7b2beae

Browse files
authored
[BACKEND] Fix inline asm bug for multiple packed <32bit output (#5273)
Resolves #5272 - Fixes logic for walking result struct from LLVM InlineAsm in case of multiple sub-32bit results - Adds lit test
1 parent 6d3ed0b commit 7b2beae

File tree

2 files changed

+15
-3
lines changed

2 files changed

+15
-3
lines changed

lib/Conversion/TritonGPUToLLVM/ElementwiseOpToLLVM.cpp

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -325,13 +325,12 @@ struct ElementwiseInlineAsmOpConversion
325325
// asmResults is a flat struct; pack its values into
326326
// [return_value][op.getPackedElement()].
327327
SmallVector<SmallVector<Value>> ret(op->getNumResults());
328+
int structIdx = 0;
328329
for (int i = 0; i < op->getNumResults(); i++) {
329-
int structIdx = 0;
330330
for (int j = 0; j < op.getPackedElement(); j++) {
331331
Value val;
332332
if (asmRetTypes.size() > 1) {
333-
val =
334-
extract_val(asmResults, i * op.getPackedElement() + structIdx++);
333+
val = extract_val(asmResults, structIdx++);
335334
} else {
336335
val = asmResults;
337336
}

test/Conversion/tritongpu_to_llvm.mlir

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1897,3 +1897,16 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, ttg.targ
18971897
tt.return
18981898
}
18991899
}
1900+
1901+
// -----
1902+
1903+
// CHECK: inline_asm_pack
1904+
#blocked = #triton_gpu.blocked<{sizePerThread = [16, 1], threadsPerWarp = [4, 8], warpsPerCTA = [1, 4], order = [0, 1]}>
1905+
module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : i32, triton_gpu.target = "cuda:90", "triton_gpu.threads-per-warp" = 32 : i32} {
1906+
// check specifically for the case where asm has two results, pack > 1, and the result bitwidth is < 32
1907+
tt.func public @inline_asm_pack(%80: tensor<64x64xi8, #blocked>) attributes {noinline = false} {
1908+
// CHECK: llvm.inline_asm asm_dialect {{.*}} (vector<4xi8>) -> !llvm.struct<(vector<2xbf16>, vector<2xbf16>, vector<2xbf16>, vector<2xbf16>)>
1909+
%83:2 = tt.elementwise_inline_asm "" {constraints = "=r,=r,=r,=r,r", packed_element = 4 : i32, pure = true} %80 : tensor<64x64xi8, #blocked> -> tensor<64x64xbf16, #blocked>, tensor<64x64xbf16, #blocked>
1910+
tt.return
1911+
}
1912+
}

0 commit comments

Comments
 (0)