Skip to content

Commit cb24881

Browse files
authored
[RemoveLayoutConversions]: Fix hoistConvertOnTopOfExtOrBroadcast (#4692)
Fix hoisting of layout convert operation. Fixes issue #4691. --------- Signed-off-by: Tiotto, Ettore <[email protected]>
1 parent 86ba555 commit cb24881

File tree

2 files changed

+51
-0
lines changed

2 files changed

+51
-0
lines changed

test/TritonIntelGPU/combine.mlir

Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2042,6 +2042,44 @@ module attributes {"ttg.target" = "cuda:80", "ttg.num-ctas" = 1 : i32, "ttg.num-
20422042

20432043
// -----
20442044

2045+
// Minimal repro for https://github.com/pytorch/pytorch/issues/154933
2046+
//
2047+
// Check that if, during hoisting conversions over ext and broadcast ops,
2048+
// we see multiple different layouts assigned to the same value, then we
2049+
// skip propagation of that layout.
2050+
2051+
// CHECK-LABEL: @hoist_on_ext_broadcast_mismatch
2052+
#blockedX = #ttg.blocked<{sizePerThread = [1, 1], threadsPerWarp = [32, 1], warpsPerCTA = [4, 1], order = [1, 0]}>
2053+
#blockedY = #ttg.blocked<{sizePerThread = [1, 4], threadsPerWarp = [32, 1], warpsPerCTA = [4, 1], order = [1, 0]}>
2054+
module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, ttg.target = "xpu"} {
2055+
tt.func public @hoist_on_ext_broadcast_mismatch(%arg0: !tt.ptr<i32> {tt.divisibility = 16 : i32}, %arg1: !tt.ptr<i32> {tt.divisibility = 16 : i32}) -> tensor<4x1xi64, #blockedY> {
2056+
%c1_i32 = arith.constant 1 : i32
2057+
%c4_i32 = arith.constant 4 : i32
2058+
%c0_i32 = arith.constant 0 : i32
2059+
%0 = tt.make_range {end = 4 : i32, start = 0 : i32} : tensor<4xi32, #ttg.slice<{dim = 1, parent = #blockedX}>>
2060+
%cast0 = arith.extsi %0 : tensor<4xi32, #ttg.slice<{dim = 1, parent = #blockedX}>> to tensor<4xi64, #ttg.slice<{dim = 1, parent = #blockedX}>>
2061+
%1 = tt.splat %arg0 : !tt.ptr<i32> -> tensor<4x!tt.ptr<i32>, #ttg.slice<{dim = 1, parent = #blockedX}>>
2062+
%2 = tt.expand_dims %cast0 {axis = 1 : i32} : tensor<4xi64, #ttg.slice<{dim = 1, parent = #blockedX}>> -> tensor<4x1xi64, #blockedX>
2063+
%3 = tt.addptr %1, %cast0 : tensor<4x!tt.ptr<i32>, #ttg.slice<{dim = 1, parent = #blockedX}>>, tensor<4xi64, #ttg.slice<{dim = 1, parent = #blockedX}>>
2064+
%4 = tt.load %3 : tensor<4x!tt.ptr<i32>, #ttg.slice<{dim = 1, parent = #blockedX}>>
2065+
%5 = tt.reshape %4 : tensor<4xi32, #ttg.slice<{dim = 1, parent = #blockedX}>> -> tensor<4x1xi32, #blockedX>
2066+
// CHECK: arith.extsi
2067+
%6 = arith.extsi %5 : tensor<4x1xi32, #blockedX> to tensor<4x1xi64, #blockedX>
2068+
%7 = arith.addi %2, %6 : tensor<4x1xi64, #blockedX>
2069+
// for loop prevents fully hoisting the conversion.
2070+
%8 = scf.for %arg2 = %c0_i32 to %c4_i32 step %c1_i32 iter_args(%arg3 = %5) -> (tensor<4x1xi32, #blockedX>) : i32 {
2071+
scf.yield %5 : tensor<4x1xi32, #blockedX>
2072+
}
2073+
// CHECK: ttg.convert_layout
2074+
%9 = arith.extsi %8 : tensor<4x1xi32, #blockedX> to tensor<4x1xi64, #blockedX>
2075+
%10 = arith.addi %7, %9 : tensor<4x1xi64, #blockedX>
2076+
%11 = ttg.convert_layout %10 : tensor<4x1xi64, #blockedX> -> tensor<4x1xi64, #blockedY>
2077+
tt.return %11 : tensor<4x1xi64, #blockedY>
2078+
}
2079+
}
2080+
2081+
// -----
2082+
20452083
#blocked = #ttg.blocked<{sizePerThread = [1, 1], threadsPerWarp = [32, 1], warpsPerCTA = [2, 1], order = [0, 1]}>
20462084
#blocked1 = #ttg.blocked<{sizePerThread = [1, 1], threadsPerWarp = [1, 32], warpsPerCTA = [2, 1], order = [1, 0]}>
20472085
#blocked2 = #ttg.blocked<{sizePerThread = [1, 1], threadsPerWarp = [1, 32], warpsPerCTA = [1, 2], order = [0, 1]}>

third_party/intel/lib/TritonIntelGPUTransforms/RemoveLayoutConversions.cpp

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1518,6 +1518,19 @@ void LayoutRematerialization::hoistConvertOnTopOfExtOrBroadcast(
15181518
return;
15191519
LogicalResult result = getRematerializableSlice(
15201520
op->getOpOperand(0), srcEncoding, tempSlice, tempLayout);
1521+
1522+
// If a value is already assigned to a _different_ layout,
1523+
// we cannot propagate past this op (as it would conflict with
1524+
// an already-assigned layout).
1525+
for (auto [val, enc] : tempLayout) {
1526+
auto preexistingLayout = layout.find(val);
1527+
if (preexistingLayout != layout.end() &&
1528+
preexistingLayout->second != enc) {
1529+
result = failure();
1530+
break;
1531+
}
1532+
}
1533+
15211534
// If we can rematerialize the rest of the ext slice we can ignore this
15221535
// ext as it won't need a convert.
15231536
if (result.succeeded()) {

0 commit comments

Comments
 (0)