Skip to content

Commit 3f682b1

Browse files
committed
[XPU][Alloc] Use upstream interface to specialize Allocation analysis
Defie custom scratch memory size getter to specialize Allocation analysis. Signed-off-by: victor-eds <[email protected]>
1 parent 31bfb67 commit 3f682b1

File tree

11 files changed

+44
-639
lines changed

11 files changed

+44
-639
lines changed

include/triton/Analysis/Allocation.h

Lines changed: 1 addition & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -249,9 +249,6 @@ class Allocation {
249249
size_t sharedMemorySize = 0;
250250
};
251251

252-
template <>
253-
void Allocation::run<triton::AllocationAnalysis>(FuncAllocMapT &funcAllocMap);
254-
255252
/// Static analysis that computes the allocation of shared memory buffers
256253
/// of the entire call graph.
257254
/// The allocation is performed in a post-order walk of the call graph.
@@ -271,11 +268,10 @@ class ModuleAllocation : public CallGraph<Allocation> {
271268
[](CallOpInterface callOp, FunctionOpInterface funcOp) {},
272269
// Post-order node walk callback
273270
[&](FunctionOpInterface funcOp) {
274-
auto [iter, inserted] = res.funcMap.try_emplace(funcOp, funcOp);
271+
auto [iter, inserted] = funcMap.try_emplace(funcOp, funcOp);
275272
if (inserted)
276273
iter->second.run(funcMap, scratchSizeGetter);
277274
});
278-
return res;
279275
}
280276

281277
size_t getSharedMemorySize() {
@@ -300,9 +296,6 @@ class ModuleAllocation : public CallGraph<Allocation> {
300296
}
301297

302298
private:
303-
explicit ModuleAllocation(ModuleOp moduleOp)
304-
: CallGraph<Allocation>(moduleOp) {}
305-
306299
FuncOffsetMapT sharedMemoryValue;
307300
};
308301

lib/Conversion/TritonGPUToLLVM/AllocateSharedMemory.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@ struct AllocateSharedMemory
2323
void runOnOperation() override {
2424
ModuleOp mod = getOperation();
2525
MLIRContext *ctx = &getContext();
26-
ModuleAllocation allocation = ModuleAllocation::get(mod);
26+
ModuleAllocation allocation(mod);
2727

2828
mod.walk([&](FunctionOpInterface funcOp) {
2929
funcOp.walk([&](Operation *op) {

python/src/passes.cc

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -17,8 +17,7 @@ namespace py = pybind11;
1717

1818
void init_triton_analysis(py::module &&m) {
1919
py::class_<mlir::ModuleAllocation>(m, "allocation", py::module_local())
20-
.def(py::init(
21-
&mlir::ModuleAllocation::get<mlir::triton::AllocationAnalysis>));
20+
.def(py::init<mlir::ModuleOp>());
2221
py::class_<mlir::ModuleMembarAnalysis>(m, "membar", py::module_local())
2322
.def(py::init<mlir::ModuleAllocation *>())
2423
.def("run", &mlir::ModuleMembarAnalysis::run);

test/lib/Analysis/TestMembar.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,7 @@ struct TestMembarPass
2525
Operation *operation = getOperation();
2626
ModuleOp moduleOp = cast<ModuleOp>(operation);
2727
// Print all ops after membar pass
28-
ModuleAllocation allocation = ModuleAllocation::get(moduleOp);
28+
ModuleAllocation allocation(moduleOp);
2929
ModuleMembarAnalysis membarPass(&allocation,
3030
mlir::triton::NVIDIA::canSkipBarSync);
3131
membarPass.run();

third_party/amd/lib/TritonAMDGPUToLLVM/OptimizeLDSUsage.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -231,7 +231,7 @@ class OptimizeAMDLDSUsage
231231
LDSLimit = targetInfo.getSharedMemorySize();
232232
}
233233

234-
ModuleAllocation allocAnalysis = ModuleAllocation::get(mod);
234+
ModuleAllocation allocAnalysis(mod);
235235
if (allocAnalysis.getSharedMemorySize() <= LDSLimit)
236236
return;
237237

third_party/amd/lib/TritonAMDGPUToLLVM/TritonGPUToLLVM.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -107,7 +107,7 @@ struct ConvertTritonAMDGPUToLLVM
107107
}
108108

109109
// Allocate shared memory and set barrier
110-
ModuleAllocation allocation = ModuleAllocation::get(mod);
110+
ModuleAllocation allocation(mod);
111111
ModuleMembarAnalysis membarPass(&allocation);
112112
membarPass.run();
113113

third_party/intel/include/Analysis/Allocation.h

Lines changed: 3 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -3,13 +3,8 @@
33

44
#include "triton/Analysis/Allocation.h"
55

6-
namespace mlir {
7-
namespace triton::intel {
8-
class AllocationAnalysis;
9-
} // namespace triton::intel
10-
template <>
11-
void Allocation::run<triton::intel::AllocationAnalysis>(
12-
FuncAllocMapT &funcAllocMap);
13-
} // namespace mlir
6+
namespace mlir::triton::intel {
7+
unsigned allocationAnalysisScratchSizeFn(Operation *op);
8+
} // namespace mlir::triton::intel
149

1510
#endif

0 commit comments

Comments
 (0)