Skip to content

Commit 94efbc1

Browse files
Sync getConvertBackwardSlice from upstream (#3329)
Changes come from upstream commit 24b8d43 and a6b15ef. Signed-off-by: Whitney Tsang <[email protected]>
1 parent b9ba137 commit 94efbc1

File tree

3 files changed

+100
-39
lines changed

3 files changed

+100
-39
lines changed

third_party/intel/include/Dialect/TritonIntelGPU/Transforms/Utility.h

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -45,9 +45,11 @@ getDotEncoding(RankedTensorType tensorType);
4545
// Get backward slice of tensor values starting from the root node along with
4646
// encoding propagation.
4747
LogicalResult getConvertBackwardSlice(
48-
Value root, SetVector<Value> &slice, Attribute rootEncoding,
48+
OpOperand &root, SetVector<Value> &slice, Attribute rootEncoding,
4949
DenseMap<Value, Attribute> &layout,
50-
std::function<bool(Operation *)> stopPropagation = nullptr);
50+
std::function<bool(Operation *)> stopPropagation = nullptr,
51+
std::function<Value(OpOperand &, Attribute)> getExistingConversion =
52+
nullptr);
5153

5254
LLVM::LLVMFuncOp lookupOrCreateSPIRVFn(Operation *symbolTable, StringRef name,
5355
ArrayRef<Type> paramTypes,

third_party/intel/lib/TritonIntelGPUTransforms/RemoveLayoutConversions.cpp

Lines changed: 44 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -154,17 +154,18 @@ class LayoutPropagation {
154154
class LayoutRematerialization {
155155
public:
156156
LayoutRematerialization(FuncOp F) : funcOp(F) {}
157+
157158
// Map the original value to the remat'ed one.
158159
void addRematValue(Value old, Attribute encoding, Value newV);
160+
// Get the remat'ed value in the given encoding, if one already exists and
161+
// is different then the layout conversion root.
162+
Value getRematValue(Value value, Attribute encoding) const {
163+
return rematMapping.lookup({value, encoding});
164+
}
165+
159166
bool hasRematValue(Value value, Attribute encoding) {
160167
return rematMapping.contains({value, encoding});
161168
}
162-
// Return the remat'ed value in the given encoding.
163-
Value getRematValue(Value value, Attribute encoding) {
164-
auto it = rematMapping.find({value, encoding});
165-
assert(it != rematMapping.end());
166-
return it->second;
167-
}
168169
void cleanup();
169170
void backwardRematerialization();
170171
void backwardRematerialization(ConvertLayoutOp convertOp);
@@ -175,6 +176,11 @@ class LayoutRematerialization {
175176
void rewriteSlice(SetVector<Value> &slice, DenseMap<Value, Attribute> &layout,
176177
ConvertLayoutOp convertOp);
177178

179+
LogicalResult getRematerializableSlice(
180+
OpOperand &root, Attribute rootEncoding, SetVector<Value> &slice,
181+
DenseMap<Value, Attribute> &layout,
182+
std::function<bool(Operation *)> stopPropagation = nullptr);
183+
178184
private:
179185
void updateRematMapping(SmallVector<std::tuple<Value, Value>> &values);
180186
// Existing tuples of (value, layout) that needs to be updated when recreating
@@ -186,6 +192,7 @@ class LayoutRematerialization {
186192
// DenseMap<std::pair<Operation*, Attribute>, Operation*>
187193
SetVector<Operation *> opToDelete;
188194
FuncOp funcOp;
195+
DominanceInfo domInfo;
189196
};
190197

191198
void LayoutRematerialization::addRematValue(Value old, Attribute encoding,
@@ -1188,10 +1195,33 @@ void LayoutRematerialization::rewriteSlice(SetVector<Value> &slice,
11881195
rewriteSlice(slice, layout, convertOp, mapping);
11891196
}
11901197

1191-
LogicalResult getRematerializableSlice(
1192-
Value root, Attribute rootEncoding, SetVector<Value> &slice,
1198+
LogicalResult LayoutRematerialization::getRematerializableSlice(
1199+
OpOperand &root, Attribute rootEncoding, SetVector<Value> &slice,
11931200
DenseMap<Value, Attribute> &layout,
1194-
std::function<bool(Operation *)> stopPropagation = nullptr) {
1201+
std::function<bool(Operation *)> stopPropagation) {
1202+
// Allow re-using existing conversions for a value. Check dominance of any
1203+
// reusable materializations against the root value. This is sufficient
1204+
// because the conversions are processed in post-order.
1205+
auto getExistingConversion = [&](OpOperand &value, Attribute encoding) {
1206+
Value remat = getRematValue(value.get(), encoding);
1207+
if (!remat)
1208+
return Value();
1209+
// `value` can be replaced with an existing rematerialization if it
1210+
// dominates the current use of value.
1211+
Operation *user = value.getOwner();
1212+
if (domInfo.properlyDominates(remat, user)) {
1213+
return remat;
1214+
}
1215+
// Alternatively, if the current use can be sunk below the existing
1216+
// rematerialization, then it is okay to use as well. E.g. the current use
1217+
// is a conversion that will be folded away when its result is
1218+
// rematerialized.
1219+
if (isa<ConvertLayoutOp>(user) && remat.getDefiningOp() &&
1220+
domInfo.properlyDominates(user, remat.getDefiningOp())) {
1221+
return remat;
1222+
}
1223+
return Value();
1224+
};
11951225
LogicalResult result = ttgi::getConvertBackwardSlice(
11961226
root, slice, rootEncoding, layout, std::move(stopPropagation));
11971227
if (result.failed() || slice.empty())
@@ -1255,7 +1285,7 @@ void LayoutRematerialization::backwardRematerialization(
12551285
SetVector<Value> slice;
12561286
DenseMap<Value, Attribute> layout;
12571287
LogicalResult result = getRematerializableSlice(
1258-
convertOp.getSrc(), targetType.getEncoding(), slice, layout);
1288+
convertOp.getSrcMutable(), targetType.getEncoding(), slice, layout);
12591289
if (result.failed()) {
12601290
LDBG(" getRematerializableSlice failed");
12611291
return;
@@ -1287,9 +1317,9 @@ void LayoutRematerialization::hoistConvertOnTopOfExtOrBroadcast(
12871317
// 1. Take a backward slice of all the tensor dependencies.
12881318
SetVector<Value> slice;
12891319
DenseMap<Value, Attribute> layout;
1290-
LogicalResult result =
1291-
getRematerializableSlice(convertOp.getSrc(), targetType.getEncoding(),
1292-
slice, layout, isExtOrBroadcastOp);
1320+
LogicalResult result = getRematerializableSlice(
1321+
convertOp.getSrcMutable(), targetType.getEncoding(), slice, layout,
1322+
isExtOrBroadcastOp);
12931323
if (result.failed())
12941324
return;
12951325

@@ -1307,7 +1337,7 @@ void LayoutRematerialization::hoistConvertOnTopOfExtOrBroadcast(
13071337
if (!srcEncoding)
13081338
return;
13091339
LogicalResult result = getRematerializableSlice(
1310-
op->getOperand(0), srcEncoding, tempSlice, tempLayout);
1340+
op->getOpOperand(0), srcEncoding, tempSlice, tempLayout);
13111341
// If we can rematerialize the rest of the ext slice we can ignore this
13121342
// ext as it won't need a convert.
13131343
if (result.succeeded()) {

third_party/intel/lib/TritonIntelGPUTransforms/Utility.cpp

Lines changed: 52 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -149,41 +149,60 @@ static bool isFreeConvert(Operation *op) {
149149
convertOp.getType());
150150
}
151151

152-
LogicalResult
153-
getConvertBackwardSlice(Value root, SetVector<Value> &slice,
154-
Attribute rootEncoding,
155-
DenseMap<Value, Attribute> &layout,
156-
std::function<bool(Operation *)> stopPropagation) {
157-
DenseSet<std::pair<Value, Attribute>> seen;
158-
SmallVector<std::pair<Value, Attribute>> queue;
159-
160-
auto enqueue = [&](Value operand, Attribute encoding) {
161-
auto x = std::make_pair(operand, encoding);
152+
LogicalResult getConvertBackwardSlice(
153+
OpOperand &root, SetVector<Value> &slice, Attribute rootEncoding,
154+
DenseMap<Value, Attribute> &layout,
155+
std::function<bool(Operation *)> stopPropagation,
156+
std::function<Value(OpOperand &, Attribute)> getExistingConversion) {
157+
DenseSet<std::pair<OpOperand *, Attribute>> seen;
158+
SmallVector<std::pair<OpOperand *, Attribute>> queue;
159+
160+
auto enqueue = [&](OpOperand &operand, Attribute encoding) {
161+
auto x = std::make_pair(&operand, encoding);
162162
if (!seen.insert(x).second) {
163163
return; // Already enqueued, skip
164164
}
165165
queue.push_back(x);
166166
};
167167
enqueue(root, rootEncoding);
168168

169+
auto updateLayout = [&](Value value, Attribute encoding) {
170+
assert(isTensorOrTensorPointerType(value.getType()));
171+
slice.insert(value);
172+
if (layout.find(value) != layout.end()) {
173+
if (layout[value] != encoding)
174+
return failure();
175+
}
176+
layout[value] = encoding;
177+
return success();
178+
};
179+
169180
while (!queue.empty()) {
170-
auto [currentValue, encoding] = queue.back();
181+
auto [currentValueUse, encoding] = queue.back();
182+
Value currentValue = currentValueUse->get();
171183
queue.pop_back();
172184
if (!isTensorOrTensorPointerType(currentValue.getType()))
173185
continue;
174-
slice.insert(currentValue);
175-
if (layout.find(currentValue) != layout.end()) {
176-
if (layout[currentValue] != encoding)
186+
// Skip propagating through for op results for now.
187+
// TODO: enable this based on needs.
188+
if (currentValue.getDefiningOp<scf::ForOp>())
189+
return failure();
190+
if (failed(updateLayout(currentValue, encoding)))
191+
return failure();
192+
193+
Value existing;
194+
if (getExistingConversion &&
195+
(existing = getExistingConversion(*currentValueUse, encoding))) {
196+
if (failed(updateLayout(existing, encoding)))
177197
return failure();
198+
currentValue = existing;
178199
}
179-
layout[currentValue] = encoding;
180200

181201
if (auto ifOp = currentValue.getDefiningOp<scf::IfOp>()) {
182-
auto results = ifOp.getResults();
183202
unsigned argIdx = mlir::cast<OpResult>(currentValue).getResultNumber();
184203

185-
auto thenValue = ifOp.thenYield().getOperand(argIdx);
186-
auto elseValue = ifOp.elseYield().getOperand(argIdx);
204+
OpOperand &thenValue = ifOp.thenYield()->getOpOperand(argIdx);
205+
OpOperand &elseValue = ifOp.elseYield()->getOpOperand(argIdx);
187206

188207
enqueue(thenValue, encoding);
189208
enqueue(elseValue, encoding);
@@ -196,10 +215,11 @@ getConvertBackwardSlice(Value root, SetVector<Value> &slice,
196215
if (result == currentValue ||
197216
!isTensorOrTensorPointerType(result.getType()))
198217
continue;
199-
enqueue(result, encoding);
218+
if (failed(updateLayout(result, encoding)))
219+
return failure();
200220
}
201221
if (isFreeConvert(definingOp)) {
202-
enqueue(definingOp->getOperand(0), encoding);
222+
enqueue(definingOp->getOpOperand(0), encoding);
203223
continue;
204224
}
205225
if (canFoldIntoConversion(definingOp, encoding))
@@ -208,7 +228,16 @@ getConvertBackwardSlice(Value root, SetVector<Value> &slice,
208228
continue;
209229
if (isa<triton::CatOp>(definingOp))
210230
return failure();
211-
for (Value operand : definingOp->getOperands()) {
231+
if (auto gather = dyn_cast<GatherOp>(definingOp)) {
232+
// Specially handle gather since its transfer function only applies
233+
// between its index operand and result.
234+
auto srcEncoding = ttgi::inferSrcEncoding(gather, encoding);
235+
if (!srcEncoding)
236+
return failure();
237+
enqueue(gather.getIndicesMutable(), srcEncoding);
238+
continue;
239+
}
240+
for (auto [i, operand] : llvm::enumerate(definingOp->getOpOperands())) {
212241
auto srcEncoding = ttgi::inferSrcEncoding(definingOp, encoding);
213242
if (!srcEncoding)
214243
return failure();
@@ -221,9 +250,9 @@ getConvertBackwardSlice(Value root, SetVector<Value> &slice,
221250
Operation *parentOp = block->getParentOp();
222251
if (auto forOp = dyn_cast<scf::ForOp>(parentOp)) {
223252
OpOperand *initOperand = forOp.getTiedLoopInit(blockArg);
224-
Value yieldOperand = forOp.getBody()->getTerminator()->getOperand(
253+
OpOperand &yieldOperand = forOp.getBody()->getTerminator()->getOpOperand(
225254
blockArg.getArgNumber() - forOp.getNumInductionVars());
226-
enqueue(initOperand->get(), encoding);
255+
enqueue(*initOperand, encoding);
227256
enqueue(yieldOperand, encoding);
228257
continue;
229258
}

0 commit comments

Comments
 (0)