Skip to content

Commit e1e5e6c

Browse files
authored
[BACKEND] Improve backward convert_layout propagation for join op (#7065)
When propagating through join op, we may end up with the same source layout even though we want a different destination layout. In this case we can store the propagation and just change the destination layout. Could probably be generalized to more ops.
1 parent 71c98c8 commit e1e5e6c

File tree

2 files changed

+28
-5
lines changed

2 files changed

+28
-5
lines changed

lib/Dialect/TritonGPU/Transforms/Utility.cpp

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -941,6 +941,13 @@ LogicalResult getConvertBackwardSlice(
941941
auto srcEncoding = inferSrcEncoding(definingOp, encoding);
942942
if (!srcEncoding)
943943
return failure();
944+
// If the infered layout matches the original one we don't need to keep
945+
// propagating.
946+
if (auto operandType =
947+
dyn_cast<RankedTensorType>(operand.get().getType())) {
948+
if (srcEncoding == operandType.getEncoding())
949+
continue;
950+
}
944951
enqueue(operand, srcEncoding);
945952
}
946953
continue;

test/TritonGPU/combine.mlir

Lines changed: 21 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -3789,20 +3789,36 @@ module attributes {"ttg.num-warps" = 4 : i32, "ttg.target" = "cuda:80"} {
37893789

37903790
#blocked = #ttg.blocked<{sizePerThread = [1, 1, 2], threadsPerWarp = [2, 16, 1], warpsPerCTA = [1, 1, 1], order = [2, 1, 0]}>
37913791
#blocked1 = #ttg.blocked<{sizePerThread = [1, 1], threadsPerWarp = [2, 16], warpsPerCTA = [1, 1], order = [1, 0]}>
3792-
#linear = #ttg.linear<{register = [], lane = [[0, 1], [0, 2], [0, 4], [0, 8], [1, 0]], warp = [], block = []}>
3792+
#blocked2 = #ttg.blocked<{sizePerThread = [1, 2], threadsPerWarp = [2, 16], warpsPerCTA = [1, 1], order = [1, 0]}>
37933793
module attributes {"ttg.num-warps" = 1 : i32, ttg.target = "cuda:80"} {
37943794
// CHECK-LABEL: join_forward
3795-
tt.func @join_forward(%arg0: tensor<2x16xf32, #linear>) -> tensor<2x16x2xf32, #blocked> {
3796-
// CHECK-LABEL: tt.join
3797-
// CHECK-LABEL: ttg.convert_layout
3798-
%0 = ttg.convert_layout %arg0 : tensor<2x16xf32, #linear> -> tensor<2x16xf32, #blocked1>
3795+
tt.func @join_forward(%arg0: tensor<2x16xf32, #blocked2>) -> tensor<2x16x2xf32, #blocked> {
3796+
// CHECK: tt.join
3797+
// CHECK: ttg.convert_layout
3798+
// CHECK: tt.return
3799+
%0 = ttg.convert_layout %arg0 : tensor<2x16xf32, #blocked2> -> tensor<2x16xf32, #blocked1>
37993800
%1 = tt.join %0, %0 : tensor<2x16xf32, #blocked1> -> tensor<2x16x2xf32, #blocked>
38003801
tt.return %1 : tensor<2x16x2xf32, #blocked>
38013802
}
38023803
}
38033804

38043805
// -----
38053806

3807+
#blocked = #ttg.blocked<{sizePerThread = [1, 32], threadsPerWarp = [32, 1], warpsPerCTA = [4, 1], order = [0, 1]}>
3808+
#blocked1 = #ttg.blocked<{sizePerThread = [1, 32, 2], threadsPerWarp = [32, 1, 1], warpsPerCTA = [4, 1, 1], order = [0, 1, 2]}>
3809+
#blocked2 = #ttg.blocked<{sizePerThread = [1, 32, 2], threadsPerWarp = [32, 1, 1], warpsPerCTA = [4, 1, 1], order = [2, 0, 1]}>
3810+
module attributes {"ttg.num-warps" = 4 : i32, ttg.target = "cuda:80"} {
3811+
// CHECK-LABEL: join_backward
3812+
tt.func @join_backward(%arg0: tensor<128x32xf16, #blocked>, %arg1: tensor<128x32xf16, #blocked>) -> tensor<128x32x2xf16, #blocked1> {
3813+
// CHECK: %[[JOIN:.*]] = tt.join
3814+
// CHECK: tt.return %[[JOIN]]
3815+
%0 = tt.join %arg0, %arg1 : tensor<128x32xf16, #blocked> -> tensor<128x32x2xf16, #blocked2>
3816+
%1 = ttg.convert_layout %0 : tensor<128x32x2xf16, #blocked2> -> tensor<128x32x2xf16, #blocked1>
3817+
tt.return %1 : tensor<128x32x2xf16, #blocked1>
3818+
}
3819+
}
3820+
// -----
3821+
38063822
#linear = #ttg.linear<{register = [[0, 2], [64, 0]], lane = [[1, 0], [2, 0], [4, 0], [8, 0], [16, 0], [0, 1]], warp = [[0, 0], [32, 0]], block = []}>
38073823
#linear1 = #ttg.linear<{register = [[0, 2], [64, 0]], lane = [[1, 0], [2, 0], [4, 0], [8, 0], [16, 0], [0, 1]], warp = [[32, 0], [0, 0]], block = []}>
38083824
#blocked = #ttg.blocked<{sizePerThread = [1, 1], threadsPerWarp = [1, 64], warpsPerCTA = [2, 2], order = [1, 0]}>

0 commit comments

Comments
 (0)