Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions include/triton/Analysis/Allocation.h
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,9 @@ namespace mlir {

namespace triton {
class AllocationAnalysis;
namespace intel {
class AllocationAnalysis;
}

// To convert a tensor from one layout to another, we need to allocate a
// temporary buffer (i.e., scratch buffer) in shared memory. The conversion may
Expand Down Expand Up @@ -238,6 +241,7 @@ class Allocation {
size_t sharedMemorySize = 0;

friend class triton::AllocationAnalysis;
friend class triton::intel::AllocationAnalysis;
};

template <>
Expand Down
12 changes: 12 additions & 0 deletions third_party/intel/include/Analysis/Allocation.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
#ifndef TRITON_INTEL_ANALYSIS_ALLOCATION_H
#define TRITON_INTEL_ANALYSIS_ALLOCATION_H

#include "triton/Analysis/Allocation.h"

namespace mlir {
template <>
void Allocation::run<triton::intel::AllocationAnalysis>(
FuncAllocMapT &funcAllocMap);
} // namespace mlir

#endif
1 change: 1 addition & 0 deletions third_party/intel/lib/Analysis/CMakeLists.txt
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
add_triton_library(TritonIntelAnalysis
Allocation.cpp
AxisInfo.cpp
DPAS.cpp
Liveness.cpp
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,8 @@ struct AllocateSharedMemory
void runOnOperation() override {
ModuleOp mod = getOperation();
MLIRContext *ctx = &getContext();
ModuleAllocation allocation = ModuleAllocation::get(mod);
ModuleAllocation allocation =
ModuleAllocation::get<triton::intel::AllocationAnalysis>(mod);

mod.walk([&](FunctionOpInterface funcOp) {
if (allocation.isRoot(funcOp) && allocation.getSharedMemorySize()) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -94,7 +94,8 @@ struct ConvertTritonGPUToLLVM

// Allocate shared memory and set barrier
if (!pipelineManager.skipSharedMemoryAllocation()) {
ModuleAllocation allocation = ModuleAllocation::get(mod);
ModuleAllocation allocation =
ModuleAllocation::get<triton::intel::AllocationAnalysis>(mod);
ModuleMembarAnalysis membarPass(&allocation);
membarPass.run();
}
Expand Down
Loading