Skip to content
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
24 changes: 18 additions & 6 deletions include/triton/Analysis/Allocation.h
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,9 @@ namespace mlir {

namespace triton {
class AllocationAnalysis;
namespace intel {
class AllocationAnalysis;
}

// To convert a tensor from one layout to another, we need to allocate a
// temporary buffer (i.e., scratch buffer) in shared memory. The conversion may
Expand Down Expand Up @@ -102,7 +105,7 @@ class Allocation {
explicit Allocation(Operation *operation) : operation(operation) {}

/// Runs allocation analysis on the given top-level operation.
void run(FuncAllocMapT &funcAllocMap);
template <typename AllocationAnalysis> void run(FuncAllocMapT &funcAllocMap);

/// Returns the operation this analysis was constructed from.
Operation *getOperation() const { return operation; }
Expand Down Expand Up @@ -238,8 +241,12 @@ class Allocation {
size_t sharedMemorySize = 0;

friend class triton::AllocationAnalysis;
friend class triton::intel::AllocationAnalysis;
};

template <>
void Allocation::run<triton::AllocationAnalysis>(FuncAllocMapT &funcAllocMap);

/// Static analysis that computes the allocation of shared memory buffers
/// of the entire call graph.
/// The allocation is performed in a post-order walk of the call graph.
Expand All @@ -250,17 +257,19 @@ class ModuleAllocation : public CallGraph<Allocation> {
public:
using FuncOffsetMapT = DenseMap<FunctionOpInterface, Value>;

explicit ModuleAllocation(ModuleOp moduleOp)
: CallGraph<Allocation>(moduleOp) {
walk<WalkOrder::PreOrder, WalkOrder::PostOrder>(
template <typename AllocationAnalysis = triton::AllocationAnalysis>
static ModuleAllocation get(ModuleOp moduleOp) {
ModuleAllocation res(moduleOp);
res.walk<WalkOrder::PreOrder, WalkOrder::PostOrder>(
Comment on lines -253 to +261
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Cannot have templated constructors like this, so this is a static member function now.

// Pre-order edge walk callback
[](CallOpInterface callOp, FunctionOpInterface funcOp) {},
// Post-order node walk callback
[&](FunctionOpInterface funcOp) {
auto [iter, inserted] = funcMap.try_emplace(funcOp, funcOp);
auto [iter, inserted] = res.funcMap.try_emplace(funcOp, funcOp);
if (inserted)
iter->second.run(funcMap);
iter->second.template run<AllocationAnalysis>(res.funcMap);
});
return res;
}

size_t getSharedMemorySize() {
Expand All @@ -285,6 +294,9 @@ class ModuleAllocation : public CallGraph<Allocation> {
}

private:
explicit ModuleAllocation(ModuleOp moduleOp)
: CallGraph<Allocation>(moduleOp) {}

FuncOffsetMapT sharedMemoryValue;
};

Expand Down
3 changes: 2 additions & 1 deletion lib/Analysis/Allocation.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -588,7 +588,8 @@ class AllocationAnalysis {

} // namespace triton

void Allocation::run(FuncAllocMapT &funcAllocMap) {
template <>
void Allocation::run<triton::AllocationAnalysis>(FuncAllocMapT &funcAllocMap) {
triton::AllocationAnalysis(getOperation(), &funcAllocMap, this);
}

Expand Down
2 changes: 1 addition & 1 deletion lib/Conversion/TritonGPUToLLVM/AllocateSharedMemory.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ struct AllocateSharedMemory
void runOnOperation() override {
ModuleOp mod = getOperation();
MLIRContext *ctx = &getContext();
ModuleAllocation allocation(mod);
ModuleAllocation allocation = ModuleAllocation::get(mod);

mod.walk([&](FunctionOpInterface funcOp) {
funcOp.walk([&](Operation *op) {
Expand Down
2 changes: 1 addition & 1 deletion test/lib/Analysis/TestAllocation.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ struct TestAllocationPass
auto &os = llvm::errs();
ModuleOp moduleOp = getOperation();
// Convert to std::string can remove quotes from opName
ModuleAllocation moduleAllocation(moduleOp);
ModuleAllocation moduleAllocation = ModuleAllocation::get(moduleOp);
moduleOp.walk([&](triton::FuncOp funcOp) {
auto opName = SymbolTable::getSymbolName(funcOp).getValue().str();
os << opName << "\n";
Expand Down
2 changes: 1 addition & 1 deletion test/lib/Analysis/TestMembar.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ struct TestMembarPass
Operation *operation = getOperation();
ModuleOp moduleOp = cast<ModuleOp>(operation);
// Print all ops after membar pass
ModuleAllocation allocation(moduleOp);
ModuleAllocation allocation = ModuleAllocation::get(moduleOp);
ModuleMembarAnalysis membarPass(&allocation,
mlir::triton::NVIDIA::canSkipBarSync);
membarPass.run();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -221,7 +221,7 @@ class OptimizeAMDLDSUsage
LDSLimit = targetInfo.getSharedMemorySize();
}

ModuleAllocation allocAnalysis(mod);
ModuleAllocation allocAnalysis = ModuleAllocation::get(mod);
if (allocAnalysis.getSharedMemorySize() <= LDSLimit)
return;

Expand Down
2 changes: 1 addition & 1 deletion third_party/amd/lib/TritonAMDGPUToLLVM/TritonGPUToLLVM.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -104,7 +104,7 @@ struct ConvertTritonAMDGPUToLLVM
}

// Allocate shared memory and set barrier
ModuleAllocation allocation(mod);
ModuleAllocation allocation = ModuleAllocation::get(mod);
ModuleMembarAnalysis membarPass(&allocation);
membarPass.run();

Expand Down
12 changes: 12 additions & 0 deletions third_party/intel/include/Analysis/Allocation.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
#ifndef TRITON_INTEL_ANALYSIS_ALLOCATION_H
#define TRITON_INTEL_ANALYSIS_ALLOCATION_H

#include "triton/Analysis/Allocation.h"

namespace mlir {
template <>
void Allocation::run<triton::intel::AllocationAnalysis>(
FuncAllocMapT &funcAllocMap);
} // namespace mlir

#endif
Loading
Loading