Skip to content

Commit 6e6cff9

Browse files
authored
Support running different allocation analyses (#2605)
Support running different allocation analyses while reusing the allocation infrastructure. **Note:** This is downstream work and should be replaced with whatever approach is taken upstream in the next few weeks. --------- Signed-off-by: victor-eds <[email protected]>
1 parent 99778f4 commit 6e6cff9

File tree

14 files changed

+701
-68
lines changed

14 files changed

+701
-68
lines changed

include/triton/Analysis/Allocation.h

Lines changed: 59 additions & 49 deletions
Original file line numberDiff line numberDiff line change
@@ -93,6 +93,45 @@ class Allocation {
9393
using BufferIdSetT = DenseSet<BufferId>;
9494
using FuncAllocMapT = CallGraph<Allocation>::FuncDataMapT;
9595

96+
/// A class that represents a shared memory buffer
97+
struct BufferT {
98+
/// Explicit: triton_gpu.local_alloc
99+
/// Scratch: triton_gpu.convert_layout
100+
/// Virtual: triton.call
101+
enum class BufferKind { Explicit, Scratch, Virtual };
102+
103+
/// MT: thread-safe
104+
inline static std::atomic<BufferId> nextId = 0;
105+
106+
BufferKind kind;
107+
BufferId id;
108+
size_t size;
109+
size_t alignment;
110+
size_t offset;
111+
112+
bool operator==(const BufferT &other) const { return id == other.id; }
113+
bool operator<(const BufferT &other) const { return id < other.id; }
114+
115+
BufferT() : BufferT(BufferKind::Explicit, 0) {}
116+
BufferT(BufferKind kind, size_t size, size_t alignment = 4,
117+
size_t offset = 0)
118+
: kind(kind), id(nextId++), size(size), alignment(alignment),
119+
offset(offset) {}
120+
121+
size_t setOffsetAligned(size_t newOffset) {
122+
return offset = llvm::alignTo(newOffset, alignment);
123+
}
124+
};
125+
126+
/// Op -> Scratch Buffer
127+
using OpScratchMapT = DenseMap<Operation *, BufferT *>;
128+
/// Value -> Explicit Buffer
129+
using ValueBufferMapT = llvm::MapVector<Value, BufferT *>;
130+
/// Value -> Alias Buffer
131+
using AliasBufferMapT = llvm::MapVector<Value, llvm::SetVector<BufferT *>>;
132+
/// BufferId -> Buffer
133+
using BufferSetT = std::map<BufferId, BufferT>;
134+
96135
static constexpr BufferId InvalidBufferId =
97136
std::numeric_limits<BufferId>::max();
98137

@@ -102,11 +141,17 @@ class Allocation {
102141
explicit Allocation(Operation *operation) : operation(operation) {}
103142

104143
/// Runs allocation analysis on the given top-level operation.
105-
void run(FuncAllocMapT &funcAllocMap);
144+
template <typename AllocationAnalysis> void run(FuncAllocMapT &funcAllocMap);
106145

107146
/// Returns the operation this analysis was constructed from.
108147
Operation *getOperation() const { return operation; }
109148

149+
const OpScratchMapT &getOpScratch() const { return opScratch; }
150+
const OpScratchMapT &getOpVirtual() const { return opVirtual; }
151+
const ValueBufferMapT &getValueBuffer() const { return valueBuffer; }
152+
const AliasBufferMapT &getAliasBuffer() const { return aliasBuffer; }
153+
void setSharedMemorySize(size_t size) { sharedMemorySize = size; }
154+
110155
/// Returns the offset of the given buffer in the shared memory.
111156
size_t getOffset(BufferId bufferId) const {
112157
return bufferSet.at(bufferId).offset;
@@ -170,47 +215,6 @@ class Allocation {
170215
/// Returns mapping from operation to list of live LDS buffers
171216
std::map<Operation *, SmallVector<BufferId>> getLiveBuffers();
172217

173-
private:
174-
/// A class that represents a shared memory buffer
175-
struct BufferT {
176-
/// Explicit: triton_gpu.local_alloc
177-
/// Scratch: triton_gpu.convert_layout
178-
/// Virtual: triton.call
179-
enum class BufferKind { Explicit, Scratch, Virtual };
180-
181-
/// MT: thread-safe
182-
inline static std::atomic<BufferId> nextId = 0;
183-
184-
BufferKind kind;
185-
BufferId id;
186-
size_t size;
187-
size_t alignment;
188-
size_t offset;
189-
190-
bool operator==(const BufferT &other) const { return id == other.id; }
191-
bool operator<(const BufferT &other) const { return id < other.id; }
192-
193-
BufferT() : BufferT(BufferKind::Explicit, 0) {}
194-
BufferT(BufferKind kind, size_t size, size_t alignment = 4,
195-
size_t offset = 0)
196-
: kind(kind), id(nextId++), size(size), alignment(alignment),
197-
offset(offset) {}
198-
199-
size_t setOffsetAligned(size_t newOffset) {
200-
return offset = llvm::alignTo(newOffset, alignment);
201-
}
202-
};
203-
204-
/// Op -> Scratch Buffer
205-
using OpScratchMapT = DenseMap<Operation *, BufferT *>;
206-
/// Value -> Explicit Buffer
207-
using ValueBufferMapT = llvm::MapVector<Value, BufferT *>;
208-
/// Value -> Alias Buffer
209-
using AliasBufferMapT = llvm::MapVector<Value, llvm::SetVector<BufferT *>>;
210-
/// BufferId -> Buffer
211-
using BufferSetT = std::map<BufferId, BufferT>;
212-
213-
private:
214218
template <BufferT::BufferKind Kind, typename KeyType, typename... Args>
215219
void addBuffer(KeyType &key, Args &&...args) {
216220
auto buffer = BufferT(Kind, std::forward<Args>(args)...);
@@ -236,10 +240,11 @@ class Allocation {
236240
AliasBufferMapT aliasBuffer;
237241
BufferSetT bufferSet;
238242
size_t sharedMemorySize = 0;
239-
240-
friend class triton::AllocationAnalysis;
241243
};
242244

245+
template <>
246+
void Allocation::run<triton::AllocationAnalysis>(FuncAllocMapT &funcAllocMap);
247+
243248
/// Static analysis that computes the allocation of shared memory buffers
244249
/// of the entire call graph.
245250
/// The allocation is performed in a post-order walk of the call graph.
@@ -250,17 +255,19 @@ class ModuleAllocation : public CallGraph<Allocation> {
250255
public:
251256
using FuncOffsetMapT = DenseMap<FunctionOpInterface, Value>;
252257

253-
explicit ModuleAllocation(ModuleOp moduleOp)
254-
: CallGraph<Allocation>(moduleOp) {
255-
walk<WalkOrder::PreOrder, WalkOrder::PostOrder>(
258+
template <typename AllocationAnalysis = triton::AllocationAnalysis>
259+
static ModuleAllocation get(ModuleOp moduleOp) {
260+
ModuleAllocation res(moduleOp);
261+
res.walk<WalkOrder::PreOrder, WalkOrder::PostOrder>(
256262
// Pre-order edge walk callback
257263
[](CallOpInterface callOp, FunctionOpInterface funcOp) {},
258264
// Post-order node walk callback
259265
[&](FunctionOpInterface funcOp) {
260-
auto [iter, inserted] = funcMap.try_emplace(funcOp, funcOp);
266+
auto [iter, inserted] = res.funcMap.try_emplace(funcOp, funcOp);
261267
if (inserted)
262-
iter->second.run(funcMap);
268+
iter->second.template run<AllocationAnalysis>(res.funcMap);
263269
});
270+
return res;
264271
}
265272

266273
size_t getSharedMemorySize() {
@@ -285,6 +292,9 @@ class ModuleAllocation : public CallGraph<Allocation> {
285292
}
286293

287294
private:
295+
explicit ModuleAllocation(ModuleOp moduleOp)
296+
: CallGraph<Allocation>(moduleOp) {}
297+
288298
FuncOffsetMapT sharedMemoryValue;
289299
};
290300

lib/Analysis/Allocation.cpp

Lines changed: 9 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -334,7 +334,7 @@ class AllocationAnalysis {
334334
/// Each buffer is allocated only once.
335335
void resolveExplicitBufferLiveness(
336336
function_ref<Interval<size_t>(Value value)> getLiveness) {
337-
for (auto valueBufferIter : allocation->valueBuffer) {
337+
for (auto valueBufferIter : allocation->getValueBuffer()) {
338338
auto value = valueBufferIter.first;
339339
auto *buffer = valueBufferIter.second;
340340
bufferRange[buffer] = getLiveness(value);
@@ -346,7 +346,7 @@ class AllocationAnalysis {
346346
/// arguments are involved.
347347
void resolveAliasBufferLiveness(
348348
function_ref<Interval<size_t>(Value value)> getLiveness) {
349-
for (auto aliasBufferIter : allocation->aliasBuffer) {
349+
for (auto aliasBufferIter : allocation->getAliasBuffer()) {
350350
auto value = aliasBufferIter.first;
351351
auto buffers = aliasBufferIter.second;
352352
auto range = getLiveness(value);
@@ -379,8 +379,8 @@ class AllocationAnalysis {
379379
operationId.lookup(op) + 1)});
380380
}
381381
};
382-
processScratchMemory(allocation->opScratch);
383-
processScratchMemory(allocation->opVirtual);
382+
processScratchMemory(allocation->getOpScratch());
383+
processScratchMemory(allocation->getOpVirtual());
384384
}
385385

386386
/// Resolves liveness of all values involved under the root operation.
@@ -544,7 +544,7 @@ class AllocationAnalysis {
544544
void allocate(const SmallVector<BufferT *> &buffers,
545545
const GraphT &interference) {
546546
// Reset shared memory size
547-
allocation->sharedMemorySize = 0;
547+
allocation->setSharedMemorySize(0);
548548
// First-fit graph coloring
549549
// Neighbors are nodes that interfere with each other.
550550
// We color a node by finding the index of the first available
@@ -579,8 +579,8 @@ class AllocationAnalysis {
579579
}
580580
if (colors.lookup(x) != 0)
581581
x->setOffsetAligned(newOffset);
582-
allocation->sharedMemorySize =
583-
std::max(allocation->sharedMemorySize, x->offset + x->size);
582+
allocation->setSharedMemorySize(
583+
std::max(allocation->getSharedMemorySize(), x->offset + x->size));
584584
}
585585
}
586586

@@ -593,7 +593,8 @@ class AllocationAnalysis {
593593

594594
} // namespace triton
595595

596-
void Allocation::run(FuncAllocMapT &funcAllocMap) {
596+
template <>
597+
void Allocation::run<triton::AllocationAnalysis>(FuncAllocMapT &funcAllocMap) {
597598
triton::AllocationAnalysis(getOperation(), &funcAllocMap, this);
598599
}
599600

lib/Conversion/TritonGPUToLLVM/AllocateSharedMemory.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@ struct AllocateSharedMemory
2323
void runOnOperation() override {
2424
ModuleOp mod = getOperation();
2525
MLIRContext *ctx = &getContext();
26-
ModuleAllocation allocation(mod);
26+
ModuleAllocation allocation = ModuleAllocation::get(mod);
2727

2828
mod.walk([&](FunctionOpInterface funcOp) {
2929
funcOp.walk([&](Operation *op) {

python/src/passes.cc

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,8 @@ namespace py = pybind11;
1717

1818
void init_triton_analysis(py::module &&m) {
1919
py::class_<mlir::ModuleAllocation>(m, "allocation", py::module_local())
20-
.def(py::init<mlir::ModuleOp>());
20+
.def(py::init(
21+
&mlir::ModuleAllocation::get<mlir::triton::AllocationAnalysis>));
2122
py::class_<mlir::ModuleMembarAnalysis>(m, "membar", py::module_local())
2223
.def(py::init<mlir::ModuleAllocation *>())
2324
.def("run", &mlir::ModuleMembarAnalysis::run);

test/lib/Analysis/TestAllocation.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@ struct TestAllocationPass
1919
auto &os = llvm::errs();
2020
ModuleOp moduleOp = getOperation();
2121
// Convert to std::string can remove quotes from opName
22-
ModuleAllocation moduleAllocation(moduleOp);
22+
ModuleAllocation moduleAllocation = ModuleAllocation::get(moduleOp);
2323
moduleOp.walk([&](triton::FuncOp funcOp) {
2424
auto opName = SymbolTable::getSymbolName(funcOp).getValue().str();
2525
os << opName << "\n";

test/lib/Analysis/TestMembar.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,7 @@ struct TestMembarPass
2525
Operation *operation = getOperation();
2626
ModuleOp moduleOp = cast<ModuleOp>(operation);
2727
// Print all ops after membar pass
28-
ModuleAllocation allocation(moduleOp);
28+
ModuleAllocation allocation = ModuleAllocation::get(moduleOp);
2929
ModuleMembarAnalysis membarPass(&allocation,
3030
mlir::triton::NVIDIA::canSkipBarSync);
3131
membarPass.run();

third_party/amd/lib/TritonAMDGPUToLLVM/OptimizeLDSUsage.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -221,7 +221,7 @@ class OptimizeAMDLDSUsage
221221
LDSLimit = targetInfo.getSharedMemorySize();
222222
}
223223

224-
ModuleAllocation allocAnalysis(mod);
224+
ModuleAllocation allocAnalysis = ModuleAllocation::get(mod);
225225
if (allocAnalysis.getSharedMemorySize() <= LDSLimit)
226226
return;
227227

third_party/amd/lib/TritonAMDGPUToLLVM/TritonGPUToLLVM.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -104,7 +104,7 @@ struct ConvertTritonAMDGPUToLLVM
104104
}
105105

106106
// Allocate shared memory and set barrier
107-
ModuleAllocation allocation(mod);
107+
ModuleAllocation allocation = ModuleAllocation::get(mod);
108108
ModuleMembarAnalysis membarPass(&allocation);
109109
membarPass.run();
110110

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,15 @@
1+
#ifndef TRITON_INTEL_ANALYSIS_ALLOCATION_H
2+
#define TRITON_INTEL_ANALYSIS_ALLOCATION_H
3+
4+
#include "triton/Analysis/Allocation.h"
5+
6+
namespace mlir {
7+
namespace triton::intel {
8+
class AllocationAnalysis;
9+
} // namespace triton::intel
10+
template <>
11+
void Allocation::run<triton::intel::AllocationAnalysis>(
12+
FuncAllocMapT &funcAllocMap);
13+
} // namespace mlir
14+
15+
#endif

0 commit comments

Comments
 (0)