Skip to content

Commit 98b40d5

Browse files
Revert "[Layouts] Propagate layouts into conditionals (#5610)" (#5710)
Reverting due to regressions in internal tests
1 parent 19fe7cb commit 98b40d5

File tree

3 files changed

+5
-238
lines changed

3 files changed

+5
-238
lines changed

lib/Dialect/TritonGPU/Transforms/RemoveLayoutConversions.cpp

Lines changed: 0 additions & 116 deletions
Original file line numberDiff line numberDiff line change
@@ -131,8 +131,6 @@ class LayoutRematerialization {
131131
void backwardRematerialization(ConvertLayoutOp convertOp);
132132
void hoistConvertOnTopOfExtOrBroadcast();
133133
void hoistConvertOnTopOfExtOrBroadcast(ConvertLayoutOp convertOp);
134-
void hoistConvertIntoConditionals();
135-
void hoistConvertIntoConditionals(ConvertLayoutOp convertOp);
136134
void rewriteSlice(SetVector<Value> &slice, DenseMap<Value, Attribute> &layout,
137135
ConvertLayoutOp convertOp, IRMapping &mapping);
138136
void rewriteSlice(SetVector<Value> &slice, DenseMap<Value, Attribute> &layout,
@@ -1022,22 +1020,6 @@ void LayoutRematerialization::hoistConvertOnTopOfExtOrBroadcast() {
10221020
}
10231021
}
10241022

1025-
void LayoutRematerialization::hoistConvertIntoConditionals() {
1026-
// Go through each ConvertLayoutOp.
1027-
SmallVector<ConvertLayoutOp> convertOps;
1028-
funcOp.walk(
1029-
[&](ConvertLayoutOp convertOp) { convertOps.push_back(convertOp); });
1030-
for (ConvertLayoutOp convertOp : convertOps) {
1031-
hoistConvertIntoConditionals(convertOp);
1032-
if (!opToDelete.contains(convertOp)) {
1033-
// If the conversion didn't get removed, consider it for reuse in future
1034-
// backward slices.
1035-
addRematValue(convertOp.getSrc(), convertOp.getType().getEncoding(),
1036-
convertOp.getResult());
1037-
}
1038-
}
1039-
}
1040-
10411023
void LayoutRematerialization::backwardRematerialization(
10421024
ConvertLayoutOp convertOp) {
10431025
// we don't handle conversions to DotOperandEncodingAttr
@@ -1169,100 +1151,6 @@ void LayoutRematerialization::hoistConvertOnTopOfExtOrBroadcast(
11691151
rewriteSlice(slice, layout, convertOp, mapping);
11701152
}
11711153

1172-
void LayoutRematerialization::hoistConvertIntoConditionals(
1173-
ConvertLayoutOp convertOp) {
1174-
// Take the backward slice of tensor dependencies, stopping at conditionals.
1175-
SetVector<Value> slice;
1176-
DenseMap<Value, Attribute> layout;
1177-
auto isIfOp = [](Operation *op) { return isa<scf::IfOp>(op); };
1178-
if (failed(getRematerializableSlice(convertOp.getSrcMutable(),
1179-
convertOp.getType().getEncoding(), slice,
1180-
layout, isIfOp)))
1181-
return;
1182-
1183-
// Find conditional edges above which the conversion can be hoisted.
1184-
SmallVector<std::pair<Value, OpOperand *>> hoistAbove;
1185-
unsigned sliceSize = slice.size();
1186-
// The routine will recurse through backward slices, e.g. to handle loops and
1187-
// conditional chains. Thus, we re-query the size of `slice`.
1188-
for (unsigned i = 0; i < slice.size(); i++) {
1189-
Value v = slice[i];
1190-
auto ifOp = v.getDefiningOp<scf::IfOp>();
1191-
if (!ifOp)
1192-
continue;
1193-
1194-
Attribute rootLayout = layout.at(v);
1195-
unsigned resIdx = cast<OpResult>(v).getResultNumber();
1196-
1197-
// Take the backward slice along each branch.
1198-
auto thenYield =
1199-
cast<scf::YieldOp>(ifOp.getThenRegion().front().getTerminator());
1200-
auto elseYield =
1201-
cast<scf::YieldOp>(ifOp.getElseRegion().front().getTerminator());
1202-
1203-
OpOperand &thenRes = thenYield.getResultsMutable()[resIdx];
1204-
OpOperand &elseRes = elseYield.getResultsMutable()[resIdx];
1205-
1206-
SetVector<Value> thenSlice, elseSlice;
1207-
DenseMap<Value, Attribute> thenLayout, elseLayout;
1208-
1209-
LogicalResult thenResult = getRematerializableSlice(
1210-
thenRes, rootLayout, thenSlice, thenLayout, isIfOp);
1211-
LogicalResult elseResult = getRematerializableSlice(
1212-
elseRes, rootLayout, elseSlice, elseLayout, isIfOp);
1213-
1214-
// If propagation across both edges of this conditional succeeded, then we
1215-
// don't need to hoist across it.
1216-
if (succeeded(thenResult) && succeeded(elseResult)) {
1217-
slice.insert(thenSlice.begin(), thenSlice.end());
1218-
slice.insert(elseSlice.begin(), elseSlice.end());
1219-
layout.insert(thenLayout.begin(), thenLayout.end());
1220-
layout.insert(elseLayout.begin(), elseLayout.end());
1221-
continue;
1222-
}
1223-
1224-
// If propagation across both edges failed, then there is nothing to do
1225-
// for this one.
1226-
if (failed(thenResult) && failed(elseResult))
1227-
continue;
1228-
1229-
// The layout conversion can be rematerialized along one edge but not the
1230-
// other. We can hoist the conversion into the other branch.
1231-
if (succeeded(elseResult)) {
1232-
std::swap(thenSlice, elseSlice);
1233-
std::swap(thenLayout, elseLayout);
1234-
hoistAbove.push_back({v, &thenRes});
1235-
} else {
1236-
hoistAbove.push_back({v, &elseRes});
1237-
}
1238-
slice.insert(thenSlice.begin(), thenSlice.end());
1239-
layout.insert(thenLayout.begin(), thenLayout.end());
1240-
}
1241-
1242-
// It's hard to know if duplicating the conversion into separate branches is
1243-
// profitable without more analysis. For now, hoist at most one.
1244-
if (hoistAbove.size() != 1)
1245-
return;
1246-
1247-
IRMapping mapping;
1248-
for (auto [result, edge] : hoistAbove) {
1249-
// Hoist the convert into the conditional and rewrite the slice.
1250-
OpBuilder b(edge->getOwner());
1251-
Value v = edge->get();
1252-
Attribute encoding = layout.at(result);
1253-
1254-
auto tensorType = cast<RankedTensorType>(v.getType());
1255-
auto newType = RankedTensorType::get(tensorType.getShape(),
1256-
tensorType.getElementType(), encoding);
1257-
1258-
Value newCvt = b.create<ConvertLayoutOp>(convertOp.getLoc(), newType, v);
1259-
1260-
mapping.map(v, newCvt);
1261-
slice.remove(v);
1262-
}
1263-
rewriteSlice(slice, layout, convertOp, mapping);
1264-
}
1265-
12661154
void backwardRematerialization(ModuleOp module) {
12671155
module.walk([](FuncOp funcOp) {
12681156
LayoutRematerialization layoutRemat(funcOp);
@@ -1277,10 +1165,6 @@ void hoistConvert(ModuleOp module) {
12771165
LayoutRematerialization layoutRemat(funcOp);
12781166
layoutRemat.hoistConvertOnTopOfExtOrBroadcast();
12791167
layoutRemat.cleanup();
1280-
1281-
layoutRemat = LayoutRematerialization(funcOp);
1282-
layoutRemat.hoistConvertIntoConditionals();
1283-
layoutRemat.cleanup();
12841168
});
12851169
}
12861170
} // namespace

lib/Dialect/TritonGPU/Transforms/Utility.cpp

Lines changed: 5 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -793,10 +793,11 @@ LogicalResult getConvertBackwardSlice(
793793
auto updateLayout = [&](Value value, Attribute encoding) {
794794
assert((isa<RankedTensorType>(value.getType())));
795795
slice.insert(value);
796-
Attribute &existing = layout[value];
797-
if (existing && existing != encoding)
798-
return failure();
799-
existing = encoding;
796+
if (layout.find(value) != layout.end()) {
797+
if (layout[value] != encoding)
798+
return failure();
799+
}
800+
layout[value] = encoding;
800801
return success();
801802
};
802803

@@ -822,8 +823,6 @@ LogicalResult getConvertBackwardSlice(
822823
}
823824

824825
if (auto ifOp = currentValue.getDefiningOp<scf::IfOp>()) {
825-
if (stopPropagation && stopPropagation(ifOp))
826-
continue;
827826
unsigned argIdx = mlir::cast<OpResult>(currentValue).getResultNumber();
828827

829828
OpOperand &thenValue = ifOp.thenYield()->getOpOperand(argIdx);

test/TritonGPU/combine.mlir

Lines changed: 0 additions & 116 deletions
Original file line numberDiff line numberDiff line change
@@ -2828,122 +2828,6 @@ tt.func @remat_across_regions(%arg0: i1, %arg1: tensor<8x8xf32, #blocked>) {
28282828

28292829
// -----
28302830

2831-
#blocked = #ttg.blocked<{sizePerThread = [1, 1], threadsPerWarp = [32, 1], warpsPerCTA = [4, 1], order = [0, 1]}>
2832-
#mma = #ttg.nvidia_mma<{versionMajor = 2, versionMinor = 0, warpsPerCTA = [4, 1], instrShape = [16, 16]}>
2833-
2834-
module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, ttg.target = "cuda:100"} {
2835-
2836-
// CHECK-LABEL: @hoist_one_conditional
2837-
tt.func @hoist_one_conditional(
2838-
%arg0: i1,
2839-
%arg1: tensor<128x32x!tt.ptr<f32>, #blocked>,
2840-
%arg2: tensor<32x128xf32, #ttg.dot_op<{opIdx = 1, parent = #mma, kWidth = 2}>>,
2841-
%arg3: tensor<128x128xf32, #mma>
2842-
) -> tensor<128x128xf32, #mma> {
2843-
2844-
// CHECK: arith.constant {{.*}} tensor<128x32xf32, #ttg.dot_op
2845-
%cst = arith.constant dense<0.000000e+00> : tensor<128x32xf32, #blocked>
2846-
// CHECK: scf.if
2847-
%0 = scf.if %arg0 -> (tensor<128x32xf32, #blocked>) {
2848-
// CHECK-NEXT: [[RES:%.*]] = tt.load
2849-
%3 = tt.load %arg1 : tensor<128x32x!tt.ptr<f32>, #blocked>
2850-
// CHECK-NEXT: ttg.convert_layout [[RES]]
2851-
// CHECK-NEXT: yield
2852-
scf.yield %3 : tensor<128x32xf32, #blocked>
2853-
} else {
2854-
scf.yield %cst : tensor<128x32xf32, #blocked>
2855-
}
2856-
// CHECK-NOT: ttg.convert_layout
2857-
%1 = ttg.convert_layout %0 : tensor<128x32xf32, #blocked> -> tensor<128x32xf32, #ttg.dot_op<{opIdx = 0, parent = #mma, kWidth = 2}>>
2858-
%2 = tt.dot %1, %arg2, %arg3 : tensor<128x32xf32, #ttg.dot_op<{opIdx = 0, parent = #mma, kWidth = 2}>> * tensor<32x128xf32, #ttg.dot_op<{opIdx = 1, parent = #mma, kWidth = 2}>> -> tensor<128x128xf32, #mma>
2859-
tt.return %2 : tensor<128x128xf32, #mma>
2860-
}
2861-
2862-
// CHECK-LABEL: @hoist_multiple_conditional
2863-
tt.func @hoist_multiple_conditional(
2864-
%arg0: i1,
2865-
%arg1: i1,
2866-
%arg2: tensor<128x32x!tt.ptr<f32>, #blocked>,
2867-
%arg3: tensor<128x32x!tt.ptr<f32>, #blocked>,
2868-
%arg4: tensor<32x128xf32, #ttg.dot_op<{opIdx = 1, parent = #mma, kWidth = 2}>>,
2869-
%arg5: tensor<128x128xf32, #mma>
2870-
) -> tensor<128x128xf32, #mma> {
2871-
// CHECK-COUNT-1: ttg.convert_layout
2872-
%cst0 = arith.constant dense<1.0> : tensor<128x32xf32, #blocked>
2873-
%cst1 = arith.constant dense<2.0> : tensor<128x32xf32, #blocked>
2874-
%0 = scf.if %arg0 -> (tensor<128x32xf32, #blocked>) {
2875-
%3 = tt.load %arg2 : tensor<128x32x!tt.ptr<f32>, #blocked>
2876-
scf.yield %3 : tensor<128x32xf32, #blocked>
2877-
} else {
2878-
scf.yield %cst0 : tensor<128x32xf32, #blocked>
2879-
}
2880-
%1 = scf.if %arg1 -> (tensor<128x32xf32, #blocked>) {
2881-
%4 = tt.load %arg3 : tensor<128x32x!tt.ptr<f32>, #blocked>
2882-
scf.yield %4 : tensor<128x32xf32, #blocked>
2883-
} else {
2884-
scf.yield %cst1 : tensor<128x32xf32, #blocked>
2885-
}
2886-
%2 = arith.addf %0, %1 : tensor<128x32xf32, #blocked>
2887-
%3 = ttg.convert_layout %2 : tensor<128x32xf32, #blocked> -> tensor<128x32xf32, #ttg.dot_op<{opIdx = 0, parent = #mma, kWidth = 2}>>
2888-
%4 = tt.dot %3, %arg4, %arg5 : tensor<128x32xf32, #ttg.dot_op<{opIdx = 0, parent = #mma, kWidth = 2}>> * tensor<32x128xf32, #ttg.dot_op<{opIdx = 1, parent = #mma, kWidth = 2}>> -> tensor<128x128xf32, #mma>
2889-
tt.return %4 : tensor<128x128xf32, #mma>
2890-
}
2891-
2892-
// CHECK-LABEL: @hoist_across_loop
2893-
tt.func @hoist_across_loop(
2894-
%arg0: i1,
2895-
%arg1: tensor<128x32x!tt.ptr<f32>, #blocked>,
2896-
%arg2: tensor<32x128xf32, #ttg.dot_op<{opIdx = 1, parent = #mma, kWidth = 2}>>,
2897-
%arg3: tensor<128x128xf32, #mma>
2898-
) -> tensor<128x128xf32, #mma> {
2899-
// CHECK: arith.constant {{.*}} tensor<128x32xf32, #ttg.dot_op
2900-
%cst = arith.constant dense<1.0> : tensor<128x32xf32, #blocked>
2901-
%c0_i32 = arith.constant 0 : i32
2902-
%c1_i32 = arith.constant 1 : i32
2903-
%c32_i32 = arith.constant 32 : i32
2904-
// CHECK: scf.for
2905-
%0:2 = scf.for %i = %c0_i32 to %c32_i32 step %c1_i32 iter_args(%arg4 = %cst, %acc = %arg3) -> (tensor<128x32xf32, #blocked>, tensor<128x128xf32, #mma>) : i32 {
2906-
// CHECK-NEXT: scf.if
2907-
%1 = scf.if %arg0 -> (tensor<128x32xf32, #blocked>) {
2908-
// CHECK-NEXT: [[RES:%.*]] = tt.load
2909-
// CHECK-NEXT: ttg.convert_layout [[RES]]
2910-
%3 = tt.load %arg1 : tensor<128x32x!tt.ptr<f32>, #blocked>
2911-
scf.yield %3 : tensor<128x32xf32, #blocked>
2912-
} else {
2913-
scf.yield %arg4 : tensor<128x32xf32, #blocked>
2914-
}
2915-
// CHECK-NOT: ttg.convert_layout
2916-
%2 = ttg.convert_layout %1 : tensor<128x32xf32, #blocked> -> tensor<128x32xf32, #ttg.dot_op<{opIdx = 0, parent = #mma, kWidth = 2}>>
2917-
%3 = tt.dot %2, %arg2, %acc : tensor<128x32xf32, #ttg.dot_op<{opIdx = 0, parent = #mma, kWidth = 2}>> * tensor<32x128xf32, #ttg.dot_op<{opIdx = 1, parent = #mma, kWidth = 2}>> -> tensor<128x128xf32, #mma>
2918-
scf.yield %1, %3 : tensor<128x32xf32, #blocked>, tensor<128x128xf32, #mma>
2919-
}
2920-
tt.return %0#1 : tensor<128x128xf32, #mma>
2921-
}
2922-
2923-
// CHECK-LABEL: @chained_if
2924-
tt.func @chained_if(%arg0: i1, %arg1: i1, %arg2: tensor<32x32x!tt.ptr<f32>, #blocked>, %arg3: tensor<32x32x!tt.ptr<f32>, #blocked>) -> tensor<32x32xf32, #mma> {
2925-
// CHECK-COUNT-1: ttg.convert_layout
2926-
%cst = arith.constant dense<1.0> : tensor<32x32xf32, #blocked>
2927-
%0 = scf.if %arg0 -> tensor<32x32xf32, #blocked> {
2928-
%anchor = tt.load %arg2 : tensor<32x32x!tt.ptr<f32>, #blocked>
2929-
scf.yield %anchor : tensor<32x32xf32, #blocked>
2930-
} else {
2931-
scf.yield %cst : tensor<32x32xf32, #blocked>
2932-
}
2933-
%1 = scf.if %arg1 -> tensor<32x32xf32, #blocked> {
2934-
%anchor = tt.load %arg3 : tensor<32x32x!tt.ptr<f32>, #blocked>
2935-
scf.yield %anchor : tensor<32x32xf32, #blocked>
2936-
} else {
2937-
scf.yield %0 : tensor<32x32xf32, #blocked>
2938-
}
2939-
%2 = ttg.convert_layout %1 : tensor<32x32xf32, #blocked> -> tensor<32x32xf32, #mma>
2940-
tt.return %2 : tensor<32x32xf32, #mma>
2941-
}
2942-
2943-
}
2944-
2945-
// -----
2946-
29472831
#linear = #ttg.linear<{register = [[1, 0], [0, 8], [0, 16]], lane = [[2, 0], [4, 0], [8, 0], [16, 0], [0, 1]], warp = [[0, 2], [0, 4]], block = []}>
29482832
#blocked = #ttg.blocked<{sizePerThread = [2, 4], threadsPerWarp = [16, 2], warpsPerCTA = [1, 4], order = [1, 0]}>
29492833

0 commit comments

Comments
 (0)