Skip to content

Commit 08786ac

Browse files
committed
[Triton][Analysis] Support running different allocation analyses
Parametrize `Allocation::run` to support more runnning more allocation analyses, thus making the allocation analysis framework more easily extensible. `ModuleAllocation` instances are now obtained from a templated `get` static member function to propagate the analysis kind to use. Signed-off-by: victor-eds <[email protected]>
1 parent 04a7b65 commit 08786ac

File tree

11 files changed

+620
-15
lines changed

11 files changed

+620
-15
lines changed

include/triton/Analysis/Allocation.h

Lines changed: 14 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -102,7 +102,7 @@ class Allocation {
102102
explicit Allocation(Operation *operation) : operation(operation) {}
103103

104104
/// Runs allocation analysis on the given top-level operation.
105-
void run(FuncAllocMapT &funcAllocMap);
105+
template <typename AllocationAnalysis> void run(FuncAllocMapT &funcAllocMap);
106106

107107
/// Returns the operation this analysis was constructed from.
108108
Operation *getOperation() const { return operation; }
@@ -240,6 +240,9 @@ class Allocation {
240240
friend class triton::AllocationAnalysis;
241241
};
242242

243+
template <>
244+
void Allocation::run<triton::AllocationAnalysis>(FuncAllocMapT &funcAllocMap);
245+
243246
/// Static analysis that computes the allocation of shared memory buffers
244247
/// of the entire call graph.
245248
/// The allocation is performed in a post-order walk of the call graph.
@@ -250,17 +253,19 @@ class ModuleAllocation : public CallGraph<Allocation> {
250253
public:
251254
using FuncOffsetMapT = DenseMap<FunctionOpInterface, Value>;
252255

253-
explicit ModuleAllocation(ModuleOp moduleOp)
254-
: CallGraph<Allocation>(moduleOp) {
255-
walk<WalkOrder::PreOrder, WalkOrder::PostOrder>(
256+
template <typename AllocationAnalysis = triton::AllocationAnalysis>
257+
static ModuleAllocation get(ModuleOp moduleOp) {
258+
ModuleAllocation res(moduleOp);
259+
res.walk<WalkOrder::PreOrder, WalkOrder::PostOrder>(
256260
// Pre-order edge walk callback
257261
[](CallOpInterface callOp, FunctionOpInterface funcOp) {},
258262
// Post-order node walk callback
259263
[&](FunctionOpInterface funcOp) {
260-
auto [iter, inserted] = funcMap.try_emplace(funcOp, funcOp);
264+
auto [iter, inserted] = res.funcMap.try_emplace(funcOp, funcOp);
261265
if (inserted)
262-
iter->second.run(funcMap);
266+
iter->second.template run<AllocationAnalysis>(res.funcMap);
263267
});
268+
return res;
264269
}
265270

266271
size_t getSharedMemorySize() {
@@ -285,6 +290,9 @@ class ModuleAllocation : public CallGraph<Allocation> {
285290
}
286291

287292
private:
293+
explicit ModuleAllocation(ModuleOp moduleOp)
294+
: CallGraph<Allocation>(moduleOp) {}
295+
288296
FuncOffsetMapT sharedMemoryValue;
289297
};
290298

lib/Analysis/Allocation.cpp

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -588,7 +588,8 @@ class AllocationAnalysis {
588588

589589
} // namespace triton
590590

591-
void Allocation::run(FuncAllocMapT &funcAllocMap) {
591+
template <>
592+
void Allocation::run<triton::AllocationAnalysis>(FuncAllocMapT &funcAllocMap) {
592593
triton::AllocationAnalysis(getOperation(), &funcAllocMap, this);
593594
}
594595

lib/Conversion/TritonGPUToLLVM/AllocateSharedMemory.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@ struct AllocateSharedMemory
2323
void runOnOperation() override {
2424
ModuleOp mod = getOperation();
2525
MLIRContext *ctx = &getContext();
26-
ModuleAllocation allocation(mod);
26+
ModuleAllocation allocation = ModuleAllocation::get(mod);
2727

2828
mod.walk([&](FunctionOpInterface funcOp) {
2929
funcOp.walk([&](Operation *op) {

test/lib/Analysis/TestAllocation.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@ struct TestAllocationPass
1919
auto &os = llvm::errs();
2020
ModuleOp moduleOp = getOperation();
2121
// Convert to std::string can remove quotes from opName
22-
ModuleAllocation moduleAllocation(moduleOp);
22+
ModuleAllocation moduleAllocation = ModuleAllocation::get(moduleOp);
2323
moduleOp.walk([&](triton::FuncOp funcOp) {
2424
auto opName = SymbolTable::getSymbolName(funcOp).getValue().str();
2525
os << opName << "\n";

test/lib/Analysis/TestMembar.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,7 @@ struct TestMembarPass
2525
Operation *operation = getOperation();
2626
ModuleOp moduleOp = cast<ModuleOp>(operation);
2727
// Print all ops after membar pass
28-
ModuleAllocation allocation(moduleOp);
28+
ModuleAllocation allocation = ModuleAllocation::get(moduleOp);
2929
ModuleMembarAnalysis membarPass(&allocation,
3030
mlir::triton::NVIDIA::canSkipBarSync);
3131
membarPass.run();

third_party/amd/lib/TritonAMDGPUToLLVM/OptimizeLDSUsage.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -221,7 +221,7 @@ class OptimizeAMDLDSUsage
221221
LDSLimit = targetInfo.getSharedMemorySize();
222222
}
223223

224-
ModuleAllocation allocAnalysis(mod);
224+
ModuleAllocation allocAnalysis = ModuleAllocation::get(mod);
225225
if (allocAnalysis.getSharedMemorySize() <= LDSLimit)
226226
return;
227227

third_party/amd/lib/TritonAMDGPUToLLVM/TritonGPUToLLVM.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -104,7 +104,7 @@ struct ConvertTritonAMDGPUToLLVM
104104
}
105105

106106
// Allocate shared memory and set barrier
107-
ModuleAllocation allocation(mod);
107+
ModuleAllocation allocation = ModuleAllocation::get(mod);
108108
ModuleMembarAnalysis membarPass(&allocation);
109109
membarPass.run();
110110

0 commit comments

Comments
 (0)