Skip to content

Commit f42bd66

Browse files
committed
Temp enhance the remove layout implementation to reduce the duplicated values with different layout in scf.for.
Signed-off-by: Lu,Chengjun <[email protected]>
1 parent 8b73e5a commit f42bd66

File tree

3 files changed

+91
-31
lines changed

3 files changed

+91
-31
lines changed

third_party/intel/include/Dialect/TritonIntelGPU/Transforms/Utility.h

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -50,7 +50,8 @@ LogicalResult getConvertBackwardSlice(
5050
DenseMap<Value, Attribute> &layout,
5151
std::function<bool(Operation *)> stopPropagation = nullptr,
5252
std::function<Value(OpOperand &, Attribute)> getExistingConversion =
53-
nullptr);
53+
nullptr,
54+
bool includeForOp = false);
5455

5556
LLVM::LLVMFuncOp lookupOrCreateSPIRVFn(Operation *symbolTable, StringRef name,
5657
ArrayRef<Type> paramTypes,

third_party/intel/lib/TritonIntelGPUTransforms/RemoveLayoutConversions.cpp

Lines changed: 50 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -149,19 +149,22 @@ class LayoutRematerialization {
149149
getConvertBackwardSlice(OpOperand &root, Attribute rootEncoding,
150150
SetVector<Value> &slice,
151151
DenseMap<Value, Attribute> &layout,
152-
std::function<bool(Operation *)> stopPropagation);
152+
std::function<bool(Operation *)> stopPropagation,
153+
bool includeForOp = false);
153154

154155
LogicalResult getRematerializableSlice(
155156
OpOperand &root, Attribute rootEncoding, SetVector<Value> &slice,
156157
DenseMap<Value, Attribute> &layout,
157-
std::function<bool(Operation *)> stopPropagation = nullptr);
158+
std::function<bool(Operation *)> stopPropagation = nullptr,
159+
bool includeForOp = false);
158160

159161
private:
160162
void updateRematMapping(SmallVector<std::tuple<Value, Value>> &values);
161163
// Existing tuples of (value, layout) that needs to be updated when recreating
162164
// scf ops. This prevents keeping track of Values that have been delete when
163-
// rewriting slices.
164-
DenseMap<Value, Attribute> mappedValues;
165+
// rewriting slices. The Value maybe mapped to different attributes in remove
166+
// layout.
167+
DenseMap<Value, SmallVector<Attribute>> mappedValues;
165168
// map of the values remat based on encoding.
166169
DenseMap<std::pair<Value, Attribute>, Value> rematMapping;
167170
// DenseMap<std::pair<Operation*, Attribute>, Operation*>
@@ -174,7 +177,11 @@ void LayoutRematerialization::addRematValue(Value old, Attribute encoding,
174177
Value newV) {
175178
LDBG("addRematValue " << old << " encoding " << encoding << " " << newV);
176179
rematMapping[{old, encoding}] = newV;
177-
mappedValues[old] = encoding;
180+
if (mappedValues.contains(old)) {
181+
mappedValues[old].push_back(encoding);
182+
} else {
183+
mappedValues[old] = {encoding};
184+
}
178185
}
179186

180187
// Remove unneeded values now that we are done with the rematMapping.
@@ -955,22 +962,28 @@ void LayoutRematerialization::updateRematMapping(
955962
for (auto [old, newV] : values) {
956963
auto it = mappedValues.find(old);
957964
if (it != mappedValues.end()) {
958-
Attribute encoding = it->second;
959-
auto rematIt = rematMapping.find({old, it->second});
960-
assert(rematIt != rematMapping.end());
961-
Value replacedValue = rematIt->second;
962-
rematMapping.erase(rematIt);
963-
mappedValues.erase(it);
964-
// Loop through the replacement value to find the new version of remat
965-
// value. This should be okay as the number of values should be small.
966-
for (auto [before, after] : values) {
967-
if (before == replacedValue) {
968-
replacedValue = after;
969-
break;
965+
SmallVector<Attribute> encodings = it->second;
966+
for (auto encoding : encodings) {
967+
auto rematIt = rematMapping.find({old, encoding});
968+
assert(rematIt != rematMapping.end());
969+
Value replacedValue = rematIt->second;
970+
rematMapping.erase(rematIt);
971+
// Loop through the replacement value to find the new version of remat
972+
// value. This should be okay as the number of values should be small.
973+
for (auto [before, after] : values) {
974+
if (before == replacedValue) {
975+
replacedValue = after;
976+
break;
977+
}
970978
}
979+
rematMapping[{newV, encoding}] = replacedValue;
980+
}
981+
mappedValues.erase(it);
982+
if (mappedValues.contains(newV)) {
983+
mappedValues[newV].append(encodings);
984+
} else {
985+
mappedValues[newV] = std::move(encodings);
971986
}
972-
rematMapping[{newV, encoding}] = replacedValue;
973-
mappedValues[newV] = encoding;
974987
}
975988
}
976989
}
@@ -1045,6 +1058,7 @@ void LayoutRematerialization::rewriteSlice(SetVector<Value> &slice,
10451058
deadOps.push_back(forOp.getOperation());
10461059
Block &loopBody = *newForOp.getBody();
10471060
for (auto m : argMapping) {
1061+
mapping.map(newForOp.getResult(m.first), newForOp.getResult(m.second));
10481062
mapping.map(forOp.getResult(m.first), newForOp.getResult(m.second));
10491063
int numIndVars = newForOp.getNumInductionVars();
10501064
mapping.map(loopBody.getArgument(m.first + numIndVars),
@@ -1161,8 +1175,12 @@ void LayoutRematerialization::rewriteSlice(SetVector<Value> &slice,
11611175
builder.replaceAllUsesWith(std::get<0>(kv), std::get<1>(kv));
11621176
}
11631177

1164-
for (Operation *op : deadOps)
1165-
opToDelete.insert(op);
1178+
for (Operation *op : deadOps) {
1179+
if (!isa<scf::ForOp>(op))
1180+
opToDelete.insert(op);
1181+
else
1182+
op->erase();
1183+
}
11661184
}
11671185

11681186
void LayoutRematerialization::rewriteSlice(SetVector<Value> &slice,
@@ -1175,7 +1193,7 @@ void LayoutRematerialization::rewriteSlice(SetVector<Value> &slice,
11751193
LogicalResult LayoutRematerialization::getConvertBackwardSlice(
11761194
OpOperand &root, Attribute rootEncoding, SetVector<Value> &slice,
11771195
DenseMap<Value, Attribute> &layout,
1178-
std::function<bool(Operation *)> stopPropagation) {
1196+
std::function<bool(Operation *)> stopPropagation, bool includeForOp) {
11791197
// Allow re-using existing conversions for a value. Check dominance of any
11801198
// reusable materializations against the root value. This is sufficient
11811199
// because the conversions are processed in post-order.
@@ -1204,15 +1222,18 @@ LogicalResult LayoutRematerialization::getConvertBackwardSlice(
12041222
};
12051223

12061224
return ttgi::getConvertBackwardSlice(root, slice, rootEncoding, layout,
1207-
stopPropagation, getExistingConversion);
1225+
stopPropagation, getExistingConversion,
1226+
includeForOp);
12081227
}
12091228

12101229
LogicalResult LayoutRematerialization::getRematerializableSlice(
12111230
OpOperand &root, Attribute rootEncoding, SetVector<Value> &slice,
12121231
DenseMap<Value, Attribute> &layout,
1213-
std::function<bool(Operation *)> stopPropagation) {
1214-
LogicalResult result = getConvertBackwardSlice(
1215-
root, rootEncoding, slice, layout, std::move(stopPropagation));
1232+
std::function<bool(Operation *)> stopPropagation, bool includeForOp) {
1233+
1234+
LogicalResult result =
1235+
getConvertBackwardSlice(root, rootEncoding, slice, layout,
1236+
std::move(stopPropagation), includeForOp);
12161237
if (result.failed() || slice.empty())
12171238
return failure();
12181239

@@ -1301,8 +1322,9 @@ void LayoutRematerialization::backwardRematerialization(
13011322
// rematerialized.
13021323
SetVector<Value> slice;
13031324
DenseMap<Value, Attribute> layout;
1304-
LogicalResult result = getRematerializableSlice(
1305-
convertOp.getSrcMutable(), targetType.getEncoding(), slice, layout);
1325+
LogicalResult result = getRematerializableSlice(convertOp.getSrcMutable(),
1326+
targetType.getEncoding(),
1327+
slice, layout, nullptr, true);
13061328
if (result.failed()) {
13071329
LDBG(" getRematerializableSlice failed");
13081330
return;

third_party/intel/lib/TritonIntelGPUTransforms/Utility.cpp

Lines changed: 39 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -182,7 +182,8 @@ LogicalResult getConvertBackwardSlice(
182182
OpOperand &root, SetVector<Value> &slice, Attribute rootEncoding,
183183
DenseMap<Value, Attribute> &layout,
184184
std::function<bool(Operation *)> stopPropagation,
185-
std::function<Value(OpOperand &, Attribute)> getExistingConversion) {
185+
std::function<Value(OpOperand &, Attribute)> getExistingConversion,
186+
bool includeForOp) {
186187
DenseSet<std::pair<OpOperand *, Attribute>> seen;
187188
SmallVector<std::pair<OpOperand *, Attribute>> queue;
188189

@@ -197,6 +198,12 @@ LogicalResult getConvertBackwardSlice(
197198

198199
auto updateLayout = [&](Value value, Attribute encoding) {
199200
assert(isTensorOrTensorPointerType(value.getType()));
201+
auto tensorType = getRankedTensorType(value.getType());
202+
auto originEncoding = tensorType.getEncoding();
203+
if (originEncoding == encoding) {
204+
return success();
205+
}
206+
200207
slice.insert(value);
201208
Attribute &existing = layout[value];
202209
if (existing && existing != encoding)
@@ -213,8 +220,38 @@ LogicalResult getConvertBackwardSlice(
213220
continue;
214221
// Skip propagating through for op results for now.
215222
// TODO: enable this based on needs.
216-
if (currentValue.getDefiningOp<scf::ForOp>())
223+
if (auto forOp = currentValue.getDefiningOp<scf::ForOp>()) {
224+
if (!includeForOp)
225+
return failure();
226+
if (stopPropagation && stopPropagation(forOp))
227+
continue;
228+
unsigned argIdx = mlir::cast<OpResult>(currentValue).getResultNumber();
229+
int numIndVars = forOp.getNumInductionVars();
230+
Block &loopBody = *forOp.getBody();
231+
auto blockArg = loopBody.getArgument(argIdx + numIndVars);
232+
233+
Value existing;
234+
if (getExistingConversion &&
235+
(existing = getExistingConversion(*currentValueUse, encoding))) {
236+
if (failed(updateLayout(currentValue, encoding)))
237+
return failure();
238+
239+
continue;
240+
}
217241
return failure();
242+
243+
OpOperand *initOperand = forOp.getTiedLoopInit(blockArg);
244+
OpOperand &yieldOperand = loopBody.getTerminator()->getOpOperand(argIdx);
245+
llvm::outs() << "johnlu getBackward slice check scf.for initOperand: "
246+
<< initOperand->get() << "\n";
247+
llvm::outs() << "johnlu getBackward slice check scf.for yieldOperand: "
248+
<< yieldOperand.get() << "\n";
249+
if (failed(updateLayout(blockArg, encoding)))
250+
return failure();
251+
enqueue(*initOperand, encoding);
252+
enqueue(yieldOperand, encoding);
253+
continue;
254+
}
218255
if (failed(updateLayout(currentValue, encoding)))
219256
return failure();
220257

0 commit comments

Comments
 (0)