@@ -93,6 +93,45 @@ class Allocation {
9393 using BufferIdSetT = DenseSet<BufferId>;
9494 using FuncAllocMapT = CallGraph<Allocation>::FuncDataMapT;
9595
96+ // / A class that represents a shared memory buffer
97+ struct BufferT {
98+ // / Explicit: triton_gpu.local_alloc
99+ // / Scratch: triton_gpu.convert_layout
100+ // / Virtual: triton.call
101+ enum class BufferKind { Explicit, Scratch, Virtual };
102+
103+ // / MT: thread-safe
104+ inline static std::atomic<BufferId> nextId = 0 ;
105+
106+ BufferKind kind;
107+ BufferId id;
108+ size_t size;
109+ size_t alignment;
110+ size_t offset;
111+
112+ bool operator ==(const BufferT &other) const { return id == other.id ; }
113+ bool operator <(const BufferT &other) const { return id < other.id ; }
114+
115+ BufferT () : BufferT(BufferKind::Explicit, 0 ) {}
116+ BufferT (BufferKind kind, size_t size, size_t alignment = 4 ,
117+ size_t offset = 0 )
118+ : kind(kind), id(nextId++), size(size), alignment(alignment),
119+ offset (offset) {}
120+
121+ size_t setOffsetAligned (size_t newOffset) {
122+ return offset = llvm::alignTo (newOffset, alignment);
123+ }
124+ };
125+
126+ // / Op -> Scratch Buffer
127+ using OpScratchMapT = DenseMap<Operation *, BufferT *>;
128+ // / Value -> Explicit Buffer
129+ using ValueBufferMapT = llvm::MapVector<Value, BufferT *>;
130+ // / Value -> Alias Buffer
131+ using AliasBufferMapT = llvm::MapVector<Value, llvm::SetVector<BufferT *>>;
132+ // / BufferId -> Buffer
133+ using BufferSetT = std::map<BufferId, BufferT>;
134+
96135 static constexpr BufferId InvalidBufferId =
97136 std::numeric_limits<BufferId>::max();
98137
@@ -102,11 +141,17 @@ class Allocation {
102141 explicit Allocation (Operation *operation) : operation(operation) {}
103142
104143 // / Runs allocation analysis on the given top-level operation.
105- void run (FuncAllocMapT &funcAllocMap);
144+ template < typename AllocationAnalysis> void run (FuncAllocMapT &funcAllocMap);
106145
107146 // / Returns the operation this analysis was constructed from.
108147 Operation *getOperation () const { return operation; }
109148
149+ const OpScratchMapT &getOpScratch () const { return opScratch; }
150+ const OpScratchMapT &getOpVirtual () const { return opVirtual; }
151+ const ValueBufferMapT &getValueBuffer () const { return valueBuffer; }
152+ const AliasBufferMapT &getAliasBuffer () const { return aliasBuffer; }
153+ void setSharedMemorySize (size_t size) { sharedMemorySize = size; }
154+
110155 // / Returns the offset of the given buffer in the shared memory.
111156 size_t getOffset (BufferId bufferId) const {
112157 return bufferSet.at (bufferId).offset ;
@@ -170,47 +215,6 @@ class Allocation {
170215 // / Returns mapping from operation to list of live LDS buffers
171216 std::map<Operation *, SmallVector<BufferId>> getLiveBuffers ();
172217
173- private:
174- // / A class that represents a shared memory buffer
175- struct BufferT {
176- // / Explicit: triton_gpu.local_alloc
177- // / Scratch: triton_gpu.convert_layout
178- // / Virtual: triton.call
179- enum class BufferKind { Explicit, Scratch, Virtual };
180-
181- // / MT: thread-safe
182- inline static std::atomic<BufferId> nextId = 0 ;
183-
184- BufferKind kind;
185- BufferId id;
186- size_t size;
187- size_t alignment;
188- size_t offset;
189-
190- bool operator ==(const BufferT &other) const { return id == other.id ; }
191- bool operator <(const BufferT &other) const { return id < other.id ; }
192-
193- BufferT () : BufferT(BufferKind::Explicit, 0 ) {}
194- BufferT (BufferKind kind, size_t size, size_t alignment = 4 ,
195- size_t offset = 0 )
196- : kind(kind), id(nextId++), size(size), alignment(alignment),
197- offset (offset) {}
198-
199- size_t setOffsetAligned (size_t newOffset) {
200- return offset = llvm::alignTo (newOffset, alignment);
201- }
202- };
203-
204- // / Op -> Scratch Buffer
205- using OpScratchMapT = DenseMap<Operation *, BufferT *>;
206- // / Value -> Explicit Buffer
207- using ValueBufferMapT = llvm::MapVector<Value, BufferT *>;
208- // / Value -> Alias Buffer
209- using AliasBufferMapT = llvm::MapVector<Value, llvm::SetVector<BufferT *>>;
210- // / BufferId -> Buffer
211- using BufferSetT = std::map<BufferId, BufferT>;
212-
213- private:
214218 template <BufferT::BufferKind Kind, typename KeyType, typename ... Args>
215219 void addBuffer (KeyType &key, Args &&...args) {
216220 auto buffer = BufferT (Kind, std::forward<Args>(args)...);
@@ -236,10 +240,11 @@ class Allocation {
236240 AliasBufferMapT aliasBuffer;
237241 BufferSetT bufferSet;
238242 size_t sharedMemorySize = 0 ;
239-
240- friend class triton ::AllocationAnalysis;
241243};
242244
245+ template <>
246+ void Allocation::run<triton::AllocationAnalysis>(FuncAllocMapT &funcAllocMap);
247+
243248// / Static analysis that computes the allocation of shared memory buffers
244249// / of the entire call graph.
245250// / The allocation is performed in a post-order walk of the call graph.
@@ -250,17 +255,19 @@ class ModuleAllocation : public CallGraph<Allocation> {
250255public:
251256 using FuncOffsetMapT = DenseMap<FunctionOpInterface, Value>;
252257
253- explicit ModuleAllocation (ModuleOp moduleOp)
254- : CallGraph<Allocation>(moduleOp) {
255- walk<WalkOrder::PreOrder, WalkOrder::PostOrder>(
258+ template <typename AllocationAnalysis = triton::AllocationAnalysis>
259+ static ModuleAllocation get (ModuleOp moduleOp) {
260+ ModuleAllocation res (moduleOp);
261+ res.walk <WalkOrder::PreOrder, WalkOrder::PostOrder>(
256262 // Pre-order edge walk callback
257263 [](CallOpInterface callOp, FunctionOpInterface funcOp) {},
258264 // Post-order node walk callback
259265 [&](FunctionOpInterface funcOp) {
260- auto [iter, inserted] = funcMap.try_emplace (funcOp, funcOp);
266+ auto [iter, inserted] = res. funcMap .try_emplace (funcOp, funcOp);
261267 if (inserted)
262- iter->second .run ( funcMap);
268+ iter->second .template run <AllocationAnalysis>(res. funcMap );
263269 });
270+ return res;
264271 }
265272
266273 size_t getSharedMemorySize () {
@@ -285,6 +292,9 @@ class ModuleAllocation : public CallGraph<Allocation> {
285292 }
286293
287294private:
295+ explicit ModuleAllocation (ModuleOp moduleOp)
296+ : CallGraph<Allocation>(moduleOp) {}
297+
288298 FuncOffsetMapT sharedMemoryValue;
289299};
290300
0 commit comments