Skip to content

Commit 6f12fd4

Browse files
authored
[RemoveLayoutConversions]: Reduce loop carried values (#4915)
This PR implements an optimization to reduce loop carried values in the RemoveLayoutConversions pass by reusing equivalent loop results with layout conversion operations instead of carrying redundant values through loops. Fixes issue #4901 --------- Signed-off-by: Tiotto, Ettore <[email protected]> Signed-off-by: Ettore Tiotto <[email protected]>
1 parent b521442 commit 6f12fd4

File tree

2 files changed

+126
-0
lines changed

2 files changed

+126
-0
lines changed

test/TritonIntelGPU/backward_combine_dpas_dot_layout.mlir

Lines changed: 54 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -345,3 +345,57 @@ module attributes {"ttg.num-warps" = 4 : i32, "ttg.threads-per-warp" = 32 : i32}
345345
tt.return
346346
}
347347
}
348+
349+
// -----
350+
351+
// CHECK: #[[BLOCKED:.+]] = #ttg.blocked<{sizePerThread = [1, 1], threadsPerWarp = [1, 16], warpsPerCTA = [2, 2], order = [1, 0]}>
352+
// CHECK: #[[DPAS:.+]] = #ttig.dpas<{repeatCount = 8, systolicDepth = 8, executionSize = 16, opsPerChan = 2, threadsPerWarp = 16, warpsPerCTA = [1, 4], repCluster = [1, 1], A = [8, 16], B = [16, 16], C = [8, 16]}>
353+
#blocked = #ttg.blocked<{sizePerThread = [1, 1], threadsPerWarp = [1, 16], warpsPerCTA = [2, 2], order = [1, 0]}>
354+
#blocked1 = #ttg.blocked<{sizePerThread = [1, 1], threadsPerWarp = [1, 16], warpsPerCTA = [1, 4], order = [1, 0]}>
355+
#dpas = #ttig.dpas<{repeatCount = 8, systolicDepth = 8, executionSize = 16, opsPerChan = 2, threadsPerWarp = 16, warpsPerCTA = [1, 4], repCluster = [1, 1], A = [8, 16], B = [16, 16], C = [8, 16]}>
356+
#dot0 = #ttg.dot_op<{opIdx = 0, parent = #dpas, kWidth=1}>
357+
#dot1 = #ttg.dot_op<{opIdx = 1, parent = #dpas, kWidth=2}>
358+
module attributes {"ttg.num-warps" = 4 : i32, "ttg.threads-per-warp" = 16 : i32, "ttig.support_sg_2d_block"} {
359+
tt.func public @reduce_loop_carried_values(%arg0: !tt.ptr<f16>, %arg1: !tt.ptr<f16>, %arg2: !tt.ptr<f16>, %arg5: i32) {
360+
%c1_i64 = arith.constant 1 : i64
361+
%c0_i32 = arith.constant 0 : i32
362+
%c0_i64 = arith.constant 0 : i64
363+
%c32_i32 = arith.constant 32 : i32
364+
%cst = arith.constant dense<0.000000e+00> : tensor<64x256xf32, #dpas>
365+
%18 = tt.make_tensor_ptr %arg0, [%c0_i64, %c0_i64], [%c0_i64, %c1_i64], [%c0_i32, %c0_i32] {order = array<i32: 1, 0>} : <tensor<64x32xf16, #blocked>>
366+
%22 = tt.make_tensor_ptr %arg1, [%c0_i64, %c0_i64], [%c0_i64, %c1_i64], [%c0_i32, %c0_i32] {order = array<i32: 1, 0>} : <tensor<32x256xf16, #blocked1>>
367+
%23:3 = scf.for %arg9 = %c0_i32 to %arg5 step %c32_i32 iter_args(%arg10 = %cst, %arg11 = %18, %arg12 = %22) -> (tensor<64x256xf32, #dpas>, !tt.ptr<tensor<64x32xf16, #blocked>>, !tt.ptr<tensor<32x256xf16, #blocked1>>) : i32 {
368+
// COM: Ensure there are only 3 loop results and not layout conversion in the loop.
369+
// CHECK: [[LOOP_RES:%.*]]:3 = scf.for
370+
// CHECK-NOT: ttg.convert_layout
371+
// CHECK: scf.yield
372+
%28 = tt.load %arg11 {boundaryCheck = array<i32: 0, 1>, ttig.block_io = "row_major" } : !tt.ptr<tensor<64x32xf16, #blocked>>
373+
%29 = tt.load %arg12 {boundaryCheck = array<i32: 0, 1>, ttig.block_io = "row_major"} : !tt.ptr<tensor<32x256xf16, #blocked1>>
374+
%30 = ttg.convert_layout %28 : tensor<64x32xf16, #blocked> -> tensor<64x32xf16, #dot0>
375+
%31 = ttg.convert_layout %29 : tensor<32x256xf16, #blocked1> -> tensor<32x256xf16, #dot1>
376+
%32 = tt.dot %30, %31, %arg10, inputPrecision = tf32 : tensor<64x32xf16, #dot0> * tensor<32x256xf16, #dot1> -> tensor<64x256xf32, #dpas>
377+
%33 = tt.advance %arg11, [%c0_i32, %c32_i32] : <tensor<64x32xf16, #blocked>>
378+
%34 = tt.advance %arg12, [%c32_i32, %c0_i32] : <tensor<32x256xf16, #blocked1>>
379+
scf.yield %32, %33, %34 : tensor<64x256xf32, #dpas>, !tt.ptr<tensor<64x32xf16, #blocked>>, !tt.ptr<tensor<32x256xf16, #blocked1>>
380+
}
381+
%24 = arith.truncf %23#0 : tensor<64x256xf32, #dpas> to tensor<64x256xf16, #dpas>
382+
%27 = tt.make_tensor_ptr %arg2, [%c0_i64, %c0_i64], [%c0_i64, %c1_i64], [%c0_i32, %c0_i32] {order = array<i32: 1, 0>} : <tensor<64x256xf16, #dpas>>
383+
tt.store %27, %24 {boundaryCheck = array<i32: 0, 1>} : !tt.ptr<tensor<64x256xf16, #dpas>>
384+
385+
// CHECK: [[LOAD1:%.*]] = tt.load [[LOOP_RES]]#1 {boundaryCheck = array<i32: 0, 1>, ttig.block_io = "row_major"} : !tt.ptr<tensor<64x32xf16, #ttg.dot_op<{opIdx = 0, parent = #[[DPAS]], kWidth = 1}>>>
386+
// CHECK: [[CONV1:%.*]] = ttg.convert_layout [[LOAD1]] : tensor<64x32xf16, #ttg.dot_op<{opIdx = 0, parent = #[[DPAS]], kWidth = 1}>> -> tensor<64x32xf16, #[[BLOCKED]]>
387+
// CHECK: [[PTR:%.*]] = tt.make_tensor_ptr {{.*}}, {{\[}}{{.*}}, {{.*}}], {{\[}}{{.*}}, {{.*}}], {{\[}}{{.*}}, {{.*}}] {order = array<i32: 1, 0>} : <tensor<64x32xf16, #[[BLOCKED]]>>
388+
// CHECK: tt.store [[PTR]], [[CONV1]] {boundaryCheck = array<i32: 0, 1>} : !tt.ptr<tensor<64x32xf16, #[[BLOCKED]]>>
389+
%28 = tt.load %23#1 {boundaryCheck = array<i32: 0, 1>, ttig.block_io = "row_major" } : !tt.ptr<tensor<64x32xf16, #blocked>>
390+
%29 = tt.make_tensor_ptr %arg2, [%c0_i64, %c0_i64], [%c0_i64, %c1_i64], [%c0_i32, %c0_i32] {order = array<i32: 1, 0>} : <tensor<64x32xf16, #blocked>>
391+
tt.store %29, %28 {boundaryCheck = array<i32: 0, 1>} : !tt.ptr<tensor<64x32xf16, #blocked>>
392+
393+
// CHECK: [[LOAD2:%.*]] = tt.load [[PTR]] {boundaryCheck = array<i32: 0, 1>, ttig.block_io = "row_major"} : !tt.ptr<tensor<64x32xf16, #[[BLOCKED]]>>
394+
// CHECK: [[CONV2:%.*]] = ttg.convert_layout [[LOAD2]] : tensor<64x32xf16, #[[BLOCKED]]> -> tensor<64x32xf16, #ttg.dot_op<{opIdx = 0, parent = #[[DPAS]], kWidth = 1}>>
395+
// CHECK: tt.store [[LOOP_RES]]#1, [[CONV2]] {boundaryCheck = array<i32: 0, 1>} : !tt.ptr<tensor<64x32xf16, #ttg.dot_op<{opIdx = 0, parent = #[[DPAS]], kWidth = 1}>>>
396+
%30 = tt.load %29 {boundaryCheck = array<i32: 0, 1>, ttig.block_io = "row_major" } : !tt.ptr<tensor<64x32xf16, #blocked>>
397+
tt.store %23#1, %30 {boundaryCheck = array<i32: 0, 1>} : !tt.ptr<tensor<64x32xf16, #blocked>>
398+
399+
tt.return
400+
}
401+
}

third_party/intel/lib/TritonIntelGPUTransforms/RemoveLayoutConversions.cpp

Lines changed: 72 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@
2525
#include "triton/Dialect/TritonGPU/IR/Dialect.h"
2626
#include "triton/Dialect/TritonGPU/Transforms/TritonGPUConversion.h"
2727
#include "triton/Dialect/TritonGPU/Transforms/Utility.h"
28+
#include "llvm/ADT/TypeSwitch.h"
2829
#include <deque>
2930

3031
namespace mlir::triton::gpu::intel {
@@ -1267,6 +1268,77 @@ void LayoutRematerialization::backwardRematerialization() {
12671268
convertOp.getResult());
12681269
}
12691270
}
1271+
1272+
// Reduce loop carried values if the value can be removed by using another
1273+
// loop yielded value plus a convert layout operation.
1274+
for (auto [pair, val] : rematMapping) {
1275+
auto arg = dyn_cast<BlockArgument>(pair.first);
1276+
if (!arg)
1277+
continue;
1278+
1279+
if (!isTensorPointerType(arg.getType()))
1280+
continue;
1281+
1282+
if (auto loopOp =
1283+
dyn_cast<LoopLikeOpInterface>(arg.getOwner()->getParentOp())) {
1284+
// Loop arguments that corresponds to a loop result which is not used are
1285+
// not interesting.
1286+
OpResult loopRes = loopOp.getTiedLoopResult(arg);
1287+
if (loopRes.getNumUses() == 0)
1288+
continue;
1289+
1290+
// Replace the loop result corresponding to the argument with an
1291+
// equivalent loop result.
1292+
auto rematArg = cast<BlockArgument>(val);
1293+
OpResult rematRes = loopOp.getTiedLoopResult(rematArg);
1294+
1295+
for (Operation *user : loopRes.getUsers()) {
1296+
Location loc = user->getLoc();
1297+
OpBuilder rewriter(user);
1298+
1299+
TypeSwitch<Operation *>(user)
1300+
.Case<LoadOp>([&](auto loadOp) {
1301+
auto newLoadOp =
1302+
rewriter.create<LoadOp>(loc, rematRes, loadOp->getAttrs());
1303+
auto convOp = rewriter.create<ConvertLayoutOp>(
1304+
loc, loadOp.getType(), newLoadOp.getResult());
1305+
loadOp->replaceAllUsesWith(convOp);
1306+
opToDelete.insert(loadOp);
1307+
LLVM_DEBUG({
1308+
DBGS() << "Replaced:\n\t" << *loadOp << "\n";
1309+
DBGS() << "with:\n\t" << *newLoadOp << "\n"
1310+
<< "\t" << *convOp << "\n";
1311+
});
1312+
})
1313+
.Case<StoreOp>([&](auto storeOp) {
1314+
Value data = storeOp.getOperand(1);
1315+
auto dataType = cast<RankedTensorType>(data.getType());
1316+
auto newPtrType = cast<PointerType>(rematRes.getType());
1317+
Attribute encoding =
1318+
cast<RankedTensorType>(newPtrType.getPointeeType())
1319+
.getEncoding();
1320+
RankedTensorType newDataType =
1321+
dataType.cloneWithEncoding(encoding);
1322+
auto convOp =
1323+
rewriter.create<ConvertLayoutOp>(loc, newDataType, data);
1324+
auto newStoreOp = rewriter.create<StoreOp>(
1325+
loc, rematRes, convOp, storeOp.getBoundaryCheck(),
1326+
storeOp.getCache(), storeOp.getEvict());
1327+
opToDelete.insert(storeOp);
1328+
LLVM_DEBUG({
1329+
DBGS() << "Replaced:\n\t" << *storeOp << "\n";
1330+
DBGS() << "with:\n\t" << *convOp << "\n"
1331+
<< "\t" << *newStoreOp << "\n";
1332+
});
1333+
})
1334+
.Default([](auto op) {
1335+
llvm::report_fatal_error(llvm::Twine(
1336+
"Unsupported operation in backward rematerialization: '" +
1337+
op->getName().getStringRef() + "'"));
1338+
});
1339+
}
1340+
}
1341+
}
12701342
}
12711343

12721344
void LayoutRematerialization::hoistConvertOnTopOfExtOrBroadcast() {

0 commit comments

Comments
 (0)