Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
88 changes: 43 additions & 45 deletions include/triton/Analysis/Allocation.h
Original file line number Diff line number Diff line change
Expand Up @@ -99,45 +99,6 @@ class Allocation {
using BufferIdSetT = DenseSet<BufferId>;
using FuncAllocMapT = CallGraph<Allocation>::FuncDataMapT;

/// A class that represents a shared memory buffer
struct BufferT {
/// Explicit: triton_gpu.local_alloc
/// Scratch: triton_gpu.convert_layout
/// Virtual: triton.call
enum class BufferKind { Explicit, Scratch, Virtual };

/// MT: thread-safe
inline static std::atomic<BufferId> nextId = 0;

BufferKind kind;
BufferId id;
size_t size;
size_t alignment;
size_t offset;

bool operator==(const BufferT &other) const { return id == other.id; }
bool operator<(const BufferT &other) const { return id < other.id; }

BufferT() : BufferT(BufferKind::Explicit, 0) {}
BufferT(BufferKind kind, size_t size, size_t alignment = 4,
size_t offset = 0)
: kind(kind), id(nextId++), size(size), alignment(alignment),
offset(offset) {}

size_t setOffsetAligned(size_t newOffset) {
return offset = llvm::alignTo(newOffset, alignment);
}
};

/// Op -> Scratch Buffer
using OpScratchMapT = DenseMap<Operation *, BufferT *>;
/// Value -> Explicit Buffer
using ValueBufferMapT = llvm::MapVector<Value, BufferT *>;
/// Value -> Alias Buffer
using AliasBufferMapT = llvm::MapVector<Value, llvm::SetVector<BufferT *>>;
/// BufferId -> Buffer
using BufferSetT = std::map<BufferId, BufferT>;

static constexpr BufferId InvalidBufferId =
std::numeric_limits<BufferId>::max();

Expand All @@ -153,12 +114,6 @@ class Allocation {
/// Returns the operation this analysis was constructed from.
Operation *getOperation() const { return operation; }

const OpScratchMapT &getOpScratch() const { return opScratch; }
const OpScratchMapT &getOpVirtual() const { return opVirtual; }
const ValueBufferMapT &getValueBuffer() const { return valueBuffer; }
const AliasBufferMapT &getAliasBuffer() const { return aliasBuffer; }
void setSharedMemorySize(size_t size) { sharedMemorySize = size; }

/// Returns the offset of the given buffer in the shared memory.
size_t getOffset(BufferId bufferId) const {
return bufferSet.at(bufferId).offset;
Expand Down Expand Up @@ -222,6 +177,47 @@ class Allocation {
/// Returns mapping from operation to list of live LDS buffers
std::map<Operation *, SmallVector<BufferId>> getLiveBuffers();

private:
/// A class that represents a shared memory buffer
struct BufferT {
/// Explicit: triton_gpu.local_alloc
/// Scratch: triton_gpu.convert_layout
/// Virtual: triton.call
enum class BufferKind { Explicit, Scratch, Virtual };

/// MT: thread-safe
inline static std::atomic<BufferId> nextId = 0;

BufferKind kind;
BufferId id;
size_t size;
size_t alignment;
size_t offset;

bool operator==(const BufferT &other) const { return id == other.id; }
bool operator<(const BufferT &other) const { return id < other.id; }

BufferT() : BufferT(BufferKind::Explicit, 0) {}
BufferT(BufferKind kind, size_t size, size_t alignment = 4,
size_t offset = 0)
: kind(kind), id(nextId++), size(size), alignment(alignment),
offset(offset) {}

size_t setOffsetAligned(size_t newOffset) {
return offset = llvm::alignTo(newOffset, alignment);
}
};

/// Op -> Scratch Buffer
using OpScratchMapT = DenseMap<Operation *, BufferT *>;
/// Value -> Explicit Buffer
using ValueBufferMapT = llvm::MapVector<Value, BufferT *>;
/// Value -> Alias Buffer
using AliasBufferMapT = llvm::MapVector<Value, llvm::SetVector<BufferT *>>;
/// BufferId -> Buffer
using BufferSetT = std::map<BufferId, BufferT>;

private:
template <BufferT::BufferKind Kind, typename KeyType, typename... Args>
void addBuffer(KeyType &key, Args &&...args) {
auto buffer = BufferT(Kind, std::forward<Args>(args)...);
Expand All @@ -247,6 +243,8 @@ class Allocation {
AliasBufferMapT aliasBuffer;
BufferSetT bufferSet;
size_t sharedMemorySize = 0;

friend class triton::AllocationAnalysis;
};

/// Static analysis that computes the allocation of shared memory buffers
Expand Down
14 changes: 7 additions & 7 deletions lib/Analysis/Allocation.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -289,7 +289,7 @@ class AllocationAnalysis {
/// Each buffer is allocated only once.
void resolveExplicitBufferLiveness(
function_ref<Interval<size_t>(Value value)> getLiveness) {
for (auto valueBufferIter : allocation->getValueBuffer()) {
for (auto valueBufferIter : allocation->valueBuffer) {
auto value = valueBufferIter.first;
auto *buffer = valueBufferIter.second;
bufferRange[buffer] = getLiveness(value);
Expand All @@ -301,7 +301,7 @@ class AllocationAnalysis {
/// arguments are involved.
void resolveAliasBufferLiveness(
function_ref<Interval<size_t>(Value value)> getLiveness) {
for (const auto &aliasBufferIter : allocation->getAliasBuffer()) {
for (const auto &aliasBufferIter : allocation->aliasBuffer) {
auto value = aliasBufferIter.first;
auto buffers = aliasBufferIter.second;
auto range = getLiveness(value);
Expand Down Expand Up @@ -334,8 +334,8 @@ class AllocationAnalysis {
operationId.lookup(op) + 1)});
}
};
processScratchMemory(allocation->getOpScratch());
processScratchMemory(allocation->getOpVirtual());
processScratchMemory(allocation->opScratch);
processScratchMemory(allocation->opVirtual);
}

/// Resolves liveness of all values involved under the root operation.
Expand Down Expand Up @@ -499,7 +499,7 @@ class AllocationAnalysis {
void allocate(const SmallVector<BufferT *> &buffers,
const GraphT &interference) {
// Reset shared memory size
allocation->setSharedMemorySize(0);
allocation->sharedMemorySize = 0;
// First-fit graph coloring
// Neighbors are nodes that interfere with each other.
// We color a node by finding the index of the first available
Expand Down Expand Up @@ -534,8 +534,8 @@ class AllocationAnalysis {
}
if (colors.lookup(x) != 0)
x->setOffsetAligned(newOffset);
allocation->setSharedMemorySize(
std::max(allocation->getSharedMemorySize(), x->offset + x->size));
allocation->sharedMemorySize =
std::max(allocation->sharedMemorySize, x->offset + x->size);
}
}

Expand Down