Skip to content

Commit 976c4e4

Browse files
authored
Revert "[Layouts] Reuse existing rematerializations in backwards pass" (#5418)
Reverts triton-lang/triton#5410 This is causes hangs/crashes in more complex examples. Reverting until a better solution is found since the follow-up doesn't fix the issue either.
1 parent a6c83ee commit 976c4e4

File tree

5 files changed

+28
-87
lines changed

5 files changed

+28
-87
lines changed

include/triton/Dialect/TritonGPU/Transforms/Utility.h

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -163,8 +163,7 @@ Operation *cloneWithInferType(mlir::OpBuilder &rewriter, Operation *op,
163163
LogicalResult getConvertBackwardSlice(
164164
Value root, SetVector<Value> &slice, Attribute rootEncoding,
165165
DenseMap<Value, Attribute> &layout,
166-
std::function<bool(Operation *)> stopPropagation = nullptr,
167-
std::function<Value(Value, Attribute)> getExistingConversion = nullptr);
166+
std::function<bool(Operation *)> stopPropagation = nullptr);
168167

169168
// Populate pattern to remove dead cycles in ForOp.
170169
void populateForOpDeadArgumentElimination(RewritePatternSet &patterns);

lib/Dialect/TritonGPU/Transforms/RemoveLayoutConversions.cpp

Lines changed: 20 additions & 46 deletions
Original file line numberDiff line numberDiff line change
@@ -116,13 +116,17 @@ class LayoutPropagation {
116116
class LayoutRematerialization {
117117
public:
118118
LayoutRematerialization(FuncOp F) : funcOp(F) {}
119-
120119
// Map the original value to the remat'ed one.
121120
void addRematValue(Value old, Attribute encoding, Value newV);
122-
// Get the remat'ed value in the given encoding, if one already exists and
123-
// is different then the layout conversion root.
124-
Value getRematValue(Value value, Attribute encoding, Value root) const;
125-
121+
bool hasRematValue(Value value, Attribute encoding) {
122+
return rematMapping.contains({value, encoding});
123+
}
124+
// Return the remat'ed value in the given encoding.
125+
Value getRematValue(Value value, Attribute encoding) {
126+
auto it = rematMapping.find({value, encoding});
127+
assert(it != rematMapping.end());
128+
return it->second;
129+
}
126130
void cleanup();
127131
void backwardRematerialization();
128132
void backwardRematerialization(ConvertLayoutOp convertOp);
@@ -133,11 +137,6 @@ class LayoutRematerialization {
133137
void rewriteSlice(SetVector<Value> &slice, DenseMap<Value, Attribute> &layout,
134138
ConvertLayoutOp convertOp);
135139

136-
LogicalResult getRematerializableSlice(
137-
Value root, Attribute rootEncoding, SetVector<Value> &slice,
138-
DenseMap<Value, Attribute> &layout,
139-
std::function<bool(Operation *)> stopPropagation = nullptr);
140-
141140
private:
142141
void updateRematMapping(SmallVector<std::tuple<Value, Value>> &values);
143142
// Existing tuples of (value, layout) that needs to be updated when recreating
@@ -158,21 +157,6 @@ void LayoutRematerialization::addRematValue(Value old, Attribute encoding,
158157
mappedValues[old] = encoding;
159158
}
160159

161-
Value LayoutRematerialization::getRematValue(Value value, Attribute encoding,
162-
Value root) const {
163-
Value remat = rematMapping.lookup({value, encoding});
164-
if (!remat)
165-
return {};
166-
// If the remat'ed value is a conversion result, make sure it is different
167-
// than the root of the one we're looking at.
168-
if (auto cvt = remat.getDefiningOp<ConvertLayoutOp>()) {
169-
if (cvt.getSrc() == root)
170-
return {};
171-
}
172-
// This remat'ed value can be reused.
173-
return remat;
174-
}
175-
176160
// Remove unneeded values now that we are done with the rematMapping.
177161
void LayoutRematerialization::cleanup() {
178162
for (Operation *op : llvm::reverse(opToDelete))
@@ -794,8 +778,8 @@ void LayoutRematerialization::rewriteSlice(SetVector<Value> &slice,
794778
auto layoutIt = layout.find(v);
795779
assert(layoutIt != layout.end());
796780
// If we already have a remat value for this value, use it.
797-
if (Value remat = getRematValue(v, layoutIt->second, convertOp.getSrc())) {
798-
mapping.map(v, remat);
781+
if (hasRematValue(v, layoutIt->second)) {
782+
mapping.map(v, getRematValue(v, layoutIt->second));
799783
valuesWithExistingRemat.insert(v);
800784
continue;
801785
}
@@ -956,17 +940,12 @@ void LayoutRematerialization::rewriteSlice(SetVector<Value> &slice,
956940
rewriteSlice(slice, layout, convertOp, mapping);
957941
}
958942

959-
LogicalResult LayoutRematerialization::getRematerializableSlice(
943+
LogicalResult getRematerializableSlice(
960944
Value root, Attribute rootEncoding, SetVector<Value> &slice,
961945
DenseMap<Value, Attribute> &layout,
962-
std::function<bool(Operation *)> stopPropagation) {
963-
// Allow re-using existing conversions for a value.
964-
auto getExistingConversion = [&](Value value, Attribute encoding) -> Value {
965-
return getRematValue(value, encoding, root);
966-
};
967-
LogicalResult result =
968-
getConvertBackwardSlice(root, slice, rootEncoding, layout,
969-
stopPropagation, getExistingConversion);
946+
std::function<bool(Operation *)> stopPropagation = nullptr) {
947+
LogicalResult result = getConvertBackwardSlice(root, slice, rootEncoding,
948+
layout, stopPropagation);
970949
if (result.failed() || slice.empty())
971950
return failure();
972951

@@ -983,14 +962,8 @@ LogicalResult LayoutRematerialization::getRematerializableSlice(
983962
void LayoutRematerialization::backwardRematerialization() {
984963
// Go through each ConvertLayoutOp.
985964
SmallVector<ConvertLayoutOp> convertOps;
986-
funcOp.walk([&](ConvertLayoutOp convertOp) {
987-
convertOps.push_back(convertOp);
988-
// Add existing layout conversions as rematerializations of themselves. This
989-
// enables rematerialization of other conversions to re-use existing
990-
// conversions. Importantly, don't add them to `mappedValues`.
991-
rematMapping.insert(
992-
{{convertOp.getSrc(), convertOp.getType().getEncoding()}, convertOp});
993-
});
965+
funcOp.walk(
966+
[&](ConvertLayoutOp convertOp) { convertOps.push_back(convertOp); });
994967
for (ConvertLayoutOp convertOp : convertOps) {
995968
backwardRematerialization(convertOp);
996969
}
@@ -1015,13 +988,14 @@ void LayoutRematerialization::backwardRematerialization(
1015988
// careful with the heuristics for both correctness and perf
1016989
if (isa<DotOperandEncodingAttr, LinearEncodingAttr>(targetType.getEncoding()))
1017990
return;
1018-
Value oldV = convertOp.getSrc();
991+
Value oldV = convertOp->getOperand(0);
1019992
LDBG("check backward remat with source " << oldV << " encoding "
1020993
<< targetType.getEncoding());
1021994
// Check to see if there are existing remat'ed values for the pair of oldValue
1022995
// and encoding.
1023-
if (Value newV = getRematValue(oldV, targetType.getEncoding(), oldV)) {
996+
if (hasRematValue(oldV, targetType.getEncoding())) {
1024997
// Replace it with the remat'ed value.
998+
Value newV = getRematValue(oldV, targetType.getEncoding());
1025999
convertOp.replaceAllUsesWith(newV);
10261000
opToDelete.insert(convertOp);
10271001
LDBG("found remat'ed value" << newV);

lib/Dialect/TritonGPU/Transforms/Utility.cpp

Lines changed: 5 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -770,11 +770,11 @@ static bool isFreeConvert(Operation *op) {
770770
convertOp.getType());
771771
}
772772

773-
LogicalResult getConvertBackwardSlice(
774-
Value root, SetVector<Value> &slice, Attribute rootEncoding,
775-
DenseMap<Value, Attribute> &layout,
776-
std::function<bool(Operation *)> stopPropagation,
777-
std::function<Value(Value, Attribute)> getExistingConversion) {
773+
LogicalResult
774+
getConvertBackwardSlice(Value root, SetVector<Value> &slice,
775+
Attribute rootEncoding,
776+
DenseMap<Value, Attribute> &layout,
777+
std::function<bool(Operation *)> stopPropagation) {
778778
DenseSet<std::pair<Value, Attribute>> seen;
779779
SmallVector<std::pair<Value, Attribute>> queue;
780780

@@ -814,12 +814,6 @@ LogicalResult getConvertBackwardSlice(
814814

815815
continue;
816816
}
817-
Value existing;
818-
if (getExistingConversion &&
819-
(existing = getExistingConversion(currentValue, encoding))) {
820-
enqueue(existing, encoding);
821-
continue;
822-
}
823817
if (auto *definingOp = currentValue.getDefiningOp()) {
824818
// If the op has multiple results we need to update all results layout.
825819
for (Value result : definingOp->getResults()) {

test/CMakeLists.txt

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,6 @@ configure_lit_site_cfg(
1414
set(TRITON_TEST_DEPENDS
1515
triton-opt
1616
triton-tensor-layout
17-
triton-llvm-opt
1817
)
1918

2019
set(FILECHECK_PATH "${LLVM_LIBRARY_DIR}/../bin/FileCheck")

test/TritonGPU/combine.mlir

Lines changed: 2 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -2427,9 +2427,8 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, "ttg.thr
24272427
%2 = tt.splat %arg0 : !tt.ptr<f16> -> tensor<64x64x!tt.ptr<f16>, #blocked2>
24282428
%3 = tt.splat %arg1 : !tt.ptr<i8> -> tensor<128x64x!tt.ptr<i8>, #blocked>
24292429
%4 = tt.splat %arg1 : !tt.ptr<i8> -> tensor<128x64x!tt.ptr<i8>, #blocked>
2430-
// CHECK: %[[F:.+]]:3 = scf.for {{.*}} -> (tensor<128xf32, #ttg.slice<{dim = 1, parent = #blocked}>>, tensor<128xf32, #ttg.slice<{dim = 1, parent = #blocked}>>, tensor<128xf32, #ttg.slice<{dim = 1, parent = #blocked}>>)
2431-
// CHECK-COUNT-4: convert_layout
2432-
// CHECK: scf.yield {{.*}} : tensor<128xf32, #ttg.slice<{dim = 1, parent = #blocked}>>, tensor<128xf32, #ttg.slice<{dim = 1, parent = #blocked}>>, tensor<128xf32, #ttg.slice<{dim = 1, parent = #blocked}>>
2430+
// CHECK: %[[F:.+]]:3 = scf.for {{.*}} -> (tensor<128xf32, #ttg.slice<{dim = 1, parent = #blocked}>>, tensor<128xf32, #ttg.slice<{dim = 1, parent = #blocked}>>, tensor<128xf32, #ttg.slice<{dim = 1, parent = #mma}>>)
2431+
// CHECK: scf.yield {{.*}} : tensor<128xf32, #ttg.slice<{dim = 1, parent = #blocked}>>, tensor<128xf32, #ttg.slice<{dim = 1, parent = #blocked}>>, tensor<128xf32, #ttg.slice<{dim = 1, parent = #mma}>>
24332432
// CHECK: }
24342433
// CHECK: tt.return %[[F]]#0, %[[F]]#1 : tensor<128xf32, #ttg.slice<{dim = 1, parent = #blocked}>>, tensor<128xf32, #ttg.slice<{dim = 1, parent = #blocked}>>
24352434
%5:3 = scf.for %arg2 = %c0_i32 to %c2048_i32 step %c64_i32 iter_args(%arg3 = %cst_2, %arg4 = %cst, %arg5 = %cst_0) -> (tensor<128x64xf32, #mma>, tensor<128xf32, #ttg.slice<{dim = 1, parent = #blocked}>>, tensor<128xf32, #ttg.slice<{dim = 1, parent = #blocked}>>) : i32 {
@@ -2773,27 +2772,3 @@ tt.func @do_not_remat(%arg0: tensor<64x64xf32, #blocked1>) -> tensor<1x64xf32, #
27732772
}
27742773

27752774
}
2776-
2777-
// -----
2778-
2779-
#blocked = #ttg.blocked<{sizePerThread = [1, 2], threadsPerWarp = [1, 32], warpsPerCTA = [4, 1], order = [1, 0]}>
2780-
#blocked1 = #ttg.blocked<{sizePerThread = [2, 1], threadsPerWarp = [32, 1], warpsPerCTA = [1, 4], order = [0, 1]}>
2781-
2782-
module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, ttg.target = "cuda:90", "ttg.threads-per-warp" = 32 : i32} {
2783-
2784-
// CHECK-LABEL: reuse_layout_conversion
2785-
tt.func @reuse_layout_conversion(%arg0: tensor<64x64xf32, #blocked>) -> (tensor<64x64xf32, #blocked>, tensor<64x64xf32, #blocked>) {
2786-
// CHECK-NEXT: %cst = arith.constant {{.*}} tensor<64x64xf32, #blocked>
2787-
%cst = arith.constant dense<2.000000e+00> : tensor<64x64xf32, #blocked1>
2788-
// CHECK-NEXT: [[TRANS:%.*]] = tt.trans %arg0 {{.*}} tensor<64x64xf32, #blocked> -> tensor<64x64xf32, #blocked1>
2789-
%0 = tt.trans %arg0 {order = array<i32: 1, 0>} : tensor<64x64xf32, #blocked> -> tensor<64x64xf32, #blocked1>
2790-
// CHECK-NEXT: [[CVT:%.*]] = ttg.convert_layout [[TRANS]] : tensor<64x64xf32, #blocked1> -> tensor<64x64xf32, #blocked>
2791-
%1 = ttg.convert_layout %0 : tensor<64x64xf32, #blocked1> -> tensor<64x64xf32, #blocked>
2792-
// CHECK-NEXT: [[RESULT:%.*]] = arith.mulf [[CVT]], %cst : tensor<64x64xf32, #blocked>
2793-
%2 = arith.mulf %0, %cst : tensor<64x64xf32, #blocked1>
2794-
%3 = ttg.convert_layout %2 : tensor<64x64xf32, #blocked1> -> tensor<64x64xf32, #blocked>
2795-
// CHECK-NEXT: return [[CVT]], [[RESULT]]
2796-
tt.return %1, %3 : tensor<64x64xf32, #blocked>, tensor<64x64xf32, #blocked>
2797-
}
2798-
2799-
}

0 commit comments

Comments
 (0)