Skip to content

Commit f0bcd4c

Browse files
authored
[intel] Sync RemoveLayoutConversions.cpp with Triton using 24b8d43 commit (#3360)
Signed-off-by: Anatoly Myachev <[email protected]>
1 parent 2eff7d8 commit f0bcd4c

File tree

1 file changed

+19
-9
lines changed

1 file changed

+19
-9
lines changed

third_party/intel/lib/TritonIntelGPUTransforms/RemoveLayoutConversions.cpp

Lines changed: 19 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
#include "mlir/Analysis/SliceAnalysis.h"
22
#include "mlir/Dialect/SCF/IR/SCF.h"
33
#include "mlir/IR/BuiltinAttributes.h"
4+
#include "mlir/IR/Dominance.h"
45
#include "mlir/IR/IRMapping.h"
56
#include "mlir/IR/PatternMatch.h"
67
#include "mlir/IR/Verifier.h"
@@ -125,9 +126,6 @@ class LayoutRematerialization {
125126
return rematMapping.lookup({value, encoding});
126127
}
127128

128-
bool hasRematValue(Value value, Attribute encoding) {
129-
return rematMapping.contains({value, encoding});
130-
}
131129
void cleanup();
132130
void backwardRematerialization();
133131
void backwardRematerialization(ConvertLayoutOp convertOp);
@@ -987,8 +985,8 @@ void LayoutRematerialization::rewriteSlice(SetVector<Value> &slice,
987985
auto layoutIt = layout.find(v);
988986
assert(layoutIt != layout.end());
989987
// If we already have a remat value for this value, use it.
990-
if (hasRematValue(v, layoutIt->second)) {
991-
mapping.map(v, getRematValue(v, layoutIt->second));
988+
if (Value remat = getRematValue(v, layoutIt->second)) {
989+
mapping.map(v, remat);
992990
valuesWithExistingRemat.insert(v);
993991
continue;
994992
}
@@ -1212,6 +1210,12 @@ void LayoutRematerialization::backwardRematerialization() {
12121210
[&](ConvertLayoutOp convertOp) { convertOps.push_back(convertOp); });
12131211
for (ConvertLayoutOp convertOp : convertOps) {
12141212
backwardRematerialization(convertOp);
1213+
if (!opToDelete.contains(convertOp)) {
1214+
// If the conversion didn't get removed, consider it for reuse in future
1215+
// backward slices.
1216+
addRematValue(convertOp.getSrc(), convertOp.getType().getEncoding(),
1217+
convertOp.getResult());
1218+
}
12151219
}
12161220
}
12171221

@@ -1222,6 +1226,12 @@ void LayoutRematerialization::hoistConvertOnTopOfExtOrBroadcast() {
12221226
[&](ConvertLayoutOp convertOp) { convertOps.push_back(convertOp); });
12231227
for (ConvertLayoutOp convertOp : convertOps) {
12241228
hoistConvertOnTopOfExtOrBroadcast(convertOp);
1229+
if (!opToDelete.contains(convertOp)) {
1230+
// If the conversion didn't get removed, consider it for reuse in future
1231+
// backward slices.
1232+
addRematValue(convertOp.getSrc(), convertOp.getType().getEncoding(),
1233+
convertOp.getResult());
1234+
}
12251235
}
12261236
}
12271237

@@ -1234,14 +1244,14 @@ void LayoutRematerialization::backwardRematerialization(
12341244
dyn_cast<DotOperandEncodingAttr>(targetType.getEncoding()))
12351245
if (isa<BlockedEncodingAttr>(dotLayout.getParent()))
12361246
return;
1237-
Value oldV = convertOp->getOperand(0);
1247+
Value oldV = convertOp.getSrc();
12381248
LDBG("check backward remat with source " << oldV << " encoding "
12391249
<< targetType.getEncoding());
12401250
// Check to see if there are existing remat'ed values for the pair of oldValue
1241-
// and encoding.
1242-
if (hasRematValue(oldV, targetType.getEncoding())) {
1251+
// and encoding. Make sure it dominates the current conversion.
1252+
Value newV = getRematValue(oldV, targetType.getEncoding());
1253+
if (newV && domInfo.properlyDominates(newV, convertOp)) {
12431254
// Replace it with the remat'ed value.
1244-
Value newV = getRematValue(oldV, targetType.getEncoding());
12451255
convertOp.replaceAllUsesWith(newV);
12461256
opToDelete.insert(convertOp);
12471257
LDBG("found remat'ed value" << newV);

0 commit comments

Comments
 (0)