Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
108 changes: 59 additions & 49 deletions include/triton/Analysis/Allocation.h
Original file line number Diff line number Diff line change
Expand Up @@ -93,6 +93,45 @@ class Allocation {
using BufferIdSetT = DenseSet<BufferId>;
using FuncAllocMapT = CallGraph<Allocation>::FuncDataMapT;

/// A class that represents a shared memory buffer
struct BufferT {
/// Explicit: triton_gpu.local_alloc
/// Scratch: triton_gpu.convert_layout
/// Virtual: triton.call
enum class BufferKind { Explicit, Scratch, Virtual };

/// MT: thread-safe
inline static std::atomic<BufferId> nextId = 0;

BufferKind kind;
BufferId id;
size_t size;
size_t alignment;
size_t offset;

bool operator==(const BufferT &other) const { return id == other.id; }
bool operator<(const BufferT &other) const { return id < other.id; }

BufferT() : BufferT(BufferKind::Explicit, 0) {}
BufferT(BufferKind kind, size_t size, size_t alignment = 4,
size_t offset = 0)
: kind(kind), id(nextId++), size(size), alignment(alignment),
offset(offset) {}

size_t setOffsetAligned(size_t newOffset) {
return offset = llvm::alignTo(newOffset, alignment);
}
};

/// Op -> Scratch Buffer
using OpScratchMapT = DenseMap<Operation *, BufferT *>;
/// Value -> Explicit Buffer
using ValueBufferMapT = llvm::MapVector<Value, BufferT *>;
/// Value -> Alias Buffer
using AliasBufferMapT = llvm::MapVector<Value, llvm::SetVector<BufferT *>>;
/// BufferId -> Buffer
using BufferSetT = std::map<BufferId, BufferT>;

static constexpr BufferId InvalidBufferId =
std::numeric_limits<BufferId>::max();

Expand All @@ -102,11 +141,17 @@ class Allocation {
explicit Allocation(Operation *operation) : operation(operation) {}

/// Runs allocation analysis on the given top-level operation.
void run(FuncAllocMapT &funcAllocMap);
template <typename AllocationAnalysis> void run(FuncAllocMapT &funcAllocMap);

/// Returns the operation this analysis was constructed from.
Operation *getOperation() const { return operation; }

const OpScratchMapT &getOpScratch() const { return opScratch; }
const OpScratchMapT &getOpVirtual() const { return opVirtual; }
const ValueBufferMapT &getValueBuffer() const { return valueBuffer; }
const AliasBufferMapT &getAliasBuffer() const { return aliasBuffer; }
void setSharedMemorySize(size_t size) { sharedMemorySize = size; }

/// Returns the offset of the given buffer in the shared memory.
size_t getOffset(BufferId bufferId) const {
return bufferSet.at(bufferId).offset;
Expand Down Expand Up @@ -170,47 +215,6 @@ class Allocation {
/// Returns mapping from operation to list of live LDS buffers
std::map<Operation *, SmallVector<BufferId>> getLiveBuffers();

private:
/// A class that represents a shared memory buffer
struct BufferT {
/// Explicit: triton_gpu.local_alloc
/// Scratch: triton_gpu.convert_layout
/// Virtual: triton.call
enum class BufferKind { Explicit, Scratch, Virtual };

/// MT: thread-safe
inline static std::atomic<BufferId> nextId = 0;

BufferKind kind;
BufferId id;
size_t size;
size_t alignment;
size_t offset;

bool operator==(const BufferT &other) const { return id == other.id; }
bool operator<(const BufferT &other) const { return id < other.id; }

BufferT() : BufferT(BufferKind::Explicit, 0) {}
BufferT(BufferKind kind, size_t size, size_t alignment = 4,
size_t offset = 0)
: kind(kind), id(nextId++), size(size), alignment(alignment),
offset(offset) {}

size_t setOffsetAligned(size_t newOffset) {
return offset = llvm::alignTo(newOffset, alignment);
}
};

/// Op -> Scratch Buffer
using OpScratchMapT = DenseMap<Operation *, BufferT *>;
/// Value -> Explicit Buffer
using ValueBufferMapT = llvm::MapVector<Value, BufferT *>;
/// Value -> Alias Buffer
using AliasBufferMapT = llvm::MapVector<Value, llvm::SetVector<BufferT *>>;
/// BufferId -> Buffer
using BufferSetT = std::map<BufferId, BufferT>;

private:
template <BufferT::BufferKind Kind, typename KeyType, typename... Args>
void addBuffer(KeyType &key, Args &&...args) {
auto buffer = BufferT(Kind, std::forward<Args>(args)...);
Expand All @@ -236,10 +240,11 @@ class Allocation {
AliasBufferMapT aliasBuffer;
BufferSetT bufferSet;
size_t sharedMemorySize = 0;

friend class triton::AllocationAnalysis;
};

template <>
void Allocation::run<triton::AllocationAnalysis>(FuncAllocMapT &funcAllocMap);

/// Static analysis that computes the allocation of shared memory buffers
/// of the entire call graph.
/// The allocation is performed in a post-order walk of the call graph.
Expand All @@ -250,17 +255,19 @@ class ModuleAllocation : public CallGraph<Allocation> {
public:
using FuncOffsetMapT = DenseMap<FunctionOpInterface, Value>;

explicit ModuleAllocation(ModuleOp moduleOp)
: CallGraph<Allocation>(moduleOp) {
walk<WalkOrder::PreOrder, WalkOrder::PostOrder>(
template <typename AllocationAnalysis = triton::AllocationAnalysis>
static ModuleAllocation get(ModuleOp moduleOp) {
ModuleAllocation res(moduleOp);
res.walk<WalkOrder::PreOrder, WalkOrder::PostOrder>(
Comment on lines -253 to +261
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Cannot have templated constructors like this, so this is a static member function now.

// Pre-order edge walk callback
[](CallOpInterface callOp, FunctionOpInterface funcOp) {},
// Post-order node walk callback
[&](FunctionOpInterface funcOp) {
auto [iter, inserted] = funcMap.try_emplace(funcOp, funcOp);
auto [iter, inserted] = res.funcMap.try_emplace(funcOp, funcOp);
if (inserted)
iter->second.run(funcMap);
iter->second.template run<AllocationAnalysis>(res.funcMap);
});
return res;
}

size_t getSharedMemorySize() {
Expand All @@ -285,6 +292,9 @@ class ModuleAllocation : public CallGraph<Allocation> {
}

private:
explicit ModuleAllocation(ModuleOp moduleOp)
: CallGraph<Allocation>(moduleOp) {}

FuncOffsetMapT sharedMemoryValue;
};

Expand Down
17 changes: 9 additions & 8 deletions lib/Analysis/Allocation.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -334,7 +334,7 @@ class AllocationAnalysis {
/// Each buffer is allocated only once.
void resolveExplicitBufferLiveness(
function_ref<Interval<size_t>(Value value)> getLiveness) {
for (auto valueBufferIter : allocation->valueBuffer) {
for (auto valueBufferIter : allocation->getValueBuffer()) {
auto value = valueBufferIter.first;
auto *buffer = valueBufferIter.second;
bufferRange[buffer] = getLiveness(value);
Expand All @@ -346,7 +346,7 @@ class AllocationAnalysis {
/// arguments are involved.
void resolveAliasBufferLiveness(
function_ref<Interval<size_t>(Value value)> getLiveness) {
for (auto aliasBufferIter : allocation->aliasBuffer) {
for (auto aliasBufferIter : allocation->getAliasBuffer()) {
auto value = aliasBufferIter.first;
auto buffers = aliasBufferIter.second;
auto range = getLiveness(value);
Expand Down Expand Up @@ -379,8 +379,8 @@ class AllocationAnalysis {
operationId.lookup(op) + 1)});
}
};
processScratchMemory(allocation->opScratch);
processScratchMemory(allocation->opVirtual);
processScratchMemory(allocation->getOpScratch());
processScratchMemory(allocation->getOpVirtual());
}

/// Resolves liveness of all values involved under the root operation.
Expand Down Expand Up @@ -544,7 +544,7 @@ class AllocationAnalysis {
void allocate(const SmallVector<BufferT *> &buffers,
const GraphT &interference) {
// Reset shared memory size
allocation->sharedMemorySize = 0;
allocation->setSharedMemorySize(0);
// First-fit graph coloring
// Neighbors are nodes that interfere with each other.
// We color a node by finding the index of the first available
Expand Down Expand Up @@ -579,8 +579,8 @@ class AllocationAnalysis {
}
if (colors.lookup(x) != 0)
x->setOffsetAligned(newOffset);
allocation->sharedMemorySize =
std::max(allocation->sharedMemorySize, x->offset + x->size);
allocation->setSharedMemorySize(
std::max(allocation->getSharedMemorySize(), x->offset + x->size));
}
}

Expand All @@ -593,7 +593,8 @@ class AllocationAnalysis {

} // namespace triton

void Allocation::run(FuncAllocMapT &funcAllocMap) {
template <>
void Allocation::run<triton::AllocationAnalysis>(FuncAllocMapT &funcAllocMap) {
triton::AllocationAnalysis(getOperation(), &funcAllocMap, this);
}

Expand Down
2 changes: 1 addition & 1 deletion lib/Conversion/TritonGPUToLLVM/AllocateSharedMemory.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ struct AllocateSharedMemory
void runOnOperation() override {
ModuleOp mod = getOperation();
MLIRContext *ctx = &getContext();
ModuleAllocation allocation(mod);
ModuleAllocation allocation = ModuleAllocation::get(mod);

mod.walk([&](FunctionOpInterface funcOp) {
funcOp.walk([&](Operation *op) {
Expand Down
3 changes: 2 additions & 1 deletion python/src/passes.cc
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,8 @@ namespace py = pybind11;

void init_triton_analysis(py::module &&m) {
py::class_<mlir::ModuleAllocation>(m, "allocation", py::module_local())
.def(py::init<mlir::ModuleOp>());
.def(py::init(
&mlir::ModuleAllocation::get<mlir::triton::AllocationAnalysis>));
py::class_<mlir::ModuleMembarAnalysis>(m, "membar", py::module_local())
.def(py::init<mlir::ModuleAllocation *>())
.def("run", &mlir::ModuleMembarAnalysis::run);
Expand Down
2 changes: 1 addition & 1 deletion test/lib/Analysis/TestAllocation.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ struct TestAllocationPass
auto &os = llvm::errs();
ModuleOp moduleOp = getOperation();
// Convert to std::string can remove quotes from opName
ModuleAllocation moduleAllocation(moduleOp);
ModuleAllocation moduleAllocation = ModuleAllocation::get(moduleOp);
moduleOp.walk([&](triton::FuncOp funcOp) {
auto opName = SymbolTable::getSymbolName(funcOp).getValue().str();
os << opName << "\n";
Expand Down
2 changes: 1 addition & 1 deletion test/lib/Analysis/TestMembar.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ struct TestMembarPass
Operation *operation = getOperation();
ModuleOp moduleOp = cast<ModuleOp>(operation);
// Print all ops after membar pass
ModuleAllocation allocation(moduleOp);
ModuleAllocation allocation = ModuleAllocation::get(moduleOp);
ModuleMembarAnalysis membarPass(&allocation,
mlir::triton::NVIDIA::canSkipBarSync);
membarPass.run();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -221,7 +221,7 @@ class OptimizeAMDLDSUsage
LDSLimit = targetInfo.getSharedMemorySize();
}

ModuleAllocation allocAnalysis(mod);
ModuleAllocation allocAnalysis = ModuleAllocation::get(mod);
if (allocAnalysis.getSharedMemorySize() <= LDSLimit)
return;

Expand Down
2 changes: 1 addition & 1 deletion third_party/amd/lib/TritonAMDGPUToLLVM/TritonGPUToLLVM.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -104,7 +104,7 @@ struct ConvertTritonAMDGPUToLLVM
}

// Allocate shared memory and set barrier
ModuleAllocation allocation(mod);
ModuleAllocation allocation = ModuleAllocation::get(mod);
ModuleMembarAnalysis membarPass(&allocation);
membarPass.run();

Expand Down
15 changes: 15 additions & 0 deletions third_party/intel/include/Analysis/Allocation.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
#ifndef TRITON_INTEL_ANALYSIS_ALLOCATION_H
#define TRITON_INTEL_ANALYSIS_ALLOCATION_H

#include "triton/Analysis/Allocation.h"

namespace mlir {
namespace triton::intel {
class AllocationAnalysis;
} // namespace triton::intel
template <>
void Allocation::run<triton::intel::AllocationAnalysis>(
FuncAllocMapT &funcAllocMap);
} // namespace mlir

#endif
Loading