Skip to content

Commit 82fec37

Browse files
authored
[Layouts] Reuse existing rematerializations in backwards pass (#5410)
This PR enables reusing existing `convert_layout` ops in the backwards pass if they didn't get removed through some other means. This enables the compiler to remove some tricky layout conversions by recognizing that the same computations can be reconstructed using other layout conversions.
1 parent 92b6e26 commit 82fec37

File tree

5 files changed

+87
-28
lines changed

5 files changed

+87
-28
lines changed

include/triton/Dialect/TritonGPU/Transforms/Utility.h

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -163,7 +163,8 @@ Operation *cloneWithInferType(mlir::OpBuilder &rewriter, Operation *op,
163163
LogicalResult getConvertBackwardSlice(
164164
Value root, SetVector<Value> &slice, Attribute rootEncoding,
165165
DenseMap<Value, Attribute> &layout,
166-
std::function<bool(Operation *)> stopPropagation = nullptr);
166+
std::function<bool(Operation *)> stopPropagation = nullptr,
167+
std::function<Value(Value, Attribute)> getExistingConversion = nullptr);
167168

168169
// Populate pattern to remove dead cycles in ForOp.
169170
void populateForOpDeadArgumentElimination(RewritePatternSet &patterns);

lib/Dialect/TritonGPU/Transforms/RemoveLayoutConversions.cpp

Lines changed: 46 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -116,17 +116,13 @@ class LayoutPropagation {
116116
class LayoutRematerialization {
117117
public:
118118
LayoutRematerialization(FuncOp F) : funcOp(F) {}
119+
119120
// Map the original value to the remat'ed one.
120121
void addRematValue(Value old, Attribute encoding, Value newV);
121-
bool hasRematValue(Value value, Attribute encoding) {
122-
return rematMapping.contains({value, encoding});
123-
}
124-
// Return the remat'ed value in the given encoding.
125-
Value getRematValue(Value value, Attribute encoding) {
126-
auto it = rematMapping.find({value, encoding});
127-
assert(it != rematMapping.end());
128-
return it->second;
129-
}
122+
// Get the remat'ed value in the given encoding, if one already exists and
123+
// is different then the layout conversion root.
124+
Value getRematValue(Value value, Attribute encoding, Value root) const;
125+
130126
void cleanup();
131127
void backwardRematerialization();
132128
void backwardRematerialization(ConvertLayoutOp convertOp);
@@ -137,6 +133,11 @@ class LayoutRematerialization {
137133
void rewriteSlice(SetVector<Value> &slice, DenseMap<Value, Attribute> &layout,
138134
ConvertLayoutOp convertOp);
139135

136+
LogicalResult getRematerializableSlice(
137+
Value root, Attribute rootEncoding, SetVector<Value> &slice,
138+
DenseMap<Value, Attribute> &layout,
139+
std::function<bool(Operation *)> stopPropagation = nullptr);
140+
140141
private:
141142
void updateRematMapping(SmallVector<std::tuple<Value, Value>> &values);
142143
// Existing tuples of (value, layout) that needs to be updated when recreating
@@ -157,6 +158,21 @@ void LayoutRematerialization::addRematValue(Value old, Attribute encoding,
157158
mappedValues[old] = encoding;
158159
}
159160

161+
Value LayoutRematerialization::getRematValue(Value value, Attribute encoding,
162+
Value root) const {
163+
Value remat = rematMapping.lookup({value, encoding});
164+
if (!remat)
165+
return {};
166+
// If the remat'ed value is a conversion result, make sure it is different
167+
// than the root of the one we're looking at.
168+
if (auto cvt = remat.getDefiningOp<ConvertLayoutOp>()) {
169+
if (cvt.getSrc() == root)
170+
return {};
171+
}
172+
// This remat'ed value can be reused.
173+
return remat;
174+
}
175+
160176
// Remove unneeded values now that we are done with the rematMapping.
161177
void LayoutRematerialization::cleanup() {
162178
for (Operation *op : llvm::reverse(opToDelete))
@@ -766,8 +782,8 @@ void LayoutRematerialization::rewriteSlice(SetVector<Value> &slice,
766782
auto layoutIt = layout.find(v);
767783
assert(layoutIt != layout.end());
768784
// If we already have a remat value for this value, use it.
769-
if (hasRematValue(v, layoutIt->second)) {
770-
mapping.map(v, getRematValue(v, layoutIt->second));
785+
if (Value remat = getRematValue(v, layoutIt->second, convertOp.getSrc())) {
786+
mapping.map(v, remat);
771787
valuesWithExistingRemat.insert(v);
772788
continue;
773789
}
@@ -928,12 +944,17 @@ void LayoutRematerialization::rewriteSlice(SetVector<Value> &slice,
928944
rewriteSlice(slice, layout, convertOp, mapping);
929945
}
930946

931-
LogicalResult getRematerializableSlice(
947+
LogicalResult LayoutRematerialization::getRematerializableSlice(
932948
Value root, Attribute rootEncoding, SetVector<Value> &slice,
933949
DenseMap<Value, Attribute> &layout,
934-
std::function<bool(Operation *)> stopPropagation = nullptr) {
935-
LogicalResult result = getConvertBackwardSlice(root, slice, rootEncoding,
936-
layout, stopPropagation);
950+
std::function<bool(Operation *)> stopPropagation) {
951+
// Allow re-using existing conversions for a value.
952+
auto getExistingConversion = [&](Value value, Attribute encoding) -> Value {
953+
return getRematValue(value, encoding, root);
954+
};
955+
LogicalResult result =
956+
getConvertBackwardSlice(root, slice, rootEncoding, layout,
957+
stopPropagation, getExistingConversion);
937958
if (result.failed() || slice.empty())
938959
return failure();
939960

@@ -950,8 +971,14 @@ LogicalResult getRematerializableSlice(
950971
void LayoutRematerialization::backwardRematerialization() {
951972
// Go through each ConvertLayoutOp.
952973
SmallVector<ConvertLayoutOp> convertOps;
953-
funcOp.walk(
954-
[&](ConvertLayoutOp convertOp) { convertOps.push_back(convertOp); });
974+
funcOp.walk([&](ConvertLayoutOp convertOp) {
975+
convertOps.push_back(convertOp);
976+
// Add existing layout conversions as rematerializations of themselves. This
977+
// enables rematerialization of other conversions to re-use existing
978+
// conversions. Importantly, don't add them to `mappedValues`.
979+
rematMapping.insert(
980+
{{convertOp.getSrc(), convertOp.getType().getEncoding()}, convertOp});
981+
});
955982
for (ConvertLayoutOp convertOp : convertOps) {
956983
backwardRematerialization(convertOp);
957984
}
@@ -976,14 +1003,13 @@ void LayoutRematerialization::backwardRematerialization(
9761003
// careful with the heuristics for both correctness and perf
9771004
if (isa<DotOperandEncodingAttr, LinearEncodingAttr>(targetType.getEncoding()))
9781005
return;
979-
Value oldV = convertOp->getOperand(0);
1006+
Value oldV = convertOp.getSrc();
9801007
LDBG("check backward remat with source " << oldV << " encoding "
9811008
<< targetType.getEncoding());
9821009
// Check to see if there are existing remat'ed values for the pair of oldValue
9831010
// and encoding.
984-
if (hasRematValue(oldV, targetType.getEncoding())) {
1011+
if (Value newV = getRematValue(oldV, targetType.getEncoding(), oldV)) {
9851012
// Replace it with the remat'ed value.
986-
Value newV = getRematValue(oldV, targetType.getEncoding());
9871013
convertOp.replaceAllUsesWith(newV);
9881014
opToDelete.insert(convertOp);
9891015
LDBG("found remat'ed value" << newV);

lib/Dialect/TritonGPU/Transforms/Utility.cpp

Lines changed: 11 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -757,11 +757,11 @@ static bool isFreeConvert(Operation *op) {
757757
convertOp.getType());
758758
}
759759

760-
LogicalResult
761-
getConvertBackwardSlice(Value root, SetVector<Value> &slice,
762-
Attribute rootEncoding,
763-
DenseMap<Value, Attribute> &layout,
764-
std::function<bool(Operation *)> stopPropagation) {
760+
LogicalResult getConvertBackwardSlice(
761+
Value root, SetVector<Value> &slice, Attribute rootEncoding,
762+
DenseMap<Value, Attribute> &layout,
763+
std::function<bool(Operation *)> stopPropagation,
764+
std::function<Value(Value, Attribute)> getExistingConversion) {
765765
DenseSet<std::pair<Value, Attribute>> seen;
766766
SmallVector<std::pair<Value, Attribute>> queue;
767767

@@ -802,6 +802,12 @@ getConvertBackwardSlice(Value root, SetVector<Value> &slice,
802802

803803
continue;
804804
}
805+
Value existing;
806+
if (getExistingConversion &&
807+
(existing = getExistingConversion(currentValue, encoding))) {
808+
enqueue(existing, encoding);
809+
continue;
810+
}
805811
if (auto *definingOp = currentValue.getDefiningOp()) {
806812
// If the op has multiple results we need to update all results layout.
807813
for (Value result : definingOp->getResults()) {

test/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@ configure_lit_site_cfg(
1414
set(TRITON_TEST_DEPENDS
1515
triton-opt
1616
triton-tensor-layout
17+
triton-llvm-opt
1718
)
1819

1920
set(FILECHECK_PATH "${LLVM_LIBRARY_DIR}/../bin/FileCheck")

test/TritonGPU/combine.mlir

Lines changed: 27 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2427,8 +2427,9 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, "ttg.thr
24272427
%2 = tt.splat %arg0 : !tt.ptr<f16> -> tensor<64x64x!tt.ptr<f16>, #blocked2>
24282428
%3 = tt.splat %arg1 : !tt.ptr<i8> -> tensor<128x64x!tt.ptr<i8>, #blocked>
24292429
%4 = tt.splat %arg1 : !tt.ptr<i8> -> tensor<128x64x!tt.ptr<i8>, #blocked>
2430-
// CHECK: %[[F:.+]]:3 = scf.for {{.*}} -> (tensor<128xf32, #ttg.slice<{dim = 1, parent = #blocked}>>, tensor<128xf32, #ttg.slice<{dim = 1, parent = #blocked}>>, tensor<128xf32, #ttg.slice<{dim = 1, parent = #mma}>>)
2431-
// CHECK: scf.yield {{.*}} : tensor<128xf32, #ttg.slice<{dim = 1, parent = #blocked}>>, tensor<128xf32, #ttg.slice<{dim = 1, parent = #blocked}>>, tensor<128xf32, #ttg.slice<{dim = 1, parent = #mma}>>
2430+
// CHECK: %[[F:.+]]:3 = scf.for {{.*}} -> (tensor<128xf32, #ttg.slice<{dim = 1, parent = #blocked}>>, tensor<128xf32, #ttg.slice<{dim = 1, parent = #blocked}>>, tensor<128xf32, #ttg.slice<{dim = 1, parent = #blocked}>>)
2431+
// CHECK-COUNT-4: convert_layout
2432+
// CHECK: scf.yield {{.*}} : tensor<128xf32, #ttg.slice<{dim = 1, parent = #blocked}>>, tensor<128xf32, #ttg.slice<{dim = 1, parent = #blocked}>>, tensor<128xf32, #ttg.slice<{dim = 1, parent = #blocked}>>
24322433
// CHECK: }
24332434
// CHECK: tt.return %[[F]]#0, %[[F]]#1 : tensor<128xf32, #ttg.slice<{dim = 1, parent = #blocked}>>, tensor<128xf32, #ttg.slice<{dim = 1, parent = #blocked}>>
24342435
%5:3 = scf.for %arg2 = %c0_i32 to %c2048_i32 step %c64_i32 iter_args(%arg3 = %cst_2, %arg4 = %cst, %arg5 = %cst_0) -> (tensor<128x64xf32, #mma>, tensor<128xf32, #ttg.slice<{dim = 1, parent = #blocked}>>, tensor<128xf32, #ttg.slice<{dim = 1, parent = #blocked}>>) : i32 {
@@ -2728,3 +2729,27 @@ tt.func @propagate_layout_gather(%arg0: tensor<1024x256xi32, #blocked>, %arg1: t
27282729
}
27292730

27302731
}
2732+
2733+
// -----
2734+
2735+
#blocked = #ttg.blocked<{sizePerThread = [1, 2], threadsPerWarp = [1, 32], warpsPerCTA = [4, 1], order = [1, 0]}>
2736+
#blocked1 = #ttg.blocked<{sizePerThread = [2, 1], threadsPerWarp = [32, 1], warpsPerCTA = [1, 4], order = [0, 1]}>
2737+
2738+
module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, ttg.target = "cuda:90", "ttg.threads-per-warp" = 32 : i32} {
2739+
2740+
// CHECK-LABEL: reuse_layout_conversion
2741+
tt.func @reuse_layout_conversion(%arg0: tensor<64x64xf32, #blocked>) -> (tensor<64x64xf32, #blocked>, tensor<64x64xf32, #blocked>) {
2742+
// CHECK-NEXT: %cst = arith.constant {{.*}} tensor<64x64xf32, #blocked>
2743+
%cst = arith.constant dense<2.000000e+00> : tensor<64x64xf32, #blocked1>
2744+
// CHECK-NEXT: [[TRANS:%.*]] = tt.trans %arg0 {{.*}} tensor<64x64xf32, #blocked> -> tensor<64x64xf32, #blocked1>
2745+
%0 = tt.trans %arg0 {order = array<i32: 1, 0>} : tensor<64x64xf32, #blocked> -> tensor<64x64xf32, #blocked1>
2746+
// CHECK-NEXT: [[CVT:%.*]] = ttg.convert_layout [[TRANS]] : tensor<64x64xf32, #blocked1> -> tensor<64x64xf32, #blocked>
2747+
%1 = ttg.convert_layout %0 : tensor<64x64xf32, #blocked1> -> tensor<64x64xf32, #blocked>
2748+
// CHECK-NEXT: [[RESULT:%.*]] = arith.mulf [[CVT]], %cst : tensor<64x64xf32, #blocked>
2749+
%2 = arith.mulf %0, %cst : tensor<64x64xf32, #blocked1>
2750+
%3 = ttg.convert_layout %2 : tensor<64x64xf32, #blocked1> -> tensor<64x64xf32, #blocked>
2751+
// CHECK-NEXT: return [[CVT]], [[RESULT]]
2752+
tt.return %1, %3 : tensor<64x64xf32, #blocked>, tensor<64x64xf32, #blocked>
2753+
}
2754+
2755+
}

0 commit comments

Comments
 (0)