Skip to content

Commit 87970fa

Browse files
ThomasRaouxloislo
authored andcommitted
[BACKEND] Fix condition to do a naive reshape with allow_reorder (triton-lang#6012)
Even though allow_reorder is set, a reshape is not always a no-op as we may swap a replicated value for a non-replicated one.
1 parent 1ba83ae commit 87970fa

File tree

2 files changed

+33
-0
lines changed

2 files changed

+33
-0
lines changed

lib/Dialect/TritonGPU/IR/Dialect.cpp

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -147,6 +147,20 @@ SmallVector<unsigned> getShapePerCTATile(Attribute layout) {
147147
}
148148

149149
bool isExpensiveView(Type srcType, Type dstType) {
150+
auto tensorSrcType = cast<RankedTensorType>(srcType);
151+
auto tensorDstType = cast<RankedTensorType>(dstType);
152+
auto llSrc =
153+
toLinearLayout(tensorSrcType.getShape(), tensorSrcType.getEncoding());
154+
auto llDst =
155+
toLinearLayout(tensorDstType.getShape(), tensorDstType.getEncoding());
156+
// In case there are replicated value we need to make sure the new and old
157+
// layout have matching masks.
158+
for (auto [srcMask, dstMask] :
159+
llvm::zip(llSrc.getFreeVariableMasks(), llDst.getFreeVariableMasks())) {
160+
assert(srcMask.first == dstMask.first);
161+
if (srcMask.second != dstMask.second)
162+
return true;
163+
}
150164
return getTotalElemsPerThread(srcType) != getTotalElemsPerThread(dstType);
151165
}
152166

test/TritonGPU/canonicalize.mlir

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -40,6 +40,25 @@ tt.func @test_canonicalize_convert_expensive_view(%arg0: tensor<256x16xf32, #blo
4040

4141
// -----
4242

43+
// test that the convert doesn't get combined with view if the resulting operations
44+
// is an expensive view which would require moving data across threads.
45+
// CHECK-LABEL: @test_canonicalize_convert_expensive_view
46+
// CHECK-SAME: (%[[ARG:.+]]: tensor<2xf32
47+
// CHECK: %[[C:.+]] = ttg.convert_layout %[[ARG]]
48+
// CHECK: %[[V:.+]] = tt.reshape %[[C]] allow_reorder
49+
// CHECK: tt.return %[[V]]
50+
#blocked = #ttg.blocked<{sizePerThread = [1, 1], threadsPerWarp = [8, 4], warpsPerCTA = [4, 1], order = [1, 0]}>
51+
#blocked1 = #ttg.blocked<{sizePerThread = [1], threadsPerWarp = [32], warpsPerCTA = [4], order = [0]}>
52+
module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, ttg.target = "cuda:80"} {
53+
tt.func @test_canonicalize_convert_expensive_view2(%arg0: tensor<2xf32, #ttg.slice<{dim = 1, parent = #blocked}>>) -> tensor<2xf32, #blocked1> {
54+
%c = ttg.convert_layout %arg0 : tensor<2xf32, #ttg.slice<{dim = 1, parent = #blocked}>> -> tensor<2xf32, #blocked1>
55+
%r = tt.reshape %c allow_reorder : tensor<2xf32, #blocked1> -> tensor<2xf32, #blocked1>
56+
tt.return %r : tensor<2xf32, #blocked1>
57+
}
58+
}
59+
60+
// -----
61+
4362
// test that the convert does get combined with the view even if the resulting operation
4463
// is an efficient view.
4564
// CHECK-LABEL: @test_canonicalize_convert_view

0 commit comments

Comments
 (0)