Skip to content

Commit 919c9d2

Browse files
[TritonGPU] Fix layout error after hoisting convert over ext/broadcast (#7058)
Fixes pytorch/pytorch#154933 In the above issue, `hoistConvertOnTopOfExtOrBroadcast` produces an invalid graph and errors out: a `tt.expand_dims` expects an input with a blocked layout, but the actual input `tt.make_range` has a linear layout. The hoistConvertOnTopOfExtOrBroadcast works like this: 1. find a backward slice from the convert op, stopping at any extension/broadcast ops 2. from the boundary of convert ops found in step 1, find backward slices from _those_ ops. In step 1 and each iteration of step 2, `getConvertBackwardSlice` will return `failure()` if the graph traversal identifies two conflicting layout assignments for the same value. However, the bug is that two separate `getConvertBackwardSlice` iterations from step 2 may identify conflicting layout assignments for the same value, and this case was previously not checked.
1 parent b6dabff commit 919c9d2

File tree

2 files changed

+54
-1
lines changed

2 files changed

+54
-1
lines changed

lib/Dialect/TritonGPU/Transforms/RemoveLayoutConversions.cpp

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1399,6 +1399,19 @@ void LayoutRematerialization::hoistConvertOnTopOfExtOrBroadcast(
13991399
return;
14001400
LogicalResult result = getRematerializableSlice(
14011401
op->getOpOperand(0), srcEncoding, tempSlice, tempLayout);
1402+
1403+
// If a value is already assigned to a _different_ layout,
1404+
// we cannot propagate past this op (as it would conflict with
1405+
// an already-assigned layout).
1406+
for (auto [val, enc] : tempLayout) {
1407+
auto preexistingLayout = layout.find(val);
1408+
if (preexistingLayout != layout.end() &&
1409+
preexistingLayout->second != enc) {
1410+
result = failure();
1411+
break;
1412+
}
1413+
}
1414+
14021415
// If we can rematerialize the rest of the ext slice we can ignore this
14031416
// ext as it won't need a convert.
14041417
if (result.succeeded()) {

test/TritonGPU/combine.mlir

Lines changed: 41 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2059,6 +2059,44 @@ module attributes {"ttg.target" = "cuda:80", "ttg.num-ctas" = 1 : i32, "ttg.num-
20592059

20602060
// -----
20612061

2062+
// Minimal repro for https://github.com/pytorch/pytorch/issues/154933
2063+
//
2064+
// Check that if, during hoisting conversions over ext and broadcast ops,
2065+
// we see multiple different layouts assigned to the same value, then we
2066+
// skip propagation of that layout.
2067+
2068+
// CHECK-LABEL: @hoist_on_ext_broadcast_mismatch
2069+
#blockedX = #ttg.blocked<{sizePerThread = [1, 1], threadsPerWarp = [32, 1], warpsPerCTA = [4, 1], order = [1, 0]}>
2070+
#blockedY = #ttg.blocked<{sizePerThread = [1, 4], threadsPerWarp = [32, 1], warpsPerCTA = [4, 1], order = [1, 0]}>
2071+
module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, ttg.target = "cuda:90", "ttg.threads-per-warp" = 32 : i32} {
2072+
tt.func public @hoist_on_ext_broadcast_mismatch(%arg0: !tt.ptr<i32> {tt.divisibility = 16 : i32}, %arg1: !tt.ptr<i32> {tt.divisibility = 16 : i32}) -> tensor<4x1xi64, #blockedY> attributes {noinline = false} {
2073+
%c1_i32 = arith.constant 1 : i32
2074+
%c4_i32 = arith.constant 4 : i32
2075+
%c0_i32 = arith.constant 0 : i32
2076+
%0 = tt.make_range {end = 4 : i32, start = 0 : i32} : tensor<4xi32, #ttg.slice<{dim = 1, parent = #blockedX}>>
2077+
%cast0 = arith.extsi %0 : tensor<4xi32, #ttg.slice<{dim = 1, parent = #blockedX}>> to tensor<4xi64, #ttg.slice<{dim = 1, parent = #blockedX}>>
2078+
%1 = tt.splat %arg0 : !tt.ptr<i32> -> tensor<4x!tt.ptr<i32>, #ttg.slice<{dim = 1, parent = #blockedX}>>
2079+
%2 = tt.expand_dims %cast0 {axis = 1 : i32} : tensor<4xi64, #ttg.slice<{dim = 1, parent = #blockedX}>> -> tensor<4x1xi64, #blockedX>
2080+
%3 = tt.addptr %1, %cast0 : tensor<4x!tt.ptr<i32>, #ttg.slice<{dim = 1, parent = #blockedX}>>, tensor<4xi64, #ttg.slice<{dim = 1, parent = #blockedX}>>
2081+
%4 = tt.load %3 : tensor<4x!tt.ptr<i32>, #ttg.slice<{dim = 1, parent = #blockedX}>>
2082+
%5 = tt.reshape %4 : tensor<4xi32, #ttg.slice<{dim = 1, parent = #blockedX}>> -> tensor<4x1xi32, #blockedX>
2083+
// CHECK: arith.extsi
2084+
%6 = arith.extsi %5 : tensor<4x1xi32, #blockedX> to tensor<4x1xi64, #blockedX>
2085+
%7 = arith.addi %2, %6 : tensor<4x1xi64, #blockedX>
2086+
// for loop prevents fully hoisting the conversion.
2087+
%8 = scf.for %arg2 = %c0_i32 to %c4_i32 step %c1_i32 iter_args(%arg3 = %5) -> (tensor<4x1xi32, #blockedX>) : i32 {
2088+
scf.yield %5 : tensor<4x1xi32, #blockedX>
2089+
}
2090+
// CHECK: ttg.convert_layout
2091+
%9 = arith.extsi %8 : tensor<4x1xi32, #blockedX> to tensor<4x1xi64, #blockedX>
2092+
%10 = arith.addi %7, %9 : tensor<4x1xi64, #blockedX>
2093+
%11 = ttg.convert_layout %10 : tensor<4x1xi64, #blockedX> -> tensor<4x1xi64, #blockedY>
2094+
tt.return %11 : tensor<4x1xi64, #blockedY>
2095+
}
2096+
}
2097+
2098+
// -----
2099+
20622100
#blocked = #ttg.blocked<{sizePerThread = [1, 1], threadsPerWarp = [32, 1], warpsPerCTA = [2, 1], order = [0, 1]}>
20632101
#blocked1 = #ttg.blocked<{sizePerThread = [1, 1], threadsPerWarp = [1, 32], warpsPerCTA = [2, 1], order = [1, 0]}>
20642102
#blocked2 = #ttg.blocked<{sizePerThread = [1, 1], threadsPerWarp = [1, 32], warpsPerCTA = [1, 2], order = [0, 1]}>
@@ -2525,7 +2563,9 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, "ttg.thr
25252563

25262564
// CHECK-LABEL: double_remat
25272565
// CHECK: %[[res:.*]] = ttg.convert_layout
2528-
// CHECK-NEXT: tt.return %[[res]]
2566+
// CHECK: tt.broadcast %[[res]]
2567+
// CHECK-NOT: ttg.convert_layout
2568+
// CHECK: tt.return
25292569
#blocked = #ttg.blocked<{sizePerThread = [1, 4], threadsPerWarp = [1, 32], warpsPerCTA = [1, 2], order = [1, 0]}>
25302570
#blocked1 = #ttg.blocked<{sizePerThread = [1, 1], threadsPerWarp = [1, 32], warpsPerCTA = [1, 2], order = [1, 0]}>
25312571
#blocked2 = #ttg.blocked<{sizePerThread = [1, 1, 1], threadsPerWarp = [1, 1, 32], warpsPerCTA = [1, 1, 2], order = [2, 1, 0]}>

0 commit comments

Comments
 (0)