Skip to content

Commit 4199e4e

Browse files
authored
[RemoveLayoutConversion]: Destroy 'ttg.convert_layout' operations unless they have a user (#4880)
The fix prevents the pass from incorrectly removing layout conversion operations that are still being used. Fixes issue #4866. Signed-off-by: Tiotto, Ettore <[email protected]>
1 parent eedd1ce commit 4199e4e

File tree

2 files changed

+62
-7
lines changed

2 files changed

+62
-7
lines changed

test/TritonIntelGPU/backward_combine_dpas_dot_layout.mlir

Lines changed: 44 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -286,3 +286,47 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, "ttg.thr
286286
tt.return
287287
}
288288
}
289+
290+
// -----
291+
292+
// COM: Fix for issue #4866
293+
294+
// CHECK: #[[BLOCKED:.+]] = #ttg.blocked<{sizePerThread = [2, 2], threadsPerWarp = [2, 16], warpsPerCTA = [4, 1], order = [1, 0]}>
295+
// CHECK: #[[BLOCKED1:.+]] = #ttg.blocked<{sizePerThread = [1, 4], threadsPerWarp = [4, 8], warpsPerCTA = [4, 1], order = [1, 0]}>
296+
#blocked = #ttg.blocked<{sizePerThread = [1, 1], threadsPerWarp = [2, 16], warpsPerCTA = [4, 1], order = [1, 0]}>
297+
#blocked1 = #ttg.blocked<{sizePerThread = [1, 1], threadsPerWarp = [1, 32], warpsPerCTA = [4, 1], order = [1, 0]}>
298+
#blocked2 = #ttg.blocked<{sizePerThread = [1, 4], threadsPerWarp = [4, 8], warpsPerCTA = [4, 1], order = [1, 0]}>
299+
#blocked3 = #ttg.blocked<{sizePerThread = [2, 2], threadsPerWarp = [2, 16], warpsPerCTA = [4, 1], order = [1, 0]}>
300+
module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, "ttg.threads-per-warp" = 32 : i32, ttig.support_dpas, ttig.support_sg_2d_block} {
301+
tt.func public @test_4866(%arg0: !tt.ptr<f16> {tt.divisibility = 16 : i32}, %arg1: !tt.ptr<f32> {tt.divisibility = 16 : i32}, %arg2: i64) {
302+
%c1_i32 = arith.constant 1 : i32
303+
%cst = arith.constant dense<0.000000e+00> : tensor<16x16xf16, #blocked>
304+
%cst_0 = arith.constant dense<5.000000e-01> : tensor<16x32xf32, #blocked1>
305+
%c64_i64 = arith.constant 64 : i64
306+
%c32_i32 = arith.constant 32 : i32
307+
%c0_i32 = arith.constant 0 : i32
308+
%c1_i64 = arith.constant 1 : i64
309+
%c16_i32 = arith.constant 16 : i32
310+
%0 = tt.make_tensor_ptr %arg0, [%arg2, %c64_i64], [%c64_i64, %c1_i64], [%c0_i32, %c32_i32] {order = array<i32: 1, 0>} : <tensor<16x32xf16, #blocked2>>
311+
%1 = tt.make_tensor_ptr %arg1, [%arg2, %c64_i64], [%c64_i64, %c1_i64], [%c0_i32, %c32_i32] {order = array<i32: 1, 0>} : <tensor<16x32xf32, #blocked2>>
312+
%2:2 = scf.for %arg3 = %c0_i32 to %c16_i32 step %c1_i32 iter_args(%arg4 = %0, %arg5 = %1) -> (!tt.ptr<tensor<16x32xf16, #blocked2>>, !tt.ptr<tensor<16x32xf32, #blocked2>>) : i32 {
313+
// CHECK: scf.for {{.*}}
314+
// CHECK: [[LOAD_RES:%.*]] = tt.load {{.*}} : !tt.ptr<tensor<16x32xf16, #[[BLOCKED1]]>>
315+
// CHECK: [[CONV1:%.*]] = ttg.convert_layout [[LOAD_RES]] : tensor<16x32xf16, #[[BLOCKED1]]> -> tensor<16x32xf16, #ttg.dot_op<{opIdx = 1, parent = #[[BLOCKED]]}>>
316+
// CHECK: [[DOT_RES:%.*]] = tt.dot %cst_0, [[CONV1]], %cst : tensor<16x16xf16, #ttg.dot_op<{opIdx = 0, parent = #[[BLOCKED]]}>> * tensor<16x32xf16, #ttg.dot_op<{opIdx = 1, parent = #[[BLOCKED]]}>> -> tensor<16x32xf32, #[[BLOCKED]]>
317+
// CHECK: [[CONV2:%.*]] = ttg.convert_layout [[DOT_RES]] : tensor<16x32xf32, #[[BLOCKED]]> -> tensor<16x32xf32, #[[BLOCKED1]]>
318+
// CHECK: tt.store {{.*}}, [[CONV2]] : !tt.ptr<tensor<16x32xf32, #[[BLOCKED1]]>>
319+
%3 = tt.load %arg4 : !tt.ptr<tensor<16x32xf16, #blocked2>>
320+
%4 = ttg.convert_layout %3 : tensor<16x32xf16, #blocked2> -> tensor<16x32xf16, #blocked1>
321+
%5 = ttg.convert_layout %cst : tensor<16x16xf16, #blocked> -> tensor<16x16xf16, #ttg.dot_op<{opIdx = 0, parent = #blocked3}>>
322+
%6 = ttg.convert_layout %4 : tensor<16x32xf16, #blocked1> -> tensor<16x32xf16, #ttg.dot_op<{opIdx = 1, parent = #blocked3}>>
323+
%7 = ttg.convert_layout %cst_0 : tensor<16x32xf32, #blocked1> -> tensor<16x32xf32, #blocked3>
324+
%8 = tt.dot %5, %6, %7 : tensor<16x16xf16, #ttg.dot_op<{opIdx = 0, parent = #blocked3}>> * tensor<16x32xf16, #ttg.dot_op<{opIdx = 1, parent = #blocked3}>> -> tensor<16x32xf32, #blocked3>
325+
%9 = ttg.convert_layout %8 : tensor<16x32xf32, #blocked3> -> tensor<16x32xf32, #blocked1>
326+
%10 = ttg.convert_layout %9 : tensor<16x32xf32, #blocked1> -> tensor<16x32xf32, #blocked2>
327+
tt.store %arg5, %10 : !tt.ptr<tensor<16x32xf32, #blocked2>>
328+
scf.yield %arg4, %arg5 : !tt.ptr<tensor<16x32xf16, #blocked2>>, !tt.ptr<tensor<16x32xf32, #blocked2>>
329+
}
330+
tt.return
331+
}
332+
}

third_party/intel/lib/TritonIntelGPUTransforms/RemoveLayoutConversions.cpp

Lines changed: 18 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -240,6 +240,16 @@ void LayoutPropagation::initAnchorLayout() {
240240
}
241241
}
242242
});
243+
244+
LLVM_DEBUG({
245+
DBGS() << "Anchors: \n";
246+
for (auto [v, info] : layouts) {
247+
DBGS().indent(2) << "Value: " << v << "\n";
248+
DBGS().indent(2) << "Encodings (" << info.encodings.size() << "):\n";
249+
for (Attribute encoding : info.encodings)
250+
DBGS().indent(4) << encoding << "\n";
251+
}
252+
});
243253
}
244254

245255
void LayoutPropagation::setEncoding(ValueRange values, LayoutInfo &info,
@@ -337,12 +347,10 @@ SmallVector<Value> LayoutPropagation::propagateToUsers(Value value,
337347
return isMMAorMMADerived;
338348
};
339349
if (llvm::all_of(info.encodings, checkMMAorMMADerived)) {
350+
SmallVector<Value> valuesToChange{storeOp.getPtr(), storeOp.getValue()};
340351
if (storeOp.getMask())
341-
setEncoding({storeOp.getPtr(), storeOp.getValue(), storeOp.getMask()},
342-
info, changed, user);
343-
else
344-
setEncoding({storeOp.getPtr(), storeOp.getValue()}, info, changed,
345-
user);
352+
valuesToChange.emplace_back(storeOp.getMask());
353+
setEncoding(valuesToChange, info, changed, user);
346354
}
347355
continue;
348356
}
@@ -481,8 +489,11 @@ void LayoutPropagation::rewriteRegion(Region &region) {
481489
}
482490
}
483491
}
484-
for (Operation *op : llvm::reverse(opToDelete))
485-
op->erase();
492+
493+
for (Operation *op : llvm::reverse(opToDelete)) {
494+
if (op->getUsers().empty())
495+
op->erase();
496+
}
486497
}
487498

488499
void LayoutPropagation::map(Value old, Value newV) {

0 commit comments

Comments
 (0)