-
Notifications
You must be signed in to change notification settings - Fork 69
[RemoveLayoutConversions]: Reduce loop carried values - part 2 #4921
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
cd237ff
11fff55
cad2f23
155d5fc
ed8cf48
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change | ||||
---|---|---|---|---|---|---|
|
@@ -25,6 +25,7 @@ | |||||
#include "triton/Dialect/TritonGPU/IR/Dialect.h" | ||||||
#include "triton/Dialect/TritonGPU/Transforms/TritonGPUConversion.h" | ||||||
#include "triton/Dialect/TritonGPU/Transforms/Utility.h" | ||||||
#include "llvm/ADT/TypeSwitch.h" | ||||||
#include <deque> | ||||||
|
||||||
namespace mlir::triton::gpu::intel { | ||||||
|
@@ -166,6 +167,7 @@ class LayoutRematerialization { | |||||
|
||||||
private: | ||||||
void updateRematMapping(SmallVector<std::tuple<Value, Value>> &values); | ||||||
void reduceLoopCarriedValues(); | ||||||
// Existing tuples of (value, layout) that needs to be updated when recreating | ||||||
// scf ops. This prevents keeping track of Values that have been delete when | ||||||
// rewriting slices. | ||||||
|
@@ -1008,6 +1010,93 @@ void LayoutRematerialization::updateRematMapping( | |||||
} | ||||||
} | ||||||
|
||||||
/// Reduce loop carried values if the value is used after the loop and can be | ||||||
/// removed by using another loop yielded value plus a convert layout operation. | ||||||
void LayoutRematerialization::reduceLoopCarriedValues() { | ||||||
for (auto [pair, val] : rematMapping) { | ||||||
if (!isa<BlockArgument>(pair.first)) | ||||||
continue; | ||||||
|
||||||
auto arg = cast<BlockArgument>(pair.first); | ||||||
if (!isTensorPointerType(arg.getType())) | ||||||
continue; | ||||||
|
||||||
auto loopOp = dyn_cast<LoopLikeOpInterface>(arg.getOwner()->getParentOp()); | ||||||
if (!loopOp) | ||||||
continue; | ||||||
|
||||||
// Loop arguments that corresponds to a loop result which is not used are | ||||||
// not interesting. | ||||||
OpResult loopRes = loopOp.getTiedLoopResult(arg); | ||||||
if (loopRes.getNumUses() == 0) | ||||||
continue; | ||||||
|
||||||
std::function<void(Operation *, Value)> processUser = [&](Operation *user, | ||||||
Value rematRes) { | ||||||
Location loc = user->getLoc(); | ||||||
OpBuilder rewriter(user); | ||||||
|
||||||
TypeSwitch<Operation *>(user) | ||||||
.Case<LoadOp>([&](auto loadOp) { | ||||||
auto newLoadOp = | ||||||
rewriter.create<LoadOp>(loc, rematRes, loadOp->getAttrs()); | ||||||
auto convOp = rewriter.create<ConvertLayoutOp>( | ||||||
loc, loadOp.getType(), newLoadOp.getResult()); | ||||||
loadOp->replaceAllUsesWith(convOp); | ||||||
opToDelete.insert(loadOp); | ||||||
LLVM_DEBUG({ | ||||||
DBGS() << "Replaced:\n\t" << *loadOp << "\n" | ||||||
<< "with:\n\t" << *newLoadOp << "\n" | ||||||
<< "\t" << *convOp << "\n"; | ||||||
}); | ||||||
}) | ||||||
.Case<StoreOp>([&](auto storeOp) { | ||||||
Value data = storeOp.getOperand(1); | ||||||
auto dataType = cast<RankedTensorType>(data.getType()); | ||||||
auto newPtrType = cast<PointerType>(rematRes.getType()); | ||||||
Attribute encoding = | ||||||
cast<RankedTensorType>(newPtrType.getPointeeType()) | ||||||
.getEncoding(); | ||||||
RankedTensorType newDataType = dataType.cloneWithEncoding(encoding); | ||||||
auto convOp = | ||||||
rewriter.create<ConvertLayoutOp>(loc, newDataType, data); | ||||||
auto newStoreOp = rewriter.create<StoreOp>( | ||||||
loc, rematRes, convOp, storeOp.getBoundaryCheck(), | ||||||
storeOp.getCache(), storeOp.getEvict()); | ||||||
opToDelete.insert(storeOp); | ||||||
LLVM_DEBUG({ | ||||||
DBGS() << "Replaced:\n\t" << *storeOp << "\n" | ||||||
<< "with:\n\t" << *convOp << "\n" | ||||||
<< "\t" << *newStoreOp << "\n"; | ||||||
}); | ||||||
}) | ||||||
.Case<AdvanceOp>([&](auto advanceOp) { | ||||||
auto newAdvanceOp = rewriter.create<AdvanceOp>( | ||||||
loc, rematRes.getType(), rematRes, advanceOp.getOffsets()); | ||||||
opToDelete.insert(advanceOp); | ||||||
LLVM_DEBUG({ | ||||||
DBGS() << "Replaced:\n\t" << *advanceOp << "\n" | ||||||
<< "with:\n\t" << *newAdvanceOp << "\n"; | ||||||
}); | ||||||
|
||||||
for (Operation *user : advanceOp->getUsers()) | ||||||
processUser(user, newAdvanceOp.getResult()); | ||||||
}) | ||||||
.Default([](auto op) { | ||||||
llvm::report_fatal_error(llvm::Twine( | ||||||
"Unsupported operation in backward rematerialization: '" + | ||||||
op->getName().getStringRef() + "'")); | ||||||
}); | ||||||
}; | ||||||
|
||||||
// Replace the loop result corresponding to the argument with an | ||||||
// equivalent loop result. | ||||||
OpResult rematRes = loopOp.getTiedLoopResult(cast<BlockArgument>(val)); | ||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The cast(val) is redundant and potentially unsafe. The variable
Suggested change
Copilot uses AI. Check for mistakes. Positive FeedbackNegative Feedback |
||||||
for (Operation *user : loopRes.getUsers()) | ||||||
processUser(user, rematRes); | ||||||
} | ||||||
} | ||||||
|
||||||
void LayoutRematerialization::rewriteSlice(SetVector<Value> &slice, | ||||||
DenseMap<Value, Attribute> &layout, | ||||||
ConvertLayoutOp convertOp, | ||||||
|
@@ -1267,6 +1356,8 @@ void LayoutRematerialization::backwardRematerialization() { | |||||
convertOp.getResult()); | ||||||
} | ||||||
} | ||||||
|
||||||
reduceLoopCarriedValues(); | ||||||
} | ||||||
|
||||||
void LayoutRematerialization::hoistConvertOnTopOfExtOrBroadcast() { | ||||||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The error message in the Default case uses string concatenation with Twine which may not work as expected. Consider using llvm::formatv or constructing the error message differently:
llvm::formatv("Unsupported operation in backward rematerialization: '{0}'", op->getName().getStringRef())
Copilot uses AI. Check for mistakes.