Skip to content

Commit cfb23d7

Browse files
authored
[RemoveLayoutConversions]: Reduce loop carried values - part 2 (#4921)
Implements functionality to reduce loop carried values in the `RemoveLayoutConversions` pass by eliminating unnecessary loop-carried tensor pointer values when they can be reconstructed from other values plus layout conversions. --------- Signed-off-by: Tiotto, Ettore <[email protected]>
1 parent 287802f commit cfb23d7

File tree

2 files changed

+98
-70
lines changed

2 files changed

+98
-70
lines changed

test/TritonIntelGPU/backward_combine_dpas_dot_layout.mlir

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -349,6 +349,7 @@ module attributes {"ttg.num-warps" = 4 : i32, "ttg.threads-per-warp" = 32 : i32}
349349
// -----
350350

351351
// CHECK: #[[BLOCKED:.+]] = #ttg.blocked<{sizePerThread = [1, 1], threadsPerWarp = [1, 16], warpsPerCTA = [2, 2], order = [1, 0]}>
352+
// CHECK: #[[BLOCKED1:.+]] = #ttg.blocked<{sizePerThread = [1, 1], threadsPerWarp = [1, 16], warpsPerCTA = [1, 4], order = [1, 0]}>
352353
// CHECK: #[[DPAS:.+]] = #ttig.dpas<{repeatCount = 8, systolicDepth = 8, executionSize = 16, opsPerChan = 2, threadsPerWarp = 16, warpsPerCTA = [1, 4], repCluster = [1, 1], A = [8, 16], B = [16, 16], C = [8, 16]}>
353354
#blocked = #ttg.blocked<{sizePerThread = [1, 1], threadsPerWarp = [1, 16], warpsPerCTA = [2, 2], order = [1, 0]}>
354355
#blocked1 = #ttg.blocked<{sizePerThread = [1, 1], threadsPerWarp = [1, 16], warpsPerCTA = [1, 4], order = [1, 0]}>
@@ -396,6 +397,14 @@ module attributes {"ttg.num-warps" = 4 : i32, "ttg.threads-per-warp" = 16 : i32,
396397
%30 = tt.load %29 {boundaryCheck = array<i32: 0, 1>, ttig.block_io = "row_major" } : !tt.ptr<tensor<64x32xf16, #blocked>>
397398
tt.store %23#1, %30 {boundaryCheck = array<i32: 0, 1>} : !tt.ptr<tensor<64x32xf16, #blocked>>
398399

400+
// CHECK: [[ADV:%.*]] = tt.advance [[LOOP_RES]]#2, {{.*}} : <tensor<32x256xf16, #ttg.dot_op<{opIdx = 1, parent = #[[DPAS]], kWidth = 2}>>>
401+
// CHECK: [[LOAD3:%.*]] = tt.load [[ADV]] {boundaryCheck = array<i32: 0, 1>, ttig.block_io = "row_major"} : !tt.ptr<tensor<32x256xf16, #ttg.dot_op<{opIdx = 1, parent = #[[DPAS]], kWidth = 2}>>>
402+
// CHECK: [[CONV3:%.*]] = ttg.convert_layout [[LOAD3]] : tensor<32x256xf16, #ttg.dot_op<{opIdx = 1, parent = #[[DPAS]], kWidth = 2}>> -> tensor<32x256xf16, #[[BLOCKED1]]>
403+
// CHECK: tt.store {{.*}}, [[CONV3]] {boundaryCheck = array<i32: 0, 1>} : !tt.ptr<tensor<32x256xf16, #[[BLOCKED1]]>
404+
%31 = tt.advance %23#2, [%c0_i32, %c32_i32] : <tensor<32x256xf16, #blocked1>>
405+
%32 = tt.load %31 {boundaryCheck = array<i32: 0, 1>, ttig.block_io = "row_major" } : !tt.ptr<tensor<32x256xf16, #blocked1>>
406+
%33 = tt.make_tensor_ptr %arg2, [%c0_i64, %c0_i64], [%c0_i64, %c1_i64], [%c0_i32, %c0_i32] {order = array<i32: 1, 0>} : <tensor<32x256xf16, #blocked1>>
407+
tt.store %33, %32 {boundaryCheck = array<i32: 0, 1>} : !tt.ptr<tensor<32x256xf16, #blocked1>>
399408
tt.return
400409
}
401410
}

third_party/intel/lib/TritonIntelGPUTransforms/RemoveLayoutConversions.cpp

Lines changed: 89 additions & 70 deletions
Original file line numberDiff line numberDiff line change
@@ -167,6 +167,7 @@ class LayoutRematerialization {
167167

168168
private:
169169
void updateRematMapping(SmallVector<std::tuple<Value, Value>> &values);
170+
void reduceLoopCarriedValues();
170171
// Existing tuples of (value, layout) that needs to be updated when recreating
171172
// scf ops. This prevents keeping track of Values that have been delete when
172173
// rewriting slices.
@@ -1009,6 +1010,93 @@ void LayoutRematerialization::updateRematMapping(
10091010
}
10101011
}
10111012

1013+
/// Reduce loop carried values if the value is used after the loop and can be
1014+
/// removed by using another loop yielded value plus a convert layout operation.
1015+
void LayoutRematerialization::reduceLoopCarriedValues() {
1016+
for (auto [pair, val] : rematMapping) {
1017+
auto arg = dyn_cast<BlockArgument>(pair.first);
1018+
if (!arg)
1019+
continue;
1020+
1021+
if (!isTensorPointerType(arg.getType()))
1022+
continue;
1023+
1024+
auto loopOp = dyn_cast<LoopLikeOpInterface>(arg.getOwner()->getParentOp());
1025+
if (!loopOp)
1026+
continue;
1027+
1028+
// Loop arguments that corresponds to a loop result which is not used are
1029+
// not interesting.
1030+
OpResult loopRes = loopOp.getTiedLoopResult(arg);
1031+
if (loopRes.getNumUses() == 0)
1032+
continue;
1033+
1034+
std::function<void(Operation *, Value)> processUser = [&](Operation *user,
1035+
Value rematRes) {
1036+
Location loc = user->getLoc();
1037+
OpBuilder rewriter(user);
1038+
1039+
TypeSwitch<Operation *>(user)
1040+
.Case<LoadOp>([&](auto loadOp) {
1041+
auto newLoadOp =
1042+
rewriter.create<LoadOp>(loc, rematRes, loadOp->getAttrs());
1043+
auto convOp = rewriter.create<ConvertLayoutOp>(
1044+
loc, loadOp.getType(), newLoadOp.getResult());
1045+
loadOp->replaceAllUsesWith(convOp);
1046+
opToDelete.insert(loadOp);
1047+
LLVM_DEBUG({
1048+
DBGS() << "Replaced:\n\t" << *loadOp << "\n"
1049+
<< "with:\n\t" << *newLoadOp << "\n"
1050+
<< "\t" << *convOp << "\n";
1051+
});
1052+
})
1053+
.Case<StoreOp>([&](auto storeOp) {
1054+
Value data = storeOp.getOperand(1);
1055+
auto dataType = cast<RankedTensorType>(data.getType());
1056+
auto newPtrType = cast<PointerType>(rematRes.getType());
1057+
Attribute encoding =
1058+
cast<RankedTensorType>(newPtrType.getPointeeType())
1059+
.getEncoding();
1060+
RankedTensorType newDataType = dataType.cloneWithEncoding(encoding);
1061+
auto convOp =
1062+
rewriter.create<ConvertLayoutOp>(loc, newDataType, data);
1063+
auto newStoreOp = rewriter.create<StoreOp>(
1064+
loc, rematRes, convOp, storeOp.getBoundaryCheck(),
1065+
storeOp.getCache(), storeOp.getEvict());
1066+
opToDelete.insert(storeOp);
1067+
LLVM_DEBUG({
1068+
DBGS() << "Replaced:\n\t" << *storeOp << "\n"
1069+
<< "with:\n\t" << *convOp << "\n"
1070+
<< "\t" << *newStoreOp << "\n";
1071+
});
1072+
})
1073+
.Case<AdvanceOp>([&](auto advanceOp) {
1074+
auto newAdvanceOp = rewriter.create<AdvanceOp>(
1075+
loc, rematRes.getType(), rematRes, advanceOp.getOffsets());
1076+
opToDelete.insert(advanceOp);
1077+
LLVM_DEBUG({
1078+
DBGS() << "Replaced:\n\t" << *advanceOp << "\n"
1079+
<< "with:\n\t" << *newAdvanceOp << "\n";
1080+
});
1081+
1082+
for (Operation *user : advanceOp->getUsers())
1083+
processUser(user, newAdvanceOp.getResult());
1084+
})
1085+
.Default([](auto op) {
1086+
llvm::report_fatal_error(llvm::Twine(
1087+
"Unsupported operation in backward rematerialization: '" +
1088+
op->getName().getStringRef() + "'"));
1089+
});
1090+
};
1091+
1092+
// Replace the loop result corresponding to the argument with an
1093+
// equivalent loop result.
1094+
OpResult rematRes = loopOp.getTiedLoopResult(cast<BlockArgument>(val));
1095+
for (Operation *user : loopRes.getUsers())
1096+
processUser(user, rematRes);
1097+
}
1098+
}
1099+
10121100
void LayoutRematerialization::rewriteSlice(SetVector<Value> &slice,
10131101
DenseMap<Value, Attribute> &layout,
10141102
ConvertLayoutOp convertOp,
@@ -1269,76 +1357,7 @@ void LayoutRematerialization::backwardRematerialization() {
12691357
}
12701358
}
12711359

1272-
// Reduce loop carried values if the value can be removed by using another
1273-
// loop yielded value plus a convert layout operation.
1274-
for (auto [pair, val] : rematMapping) {
1275-
auto arg = dyn_cast<BlockArgument>(pair.first);
1276-
if (!arg)
1277-
continue;
1278-
1279-
if (!isTensorPointerType(arg.getType()))
1280-
continue;
1281-
1282-
if (auto loopOp =
1283-
dyn_cast<LoopLikeOpInterface>(arg.getOwner()->getParentOp())) {
1284-
// Loop arguments that corresponds to a loop result which is not used are
1285-
// not interesting.
1286-
OpResult loopRes = loopOp.getTiedLoopResult(arg);
1287-
if (loopRes.getNumUses() == 0)
1288-
continue;
1289-
1290-
// Replace the loop result corresponding to the argument with an
1291-
// equivalent loop result.
1292-
auto rematArg = cast<BlockArgument>(val);
1293-
OpResult rematRes = loopOp.getTiedLoopResult(rematArg);
1294-
1295-
for (Operation *user : loopRes.getUsers()) {
1296-
Location loc = user->getLoc();
1297-
OpBuilder rewriter(user);
1298-
1299-
TypeSwitch<Operation *>(user)
1300-
.Case<LoadOp>([&](auto loadOp) {
1301-
auto newLoadOp =
1302-
rewriter.create<LoadOp>(loc, rematRes, loadOp->getAttrs());
1303-
auto convOp = rewriter.create<ConvertLayoutOp>(
1304-
loc, loadOp.getType(), newLoadOp.getResult());
1305-
loadOp->replaceAllUsesWith(convOp);
1306-
opToDelete.insert(loadOp);
1307-
LLVM_DEBUG({
1308-
DBGS() << "Replaced:\n\t" << *loadOp << "\n";
1309-
DBGS() << "with:\n\t" << *newLoadOp << "\n"
1310-
<< "\t" << *convOp << "\n";
1311-
});
1312-
})
1313-
.Case<StoreOp>([&](auto storeOp) {
1314-
Value data = storeOp.getOperand(1);
1315-
auto dataType = cast<RankedTensorType>(data.getType());
1316-
auto newPtrType = cast<PointerType>(rematRes.getType());
1317-
Attribute encoding =
1318-
cast<RankedTensorType>(newPtrType.getPointeeType())
1319-
.getEncoding();
1320-
RankedTensorType newDataType =
1321-
dataType.cloneWithEncoding(encoding);
1322-
auto convOp =
1323-
rewriter.create<ConvertLayoutOp>(loc, newDataType, data);
1324-
auto newStoreOp = rewriter.create<StoreOp>(
1325-
loc, rematRes, convOp, storeOp.getBoundaryCheck(),
1326-
storeOp.getCache(), storeOp.getEvict());
1327-
opToDelete.insert(storeOp);
1328-
LLVM_DEBUG({
1329-
DBGS() << "Replaced:\n\t" << *storeOp << "\n";
1330-
DBGS() << "with:\n\t" << *convOp << "\n"
1331-
<< "\t" << *newStoreOp << "\n";
1332-
});
1333-
})
1334-
.Default([](auto op) {
1335-
llvm::report_fatal_error(llvm::Twine(
1336-
"Unsupported operation in backward rematerialization: '" +
1337-
op->getName().getStringRef() + "'"));
1338-
});
1339-
}
1340-
}
1341-
}
1360+
reduceLoopCarriedValues();
13421361
}
13431362

13441363
void LayoutRematerialization::hoistConvertOnTopOfExtOrBroadcast() {

0 commit comments

Comments
 (0)