Skip to content

Commit dacd155

Browse files
authored
[AMD][MFMA] Use linear layout for batch-non-slowest 3D conversion (triton-lang#6371)
Please see equal change for WMMA: triton-lang#6350 Signed-off-by: Ilya Veselov <[email protected]>
1 parent 9dc65e1 commit dacd155

File tree

1 file changed

+5
-2
lines changed

1 file changed

+5
-2
lines changed

third_party/amd/lib/TritonAMDGPUToLLVM/ConvertLayoutOpToLLVM/SharedToDotOperandMFMA.cpp

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -230,8 +230,11 @@ Value convertLayout(int opIdx, ConversionPatternRewriter &rewriter,
230230
if (!sharedLayout)
231231
return Value();
232232
auto order = sharedLayout.getOrder();
233-
assert((rank == 2 || order[2] == 0) &&
234-
"expect batch to be the slowest dimension");
233+
234+
// Rely on the linear layout conversion logic in this case, since only slowest
235+
// dimension for batch is supported here
236+
if (rank != 2 && order.back() != 0)
237+
return Value();
235238

236239
auto elemTy = aTensorTy.getElementType();
237240
auto kWidth = encoding.getKWidth();

0 commit comments

Comments
 (0)