From 38b538448c13e8fe83a8275cd2fb3a692e07881c Mon Sep 17 00:00:00 2001 From: Matthias Springer Date: Sat, 19 Apr 2025 12:37:23 +0200 Subject: [PATCH] tmp --- .../MemRefToLLVM/AllocLikeConversion.h | 153 -------- .../MemRefToLLVM/AllocLikeConversion.cpp | 195 ---------- .../Conversion/MemRefToLLVM/CMakeLists.txt | 1 - .../Conversion/MemRefToLLVM/MemRefToLLVM.cpp | 357 +++++++++++++++--- 4 files changed, 301 insertions(+), 405 deletions(-) delete mode 100644 mlir/include/mlir/Conversion/MemRefToLLVM/AllocLikeConversion.h delete mode 100644 mlir/lib/Conversion/MemRefToLLVM/AllocLikeConversion.cpp diff --git a/mlir/include/mlir/Conversion/MemRefToLLVM/AllocLikeConversion.h b/mlir/include/mlir/Conversion/MemRefToLLVM/AllocLikeConversion.h deleted file mode 100644 index 8bf04219c759a..0000000000000 --- a/mlir/include/mlir/Conversion/MemRefToLLVM/AllocLikeConversion.h +++ /dev/null @@ -1,153 +0,0 @@ -//===- AllocLikeConversion.h - Convert allocation ops to LLVM ---*- C++ -*-===// -// -// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. -// See https://llvm.org/LICENSE.txt for license information. -// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception -// -//===----------------------------------------------------------------------===// - -#ifndef MLIR_CONVERSION_MEMREFTOLLVM_ALLOCLIKECONVERSION_H -#define MLIR_CONVERSION_MEMREFTOLLVM_ALLOCLIKECONVERSION_H - -#include "mlir/Conversion/LLVMCommon/Pattern.h" - -namespace mlir { - -/// Lowering for memory allocation ops. -struct AllocationOpLLVMLowering : public ConvertToLLVMPattern { - using ConvertToLLVMPattern::createIndexAttrConstant; - using ConvertToLLVMPattern::getIndexType; - using ConvertToLLVMPattern::getVoidPtrType; - - explicit AllocationOpLLVMLowering(StringRef opName, - const LLVMTypeConverter &converter, - PatternBenefit benefit = 1) - : ConvertToLLVMPattern(opName, &converter.getContext(), converter, - benefit) {} - -protected: - /// Computes the aligned value for 'input' as follows: - /// bumped = input + alignement - 1 - /// aligned = bumped - bumped % alignment - static Value createAligned(ConversionPatternRewriter &rewriter, Location loc, - Value input, Value alignment); - - static MemRefType getMemRefResultType(Operation *op) { - return cast(op->getResult(0).getType()); - } - - /// Computes the alignment for the given memory allocation op. - template - Value getAlignment(ConversionPatternRewriter &rewriter, Location loc, - OpType op) const { - MemRefType memRefType = op.getType(); - Value alignment; - if (auto alignmentAttr = op.getAlignment()) { - Type indexType = getIndexType(); - alignment = - createIndexAttrConstant(rewriter, loc, indexType, *alignmentAttr); - } else if (!memRefType.getElementType().isSignlessIntOrIndexOrFloat()) { - // In the case where no alignment is specified, we may want to override - // `malloc's` behavior. `malloc` typically aligns at the size of the - // biggest scalar on a target HW. For non-scalars, use the natural - // alignment of the LLVM type given by the LLVM DataLayout. - alignment = getSizeInBytes(loc, memRefType.getElementType(), rewriter); - } - return alignment; - } - - /// Computes the alignment for aligned_alloc used to allocate the buffer for - /// the memory allocation op. - /// - /// Aligned_alloc requires the allocation size to be a power of two, and the - /// allocation size to be a multiple of the alignment. - template - int64_t alignedAllocationGetAlignment(ConversionPatternRewriter &rewriter, - Location loc, OpType op, - const DataLayout *defaultLayout) const { - if (std::optional alignment = op.getAlignment()) - return *alignment; - - // Whenever we don't have alignment set, we will use an alignment - // consistent with the element type; since the allocation size has to be a - // power of two, we will bump to the next power of two if it isn't. - unsigned eltSizeBytes = - getMemRefEltSizeInBytes(op.getType(), op, defaultLayout); - return std::max(kMinAlignedAllocAlignment, - llvm::PowerOf2Ceil(eltSizeBytes)); - } - - /// Allocates a memory buffer using an allocation method that doesn't - /// guarantee alignment. Returns the pointer and its aligned value. - std::tuple - allocateBufferManuallyAlign(ConversionPatternRewriter &rewriter, Location loc, - Value sizeBytes, Operation *op, - Value alignment) const; - - /// Allocates a memory buffer using an aligned allocation method. - Value allocateBufferAutoAlign(ConversionPatternRewriter &rewriter, - Location loc, Value sizeBytes, Operation *op, - const DataLayout *defaultLayout, - int64_t alignment) const; - -private: - /// Computes the byte size for the MemRef element type. - unsigned getMemRefEltSizeInBytes(MemRefType memRefType, Operation *op, - const DataLayout *defaultLayout) const; - - /// Returns true if the memref size in bytes is known to be a multiple of - /// factor. - bool isMemRefSizeMultipleOf(MemRefType type, uint64_t factor, Operation *op, - const DataLayout *defaultLayout) const; - - /// The minimum alignment to use with aligned_alloc (has to be a power of 2). - static constexpr uint64_t kMinAlignedAllocAlignment = 16UL; -}; - -/// Lowering for AllocOp and AllocaOp. -struct AllocLikeOpLLVMLowering : public AllocationOpLLVMLowering { - explicit AllocLikeOpLLVMLowering(StringRef opName, - const LLVMTypeConverter &converter, - PatternBenefit benefit = 1) - : AllocationOpLLVMLowering(opName, converter, benefit) {} - -protected: - /// Allocates the underlying buffer. Returns the allocated pointer and the - /// aligned pointer. - virtual std::tuple - allocateBuffer(ConversionPatternRewriter &rewriter, Location loc, Value size, - Operation *op) const = 0; - - /// Sets the flag 'requiresNumElements', specifying the Op requires the number - /// of elements instead of the size in bytes. - void setRequiresNumElements(); - -private: - // An `alloc` is converted into a definition of a memref descriptor value and - // a call to `malloc` to allocate the underlying data buffer. The memref - // descriptor is of the LLVM structure type where: - // 1. the first element is a pointer to the allocated (typed) data buffer, - // 2. the second element is a pointer to the (typed) payload, aligned to the - // specified alignment, - // 3. the remaining elements serve to store all the sizes and strides of the - // memref using LLVM-converted `index` type. - // - // Alignment is performed by allocating `alignment` more bytes than - // requested and shifting the aligned pointer relative to the allocated - // memory. Note: `alignment - ` would actually be - // sufficient. If alignment is unspecified, the two pointers are equal. - - // An `alloca` is converted into a definition of a memref descriptor value and - // an llvm.alloca to allocate the underlying data buffer. - LogicalResult - matchAndRewrite(Operation *op, ArrayRef operands, - ConversionPatternRewriter &rewriter) const override; - - // Flag for specifying the Op requires the number of elements instead of the - // size in bytes. - bool requiresNumElements = false; -}; - -} // namespace mlir - -#endif // MLIR_CONVERSION_MEMREFTOLLVM_ALLOCLIKECONVERSION_H diff --git a/mlir/lib/Conversion/MemRefToLLVM/AllocLikeConversion.cpp b/mlir/lib/Conversion/MemRefToLLVM/AllocLikeConversion.cpp deleted file mode 100644 index e9b79983696aa..0000000000000 --- a/mlir/lib/Conversion/MemRefToLLVM/AllocLikeConversion.cpp +++ /dev/null @@ -1,195 +0,0 @@ -//===- AllocLikeConversion.cpp - LLVM conversion for alloc operations -----===// -// -// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. -// See https://llvm.org/LICENSE.txt for license information. -// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception -// -//===----------------------------------------------------------------------===// - -#include "mlir/Conversion/MemRefToLLVM/AllocLikeConversion.h" -#include "mlir/Analysis/DataLayoutAnalysis.h" -#include "mlir/Dialect/LLVMIR/FunctionCallUtils.h" -#include "mlir/Dialect/LLVMIR/LLVMDialect.h" -#include "mlir/IR/SymbolTable.h" - -using namespace mlir; - -static FailureOr -getNotalignedAllocFn(OpBuilder &b, const LLVMTypeConverter *typeConverter, - Operation *module, Type indexType) { - bool useGenericFn = typeConverter->getOptions().useGenericFunctions; - if (useGenericFn) - return LLVM::lookupOrCreateGenericAllocFn(b, module, indexType); - - return LLVM::lookupOrCreateMallocFn(b, module, indexType); -} - -static FailureOr -getAlignedAllocFn(OpBuilder &b, const LLVMTypeConverter *typeConverter, - Operation *module, Type indexType) { - bool useGenericFn = typeConverter->getOptions().useGenericFunctions; - - if (useGenericFn) - return LLVM::lookupOrCreateGenericAlignedAllocFn(b, module, indexType); - - return LLVM::lookupOrCreateAlignedAllocFn(b, module, indexType); -} - -Value AllocationOpLLVMLowering::createAligned( - ConversionPatternRewriter &rewriter, Location loc, Value input, - Value alignment) { - Value one = createIndexAttrConstant(rewriter, loc, alignment.getType(), 1); - Value bump = rewriter.create(loc, alignment, one); - Value bumped = rewriter.create(loc, input, bump); - Value mod = rewriter.create(loc, bumped, alignment); - return rewriter.create(loc, bumped, mod); -} - -static Value castAllocFuncResult(ConversionPatternRewriter &rewriter, - Location loc, Value allocatedPtr, - MemRefType memRefType, Type elementPtrType, - const LLVMTypeConverter &typeConverter) { - auto allocatedPtrTy = cast(allocatedPtr.getType()); - FailureOr maybeMemrefAddrSpace = - typeConverter.getMemRefAddressSpace(memRefType); - if (failed(maybeMemrefAddrSpace)) - return Value(); - unsigned memrefAddrSpace = *maybeMemrefAddrSpace; - if (allocatedPtrTy.getAddressSpace() != memrefAddrSpace) - allocatedPtr = rewriter.create( - loc, LLVM::LLVMPointerType::get(rewriter.getContext(), memrefAddrSpace), - allocatedPtr); - return allocatedPtr; -} - -std::tuple AllocationOpLLVMLowering::allocateBufferManuallyAlign( - ConversionPatternRewriter &rewriter, Location loc, Value sizeBytes, - Operation *op, Value alignment) const { - if (alignment) { - // Adjust the allocation size to consider alignment. - sizeBytes = rewriter.create(loc, sizeBytes, alignment); - } - - MemRefType memRefType = getMemRefResultType(op); - // Allocate the underlying buffer. - Type elementPtrType = this->getElementPtrType(memRefType); - assert(elementPtrType && "could not compute element ptr type"); - FailureOr allocFuncOp = getNotalignedAllocFn( - rewriter, getTypeConverter(), - op->getParentWithTrait(), getIndexType()); - if (failed(allocFuncOp)) - return std::make_tuple(Value(), Value()); - auto results = - rewriter.create(loc, allocFuncOp.value(), sizeBytes); - - Value allocatedPtr = - castAllocFuncResult(rewriter, loc, results.getResult(), memRefType, - elementPtrType, *getTypeConverter()); - if (!allocatedPtr) - return std::make_tuple(Value(), Value()); - Value alignedPtr = allocatedPtr; - if (alignment) { - // Compute the aligned pointer. - Value allocatedInt = - rewriter.create(loc, getIndexType(), allocatedPtr); - Value alignmentInt = createAligned(rewriter, loc, allocatedInt, alignment); - alignedPtr = - rewriter.create(loc, elementPtrType, alignmentInt); - } - - return std::make_tuple(allocatedPtr, alignedPtr); -} - -unsigned AllocationOpLLVMLowering::getMemRefEltSizeInBytes( - MemRefType memRefType, Operation *op, - const DataLayout *defaultLayout) const { - const DataLayout *layout = defaultLayout; - if (const DataLayoutAnalysis *analysis = - getTypeConverter()->getDataLayoutAnalysis()) { - layout = &analysis->getAbove(op); - } - Type elementType = memRefType.getElementType(); - if (auto memRefElementType = dyn_cast(elementType)) - return getTypeConverter()->getMemRefDescriptorSize(memRefElementType, - *layout); - if (auto memRefElementType = dyn_cast(elementType)) - return getTypeConverter()->getUnrankedMemRefDescriptorSize( - memRefElementType, *layout); - return layout->getTypeSize(elementType); -} - -bool AllocationOpLLVMLowering::isMemRefSizeMultipleOf( - MemRefType type, uint64_t factor, Operation *op, - const DataLayout *defaultLayout) const { - uint64_t sizeDivisor = getMemRefEltSizeInBytes(type, op, defaultLayout); - for (unsigned i = 0, e = type.getRank(); i < e; i++) { - if (type.isDynamicDim(i)) - continue; - sizeDivisor = sizeDivisor * type.getDimSize(i); - } - return sizeDivisor % factor == 0; -} - -Value AllocationOpLLVMLowering::allocateBufferAutoAlign( - ConversionPatternRewriter &rewriter, Location loc, Value sizeBytes, - Operation *op, const DataLayout *defaultLayout, int64_t alignment) const { - Value allocAlignment = - createIndexAttrConstant(rewriter, loc, getIndexType(), alignment); - - MemRefType memRefType = getMemRefResultType(op); - // Function aligned_alloc requires size to be a multiple of alignment; we pad - // the size to the next multiple if necessary. - if (!isMemRefSizeMultipleOf(memRefType, alignment, op, defaultLayout)) - sizeBytes = createAligned(rewriter, loc, sizeBytes, allocAlignment); - - Type elementPtrType = this->getElementPtrType(memRefType); - FailureOr allocFuncOp = getAlignedAllocFn( - rewriter, getTypeConverter(), - op->getParentWithTrait(), getIndexType()); - if (failed(allocFuncOp)) - return Value(); - auto results = rewriter.create( - loc, allocFuncOp.value(), ValueRange({allocAlignment, sizeBytes})); - - return castAllocFuncResult(rewriter, loc, results.getResult(), memRefType, - elementPtrType, *getTypeConverter()); -} - -void AllocLikeOpLLVMLowering::setRequiresNumElements() { - requiresNumElements = true; -} - -LogicalResult AllocLikeOpLLVMLowering::matchAndRewrite( - Operation *op, ArrayRef operands, - ConversionPatternRewriter &rewriter) const { - MemRefType memRefType = getMemRefResultType(op); - if (!isConvertibleAndHasIdentityMaps(memRefType)) - return rewriter.notifyMatchFailure(op, "incompatible memref type"); - auto loc = op->getLoc(); - - // Get actual sizes of the memref as values: static sizes are constant - // values and dynamic sizes are passed to 'alloc' as operands. In case of - // zero-dimensional memref, assume a scalar (size 1). - SmallVector sizes; - SmallVector strides; - Value size; - - this->getMemRefDescriptorSizes(loc, memRefType, operands, rewriter, sizes, - strides, size, !requiresNumElements); - - // Allocate the underlying buffer. - auto [allocatedPtr, alignedPtr] = - this->allocateBuffer(rewriter, loc, size, op); - - if (!allocatedPtr || !alignedPtr) - return rewriter.notifyMatchFailure(loc, - "underlying buffer allocation failed"); - - // Create the MemRef descriptor. - auto memRefDescriptor = this->createMemRefDescriptor( - loc, memRefType, allocatedPtr, alignedPtr, sizes, strides, rewriter); - - // Return the final value of the descriptor. - rewriter.replaceOp(op, {memRefDescriptor}); - return success(); -} diff --git a/mlir/lib/Conversion/MemRefToLLVM/CMakeLists.txt b/mlir/lib/Conversion/MemRefToLLVM/CMakeLists.txt index f0d95f5ada290..9da4b23d42f41 100644 --- a/mlir/lib/Conversion/MemRefToLLVM/CMakeLists.txt +++ b/mlir/lib/Conversion/MemRefToLLVM/CMakeLists.txt @@ -1,5 +1,4 @@ add_mlir_conversion_library(MLIRMemRefToLLVM - AllocLikeConversion.cpp MemRefToLLVM.cpp ADDITIONAL_HEADER_DIRS diff --git a/mlir/lib/Conversion/MemRefToLLVM/MemRefToLLVM.cpp b/mlir/lib/Conversion/MemRefToLLVM/MemRefToLLVM.cpp index 9c219d8a3d8cb..c8b2c0bdc6c20 100644 --- a/mlir/lib/Conversion/MemRefToLLVM/MemRefToLLVM.cpp +++ b/mlir/lib/Conversion/MemRefToLLVM/MemRefToLLVM.cpp @@ -13,7 +13,6 @@ #include "mlir/Conversion/LLVMCommon/ConversionTarget.h" #include "mlir/Conversion/LLVMCommon/Pattern.h" #include "mlir/Conversion/LLVMCommon/TypeConverter.h" -#include "mlir/Conversion/MemRefToLLVM/AllocLikeConversion.h" #include "mlir/Dialect/Arith/IR/Arith.h" #include "mlir/Dialect/Func/IR/FuncOps.h" #include "mlir/Dialect/LLVMIR/FunctionCallUtils.h" @@ -53,33 +52,247 @@ getFreeFn(OpBuilder &b, const LLVMTypeConverter *typeConverter, return LLVM::lookupOrCreateFreeFn(b, module); } -struct AllocOpLowering : public AllocLikeOpLLVMLowering { - AllocOpLowering(const LLVMTypeConverter &converter) - : AllocLikeOpLLVMLowering(memref::AllocOp::getOperationName(), - converter) {} - std::tuple allocateBuffer(ConversionPatternRewriter &rewriter, - Location loc, Value sizeBytes, - Operation *op) const override { - return allocateBufferManuallyAlign( - rewriter, loc, sizeBytes, op, - getAlignment(rewriter, loc, cast(op))); +static FailureOr +getNotalignedAllocFn(OpBuilder &b, const LLVMTypeConverter *typeConverter, + Operation *module, Type indexType) { + bool useGenericFn = typeConverter->getOptions().useGenericFunctions; + if (useGenericFn) + return LLVM::lookupOrCreateGenericAllocFn(b, module, indexType); + + return LLVM::lookupOrCreateMallocFn(b, module, indexType); +} + +static FailureOr +getAlignedAllocFn(OpBuilder &b, const LLVMTypeConverter *typeConverter, + Operation *module, Type indexType) { + bool useGenericFn = typeConverter->getOptions().useGenericFunctions; + + if (useGenericFn) + return LLVM::lookupOrCreateGenericAlignedAllocFn(b, module, indexType); + + return LLVM::lookupOrCreateAlignedAllocFn(b, module, indexType); +} + +/// Computes the aligned value for 'input' as follows: +/// bumped = input + alignement - 1 +/// aligned = bumped - bumped % alignment +static Value createAligned(ConversionPatternRewriter &rewriter, Location loc, + Value input, Value alignment) { + Value one = rewriter.create(loc, alignment.getType(), + rewriter.getIndexAttr(1)); + Value bump = rewriter.create(loc, alignment, one); + Value bumped = rewriter.create(loc, input, bump); + Value mod = rewriter.create(loc, bumped, alignment); + return rewriter.create(loc, bumped, mod); +} + +/// Computes the byte size for the MemRef element type. +static unsigned getMemRefEltSizeInBytes(const LLVMTypeConverter *typeConverter, + MemRefType memRefType, Operation *op, + const DataLayout *defaultLayout) { + const DataLayout *layout = defaultLayout; + if (const DataLayoutAnalysis *analysis = + typeConverter->getDataLayoutAnalysis()) { + layout = &analysis->getAbove(op); + } + Type elementType = memRefType.getElementType(); + if (auto memRefElementType = dyn_cast(elementType)) + return typeConverter->getMemRefDescriptorSize(memRefElementType, *layout); + if (auto memRefElementType = dyn_cast(elementType)) + return typeConverter->getUnrankedMemRefDescriptorSize(memRefElementType, + *layout); + return layout->getTypeSize(elementType); +} + +static Value castAllocFuncResult(ConversionPatternRewriter &rewriter, + Location loc, Value allocatedPtr, + MemRefType memRefType, Type elementPtrType, + const LLVMTypeConverter &typeConverter) { + auto allocatedPtrTy = cast(allocatedPtr.getType()); + FailureOr maybeMemrefAddrSpace = + typeConverter.getMemRefAddressSpace(memRefType); + assert(succeeded(maybeMemrefAddrSpace) && "unsupported address space"); + unsigned memrefAddrSpace = *maybeMemrefAddrSpace; + if (allocatedPtrTy.getAddressSpace() != memrefAddrSpace) + allocatedPtr = rewriter.create( + loc, LLVM::LLVMPointerType::get(rewriter.getContext(), memrefAddrSpace), + allocatedPtr); + return allocatedPtr; +} + +struct AllocOpLowering : public ConvertOpToLLVMPattern { + using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern; + + LogicalResult + matchAndRewrite(memref::AllocOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + auto loc = op.getLoc(); + MemRefType memRefType = op.getType(); + if (!isConvertibleAndHasIdentityMaps(memRefType)) + return rewriter.notifyMatchFailure(op, "incompatible memref type"); + + // Get or insert alloc function into the module. + FailureOr allocFuncOp = getNotalignedAllocFn( + rewriter, getTypeConverter(), + op->getParentWithTrait(), getIndexType()); + if (failed(allocFuncOp)) + return failure(); + + // Get actual sizes of the memref as values: static sizes are constant + // values and dynamic sizes are passed to 'alloc' as operands. In case of + // zero-dimensional memref, assume a scalar (size 1). + SmallVector sizes; + SmallVector strides; + Value sizeBytes; + + this->getMemRefDescriptorSizes(loc, memRefType, adaptor.getOperands(), + rewriter, sizes, strides, sizeBytes, true); + + Value alignment = getAlignment(rewriter, loc, op); + if (alignment) { + // Adjust the allocation size to consider alignment. + sizeBytes = rewriter.create(loc, sizeBytes, alignment); + } + + // Allocate the underlying buffer. + Type elementPtrType = this->getElementPtrType(memRefType); + assert(elementPtrType && "could not compute element ptr type"); + auto results = + rewriter.create(loc, allocFuncOp.value(), sizeBytes); + + Value allocatedPtr = + castAllocFuncResult(rewriter, loc, results.getResult(), memRefType, + elementPtrType, *getTypeConverter()); + Value alignedPtr = allocatedPtr; + if (alignment) { + // Compute the aligned pointer. + Value allocatedInt = + rewriter.create(loc, getIndexType(), allocatedPtr); + Value alignmentInt = + createAligned(rewriter, loc, allocatedInt, alignment); + alignedPtr = + rewriter.create(loc, elementPtrType, alignmentInt); + } + + // Create the MemRef descriptor. + auto memRefDescriptor = this->createMemRefDescriptor( + loc, memRefType, allocatedPtr, alignedPtr, sizes, strides, rewriter); + + // Return the final value of the descriptor. + rewriter.replaceOp(op, {memRefDescriptor}); + return success(); + } + + /// Computes the alignment for the given memory allocation op. + template + Value getAlignment(ConversionPatternRewriter &rewriter, Location loc, + OpType op) const { + MemRefType memRefType = op.getType(); + Value alignment; + if (auto alignmentAttr = op.getAlignment()) { + Type indexType = getIndexType(); + alignment = + createIndexAttrConstant(rewriter, loc, indexType, *alignmentAttr); + } else if (!memRefType.getElementType().isSignlessIntOrIndexOrFloat()) { + // In the case where no alignment is specified, we may want to override + // `malloc's` behavior. `malloc` typically aligns at the size of the + // biggest scalar on a target HW. For non-scalars, use the natural + // alignment of the LLVM type given by the LLVM DataLayout. + alignment = getSizeInBytes(loc, memRefType.getElementType(), rewriter); + } + return alignment; } }; -struct AlignedAllocOpLowering : public AllocLikeOpLLVMLowering { - AlignedAllocOpLowering(const LLVMTypeConverter &converter) - : AllocLikeOpLLVMLowering(memref::AllocOp::getOperationName(), - converter) {} - std::tuple allocateBuffer(ConversionPatternRewriter &rewriter, - Location loc, Value sizeBytes, - Operation *op) const override { - Value ptr = allocateBufferAutoAlign( - rewriter, loc, sizeBytes, op, &defaultLayout, - alignedAllocationGetAlignment(rewriter, loc, cast(op), - &defaultLayout)); - if (!ptr) - return std::make_tuple(Value(), Value()); - return std::make_tuple(ptr, ptr); +struct AlignedAllocOpLowering : public ConvertOpToLLVMPattern { + using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern; + + LogicalResult + matchAndRewrite(memref::AllocOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + auto loc = op.getLoc(); + MemRefType memRefType = op.getType(); + if (!isConvertibleAndHasIdentityMaps(memRefType)) + return rewriter.notifyMatchFailure(op, "incompatible memref type"); + + // Get or insert alloc function into module. + FailureOr allocFuncOp = getAlignedAllocFn( + rewriter, getTypeConverter(), + op->getParentWithTrait(), getIndexType()); + if (failed(allocFuncOp)) + return failure(); + + // Get actual sizes of the memref as values: static sizes are constant + // values and dynamic sizes are passed to 'alloc' as operands. In case of + // zero-dimensional memref, assume a scalar (size 1). + SmallVector sizes; + SmallVector strides; + Value sizeBytes; + + this->getMemRefDescriptorSizes(loc, memRefType, adaptor.getOperands(), + rewriter, sizes, strides, sizeBytes, !false); + + int64_t alignment = alignedAllocationGetAlignment(op, &defaultLayout); + + Value allocAlignment = + createIndexAttrConstant(rewriter, loc, getIndexType(), alignment); + + // Function aligned_alloc requires size to be a multiple of alignment; we + // pad the size to the next multiple if necessary. + if (!isMemRefSizeMultipleOf(memRefType, alignment, op, &defaultLayout)) + sizeBytes = createAligned(rewriter, loc, sizeBytes, allocAlignment); + + Type elementPtrType = this->getElementPtrType(memRefType); + auto results = rewriter.create( + loc, allocFuncOp.value(), ValueRange({allocAlignment, sizeBytes})); + + Value ptr = + castAllocFuncResult(rewriter, loc, results.getResult(), memRefType, + elementPtrType, *getTypeConverter()); + + // Create the MemRef descriptor. + auto memRefDescriptor = this->createMemRefDescriptor( + loc, memRefType, ptr, ptr, sizes, strides, rewriter); + + // Return the final value of the descriptor. + rewriter.replaceOp(op, {memRefDescriptor}); + return success(); + } + + /// The minimum alignment to use with aligned_alloc (has to be a power of 2). + static constexpr uint64_t kMinAlignedAllocAlignment = 16UL; + + /// Computes the alignment for aligned_alloc used to allocate the buffer for + /// the memory allocation op. + /// + /// Aligned_alloc requires the allocation size to be a power of two, and the + /// allocation size to be a multiple of the alignment. + int64_t alignedAllocationGetAlignment(memref::AllocOp op, + const DataLayout *defaultLayout) const { + if (std::optional alignment = op.getAlignment()) + return *alignment; + + // Whenever we don't have alignment set, we will use an alignment + // consistent with the element type; since the allocation size has to be a + // power of two, we will bump to the next power of two if it isn't. + unsigned eltSizeBytes = getMemRefEltSizeInBytes( + getTypeConverter(), op.getType(), op, defaultLayout); + return std::max(kMinAlignedAllocAlignment, + llvm::PowerOf2Ceil(eltSizeBytes)); + } + + /// Returns true if the memref size in bytes is known to be a multiple of + /// factor. + bool isMemRefSizeMultipleOf(MemRefType type, uint64_t factor, Operation *op, + const DataLayout *defaultLayout) const { + uint64_t sizeDivisor = + getMemRefEltSizeInBytes(getTypeConverter(), type, op, defaultLayout); + for (unsigned i = 0, e = type.getRank(); i < e; i++) { + if (type.isDynamicDim(i)) + continue; + sizeDivisor = sizeDivisor * type.getDimSize(i); + } + return sizeDivisor % factor == 0; } private: @@ -87,38 +300,52 @@ struct AlignedAllocOpLowering : public AllocLikeOpLLVMLowering { DataLayout defaultLayout; }; -struct AllocaOpLowering : public AllocLikeOpLLVMLowering { - AllocaOpLowering(const LLVMTypeConverter &converter) - : AllocLikeOpLLVMLowering(memref::AllocaOp::getOperationName(), - converter) { - setRequiresNumElements(); - } +struct AllocaOpLowering : public ConvertOpToLLVMPattern { + using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern; /// Allocates the underlying buffer using the right call. `allocatedBytePtr` /// is set to null for stack allocations. `accessAlignment` is set if /// alignment is needed post allocation (for eg. in conjunction with malloc). - std::tuple allocateBuffer(ConversionPatternRewriter &rewriter, - Location loc, Value size, - Operation *op) const override { + LogicalResult + matchAndRewrite(memref::AllocaOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + auto loc = op.getLoc(); + MemRefType memRefType = op.getType(); + if (!isConvertibleAndHasIdentityMaps(memRefType)) + return rewriter.notifyMatchFailure(op, "incompatible memref type"); + + // Get actual sizes of the memref as values: static sizes are constant + // values and dynamic sizes are passed to 'alloc' as operands. In case of + // zero-dimensional memref, assume a scalar (size 1). + SmallVector sizes; + SmallVector strides; + Value size; + + this->getMemRefDescriptorSizes(loc, memRefType, adaptor.getOperands(), + rewriter, sizes, strides, size, !true); // With alloca, one gets a pointer to the element type right away. // For stack allocations. - auto allocaOp = cast(op); auto elementType = - typeConverter->convertType(allocaOp.getType().getElementType()); + typeConverter->convertType(op.getType().getElementType()); FailureOr maybeAddressSpace = - getTypeConverter()->getMemRefAddressSpace(allocaOp.getType()); - if (failed(maybeAddressSpace)) - return std::make_tuple(Value(), Value()); + getTypeConverter()->getMemRefAddressSpace(op.getType()); + assert(succeeded(maybeAddressSpace) && "unsupported address space"); unsigned addrSpace = *maybeAddressSpace; auto elementPtrType = LLVM::LLVMPointerType::get(rewriter.getContext(), addrSpace); - auto allocatedElementPtr = - rewriter.create(loc, elementPtrType, elementType, size, - allocaOp.getAlignment().value_or(0)); + auto allocatedElementPtr = rewriter.create( + loc, elementPtrType, elementType, size, op.getAlignment().value_or(0)); - return std::make_tuple(allocatedElementPtr, allocatedElementPtr); + // Create the MemRef descriptor. + auto memRefDescriptor = this->createMemRefDescriptor( + loc, memRefType, allocatedElementPtr, allocatedElementPtr, sizes, + strides, rewriter); + + // Return the final value of the descriptor. + rewriter.replaceOp(op, {memRefDescriptor}); + return success(); } }; @@ -527,31 +754,43 @@ struct GlobalMemrefOpLowering /// GetGlobalMemrefOp is lowered into a Memref descriptor with the pointer to /// the first element stashed into the descriptor. This reuses /// `AllocLikeOpLowering` to reuse the Memref descriptor construction. -struct GetGlobalMemrefOpLowering : public AllocLikeOpLLVMLowering { - GetGlobalMemrefOpLowering(const LLVMTypeConverter &converter) - : AllocLikeOpLLVMLowering(memref::GetGlobalOp::getOperationName(), - converter) {} +struct GetGlobalMemrefOpLowering + : public ConvertOpToLLVMPattern { + using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern; /// Buffer "allocation" for memref.get_global op is getting the address of /// the global variable referenced. - std::tuple allocateBuffer(ConversionPatternRewriter &rewriter, - Location loc, Value sizeBytes, - Operation *op) const override { - auto getGlobalOp = cast(op); - MemRefType type = cast(getGlobalOp.getResult().getType()); + LogicalResult + matchAndRewrite(memref::GetGlobalOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + auto loc = op.getLoc(); + MemRefType memRefType = op.getType(); + if (!isConvertibleAndHasIdentityMaps(memRefType)) + return rewriter.notifyMatchFailure(op, "incompatible memref type"); + + // Get actual sizes of the memref as values: static sizes are constant + // values and dynamic sizes are passed to 'alloc' as operands. In case of + // zero-dimensional memref, assume a scalar (size 1). + SmallVector sizes; + SmallVector strides; + Value sizeBytes; + + this->getMemRefDescriptorSizes(loc, memRefType, adaptor.getOperands(), + rewriter, sizes, strides, sizeBytes, !false); + + MemRefType type = cast(op.getResult().getType()); // This is called after a type conversion, which would have failed if this // call fails. FailureOr maybeAddressSpace = getTypeConverter()->getMemRefAddressSpace(type); - if (failed(maybeAddressSpace)) - return std::make_tuple(Value(), Value()); + assert(succeeded(maybeAddressSpace) && "unsupported address space"); unsigned memSpace = *maybeAddressSpace; Type arrayTy = convertGlobalMemrefTypeToLLVM(type, *getTypeConverter()); auto ptrTy = LLVM::LLVMPointerType::get(rewriter.getContext(), memSpace); auto addressOf = - rewriter.create(loc, ptrTy, getGlobalOp.getName()); + rewriter.create(loc, ptrTy, op.getName()); // Get the address of the first element in the array by creating a GEP with // the address of the GV as the base, and (rank + 1) number of 0 indices. @@ -570,7 +809,13 @@ struct GetGlobalMemrefOpLowering : public AllocLikeOpLLVMLowering { // Both allocated and aligned pointers are same. We could potentially stash // a nullptr for the allocated pointer since we do not expect any dealloc. - return std::make_tuple(deadBeefPtr, gep); + // Create the MemRef descriptor. + auto memRefDescriptor = this->createMemRefDescriptor( + loc, memRefType, deadBeefPtr, gep, sizes, strides, rewriter); + + // Return the final value of the descriptor. + rewriter.replaceOp(op, {memRefDescriptor}); + return success(); } };