Skip to content

Commit c1ed673

Browse files
authored
[BACKEND] Fix propagation through join op (#5987)
we were passing the wrong shape
1 parent bb78fae commit c1ed673

File tree

2 files changed

+18
-2
lines changed

2 files changed

+18
-2
lines changed

lib/Dialect/TritonGPU/Transforms/Utility.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -317,7 +317,7 @@ static Attribute inferDstEncoding(triton::ExpandDimsOp op, Attribute encoding) {
317317

318318
static Attribute inferDstEncoding(JoinOp op, Attribute srcEnc) {
319319
Attribute dstEnc;
320-
auto shape = op.getResult().getType().getShape();
320+
auto shape = op.getLhs().getType().getShape();
321321
if (srcEnc.getDialect()
322322
.getRegisteredInterface<DialectInferLayoutInterface>()
323323
->inferJoinOpEncoding(srcEnc, dstEnc, shape,
@@ -371,7 +371,7 @@ static Attribute inferSrcEncoding(JoinOp op, Attribute dstEnc) {
371371
static Attribute inferSrcEncoding(SplitOp op, Attribute dstEnc) {
372372
// Join is the inverse of split.
373373
Attribute srcEnc;
374-
auto shape = op.getSrc().getType().getShape();
374+
auto shape = op.getOutLHS().getType().getShape();
375375
if (dstEnc.getDialect()
376376
.getRegisteredInterface<DialectInferLayoutInterface>()
377377
->inferJoinOpEncoding(dstEnc, srcEnc, shape, /*loc=*/std::nullopt)

test/TritonGPU/combine.mlir

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3663,3 +3663,19 @@ module attributes {"ttg.num-warps" = 4 : i32, "ttg.target" = "cuda:80"} {
36633663
tt.return %130 : tensor<16x128xf32, #mma>
36643664
}
36653665
}
3666+
3667+
// -----
3668+
3669+
#blocked = #ttg.blocked<{sizePerThread = [1, 1, 2], threadsPerWarp = [2, 16, 1], warpsPerCTA = [1, 1, 1], order = [2, 1, 0]}>
3670+
#blocked1 = #ttg.blocked<{sizePerThread = [1, 1], threadsPerWarp = [2, 16], warpsPerCTA = [1, 1], order = [1, 0]}>
3671+
#linear = #ttg.linear<{register = [], lane = [[0, 1], [0, 2], [0, 4], [0, 8], [1, 0]], warp = [], block = []}>
3672+
module attributes {"ttg.num-warps" = 1 : i32, ttg.target = "cuda:80"} {
3673+
// CHECK-LABEL: join_forward
3674+
tt.func @join_forward(%arg0: tensor<2x16xf32, #linear>) -> tensor<2x16x2xf32, #blocked> {
3675+
// CHECK-LABEL: tt.join
3676+
// CHECK-LABEL: ttg.convert_layout
3677+
%0 = ttg.convert_layout %arg0 : tensor<2x16xf32, #linear> -> tensor<2x16xf32, #blocked1>
3678+
%1 = tt.join %0, %0 : tensor<2x16xf32, #blocked1> -> tensor<2x16x2xf32, #blocked>
3679+
tt.return %1 : tensor<2x16x2xf32, #blocked>
3680+
}
3681+
}

0 commit comments

Comments
 (0)