diff --git a/include/triton/Analysis/Allocation.h b/include/triton/Analysis/Allocation.h index 3a488e65ed..5d5f1a5709 100644 --- a/include/triton/Analysis/Allocation.h +++ b/include/triton/Analysis/Allocation.h @@ -93,6 +93,45 @@ class Allocation { using BufferIdSetT = DenseSet; using FuncAllocMapT = CallGraph::FuncDataMapT; + /// A class that represents a shared memory buffer + struct BufferT { + /// Explicit: triton_gpu.local_alloc + /// Scratch: triton_gpu.convert_layout + /// Virtual: triton.call + enum class BufferKind { Explicit, Scratch, Virtual }; + + /// MT: thread-safe + inline static std::atomic nextId = 0; + + BufferKind kind; + BufferId id; + size_t size; + size_t alignment; + size_t offset; + + bool operator==(const BufferT &other) const { return id == other.id; } + bool operator<(const BufferT &other) const { return id < other.id; } + + BufferT() : BufferT(BufferKind::Explicit, 0) {} + BufferT(BufferKind kind, size_t size, size_t alignment = 4, + size_t offset = 0) + : kind(kind), id(nextId++), size(size), alignment(alignment), + offset(offset) {} + + size_t setOffsetAligned(size_t newOffset) { + return offset = llvm::alignTo(newOffset, alignment); + } + }; + + /// Op -> Scratch Buffer + using OpScratchMapT = DenseMap; + /// Value -> Explicit Buffer + using ValueBufferMapT = llvm::MapVector; + /// Value -> Alias Buffer + using AliasBufferMapT = llvm::MapVector>; + /// BufferId -> Buffer + using BufferSetT = std::map; + static constexpr BufferId InvalidBufferId = std::numeric_limits::max(); @@ -102,11 +141,17 @@ class Allocation { explicit Allocation(Operation *operation) : operation(operation) {} /// Runs allocation analysis on the given top-level operation. - void run(FuncAllocMapT &funcAllocMap); + template void run(FuncAllocMapT &funcAllocMap); /// Returns the operation this analysis was constructed from. Operation *getOperation() const { return operation; } + const OpScratchMapT &getOpScratch() const { return opScratch; } + const OpScratchMapT &getOpVirtual() const { return opVirtual; } + const ValueBufferMapT &getValueBuffer() const { return valueBuffer; } + const AliasBufferMapT &getAliasBuffer() const { return aliasBuffer; } + void setSharedMemorySize(size_t size) { sharedMemorySize = size; } + /// Returns the offset of the given buffer in the shared memory. size_t getOffset(BufferId bufferId) const { return bufferSet.at(bufferId).offset; @@ -170,47 +215,6 @@ class Allocation { /// Returns mapping from operation to list of live LDS buffers std::map> getLiveBuffers(); -private: - /// A class that represents a shared memory buffer - struct BufferT { - /// Explicit: triton_gpu.local_alloc - /// Scratch: triton_gpu.convert_layout - /// Virtual: triton.call - enum class BufferKind { Explicit, Scratch, Virtual }; - - /// MT: thread-safe - inline static std::atomic nextId = 0; - - BufferKind kind; - BufferId id; - size_t size; - size_t alignment; - size_t offset; - - bool operator==(const BufferT &other) const { return id == other.id; } - bool operator<(const BufferT &other) const { return id < other.id; } - - BufferT() : BufferT(BufferKind::Explicit, 0) {} - BufferT(BufferKind kind, size_t size, size_t alignment = 4, - size_t offset = 0) - : kind(kind), id(nextId++), size(size), alignment(alignment), - offset(offset) {} - - size_t setOffsetAligned(size_t newOffset) { - return offset = llvm::alignTo(newOffset, alignment); - } - }; - - /// Op -> Scratch Buffer - using OpScratchMapT = DenseMap; - /// Value -> Explicit Buffer - using ValueBufferMapT = llvm::MapVector; - /// Value -> Alias Buffer - using AliasBufferMapT = llvm::MapVector>; - /// BufferId -> Buffer - using BufferSetT = std::map; - -private: template void addBuffer(KeyType &key, Args &&...args) { auto buffer = BufferT(Kind, std::forward(args)...); @@ -236,10 +240,11 @@ class Allocation { AliasBufferMapT aliasBuffer; BufferSetT bufferSet; size_t sharedMemorySize = 0; - - friend class triton::AllocationAnalysis; }; +template <> +void Allocation::run(FuncAllocMapT &funcAllocMap); + /// Static analysis that computes the allocation of shared memory buffers /// of the entire call graph. /// The allocation is performed in a post-order walk of the call graph. @@ -250,17 +255,19 @@ class ModuleAllocation : public CallGraph { public: using FuncOffsetMapT = DenseMap; - explicit ModuleAllocation(ModuleOp moduleOp) - : CallGraph(moduleOp) { - walk( + template + static ModuleAllocation get(ModuleOp moduleOp) { + ModuleAllocation res(moduleOp); + res.walk( // Pre-order edge walk callback [](CallOpInterface callOp, FunctionOpInterface funcOp) {}, // Post-order node walk callback [&](FunctionOpInterface funcOp) { - auto [iter, inserted] = funcMap.try_emplace(funcOp, funcOp); + auto [iter, inserted] = res.funcMap.try_emplace(funcOp, funcOp); if (inserted) - iter->second.run(funcMap); + iter->second.template run(res.funcMap); }); + return res; } size_t getSharedMemorySize() { @@ -285,6 +292,9 @@ class ModuleAllocation : public CallGraph { } private: + explicit ModuleAllocation(ModuleOp moduleOp) + : CallGraph(moduleOp) {} + FuncOffsetMapT sharedMemoryValue; }; diff --git a/lib/Analysis/Allocation.cpp b/lib/Analysis/Allocation.cpp index 665b97aeeb..82b8704b98 100644 --- a/lib/Analysis/Allocation.cpp +++ b/lib/Analysis/Allocation.cpp @@ -334,7 +334,7 @@ class AllocationAnalysis { /// Each buffer is allocated only once. void resolveExplicitBufferLiveness( function_ref(Value value)> getLiveness) { - for (auto valueBufferIter : allocation->valueBuffer) { + for (auto valueBufferIter : allocation->getValueBuffer()) { auto value = valueBufferIter.first; auto *buffer = valueBufferIter.second; bufferRange[buffer] = getLiveness(value); @@ -346,7 +346,7 @@ class AllocationAnalysis { /// arguments are involved. void resolveAliasBufferLiveness( function_ref(Value value)> getLiveness) { - for (auto aliasBufferIter : allocation->aliasBuffer) { + for (auto aliasBufferIter : allocation->getAliasBuffer()) { auto value = aliasBufferIter.first; auto buffers = aliasBufferIter.second; auto range = getLiveness(value); @@ -379,8 +379,8 @@ class AllocationAnalysis { operationId.lookup(op) + 1)}); } }; - processScratchMemory(allocation->opScratch); - processScratchMemory(allocation->opVirtual); + processScratchMemory(allocation->getOpScratch()); + processScratchMemory(allocation->getOpVirtual()); } /// Resolves liveness of all values involved under the root operation. @@ -544,7 +544,7 @@ class AllocationAnalysis { void allocate(const SmallVector &buffers, const GraphT &interference) { // Reset shared memory size - allocation->sharedMemorySize = 0; + allocation->setSharedMemorySize(0); // First-fit graph coloring // Neighbors are nodes that interfere with each other. // We color a node by finding the index of the first available @@ -579,8 +579,8 @@ class AllocationAnalysis { } if (colors.lookup(x) != 0) x->setOffsetAligned(newOffset); - allocation->sharedMemorySize = - std::max(allocation->sharedMemorySize, x->offset + x->size); + allocation->setSharedMemorySize( + std::max(allocation->getSharedMemorySize(), x->offset + x->size)); } } @@ -593,7 +593,8 @@ class AllocationAnalysis { } // namespace triton -void Allocation::run(FuncAllocMapT &funcAllocMap) { +template <> +void Allocation::run(FuncAllocMapT &funcAllocMap) { triton::AllocationAnalysis(getOperation(), &funcAllocMap, this); } diff --git a/lib/Conversion/TritonGPUToLLVM/AllocateSharedMemory.cpp b/lib/Conversion/TritonGPUToLLVM/AllocateSharedMemory.cpp index aae9faf0ee..a85abe7c7f 100644 --- a/lib/Conversion/TritonGPUToLLVM/AllocateSharedMemory.cpp +++ b/lib/Conversion/TritonGPUToLLVM/AllocateSharedMemory.cpp @@ -23,7 +23,7 @@ struct AllocateSharedMemory void runOnOperation() override { ModuleOp mod = getOperation(); MLIRContext *ctx = &getContext(); - ModuleAllocation allocation(mod); + ModuleAllocation allocation = ModuleAllocation::get(mod); mod.walk([&](FunctionOpInterface funcOp) { funcOp.walk([&](Operation *op) { diff --git a/python/src/passes.cc b/python/src/passes.cc index 98d8369d40..bceb88eb7b 100644 --- a/python/src/passes.cc +++ b/python/src/passes.cc @@ -17,7 +17,8 @@ namespace py = pybind11; void init_triton_analysis(py::module &&m) { py::class_(m, "allocation", py::module_local()) - .def(py::init()); + .def(py::init( + &mlir::ModuleAllocation::get)); py::class_(m, "membar", py::module_local()) .def(py::init()) .def("run", &mlir::ModuleMembarAnalysis::run); diff --git a/test/lib/Analysis/TestAllocation.cpp b/test/lib/Analysis/TestAllocation.cpp index 772e0258bf..97adcfdf96 100644 --- a/test/lib/Analysis/TestAllocation.cpp +++ b/test/lib/Analysis/TestAllocation.cpp @@ -19,7 +19,7 @@ struct TestAllocationPass auto &os = llvm::errs(); ModuleOp moduleOp = getOperation(); // Convert to std::string can remove quotes from opName - ModuleAllocation moduleAllocation(moduleOp); + ModuleAllocation moduleAllocation = ModuleAllocation::get(moduleOp); moduleOp.walk([&](triton::FuncOp funcOp) { auto opName = SymbolTable::getSymbolName(funcOp).getValue().str(); os << opName << "\n"; diff --git a/test/lib/Analysis/TestMembar.cpp b/test/lib/Analysis/TestMembar.cpp index 25e8e2d198..32546808bb 100644 --- a/test/lib/Analysis/TestMembar.cpp +++ b/test/lib/Analysis/TestMembar.cpp @@ -25,7 +25,7 @@ struct TestMembarPass Operation *operation = getOperation(); ModuleOp moduleOp = cast(operation); // Print all ops after membar pass - ModuleAllocation allocation(moduleOp); + ModuleAllocation allocation = ModuleAllocation::get(moduleOp); ModuleMembarAnalysis membarPass(&allocation, mlir::triton::NVIDIA::canSkipBarSync); membarPass.run(); diff --git a/third_party/amd/lib/TritonAMDGPUToLLVM/OptimizeLDSUsage.cpp b/third_party/amd/lib/TritonAMDGPUToLLVM/OptimizeLDSUsage.cpp index db3223f119..6b6804b252 100644 --- a/third_party/amd/lib/TritonAMDGPUToLLVM/OptimizeLDSUsage.cpp +++ b/third_party/amd/lib/TritonAMDGPUToLLVM/OptimizeLDSUsage.cpp @@ -221,7 +221,7 @@ class OptimizeAMDLDSUsage LDSLimit = targetInfo.getSharedMemorySize(); } - ModuleAllocation allocAnalysis(mod); + ModuleAllocation allocAnalysis = ModuleAllocation::get(mod); if (allocAnalysis.getSharedMemorySize() <= LDSLimit) return; diff --git a/third_party/amd/lib/TritonAMDGPUToLLVM/TritonGPUToLLVM.cpp b/third_party/amd/lib/TritonAMDGPUToLLVM/TritonGPUToLLVM.cpp index d227bb6c6a..d411696d23 100644 --- a/third_party/amd/lib/TritonAMDGPUToLLVM/TritonGPUToLLVM.cpp +++ b/third_party/amd/lib/TritonAMDGPUToLLVM/TritonGPUToLLVM.cpp @@ -104,7 +104,7 @@ struct ConvertTritonAMDGPUToLLVM } // Allocate shared memory and set barrier - ModuleAllocation allocation(mod); + ModuleAllocation allocation = ModuleAllocation::get(mod); ModuleMembarAnalysis membarPass(&allocation); membarPass.run(); diff --git a/third_party/intel/include/Analysis/Allocation.h b/third_party/intel/include/Analysis/Allocation.h new file mode 100644 index 0000000000..afdef179a1 --- /dev/null +++ b/third_party/intel/include/Analysis/Allocation.h @@ -0,0 +1,15 @@ +#ifndef TRITON_INTEL_ANALYSIS_ALLOCATION_H +#define TRITON_INTEL_ANALYSIS_ALLOCATION_H + +#include "triton/Analysis/Allocation.h" + +namespace mlir { +namespace triton::intel { +class AllocationAnalysis; +} // namespace triton::intel +template <> +void Allocation::run( + FuncAllocMapT &funcAllocMap); +} // namespace mlir + +#endif diff --git a/third_party/intel/lib/Analysis/Allocation.cpp b/third_party/intel/lib/Analysis/Allocation.cpp new file mode 100644 index 0000000000..1fba62b609 --- /dev/null +++ b/third_party/intel/lib/Analysis/Allocation.cpp @@ -0,0 +1,602 @@ +#include "intel/include/Analysis/Allocation.h" + +#include +#include +#include + +#include "mlir/Analysis/DataFlowFramework.h" +#include "mlir/Analysis/Liveness.h" +#include "mlir/Analysis/SliceAnalysis.h" +#include "mlir/Dialect/Tensor/IR/Tensor.h" +#include "mlir/Support/LLVM.h" +#include "triton/Analysis/Alias.h" +#include "triton/Dialect/Triton/IR/Dialect.h" +#include "triton/Dialect/Triton/IR/Utility.h" +#include "triton/Dialect/TritonGPU/IR/Dialect.h" +#include "llvm/ADT/SmallVector.h" + +using ::mlir::triton::gpu::AMDMfmaEncodingAttr; +using ::mlir::triton::gpu::BlockedEncodingAttr; +using ::mlir::triton::gpu::DotOperandEncodingAttr; +using ::mlir::triton::gpu::getContigPerThread; +using ::mlir::triton::gpu::getOrder; +using ::mlir::triton::gpu::getShapePerCTA; +using ::mlir::triton::gpu::getShapePerCTATile; +using ::mlir::triton::gpu::getSizePerThread; +using ::mlir::triton::gpu::getUniqueContigPerThread; +using ::mlir::triton::gpu::NvidiaMmaEncodingAttr; +using ::mlir::triton::gpu::SharedEncodingAttr; +using ::mlir::triton::gpu::SliceEncodingAttr; + +namespace mlir { + +//===----------------------------------------------------------------------===// +// Shared Memory Allocation Analysis +//===----------------------------------------------------------------------===// +namespace triton::intel { + +// Bitwidth of pointers +constexpr int kPtrBitWidth = 64; + +static std::pair, SmallVector> +getCvtOrder(Attribute srcLayout, Attribute dstLayout) { + auto srcMmaLayout = mlir::dyn_cast(srcLayout); + auto srcDotLayout = mlir::dyn_cast(srcLayout); + auto dstMmaLayout = mlir::dyn_cast(dstLayout); + auto dstDotLayout = mlir::dyn_cast(dstLayout); + + assert(!(srcMmaLayout && dstMmaLayout && !srcMmaLayout.isAmpere() && + !srcMmaLayout.isHopper()) && + "mma -> mma layout conversion is only supported on Ampere"); + + // mma or dot layout does not have an order, so the order depends on the + // layout of the other operand. + auto inOrd = (srcMmaLayout || srcDotLayout) ? getOrder(dstLayout) + : getOrder(srcLayout); + auto outOrd = (dstMmaLayout || dstDotLayout) ? getOrder(srcLayout) + : getOrder(dstLayout); + + return {inOrd, outOrd}; +} + +static SmallVector getRepShapeForCvt(RankedTensorType srcTy, + RankedTensorType dstTy) { + Attribute srcLayout = srcTy.getEncoding(); + Attribute dstLayout = dstTy.getEncoding(); + + if (!cvtNeedsSharedMemory(srcTy, dstTy)) { + return {}; + } + + if (shouldUseDistSmem(srcLayout, dstLayout)) { + // TODO: padding to avoid bank conflicts + return convertType(getShapePerCTA(srcTy)); + } + + assert(srcLayout && dstLayout && "Unexpected layout in getRepShapeForCvt()"); + + auto srcShapePerCTA = getShapePerCTA(srcTy); + auto dstShapePerCTA = getShapePerCTA(dstTy); + auto srcShapePerCTATile = getShapePerCTATile(srcLayout, srcTy.getShape()); + auto dstShapePerCTATile = getShapePerCTATile(dstLayout, dstTy.getShape()); + + unsigned rank = dstTy.getRank(); + SmallVector repShape(rank); + for (unsigned d = 0; d < rank; ++d) { + repShape[d] = + std::max(std::min(srcShapePerCTA[d], srcShapePerCTATile[d]), + std::min(dstShapePerCTA[d], dstShapePerCTATile[d])); + } + return repShape; +} + +// Both `atomic_cas` and `atomic_rmw need a single scratch element if returning +// a scalar value because Triton's block-based programming model ensures that +// all threads in each block see the same return value, even those threads that +// do not participate in the atomic operation +static SmallVector getRepShapeForAtomic(Value result) { + SmallVector smemShape; + if (atomicNeedsSharedMemory(result)) { + smemShape.push_back(1); + } + return smemShape; +} + +ScratchConfig getScratchConfigForCvt(RankedTensorType srcTy, + RankedTensorType dstTy) { + // Initialize vector sizes and stride + auto repShape = getRepShapeForCvt(srcTy, dstTy); + if (repShape.empty()) + return ScratchConfig({}, {}); + ScratchConfig scratchConfig(repShape, repShape); + auto rank = repShape.size(); + Attribute srcLayout = srcTy.getEncoding(); + Attribute dstLayout = dstTy.getEncoding(); + + assert(cvtNeedsSharedMemory(srcTy, dstTy)); + + // FIXME This is NOT entirely correct + // This should be getElemOrder, but we don't have such a method + // TODO Implement getElemOrder and make sure it's consistent with + // getContigPerThread + auto inOrd = gpu::getThreadOrder(srcLayout); + auto outOrd = gpu::getThreadOrder(dstLayout); + scratchConfig.order = outOrd; + + unsigned srcContigPerThread = + getUniqueContigPerThread(srcLayout, srcTy.getShape())[inOrd[0]]; + unsigned dstContigPerThread = + getUniqueContigPerThread(dstLayout, dstTy.getShape())[outOrd[0]]; + // TODO: Fix the legacy issue that ourOrd[0] == 0 always means + // that we cannot do vectorization. + unsigned innerDim = rank - 1; + scratchConfig.inVec = outOrd[0] != innerDim ? 1 + : inOrd[0] != innerDim ? 1 + : srcContigPerThread; + scratchConfig.outVec = outOrd[0] != innerDim ? 1 : dstContigPerThread; + + if (auto mma = mlir::dyn_cast(srcLayout)) { + if (mma.getVersionMajor() == 1) { + // For conversions to MmaV1 (Nvidia V100), this inVec is hardcoded in the + // codegen. + scratchConfig.inVec = srcContigPerThread; + } else if (mlir::isa(dstLayout)) { + // when storing from mma layout and loading in blocked layout vectorizing + // the load back gives better performance even if there is a + // transposition. + scratchConfig.outVec = dstContigPerThread; + } + } + + // No padding is required if the tensor is 1-D, or if all dimensions except + // the first accessed dimension have a size of 1. + if (rank <= 1 || product(repShape) == repShape[outOrd[0]]) + return scratchConfig; + + auto paddedSize = std::max(scratchConfig.inVec, scratchConfig.outVec); + scratchConfig.paddedRepShape[outOrd[0]] += paddedSize; + return scratchConfig; +} + +class AllocationAnalysis { +public: + AllocationAnalysis(Operation *operation, + Allocation::FuncAllocMapT *funcAllocMap, + Allocation *allocation) + : operation(operation), funcAllocMap(funcAllocMap), + allocation(allocation) { + run(); + } + +private: + using BufferT = Allocation::BufferT; + + /// Value -> Liveness Range + /// Use MapVector to ensure determinism. + using BufferRangeMapT = llvm::MapVector>; + /// Nodes -> Nodes + using GraphT = DenseMap>; + + void run() { + getValuesAndSizes(); + resolveLiveness(); + computeOffsets(); + } + + /// Initializes explicitly defined shared memory values for a given operation. + void getExplicitValueSize(Operation *op) { + for (Value result : op->getResults()) { + auto alloc = result.getDefiningOp(); + if (alloc && alloc.isSharedMemoryAlloc()) { + // Bytes could be a different value once we support padding or other + // allocation policies. + auto allocType = alloc.getType(); + auto shapePerCTA = triton::gpu::getShapePerCTA(allocType); + auto bytes = product(shapePerCTA) * + allocType.getElementTypeBitWidth() / 8; + + auto alignment = alloc.getAlignmentOrDefault(); + allocation->addBuffer(result, bytes, + alignment); + } + } + } + + template + void maybeAddScratchBuffer(Operation *op, unsigned bytes, + unsigned alignment) { + if (bytes > 0) + allocation->addBuffer(op, bytes, alignment); + } + + template + void maybeAddScratchBuffer(Operation *op, unsigned bytes) { + if (bytes > 0) + allocation->addBuffer(op, bytes); + } + + /// Initializes temporary shared memory for a given operation. + void getScratchValueSize(Operation *op) { + const size_t scratchAlignment = 128; + if (auto reduceOp = dyn_cast(op)) { + ReduceOpHelper helper(reduceOp); + unsigned bytes = helper.getScratchSizeInBytes(); + maybeAddScratchBuffer(op, bytes, + scratchAlignment); + } else if (auto scanOp = dyn_cast(op)) { + ScanLoweringHelper helper(scanOp); + unsigned bytes = helper.getScratchSizeInBytes(); + maybeAddScratchBuffer(op, bytes, + scratchAlignment); + } else if (auto histogram = dyn_cast(op)) { + auto dstTy = histogram.getType(); + int threadsPerWarp = triton::gpu::TritonGPUDialect::getThreadsPerWarp( + op->getParentOfType()); + auto bytes = std::max(dstTy.getNumElements(), threadsPerWarp) * + std::max(8, dstTy.getElementTypeBitWidth()) / 8; + maybeAddScratchBuffer(op, bytes, + scratchAlignment); + } else if (auto cvtLayout = dyn_cast(op)) { + auto srcTy = cvtLayout.getSrc().getType(); + auto dstTy = cvtLayout.getType(); + auto srcEncoding = srcTy.getEncoding(); + auto dstEncoding = dstTy.getEncoding(); + if (mlir::isa(srcEncoding) || + mlir::isa(dstEncoding)) { + // Conversions from/to shared memory do not need scratch memory. + return; + } + // ConvertLayoutOp with both input/output non-shared_layout + // TODO: Besides of implementing ConvertLayoutOp via shared memory, it's + // also possible to realize it with other approaches in restricted + // conditions, such as warp-shuffle + auto scratchConfig = getScratchConfigForCvt(srcTy, dstTy); + auto elems = getNumScratchElements(scratchConfig.paddedRepShape); + auto bytes = + isa(srcTy.getElementType()) + ? elems * kPtrBitWidth / 8 + : elems * std::max(8, srcTy.getElementTypeBitWidth()) / 8; + maybeAddScratchBuffer(op, bytes, + scratchAlignment); + } else if (isa(op)) { + auto value = op->getOperand(0); + // only scalar requires scratch memory + // make it explicit for readability + if (dyn_cast(value.getType())) { + // nothing to do + } else { + auto smemShape = getRepShapeForAtomic(op->getResult(0)); + auto elems = getNumScratchElements(smemShape); + auto elemTy = + cast(value.getType()).getPointeeType(); + auto bytes = + isa(elemTy) + ? elems * kPtrBitWidth / 8 + : elems * std::max(8, elemTy.getIntOrFloatBitWidth()) / 8; + maybeAddScratchBuffer(op, bytes, + scratchAlignment); + } + } else if (auto callOp = dyn_cast(op)) { + auto callable = callOp.resolveCallable(); + auto funcOp = dyn_cast(callable); + auto *funcAlloc = &(*funcAllocMap)[funcOp]; + auto bytes = funcAlloc->getSharedMemorySize(); + maybeAddScratchBuffer(op, bytes, + scratchAlignment); + } else if (auto createTensormap = + dyn_cast(op)) { + constexpr int32_t kTMASize = 128; + constexpr int32_t kTMAAlign = 128; + maybeAddScratchBuffer(op, kTMASize, + kTMAAlign); + } + } + + void getValueAlias(Value value, SharedMemoryAliasAnalysis &analysis) { + dataflow::Lattice *latticeElement = + analysis.getLatticeElement(value); + if (latticeElement) { + AliasInfo &info = latticeElement->getValue(); + if (!info.getAllocs().empty()) { + for (auto alloc : info.getAllocs()) { + allocation->addAlias(value, alloc); + } + } + } + } + + /// Extract all shared memory values and their sizes + void getValuesAndSizes() { + // Get the alloc values + operation->walk([&](Operation *op) { + getExplicitValueSize(op); + getScratchValueSize(op); + }); + // Get the alias values + std::unique_ptr solver = createDataFlowSolver(); + SharedMemoryAliasAnalysis *aliasAnalysis = + solver->load(); + if (failed(solver->initializeAndRun(operation))) { + // TODO: return error instead of bailing out.. + llvm_unreachable("failed to run SharedMemoryAliasAnalysis"); + } + operation->walk([&](Operation *op) { + for (auto operand : op->getOperands()) { + getValueAlias(operand, *aliasAnalysis); + } + for (auto value : op->getResults()) { + getValueAlias(value, *aliasAnalysis); + } + }); + } + + /// Computes the liveness range of the allocated value. + /// Each buffer is allocated only once. + void resolveExplicitBufferLiveness( + function_ref(Value value)> getLiveness) { + for (auto valueBufferIter : allocation->getValueBuffer()) { + auto value = valueBufferIter.first; + auto *buffer = valueBufferIter.second; + bufferRange[buffer] = getLiveness(value); + } + } + + /// Extends the liveness range by unionizing the liveness range of the aliased + /// values because each allocated buffer could be an alias of others, if block + /// arguments are involved. + void resolveAliasBufferLiveness( + function_ref(Value value)> getLiveness) { + for (auto aliasBufferIter : allocation->getAliasBuffer()) { + auto value = aliasBufferIter.first; + auto buffers = aliasBufferIter.second; + auto range = getLiveness(value); + for (auto *buffer : buffers) { + auto minId = range.start(); + auto maxId = range.end(); + if (bufferRange.count(buffer)) { + // Extend the allocated buffer's range + minId = std::min(minId, bufferRange[buffer].start()); + maxId = std::max(maxId, bufferRange[buffer].end()); + } + bufferRange[buffer] = Interval(minId, maxId); + } + } + } + + /// Computes the liveness range of scratched buffers. + /// Some operations may have a temporary buffer that is not explicitly + /// allocated, but is used to store intermediate results. + void resolveScratchBufferLiveness( + const DenseMap &operationId) { + // Analyze liveness of scratch buffers and virtual buffers. + auto processScratchMemory = [&](const auto &container) { + for (auto opScratchIter : container) { + // Any scratch memory's live range is the current operation's live + // range. + auto *op = opScratchIter.first; + auto *buffer = opScratchIter.second; + bufferRange.insert({buffer, Interval(operationId.lookup(op), + operationId.lookup(op) + 1)}); + } + }; + processScratchMemory(allocation->getOpScratch()); + processScratchMemory(allocation->getOpVirtual()); + } + + /// Resolves liveness of all values involved under the root operation. + void resolveLiveness() { + // Assign an ID to each operation using post-order traversal. + // To achieve the correct liveness range, the parent operation's ID + // should be greater than each of its child operation's ID . + // Example: + // ... + // %5 = triton.convert_layout %4 + // %6 = scf.for ... iter_args(%arg0 = %0) -> (i32) { + // %2 = triton.convert_layout %5 + // ... + // scf.yield %arg0 + // } + // For example, %5 is defined in the parent region and used in + // the child region, and is not passed as a block argument. + // %6 should should have an ID greater than its child operations, + // otherwise %5 liveness range ends before the child operation's liveness + // range ends. + DenseMap operationId; + operation->walk( + [&](Operation *op) { operationId[op] = operationId.size(); }); + + // Analyze liveness of explicit buffers + Liveness liveness(operation); + auto getValueLivenessRange = [&](Value value) { + auto liveOperations = liveness.resolveLiveness(value); + auto minId = std::numeric_limits::max(); + auto maxId = std::numeric_limits::min(); + std::for_each(liveOperations.begin(), liveOperations.end(), + [&](Operation *liveOp) { + if (operationId[liveOp] < minId) { + minId = operationId[liveOp]; + } + if ((operationId[liveOp] + 1) > maxId) { + maxId = operationId[liveOp] + 1; + } + }); + return Interval(minId, maxId); + }; + + resolveExplicitBufferLiveness(getValueLivenessRange); + resolveAliasBufferLiveness(getValueLivenessRange); + resolveScratchBufferLiveness(operationId); + } + + /// Computes the shared memory offsets for all related values. + /// Paper: Algorithms for Compile-Time Memory Optimization + /// (https://dl.acm.org/doi/pdf/10.5555/314500.315082) + void computeOffsets() { + SmallVector buffers; + for (auto bufferIter : bufferRange) { + buffers.emplace_back(bufferIter.first); + } + + calculateStarts(buffers); + + // NOTE: The original paper doesn't consider interference between + // the bumped ranges. Buffers that previously do not interfere with + // could interfere after offset bumping if their liveness ranges overlap. + // Therefore, we rerun the interference graph algorithm after bumping so + // that we regroup the buffers and color them again. Since we always + // increase the buffer offset and keep reducing conflicts, we will + // eventually reach a fixed point. + GraphT interference; + buildInterferenceGraph(buffers, interference); + do { + allocate(buffers, interference); + buildInterferenceGraph(buffers, interference); + } while (!interference.empty()); + } + + /// Computes the initial shared memory offsets. + void calculateStarts(const SmallVector &buffers) { + // v = values in shared memory + // t = triplet of (size, start, end) + // shared memory space + // - + // | *******t4 + // | /|\ v2 inserts t4, t5, and t6 + // | | + // | ******t5 ************t6 + // | ^^^^^v2^^^^^^ + // | | *********************t2 + // | \|/ v2 erases t1 + // | ******t1 ^^^^^^^^^v1^^^^^^^^^ ************t3 + // |---------------------------------------------| liveness range + // 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 ... + // If the available triple's range is less than a given buffer range, + // we won't know if there has been an overlap without using graph coloring. + // Start -> Liveness Range + using TripleMapT = std::multimap>; + TripleMapT tripleMap; + tripleMap.insert(std::make_pair(0, Interval())); + SmallVector xBuffers = buffers; + while (!xBuffers.empty()) { + auto tripleIt = tripleMap.begin(); + auto offset = tripleIt->first; + auto range = tripleIt->second; + tripleMap.erase(tripleIt); + auto bufferIt = + std::find_if(xBuffers.begin(), xBuffers.end(), [&](auto *buffer) { + auto xRange = bufferRange[buffer]; + bool res = xRange.intersects(range); + for (auto val : tripleMap) + res = res && + !val.second.intersects(xRange); // only one buffer intersect + return res; + }); + if (bufferIt != xBuffers.end()) { + auto buffer = *bufferIt; + auto xSize = buffer->size; + auto xRange = bufferRange.lookup(buffer); + // TODO(Keren): A buffer's size shouldn't be determined here, have to + // clean it up + size_t alignOffset = buffer->setOffsetAligned(offset); + tripleMap.insert({alignOffset + xSize, + Interval{std::max(range.start(), xRange.start()), + std::min(range.end(), xRange.end())}}); + // We could either insert (range.start, xRange.start) or (range.start, + // xRange.end), both are correct and determine the potential buffer + // offset, and the graph coloring algorithm will solve the interference, + // if any + if (range.start() < xRange.start()) + tripleMap.insert({offset, Interval{range.start(), xRange.end()}}); + if (xRange.end() < range.end()) + tripleMap.insert({offset, Interval{xRange.start(), range.end()}}); + xBuffers.erase(bufferIt); + } + } + } + + /// Builds a graph of all shared memory values. Edges are created between + /// shared memory values that are overlapping. + void buildInterferenceGraph(const SmallVector &buffers, + GraphT &interference) { + // Reset interference graph + interference.clear(); + for (auto x : buffers) { + for (auto y : buffers) { + if (x == y) + continue; + auto xStart = x->offset; + auto yStart = y->offset; + auto xSize = x->size; + auto ySize = y->size; + Interval xSizeRange = {xStart, xStart + xSize}; + Interval ySizeRange = {yStart, yStart + ySize}; + auto xOpRange = bufferRange.lookup(x); + auto yOpRange = bufferRange.lookup(y); + if (xOpRange.intersects(yOpRange) && + xSizeRange.intersects(ySizeRange)) { + interference[x].insert(y); + } + } + } + } + + /// Finalizes shared memory offsets considering interference. + void allocate(const SmallVector &buffers, + const GraphT &interference) { + // Reset shared memory size + allocation->setSharedMemorySize(0); + // First-fit graph coloring + // Neighbors are nodes that interfere with each other. + // We color a node by finding the index of the first available + // non-neighboring node or the first neighboring node without any color. + // Nodes with the same color do not interfere with each other. + DenseMap colors; + for (auto value : buffers) { + colors[value] = (value == buffers[0]) ? 0 : -1; + } + SmallVector available(buffers.size()); + for (auto x : buffers) { + std::fill(available.begin(), available.end(), true); + for (auto y : interference.lookup(x)) { + int color = colors[y]; + if (color >= 0) { + available[color] = false; + } + } + auto it = std::find(available.begin(), available.end(), true); + colors[x] = std::distance(available.begin(), it); + } + // Finalize allocation + // color0: [0, 7), [0, 8), [0, 15) -> [0, 7), [0, 8), [0, 15) + // color1: [7, 9) -> [0 + 1 * 15, 9 + 1 * 15) -> [15, 24) + // color2: [8, 12) -> [8 + 2 * 15, 12 + 2 * 15) -> [38, 42) + // TODO(Keren): We are wasting memory here. + // Nodes with color2 can actually start with 24. + for (auto x : buffers) { + size_t newOffset = 0; + for (auto y : interference.lookup(x)) { + newOffset = std::max(newOffset, y->offset + y->size); + } + if (colors.lookup(x) != 0) + x->setOffsetAligned(newOffset); + allocation->setSharedMemorySize( + std::max(allocation->getSharedMemorySize(), x->offset + x->size)); + } + } + +private: + Operation *operation; + Allocation::FuncAllocMapT *funcAllocMap; + Allocation *allocation; + BufferRangeMapT bufferRange; +}; + +} // namespace triton::intel + +template <> +void Allocation::run( + FuncAllocMapT &funcAllocMap) { + triton::intel::AllocationAnalysis(getOperation(), &funcAllocMap, this); +} + +} // namespace mlir diff --git a/third_party/intel/lib/Analysis/CMakeLists.txt b/third_party/intel/lib/Analysis/CMakeLists.txt index e51b359137..cf10374a69 100644 --- a/third_party/intel/lib/Analysis/CMakeLists.txt +++ b/third_party/intel/lib/Analysis/CMakeLists.txt @@ -1,4 +1,5 @@ add_triton_library(TritonIntelAnalysis + Allocation.cpp AxisInfo.cpp DPAS.cpp Liveness.cpp diff --git a/third_party/intel/lib/TritonIntelGPUToLLVM/AllocateSharedMemory.cpp b/third_party/intel/lib/TritonIntelGPUToLLVM/AllocateSharedMemory.cpp index 61932d1066..1a9e44e92e 100644 --- a/third_party/intel/lib/TritonIntelGPUToLLVM/AllocateSharedMemory.cpp +++ b/third_party/intel/lib/TritonIntelGPUToLLVM/AllocateSharedMemory.cpp @@ -1,8 +1,8 @@ +#include "intel/include/Analysis/Allocation.h" #include "intel/include/Dialect/TritonGEN/IR/TritonGENDialect.h" #include "intel/include/TritonIntelGPUToLLVM/Passes.h" #include "mlir/Dialect/LLVMIR/LLVMTypes.h" -#include "triton/Analysis/Allocation.h" using namespace mlir; @@ -22,7 +22,8 @@ struct AllocateSharedMemory void runOnOperation() override { ModuleOp mod = getOperation(); MLIRContext *ctx = &getContext(); - ModuleAllocation allocation(mod); + ModuleAllocation allocation = + ModuleAllocation::get(mod); mod.walk([&](FunctionOpInterface funcOp) { if (allocation.isRoot(funcOp) && allocation.getSharedMemorySize()) { diff --git a/third_party/intel/lib/TritonIntelGPUToLLVM/TritonGPUToLLVM.cpp b/third_party/intel/lib/TritonIntelGPUToLLVM/TritonGPUToLLVM.cpp index 3d3bbb3015..c10a2e8aff 100644 --- a/third_party/intel/lib/TritonIntelGPUToLLVM/TritonGPUToLLVM.cpp +++ b/third_party/intel/lib/TritonIntelGPUToLLVM/TritonGPUToLLVM.cpp @@ -14,7 +14,8 @@ #include "intel/include/TritonGENToLLVM/TritonGENToLLVMPass.h" #include "intel/include/TritonIntelGPUToLLVM/Passes.h" -#include "triton/Analysis/Allocation.h" +#include "intel/include/Analysis/Allocation.h" +#include "triton/Analysis/AxisInfo.h" #include "triton/Analysis/Membar.h" #include "triton/Conversion/TritonGPUToLLVM/PatternTritonGPUOpToLLVM.h" #include "triton/Dialect/Triton/IR/Dialect.h" @@ -94,7 +95,8 @@ struct ConvertTritonGPUToLLVM // Allocate shared memory and set barrier if (!pipelineManager.skipSharedMemoryAllocation()) { - ModuleAllocation allocation(mod); + ModuleAllocation allocation = + ModuleAllocation::get(mod); ModuleMembarAnalysis membarPass(&allocation); membarPass.run(); } diff --git a/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/TritonGPUToLLVM.cpp b/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/TritonGPUToLLVM.cpp index 6674c9a810..9c7cfc044d 100644 --- a/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/TritonGPUToLLVM.cpp +++ b/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/TritonGPUToLLVM.cpp @@ -96,7 +96,7 @@ struct ConvertTritonGPUToLLVM int threadsPerWarp = triton::gpu::TritonGPUDialect::getThreadsPerWarp(mod); // Allocate shared memory and set barrier - ModuleAllocation allocation(mod); + ModuleAllocation allocation = ModuleAllocation::get(mod); ModuleMembarAnalysis membarPass(&allocation, NVIDIA::canSkipBarSync); membarPass.run();