Skip to content

[Draft] [BACKEND] Enhance the remove layout implementation to reduce the duplicated values with different layout in scf.for. #4527

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft
wants to merge 1 commit into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,8 @@ LogicalResult getConvertBackwardSlice(
DenseMap<Value, Attribute> &layout,
std::function<bool(Operation *)> stopPropagation = nullptr,
std::function<Value(OpOperand &, Attribute)> getExistingConversion =
nullptr);
nullptr,
bool includeForOp = false);

LLVM::LLVMFuncOp lookupOrCreateSPIRVFn(Operation *symbolTable, StringRef name,
ArrayRef<Type> paramTypes,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -149,19 +149,22 @@ class LayoutRematerialization {
getConvertBackwardSlice(OpOperand &root, Attribute rootEncoding,
SetVector<Value> &slice,
DenseMap<Value, Attribute> &layout,
std::function<bool(Operation *)> stopPropagation);
std::function<bool(Operation *)> stopPropagation,
bool includeForOp = false);

LogicalResult getRematerializableSlice(
OpOperand &root, Attribute rootEncoding, SetVector<Value> &slice,
DenseMap<Value, Attribute> &layout,
std::function<bool(Operation *)> stopPropagation = nullptr);
std::function<bool(Operation *)> stopPropagation = nullptr,
bool includeForOp = false);

private:
void updateRematMapping(SmallVector<std::tuple<Value, Value>> &values);
// Existing tuples of (value, layout) that needs to be updated when recreating
// scf ops. This prevents keeping track of Values that have been delete when
// rewriting slices.
DenseMap<Value, Attribute> mappedValues;
// rewriting slices. The Value maybe mapped to different attributes in remove
// layout.
DenseMap<Value, SmallVector<Attribute>> mappedValues;
// map of the values remat based on encoding.
DenseMap<std::pair<Value, Attribute>, Value> rematMapping;
// DenseMap<std::pair<Operation*, Attribute>, Operation*>
Expand All @@ -174,7 +177,11 @@ void LayoutRematerialization::addRematValue(Value old, Attribute encoding,
Value newV) {
LDBG("addRematValue " << old << " encoding " << encoding << " " << newV);
rematMapping[{old, encoding}] = newV;
mappedValues[old] = encoding;
if (mappedValues.contains(old)) {
mappedValues[old].push_back(encoding);
} else {
mappedValues[old] = {encoding};
}
}

// Remove unneeded values now that we are done with the rematMapping.
Expand Down Expand Up @@ -955,22 +962,28 @@ void LayoutRematerialization::updateRematMapping(
for (auto [old, newV] : values) {
auto it = mappedValues.find(old);
if (it != mappedValues.end()) {
Attribute encoding = it->second;
auto rematIt = rematMapping.find({old, it->second});
assert(rematIt != rematMapping.end());
Value replacedValue = rematIt->second;
rematMapping.erase(rematIt);
mappedValues.erase(it);
// Loop through the replacement value to find the new version of remat
// value. This should be okay as the number of values should be small.
for (auto [before, after] : values) {
if (before == replacedValue) {
replacedValue = after;
break;
SmallVector<Attribute> encodings = it->second;
for (auto encoding : encodings) {
auto rematIt = rematMapping.find({old, encoding});
assert(rematIt != rematMapping.end());
Value replacedValue = rematIt->second;
rematMapping.erase(rematIt);
// Loop through the replacement value to find the new version of remat
// value. This should be okay as the number of values should be small.
for (auto [before, after] : values) {
if (before == replacedValue) {
replacedValue = after;
break;
}
}
rematMapping[{newV, encoding}] = replacedValue;
}
mappedValues.erase(it);
if (mappedValues.contains(newV)) {
mappedValues[newV].append(encodings);
} else {
mappedValues[newV] = std::move(encodings);
}
rematMapping[{newV, encoding}] = replacedValue;
mappedValues[newV] = encoding;
}
}
}
Expand Down Expand Up @@ -1045,6 +1058,7 @@ void LayoutRematerialization::rewriteSlice(SetVector<Value> &slice,
deadOps.push_back(forOp.getOperation());
Block &loopBody = *newForOp.getBody();
for (auto m : argMapping) {
mapping.map(newForOp.getResult(m.first), newForOp.getResult(m.second));
mapping.map(forOp.getResult(m.first), newForOp.getResult(m.second));
int numIndVars = newForOp.getNumInductionVars();
mapping.map(loopBody.getArgument(m.first + numIndVars),
Expand Down Expand Up @@ -1161,8 +1175,12 @@ void LayoutRematerialization::rewriteSlice(SetVector<Value> &slice,
builder.replaceAllUsesWith(std::get<0>(kv), std::get<1>(kv));
}

for (Operation *op : deadOps)
opToDelete.insert(op);
for (Operation *op : deadOps) {
if (!isa<scf::ForOp>(op))
opToDelete.insert(op);
else
op->erase();
}
}

void LayoutRematerialization::rewriteSlice(SetVector<Value> &slice,
Expand All @@ -1175,7 +1193,7 @@ void LayoutRematerialization::rewriteSlice(SetVector<Value> &slice,
LogicalResult LayoutRematerialization::getConvertBackwardSlice(
OpOperand &root, Attribute rootEncoding, SetVector<Value> &slice,
DenseMap<Value, Attribute> &layout,
std::function<bool(Operation *)> stopPropagation) {
std::function<bool(Operation *)> stopPropagation, bool includeForOp) {
// Allow re-using existing conversions for a value. Check dominance of any
// reusable materializations against the root value. This is sufficient
// because the conversions are processed in post-order.
Expand Down Expand Up @@ -1204,15 +1222,18 @@ LogicalResult LayoutRematerialization::getConvertBackwardSlice(
};

return ttgi::getConvertBackwardSlice(root, slice, rootEncoding, layout,
stopPropagation, getExistingConversion);
stopPropagation, getExistingConversion,
includeForOp);
}

LogicalResult LayoutRematerialization::getRematerializableSlice(
OpOperand &root, Attribute rootEncoding, SetVector<Value> &slice,
DenseMap<Value, Attribute> &layout,
std::function<bool(Operation *)> stopPropagation) {
LogicalResult result = getConvertBackwardSlice(
root, rootEncoding, slice, layout, std::move(stopPropagation));
std::function<bool(Operation *)> stopPropagation, bool includeForOp) {

LogicalResult result =
getConvertBackwardSlice(root, rootEncoding, slice, layout,
std::move(stopPropagation), includeForOp);
if (result.failed() || slice.empty())
return failure();

Expand Down Expand Up @@ -1301,8 +1322,9 @@ void LayoutRematerialization::backwardRematerialization(
// rematerialized.
SetVector<Value> slice;
DenseMap<Value, Attribute> layout;
LogicalResult result = getRematerializableSlice(
convertOp.getSrcMutable(), targetType.getEncoding(), slice, layout);
LogicalResult result = getRematerializableSlice(convertOp.getSrcMutable(),
targetType.getEncoding(),
slice, layout, nullptr, true);
if (result.failed()) {
LDBG(" getRematerializableSlice failed");
return;
Expand Down
41 changes: 39 additions & 2 deletions third_party/intel/lib/TritonIntelGPUTransforms/Utility.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -182,7 +182,8 @@ LogicalResult getConvertBackwardSlice(
OpOperand &root, SetVector<Value> &slice, Attribute rootEncoding,
DenseMap<Value, Attribute> &layout,
std::function<bool(Operation *)> stopPropagation,
std::function<Value(OpOperand &, Attribute)> getExistingConversion) {
std::function<Value(OpOperand &, Attribute)> getExistingConversion,
bool includeForOp) {
DenseSet<std::pair<OpOperand *, Attribute>> seen;
SmallVector<std::pair<OpOperand *, Attribute>> queue;

Expand All @@ -197,6 +198,12 @@ LogicalResult getConvertBackwardSlice(

auto updateLayout = [&](Value value, Attribute encoding) {
assert(isTensorOrTensorPointerType(value.getType()));
auto tensorType = getRankedTensorType(value.getType());
auto originEncoding = tensorType.getEncoding();
if (originEncoding == encoding) {
return success();
}

slice.insert(value);
Attribute &existing = layout[value];
if (existing && existing != encoding)
Expand All @@ -213,8 +220,38 @@ LogicalResult getConvertBackwardSlice(
continue;
// Skip propagating through for op results for now.
// TODO: enable this based on needs.
if (currentValue.getDefiningOp<scf::ForOp>())
if (auto forOp = currentValue.getDefiningOp<scf::ForOp>()) {
if (!includeForOp)
return failure();
if (stopPropagation && stopPropagation(forOp))
continue;
unsigned argIdx = mlir::cast<OpResult>(currentValue).getResultNumber();
int numIndVars = forOp.getNumInductionVars();
Block &loopBody = *forOp.getBody();
auto blockArg = loopBody.getArgument(argIdx + numIndVars);

Value existing;
if (getExistingConversion &&
(existing = getExistingConversion(*currentValueUse, encoding))) {
if (failed(updateLayout(currentValue, encoding)))
return failure();

continue;
}
return failure();

OpOperand *initOperand = forOp.getTiedLoopInit(blockArg);
OpOperand &yieldOperand = loopBody.getTerminator()->getOpOperand(argIdx);
llvm::outs() << "johnlu getBackward slice check scf.for initOperand: "
<< initOperand->get() << "\n";
llvm::outs() << "johnlu getBackward slice check scf.for yieldOperand: "
<< yieldOperand.get() << "\n";
if (failed(updateLayout(blockArg, encoding)))
return failure();
enqueue(*initOperand, encoding);
enqueue(yieldOperand, encoding);
continue;
}
if (failed(updateLayout(currentValue, encoding)))
return failure();

Expand Down
Loading