diff --git a/mlir/include/mlir/Conversion/ControlFlowToLLVM/ControlFlowToLLVM.h b/mlir/include/mlir/Conversion/ControlFlowToLLVM/ControlFlowToLLVM.h index 88f18022da9bb..2dfb6b03bcfcd 100644 --- a/mlir/include/mlir/Conversion/ControlFlowToLLVM/ControlFlowToLLVM.h +++ b/mlir/include/mlir/Conversion/ControlFlowToLLVM/ControlFlowToLLVM.h @@ -20,6 +20,7 @@ class DialectRegistry; class LLVMTypeConverter; class RewritePatternSet; class Pass; +class SymbolTableCollection; #define GEN_PASS_DECL_CONVERTCONTROLFLOWTOLLVMPASS #include "mlir/Conversion/Passes.h.inc" @@ -39,9 +40,9 @@ void populateControlFlowToLLVMConversionPatterns( /// Populate the cf.assert to LLVM conversion pattern. If `abortOnFailure` is /// set to false, the program execution continues when a condition is /// unsatisfied. -void populateAssertToLLVMConversionPattern(const LLVMTypeConverter &converter, - RewritePatternSet &patterns, - bool abortOnFailure = true); +void populateAssertToLLVMConversionPattern( + const LLVMTypeConverter &converter, RewritePatternSet &patterns, + bool abortOnFailure = true, SymbolTableCollection *symbolTables = nullptr); void registerConvertControlFlowToLLVMInterface(DialectRegistry ®istry); diff --git a/mlir/include/mlir/Conversion/FuncToLLVM/ConvertFuncToLLVM.h b/mlir/include/mlir/Conversion/FuncToLLVM/ConvertFuncToLLVM.h index b1ea2740c0605..e530b0a43b8e0 100644 --- a/mlir/include/mlir/Conversion/FuncToLLVM/ConvertFuncToLLVM.h +++ b/mlir/include/mlir/Conversion/FuncToLLVM/ConvertFuncToLLVM.h @@ -27,20 +27,23 @@ class DialectRegistry; class LLVMTypeConverter; class RewritePatternSet; class SymbolTable; +class SymbolTableCollection; /// Convert input FunctionOpInterface operation to LLVMFuncOp by using the /// provided LLVMTypeConverter. Return failure if failed to so. FailureOr convertFuncOpToLLVMFuncOp(FunctionOpInterface funcOp, ConversionPatternRewriter &rewriter, - const LLVMTypeConverter &converter); + const LLVMTypeConverter &converter, + SymbolTableCollection *symbolTables = nullptr); /// Collect the default pattern to convert a FuncOp to the LLVM dialect. If /// `emitCWrappers` is set, the pattern will also produce functions /// that pass memref descriptors by pointer-to-structure in addition to the /// default unpacked form. void populateFuncToLLVMFuncOpConversionPattern( - const LLVMTypeConverter &converter, RewritePatternSet &patterns); + const LLVMTypeConverter &converter, RewritePatternSet &patterns, + SymbolTableCollection *symbolTables = nullptr); /// Collect the patterns to convert from the Func dialect to LLVM. The /// conversion patterns capture the LLVMTypeConverter and the LowerToLLVMOptions @@ -57,7 +60,7 @@ void populateFuncToLLVMFuncOpConversionPattern( /// not an error to provide it anyway. void populateFuncToLLVMConversionPatterns( const LLVMTypeConverter &converter, RewritePatternSet &patterns, - const SymbolTable *symbolTable = nullptr); + SymbolTableCollection *symbolTables = nullptr); void registerConvertFuncToLLVMInterface(DialectRegistry ®istry); diff --git a/mlir/include/mlir/Conversion/LLVMCommon/PrintCallHelper.h b/mlir/include/mlir/Conversion/LLVMCommon/PrintCallHelper.h index 33402301115b7..d7de40555bb6a 100644 --- a/mlir/include/mlir/Conversion/LLVMCommon/PrintCallHelper.h +++ b/mlir/include/mlir/Conversion/LLVMCommon/PrintCallHelper.h @@ -17,6 +17,7 @@ namespace mlir { class OpBuilder; class LLVMTypeConverter; +class SymbolTableCollection; namespace LLVM { @@ -26,7 +27,8 @@ namespace LLVM { LogicalResult createPrintStrCall( OpBuilder &builder, Location loc, ModuleOp moduleOp, StringRef symbolName, StringRef string, const LLVMTypeConverter &typeConverter, - bool addNewline = true, std::optional runtimeFunctionName = {}); + bool addNewline = true, std::optional runtimeFunctionName = {}, + SymbolTableCollection *symbolTables = nullptr); } // namespace LLVM } // namespace mlir diff --git a/mlir/include/mlir/Conversion/MemRefToLLVM/MemRefToLLVM.h b/mlir/include/mlir/Conversion/MemRefToLLVM/MemRefToLLVM.h index 996a64baf9dd5..e93d5bdce7bf2 100644 --- a/mlir/include/mlir/Conversion/MemRefToLLVM/MemRefToLLVM.h +++ b/mlir/include/mlir/Conversion/MemRefToLLVM/MemRefToLLVM.h @@ -16,6 +16,7 @@ class DialectRegistry; class Pass; class LLVMTypeConverter; class RewritePatternSet; +class SymbolTableCollection; #define GEN_PASS_DECL_FINALIZEMEMREFTOLLVMCONVERSIONPASS #include "mlir/Conversion/Passes.h.inc" @@ -23,7 +24,8 @@ class RewritePatternSet; /// Collect a set of patterns to convert memory-related operations from the /// MemRef dialect to the LLVM dialect. void populateFinalizeMemRefToLLVMConversionPatterns( - const LLVMTypeConverter &converter, RewritePatternSet &patterns); + const LLVMTypeConverter &converter, RewritePatternSet &patterns, + SymbolTableCollection *symbolTables = nullptr); void registerConvertMemRefToLLVMInterface(DialectRegistry ®istry); diff --git a/mlir/include/mlir/Dialect/LLVMIR/FunctionCallUtils.h b/mlir/include/mlir/Dialect/LLVMIR/FunctionCallUtils.h index 4a7ec6f2efe64..8ad9ed18acebd 100644 --- a/mlir/include/mlir/Dialect/LLVMIR/FunctionCallUtils.h +++ b/mlir/include/mlir/Dialect/LLVMIR/FunctionCallUtils.h @@ -24,6 +24,7 @@ class OpBuilder; class Operation; class Type; class ValueRange; +class SymbolTableCollection; namespace LLVM { class LLVMFuncOp; @@ -33,55 +34,73 @@ class LLVMFuncOp; /// implemented separately (e.g. as part of a support runtime library or as part /// of the libc). /// Failure if an unexpected version of function is found. -FailureOr lookupOrCreatePrintI64Fn(OpBuilder &b, - Operation *moduleOp); -FailureOr lookupOrCreatePrintU64Fn(OpBuilder &b, - Operation *moduleOp); -FailureOr lookupOrCreatePrintF16Fn(OpBuilder &b, - Operation *moduleOp); -FailureOr lookupOrCreatePrintBF16Fn(OpBuilder &b, - Operation *moduleOp); -FailureOr lookupOrCreatePrintF32Fn(OpBuilder &b, - Operation *moduleOp); -FailureOr lookupOrCreatePrintF64Fn(OpBuilder &b, - Operation *moduleOp); +FailureOr +lookupOrCreatePrintI64Fn(OpBuilder &b, Operation *moduleOp, + SymbolTableCollection *symbolTables = nullptr); +FailureOr +lookupOrCreatePrintU64Fn(OpBuilder &b, Operation *moduleOp, + SymbolTableCollection *symbolTables = nullptr); +FailureOr +lookupOrCreatePrintF16Fn(OpBuilder &b, Operation *moduleOp, + SymbolTableCollection *symbolTables = nullptr); +FailureOr +lookupOrCreatePrintBF16Fn(OpBuilder &b, Operation *moduleOp, + SymbolTableCollection *symbolTables = nullptr); +FailureOr +lookupOrCreatePrintF32Fn(OpBuilder &b, Operation *moduleOp, + SymbolTableCollection *symbolTables = nullptr); +FailureOr +lookupOrCreatePrintF64Fn(OpBuilder &b, Operation *moduleOp, + SymbolTableCollection *symbolTables = nullptr); /// Declares a function to print a C-string. /// If a custom runtime function is defined via `runtimeFunctionName`, it must /// have the signature void(char const*). The default function is `printString`. FailureOr lookupOrCreatePrintStringFn(OpBuilder &b, Operation *moduleOp, - std::optional runtimeFunctionName = {}); -FailureOr lookupOrCreatePrintOpenFn(OpBuilder &b, - Operation *moduleOp); -FailureOr lookupOrCreatePrintCloseFn(OpBuilder &b, - Operation *moduleOp); -FailureOr lookupOrCreatePrintCommaFn(OpBuilder &b, - Operation *moduleOp); -FailureOr lookupOrCreatePrintNewlineFn(OpBuilder &b, - Operation *moduleOp); + std::optional runtimeFunctionName = {}, + SymbolTableCollection *symbolTables = nullptr); +FailureOr +lookupOrCreatePrintOpenFn(OpBuilder &b, Operation *moduleOp, + SymbolTableCollection *symbolTables = nullptr); +FailureOr +lookupOrCreatePrintCloseFn(OpBuilder &b, Operation *moduleOp, + SymbolTableCollection *symbolTables = nullptr); +FailureOr +lookupOrCreatePrintCommaFn(OpBuilder &b, Operation *moduleOp, + SymbolTableCollection *symbolTables = nullptr); +FailureOr +lookupOrCreatePrintNewlineFn(OpBuilder &b, Operation *moduleOp, + SymbolTableCollection *symbolTables = nullptr); +FailureOr +lookupOrCreateMallocFn(OpBuilder &b, Operation *moduleOp, Type indexType, + SymbolTableCollection *symbolTables = nullptr); FailureOr -lookupOrCreateMallocFn(OpBuilder &b, Operation *moduleOp, Type indexType); +lookupOrCreateAlignedAllocFn(OpBuilder &b, Operation *moduleOp, Type indexType, + SymbolTableCollection *symbolTables = nullptr); FailureOr -lookupOrCreateAlignedAllocFn(OpBuilder &b, Operation *moduleOp, Type indexType); -FailureOr lookupOrCreateFreeFn(OpBuilder &b, - Operation *moduleOp); +lookupOrCreateFreeFn(OpBuilder &b, Operation *moduleOp, + SymbolTableCollection *symbolTables = nullptr); FailureOr -lookupOrCreateGenericAllocFn(OpBuilder &b, Operation *moduleOp, Type indexType); +lookupOrCreateGenericAllocFn(OpBuilder &b, Operation *moduleOp, Type indexType, + SymbolTableCollection *symbolTables = nullptr); +FailureOr lookupOrCreateGenericAlignedAllocFn( + OpBuilder &b, Operation *moduleOp, Type indexType, + SymbolTableCollection *symbolTables = nullptr); FailureOr -lookupOrCreateGenericAlignedAllocFn(OpBuilder &b, Operation *moduleOp, - Type indexType); -FailureOr lookupOrCreateGenericFreeFn(OpBuilder &b, - Operation *moduleOp); +lookupOrCreateGenericFreeFn(OpBuilder &b, Operation *moduleOp, + SymbolTableCollection *symbolTables = nullptr); FailureOr lookupOrCreateMemRefCopyFn(OpBuilder &b, Operation *moduleOp, Type indexType, - Type unrankedDescriptorType); + Type unrankedDescriptorType, + SymbolTableCollection *symbolTables = nullptr); /// Create a FuncOp with signature `resultType`(`paramTypes`)` and name `name`. /// Return a failure if the FuncOp found has unexpected signature. FailureOr lookupOrCreateFn(OpBuilder &b, Operation *moduleOp, StringRef name, ArrayRef paramTypes = {}, Type resultType = {}, - bool isVarArg = false, bool isReserved = false); + bool isVarArg = false, bool isReserved = false, + SymbolTableCollection *symbolTables = nullptr); } // namespace LLVM } // namespace mlir diff --git a/mlir/lib/Conversion/ControlFlowToLLVM/ControlFlowToLLVM.cpp b/mlir/lib/Conversion/ControlFlowToLLVM/ControlFlowToLLVM.cpp index d31d7d801e149..3d0804fd11b6b 100644 --- a/mlir/lib/Conversion/ControlFlowToLLVM/ControlFlowToLLVM.cpp +++ b/mlir/lib/Conversion/ControlFlowToLLVM/ControlFlowToLLVM.cpp @@ -44,9 +44,10 @@ namespace { /// lowering. struct AssertOpLowering : public ConvertOpToLLVMPattern { explicit AssertOpLowering(const LLVMTypeConverter &typeConverter, - bool abortOnFailedAssert = true) + bool abortOnFailedAssert = true, + SymbolTableCollection *symbolTables = nullptr) : ConvertOpToLLVMPattern(typeConverter, /*benefit=*/1), - abortOnFailedAssert(abortOnFailedAssert) {} + abortOnFailedAssert(abortOnFailedAssert), symbolTables(symbolTables) {} LogicalResult matchAndRewrite(cf::AssertOp op, OpAdaptor adaptor, @@ -64,7 +65,7 @@ struct AssertOpLowering : public ConvertOpToLLVMPattern { auto createResult = LLVM::createPrintStrCall( rewriter, loc, module, "assert_msg", op.getMsg(), *getTypeConverter(), /*addNewLine=*/false, - /*runtimeFunctionName=*/"puts"); + /*runtimeFunctionName=*/"puts", symbolTables); if (createResult.failed()) return failure(); @@ -96,6 +97,8 @@ struct AssertOpLowering : public ConvertOpToLLVMPattern { /// If set to `false`, messages are printed but program execution continues. /// This is useful for testing asserts. bool abortOnFailedAssert = true; + + SymbolTableCollection *symbolTables = nullptr; }; /// Helper function for converting branch ops. This function converts the @@ -232,8 +235,8 @@ void mlir::cf::populateControlFlowToLLVMConversionPatterns( void mlir::cf::populateAssertToLLVMConversionPattern( const LLVMTypeConverter &converter, RewritePatternSet &patterns, - bool abortOnFailure) { - patterns.add(converter, abortOnFailure); + bool abortOnFailure, SymbolTableCollection *symbolTables) { + patterns.add(converter, abortOnFailure, symbolTables); } //===----------------------------------------------------------------------===// diff --git a/mlir/lib/Conversion/FuncToLLVM/FuncToLLVM.cpp b/mlir/lib/Conversion/FuncToLLVM/FuncToLLVM.cpp index 538016927256b..4499cbd4d1a20 100644 --- a/mlir/lib/Conversion/FuncToLLVM/FuncToLLVM.cpp +++ b/mlir/lib/Conversion/FuncToLLVM/FuncToLLVM.cpp @@ -299,10 +299,9 @@ static void restoreByValRefArgumentType( } } -FailureOr -mlir::convertFuncOpToLLVMFuncOp(FunctionOpInterface funcOp, - ConversionPatternRewriter &rewriter, - const LLVMTypeConverter &converter) { +FailureOr mlir::convertFuncOpToLLVMFuncOp( + FunctionOpInterface funcOp, ConversionPatternRewriter &rewriter, + const LLVMTypeConverter &converter, SymbolTableCollection *symbolTables) { // Check the funcOp has `FunctionType`. auto funcTy = dyn_cast(funcOp.getFunctionType()); if (!funcTy) @@ -361,10 +360,25 @@ mlir::convertFuncOpToLLVMFuncOp(FunctionOpInterface funcOp, SmallVector attributes; filterFuncAttributes(funcOp, attributes); + + Operation *symbolTableOp = funcOp->getParentWithTrait(); + + if (symbolTables && symbolTableOp) { + SymbolTable &symbolTable = symbolTables->getSymbolTable(symbolTableOp); + symbolTable.remove(funcOp); + } + auto newFuncOp = rewriter.create( funcOp.getLoc(), funcOp.getName(), llvmType, linkage, /*dsoLocal=*/false, /*cconv=*/LLVM::CConv::C, /*comdat=*/nullptr, attributes); + + if (symbolTables && symbolTableOp) { + auto ip = rewriter.getInsertionPoint(); + SymbolTable &symbolTable = symbolTables->getSymbolTable(symbolTableOp); + symbolTable.insert(newFuncOp, ip); + } + cast(newFuncOp.getOperation()) .setVisibility(funcOp.getVisibility()); @@ -473,16 +487,20 @@ namespace { /// FuncOp legalization pattern that converts MemRef arguments to pointers to /// MemRef descriptors (LLVM struct data types) containing all the MemRef type /// information. -struct FuncOpConversion : public ConvertOpToLLVMPattern { - FuncOpConversion(const LLVMTypeConverter &converter) - : ConvertOpToLLVMPattern(converter) {} +class FuncOpConversion : public ConvertOpToLLVMPattern { + SymbolTableCollection *symbolTables = nullptr; + +public: + explicit FuncOpConversion(const LLVMTypeConverter &converter, + SymbolTableCollection *symbolTables = nullptr) + : ConvertOpToLLVMPattern(converter), symbolTables(symbolTables) {} LogicalResult matchAndRewrite(func::FuncOp funcOp, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { FailureOr newFuncOp = mlir::convertFuncOpToLLVMFuncOp( cast(funcOp.getOperation()), rewriter, - *getTypeConverter()); + *getTypeConverter(), symbolTables); if (failed(newFuncOp)) return rewriter.notifyMatchFailure(funcOp, "Could not convert funcop"); @@ -591,11 +609,11 @@ struct CallOpInterfaceLowering : public ConvertOpToLLVMPattern { class CallOpLowering : public CallOpInterfaceLowering { public: - CallOpLowering(const LLVMTypeConverter &typeConverter, - // Can be nullptr. - const SymbolTable *symbolTable, PatternBenefit benefit = 1) + explicit CallOpLowering(const LLVMTypeConverter &typeConverter, + SymbolTableCollection *symbolTables = nullptr, + PatternBenefit benefit = 1) : CallOpInterfaceLowering(typeConverter, benefit), - symbolTable(symbolTable) {} + symbolTables(symbolTables) {} LogicalResult matchAndRewrite(func::CallOp callOp, OpAdaptor adaptor, @@ -603,10 +621,10 @@ class CallOpLowering : public CallOpInterfaceLowering { bool useBarePtrCallConv = false; if (getTypeConverter()->getOptions().useBarePtrCallConv) { useBarePtrCallConv = true; - } else if (symbolTable != nullptr) { + } else if (symbolTables != nullptr) { // Fast lookup. Operation *callee = - symbolTable->lookup(callOp.getCalleeAttr().getValue()); + symbolTables->lookupNearestSymbolFrom(callOp, callOp.getCalleeAttr()); useBarePtrCallConv = callee != nullptr && callee->hasAttr(barePtrAttrName); } else { @@ -620,7 +638,7 @@ class CallOpLowering : public CallOpInterfaceLowering { } private: - const SymbolTable *symbolTable = nullptr; + SymbolTableCollection *symbolTables = nullptr; }; struct CallIndirectOpLowering @@ -731,16 +749,17 @@ struct ReturnOpLowering : public ConvertOpToLLVMPattern { } // namespace void mlir::populateFuncToLLVMFuncOpConversionPattern( - const LLVMTypeConverter &converter, RewritePatternSet &patterns) { - patterns.add(converter); + const LLVMTypeConverter &converter, RewritePatternSet &patterns, + SymbolTableCollection *symbolTables) { + patterns.add(converter, symbolTables); } void mlir::populateFuncToLLVMConversionPatterns( const LLVMTypeConverter &converter, RewritePatternSet &patterns, - const SymbolTable *symbolTable) { - populateFuncToLLVMFuncOpConversionPattern(converter, patterns); + SymbolTableCollection *symbolTables) { + populateFuncToLLVMFuncOpConversionPattern(converter, patterns, symbolTables); patterns.add(converter); - patterns.add(converter, symbolTable); + patterns.add(converter, symbolTables); patterns.add(converter); patterns.add(converter); } @@ -780,15 +799,11 @@ struct ConvertFuncToLLVMPass LLVMTypeConverter typeConverter(&getContext(), options, &dataLayoutAnalysis); - std::optional optSymbolTable = std::nullopt; - const SymbolTable *symbolTable = nullptr; - if (!options.useBarePtrCallConv) { - optSymbolTable.emplace(m); - symbolTable = &optSymbolTable.value(); - } - RewritePatternSet patterns(&getContext()); - populateFuncToLLVMConversionPatterns(typeConverter, patterns, symbolTable); + SymbolTableCollection symbolTables; + + populateFuncToLLVMConversionPatterns(typeConverter, patterns, + &symbolTables); LLVMConversionTarget target(getContext()); if (failed(applyPartialConversion(m, target, std::move(patterns)))) diff --git a/mlir/lib/Conversion/LLVMCommon/PrintCallHelper.cpp b/mlir/lib/Conversion/LLVMCommon/PrintCallHelper.cpp index 2815e05b3e11b..49c73fbc9dd79 100644 --- a/mlir/lib/Conversion/LLVMCommon/PrintCallHelper.cpp +++ b/mlir/lib/Conversion/LLVMCommon/PrintCallHelper.cpp @@ -17,8 +17,26 @@ using namespace mlir; using namespace llvm; -static std::string ensureSymbolNameIsUnique(ModuleOp moduleOp, - StringRef symbolName) { +/// Check if a given symbol name is already in use within the module operation. +/// If no symbol with such name is present, then the same identifier is +/// returned. Otherwise, a unique and yet unused identifier is computed starting +/// from the requested one. +static std::string +ensureSymbolNameIsUnique(ModuleOp moduleOp, StringRef symbolName, + SymbolTableCollection *symbolTables = nullptr) { + if (symbolTables) { + SymbolTable &symbolTable = symbolTables->getSymbolTable(moduleOp); + unsigned counter = 0; + SmallString<128> uniqueName = symbolTable.generateSymbolName<128>( + symbolName, + [&](const SmallString<128> &tentativeName) { + return symbolTable.lookupSymbolIn(moduleOp, tentativeName) != nullptr; + }, + counter); + + return static_cast(uniqueName); + } + static int counter = 0; std::string uniqueName = std::string(symbolName); while (moduleOp.lookupSymbol(uniqueName)) { @@ -30,7 +48,8 @@ static std::string ensureSymbolNameIsUnique(ModuleOp moduleOp, LogicalResult mlir::LLVM::createPrintStrCall( OpBuilder &builder, Location loc, ModuleOp moduleOp, StringRef symbolName, StringRef string, const LLVMTypeConverter &typeConverter, bool addNewline, - std::optional runtimeFunctionName) { + std::optional runtimeFunctionName, + SymbolTableCollection *symbolTables) { auto ip = builder.saveInsertionPoint(); builder.setInsertionPointToStart(moduleOp.getBody()); MLIRContext *ctx = builder.getContext(); @@ -49,7 +68,7 @@ LogicalResult mlir::LLVM::createPrintStrCall( LLVM::LLVMArrayType::get(IntegerType::get(ctx, 8), elementVals.size()); auto globalOp = builder.create( loc, arrayTy, /*constant=*/true, LLVM::Linkage::Private, - ensureSymbolNameIsUnique(moduleOp, symbolName), dataAttr); + ensureSymbolNameIsUnique(moduleOp, symbolName, symbolTables), dataAttr); auto ptrTy = LLVM::LLVMPointerType::get(builder.getContext()); // Emit call to `printStr` in runtime library. diff --git a/mlir/lib/Conversion/MemRefToLLVM/MemRefToLLVM.cpp b/mlir/lib/Conversion/MemRefToLLVM/MemRefToLLVM.cpp index 8ccf1bfc292d5..e8294a5234c4f 100644 --- a/mlir/lib/Conversion/MemRefToLLVM/MemRefToLLVM.cpp +++ b/mlir/lib/Conversion/MemRefToLLVM/MemRefToLLVM.cpp @@ -48,35 +48,39 @@ static bool isStaticStrideOrOffset(int64_t strideOrOffset) { } static FailureOr -getFreeFn(OpBuilder &b, const LLVMTypeConverter *typeConverter, - ModuleOp module) { +getFreeFn(OpBuilder &b, const LLVMTypeConverter *typeConverter, ModuleOp module, + SymbolTableCollection *symbolTables) { bool useGenericFn = typeConverter->getOptions().useGenericFunctions; if (useGenericFn) - return LLVM::lookupOrCreateGenericFreeFn(b, module); + return LLVM::lookupOrCreateGenericFreeFn(b, module, symbolTables); - return LLVM::lookupOrCreateFreeFn(b, module); + return LLVM::lookupOrCreateFreeFn(b, module, symbolTables); } static FailureOr getNotalignedAllocFn(OpBuilder &b, const LLVMTypeConverter *typeConverter, - Operation *module, Type indexType) { + Operation *module, Type indexType, + SymbolTableCollection *symbolTables) { bool useGenericFn = typeConverter->getOptions().useGenericFunctions; if (useGenericFn) - return LLVM::lookupOrCreateGenericAllocFn(b, module, indexType); + return LLVM::lookupOrCreateGenericAllocFn(b, module, indexType, + symbolTables); - return LLVM::lookupOrCreateMallocFn(b, module, indexType); + return LLVM::lookupOrCreateMallocFn(b, module, indexType, symbolTables); } static FailureOr getAlignedAllocFn(OpBuilder &b, const LLVMTypeConverter *typeConverter, - Operation *module, Type indexType) { + Operation *module, Type indexType, + SymbolTableCollection *symbolTables) { bool useGenericFn = typeConverter->getOptions().useGenericFunctions; if (useGenericFn) - return LLVM::lookupOrCreateGenericAlignedAllocFn(b, module, indexType); + return LLVM::lookupOrCreateGenericAlignedAllocFn(b, module, indexType, + symbolTables); - return LLVM::lookupOrCreateAlignedAllocFn(b, module, indexType); + return LLVM::lookupOrCreateAlignedAllocFn(b, module, indexType, symbolTables); } /// Computes the aligned value for 'input' as follows: @@ -126,8 +130,15 @@ static Value castAllocFuncResult(ConversionPatternRewriter &rewriter, return allocatedPtr; } -struct AllocOpLowering : public ConvertOpToLLVMPattern { - using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern; +class AllocOpLowering : public ConvertOpToLLVMPattern { + SymbolTableCollection *symbolTables = nullptr; + +public: + explicit AllocOpLowering(const LLVMTypeConverter &typeConverter, + SymbolTableCollection *symbolTables = nullptr, + PatternBenefit benefit = 1) + : ConvertOpToLLVMPattern(typeConverter, benefit), + symbolTables(symbolTables) {} LogicalResult matchAndRewrite(memref::AllocOp op, OpAdaptor adaptor, @@ -138,9 +149,10 @@ struct AllocOpLowering : public ConvertOpToLLVMPattern { return rewriter.notifyMatchFailure(op, "incompatible memref type"); // Get or insert alloc function into the module. - FailureOr allocFuncOp = getNotalignedAllocFn( - rewriter, getTypeConverter(), - op->getParentWithTrait(), getIndexType()); + FailureOr allocFuncOp = + getNotalignedAllocFn(rewriter, getTypeConverter(), + op->getParentWithTrait(), + getIndexType(), symbolTables); if (failed(allocFuncOp)) return failure(); @@ -210,8 +222,15 @@ struct AllocOpLowering : public ConvertOpToLLVMPattern { } }; -struct AlignedAllocOpLowering : public ConvertOpToLLVMPattern { - using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern; +class AlignedAllocOpLowering : public ConvertOpToLLVMPattern { + SymbolTableCollection *symbolTables = nullptr; + +public: + explicit AlignedAllocOpLowering(const LLVMTypeConverter &typeConverter, + SymbolTableCollection *symbolTables = nullptr, + PatternBenefit benefit = 1) + : ConvertOpToLLVMPattern(typeConverter, benefit), + symbolTables(symbolTables) {} LogicalResult matchAndRewrite(memref::AllocOp op, OpAdaptor adaptor, @@ -222,9 +241,10 @@ struct AlignedAllocOpLowering : public ConvertOpToLLVMPattern { return rewriter.notifyMatchFailure(op, "incompatible memref type"); // Get or insert alloc function into module. - FailureOr allocFuncOp = getAlignedAllocFn( - rewriter, getTypeConverter(), - op->getParentWithTrait(), getIndexType()); + FailureOr allocFuncOp = + getAlignedAllocFn(rewriter, getTypeConverter(), + op->getParentWithTrait(), + getIndexType(), symbolTables); if (failed(allocFuncOp)) return failure(); @@ -446,18 +466,23 @@ struct AssumeAlignmentOpLowering // A `dealloc` is converted into a call to `free` on the underlying data buffer. // The memref descriptor being an SSA value, there is no need to clean it up // in any way. -struct DeallocOpLowering : public ConvertOpToLLVMPattern { - using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern; +class DeallocOpLowering : public ConvertOpToLLVMPattern { + SymbolTableCollection *symbolTables = nullptr; - explicit DeallocOpLowering(const LLVMTypeConverter &converter) - : ConvertOpToLLVMPattern(converter) {} +public: + explicit DeallocOpLowering(const LLVMTypeConverter &typeConverter, + SymbolTableCollection *symbolTables = nullptr, + PatternBenefit benefit = 1) + : ConvertOpToLLVMPattern(typeConverter, benefit), + symbolTables(symbolTables) {} LogicalResult matchAndRewrite(memref::DeallocOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { // Insert the `free` declaration if it is not already present. - FailureOr freeFunc = getFreeFn( - rewriter, getTypeConverter(), op->getParentOfType()); + FailureOr freeFunc = + getFreeFn(rewriter, getTypeConverter(), op->getParentOfType(), + symbolTables); if (failed(freeFunc)) return failure(); Value allocatedPtr; @@ -710,9 +735,15 @@ convertGlobalMemrefTypeToLLVM(MemRefType type, } /// GlobalMemrefOp is lowered to a LLVM Global Variable. -struct GlobalMemrefOpLowering - : public ConvertOpToLLVMPattern { - using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern; +class GlobalMemrefOpLowering : public ConvertOpToLLVMPattern { + SymbolTableCollection *symbolTables = nullptr; + +public: + explicit GlobalMemrefOpLowering(const LLVMTypeConverter &typeConverter, + SymbolTableCollection *symbolTables = nullptr, + PatternBenefit benefit = 1) + : ConvertOpToLLVMPattern(typeConverter, benefit), + symbolTables(symbolTables) {} LogicalResult matchAndRewrite(memref::GlobalOp global, OpAdaptor adaptor, @@ -743,9 +774,31 @@ struct GlobalMemrefOpLowering if (failed(addressSpace)) return global.emitOpError( "memory space cannot be converted to an integer address space"); + + if (symbolTables) { + Operation *symbolTableOp = + global->getParentWithTrait(); + + if (symbolTableOp) { + SymbolTable &symbolTable = symbolTables->getSymbolTable(symbolTableOp); + symbolTable.remove(global); + } + } + auto newGlobal = rewriter.replaceOpWithNewOp( global, arrayTy, global.getConstant(), linkage, global.getSymName(), initialValue, alignment, *addressSpace); + + if (symbolTables) { + Operation *symbolTableOp = + global->getParentWithTrait(); + + if (symbolTableOp) { + SymbolTable &symbolTable = symbolTables->getSymbolTable(symbolTableOp); + symbolTable.insert(newGlobal, rewriter.getInsertionPoint()); + } + } + if (!global.isExternal() && global.isUninitialized()) { rewriter.createBlock(&newGlobal.getInitializerRegion()); Value undef[] = { @@ -997,8 +1050,15 @@ struct MemRefCastOpLowering : public ConvertOpToLLVMPattern { /// For memrefs with identity layouts, the copy is lowered to the llvm /// `memcpy` intrinsic. For non-identity layouts, the copy is lowered to a call /// to the generic `MemrefCopyFn`. -struct MemRefCopyOpLowering : public ConvertOpToLLVMPattern { - using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern; +class MemRefCopyOpLowering : public ConvertOpToLLVMPattern { + SymbolTableCollection *symbolTables = nullptr; + +public: + explicit MemRefCopyOpLowering(const LLVMTypeConverter &typeConverter, + SymbolTableCollection *symbolTables = nullptr, + PatternBenefit benefit = 1) + : ConvertOpToLLVMPattern(typeConverter, benefit), + symbolTables(symbolTables) {} LogicalResult lowerToMemCopyIntrinsic(memref::CopyOp op, OpAdaptor adaptor, @@ -1093,7 +1153,7 @@ struct MemRefCopyOpLowering : public ConvertOpToLLVMPattern { auto elemSize = getSizeInBytes(loc, srcType.getElementType(), rewriter); auto copyFn = LLVM::lookupOrCreateMemRefCopyFn( rewriter, op->getParentOfType(), getIndexType(), - sourcePtr.getType()); + sourcePtr.getType(), symbolTables); if (failed(copyFn)) return failure(); rewriter.create(loc, copyFn.value(), @@ -1928,7 +1988,8 @@ class ExtractStridedMetadataOpLowering } // namespace void mlir::populateFinalizeMemRefToLLVMConversionPatterns( - const LLVMTypeConverter &converter, RewritePatternSet &patterns) { + const LLVMTypeConverter &converter, RewritePatternSet &patterns, + SymbolTableCollection *symbolTables) { // clang-format off patterns.add< AllocaOpLowering, @@ -1939,11 +2000,9 @@ void mlir::populateFinalizeMemRefToLLVMConversionPatterns( DimOpLowering, ExtractStridedMetadataOpLowering, GenericAtomicRMWOpLowering, - GlobalMemrefOpLowering, GetGlobalMemrefOpLowering, LoadOpLowering, MemRefCastOpLowering, - MemRefCopyOpLowering, MemorySpaceCastOpLowering, MemRefReinterpretCastOpLowering, MemRefReshapeOpLowering, @@ -1956,11 +2015,14 @@ void mlir::populateFinalizeMemRefToLLVMConversionPatterns( TransposeOpLowering, ViewOpLowering>(converter); // clang-format on + patterns.add(converter, + symbolTables); auto allocLowering = converter.getOptions().allocLowering; if (allocLowering == LowerToLLVMOptions::AllocLowering::AlignedAlloc) - patterns.add(converter); + patterns.add(converter, + symbolTables); else if (allocLowering == LowerToLLVMOptions::AllocLowering::Malloc) - patterns.add(converter); + patterns.add(converter, symbolTables); } namespace { @@ -1987,7 +2049,9 @@ struct FinalizeMemRefToLLVMConversionPass LLVMTypeConverter typeConverter(&getContext(), options, &dataLayoutAnalysis); RewritePatternSet patterns(&getContext()); - populateFinalizeMemRefToLLVMConversionPatterns(typeConverter, patterns); + SymbolTableCollection symbolTables; + populateFinalizeMemRefToLLVMConversionPatterns(typeConverter, patterns, + &symbolTables); LLVMConversionTarget target(getContext()); target.addLegalOp(); if (failed(applyPartialConversion(op, target, std::move(patterns)))) diff --git a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp index f725993635672..d53d11f87efe8 100644 --- a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp +++ b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp @@ -1595,8 +1595,14 @@ class VectorCreateMaskOpConversion }; class VectorPrintOpConversion : public ConvertOpToLLVMPattern { + SymbolTableCollection *symbolTables = nullptr; + public: - using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern; + explicit VectorPrintOpConversion( + const LLVMTypeConverter &typeConverter, + SymbolTableCollection *symbolTables = nullptr) + : ConvertOpToLLVMPattern(typeConverter), + symbolTables(symbolTables) {} // Lowering implementation that relies on a small runtime support library, // which only needs to provide a few printing methods (single value for all @@ -1643,13 +1649,17 @@ class VectorPrintOpConversion : public ConvertOpToLLVMPattern { FailureOr op = [&]() { switch (punct) { case PrintPunctuation::Close: - return LLVM::lookupOrCreatePrintCloseFn(rewriter, parent); + return LLVM::lookupOrCreatePrintCloseFn(rewriter, parent, + symbolTables); case PrintPunctuation::Open: - return LLVM::lookupOrCreatePrintOpenFn(rewriter, parent); + return LLVM::lookupOrCreatePrintOpenFn(rewriter, parent, + symbolTables); case PrintPunctuation::Comma: - return LLVM::lookupOrCreatePrintCommaFn(rewriter, parent); + return LLVM::lookupOrCreatePrintCommaFn(rewriter, parent, + symbolTables); case PrintPunctuation::NewLine: - return LLVM::lookupOrCreatePrintNewlineFn(rewriter, parent); + return LLVM::lookupOrCreatePrintNewlineFn(rewriter, parent, + symbolTables); default: llvm_unreachable("unexpected punctuation"); } @@ -1683,17 +1693,17 @@ class VectorPrintOpConversion : public ConvertOpToLLVMPattern { PrintConversion conversion = PrintConversion::None; FailureOr printer; if (printType.isF32()) { - printer = LLVM::lookupOrCreatePrintF32Fn(rewriter, parent); + printer = LLVM::lookupOrCreatePrintF32Fn(rewriter, parent, symbolTables); } else if (printType.isF64()) { - printer = LLVM::lookupOrCreatePrintF64Fn(rewriter, parent); + printer = LLVM::lookupOrCreatePrintF64Fn(rewriter, parent, symbolTables); } else if (printType.isF16()) { conversion = PrintConversion::Bitcast16; // bits! - printer = LLVM::lookupOrCreatePrintF16Fn(rewriter, parent); + printer = LLVM::lookupOrCreatePrintF16Fn(rewriter, parent, symbolTables); } else if (printType.isBF16()) { conversion = PrintConversion::Bitcast16; // bits! - printer = LLVM::lookupOrCreatePrintBF16Fn(rewriter, parent); + printer = LLVM::lookupOrCreatePrintBF16Fn(rewriter, parent, symbolTables); } else if (printType.isIndex()) { - printer = LLVM::lookupOrCreatePrintU64Fn(rewriter, parent); + printer = LLVM::lookupOrCreatePrintU64Fn(rewriter, parent, symbolTables); } else if (auto intTy = dyn_cast(printType)) { // Integers need a zero or sign extension on the operand // (depending on the source type) as well as a signed or @@ -1703,7 +1713,8 @@ class VectorPrintOpConversion : public ConvertOpToLLVMPattern { if (width <= 64) { if (width < 64) conversion = PrintConversion::ZeroExt64; - printer = LLVM::lookupOrCreatePrintU64Fn(rewriter, parent); + printer = + LLVM::lookupOrCreatePrintU64Fn(rewriter, parent, symbolTables); } else { return failure(); } @@ -1716,7 +1727,8 @@ class VectorPrintOpConversion : public ConvertOpToLLVMPattern { conversion = PrintConversion::ZeroExt64; else if (width < 64) conversion = PrintConversion::SignExt64; - printer = LLVM::lookupOrCreatePrintI64Fn(rewriter, parent); + printer = + LLVM::lookupOrCreatePrintI64Fn(rewriter, parent, symbolTables); } else { return failure(); } diff --git a/mlir/lib/Dialect/LLVMIR/IR/FunctionCallUtils.cpp b/mlir/lib/Dialect/LLVMIR/IR/FunctionCallUtils.cpp index 1b4a8f496d3d0..89f765dacda35 100644 --- a/mlir/lib/Dialect/LLVMIR/IR/FunctionCallUtils.cpp +++ b/mlir/lib/Dialect/LLVMIR/IR/FunctionCallUtils.cpp @@ -44,15 +44,31 @@ static constexpr llvm::StringRef kGenericAlignedAlloc = static constexpr llvm::StringRef kGenericFree = "_mlir_memref_to_llvm_free"; static constexpr llvm::StringRef kMemRefCopy = "memrefCopy"; +namespace { +/// Search for an LLVMFuncOp with a given name within an operation with the +/// SymbolTable trait. An optional collection of cached symbol tables can be +/// given to avoid a linear scan of the symbol table operation. +LLVM::LLVMFuncOp lookupFuncOp(StringRef name, Operation *symbolTableOp, + SymbolTableCollection *symbolTables = nullptr) { + if (symbolTables) { + return symbolTables->lookupSymbolIn( + symbolTableOp, StringAttr::get(symbolTableOp->getContext(), name)); + } + + return llvm::dyn_cast_or_null( + SymbolTable::lookupSymbolIn(symbolTableOp, name)); +} +} // namespace + /// Generic print function lookupOrCreate helper. FailureOr mlir::LLVM::lookupOrCreateFn(OpBuilder &b, Operation *moduleOp, StringRef name, ArrayRef paramTypes, Type resultType, - bool isVarArg, bool isReserved) { + bool isVarArg, bool isReserved, + SymbolTableCollection *symbolTables) { assert(moduleOp->hasTrait() && "expected SymbolTable operation"); - auto func = llvm::dyn_cast_or_null( - SymbolTable::lookupSymbolIn(moduleOp, name)); + auto func = lookupFuncOp(name, moduleOp, symbolTables); auto funcT = LLVMFunctionType::get(resultType, paramTypes, isVarArg); // Assert the signature of the found function is same as expected if (func) { @@ -73,60 +89,75 @@ mlir::LLVM::lookupOrCreateFn(OpBuilder &b, Operation *moduleOp, StringRef name, OpBuilder::InsertionGuard g(b); assert(!moduleOp->getRegion(0).empty() && "expected non-empty region"); b.setInsertionPointToStart(&moduleOp->getRegion(0).front()); - return b.create( + auto funcOp = b.create( moduleOp->getLoc(), name, LLVM::LLVMFunctionType::get(resultType, paramTypes, isVarArg)); + + if (symbolTables) { + SymbolTable &symbolTable = symbolTables->getSymbolTable(moduleOp); + symbolTable.insert(funcOp, moduleOp->getRegion(0).front().begin()); + } + + return funcOp; } static FailureOr lookupOrCreateReservedFn(OpBuilder &b, Operation *moduleOp, StringRef name, - ArrayRef paramTypes, Type resultType) { + ArrayRef paramTypes, Type resultType, + SymbolTableCollection *symbolTables) { return lookupOrCreateFn(b, moduleOp, name, paramTypes, resultType, - /*isVarArg=*/false, /*isReserved=*/true); + /*isVarArg=*/false, /*isReserved=*/true, + symbolTables); } FailureOr -mlir::LLVM::lookupOrCreatePrintI64Fn(OpBuilder &b, Operation *moduleOp) { +mlir::LLVM::lookupOrCreatePrintI64Fn(OpBuilder &b, Operation *moduleOp, + SymbolTableCollection *symbolTables) { return lookupOrCreateReservedFn( b, moduleOp, kPrintI64, IntegerType::get(moduleOp->getContext(), 64), - LLVM::LLVMVoidType::get(moduleOp->getContext())); + LLVM::LLVMVoidType::get(moduleOp->getContext()), symbolTables); } FailureOr -mlir::LLVM::lookupOrCreatePrintU64Fn(OpBuilder &b, Operation *moduleOp) { +mlir::LLVM::lookupOrCreatePrintU64Fn(OpBuilder &b, Operation *moduleOp, + SymbolTableCollection *symbolTables) { return lookupOrCreateReservedFn( b, moduleOp, kPrintU64, IntegerType::get(moduleOp->getContext(), 64), - LLVM::LLVMVoidType::get(moduleOp->getContext())); + LLVM::LLVMVoidType::get(moduleOp->getContext()), symbolTables); } FailureOr -mlir::LLVM::lookupOrCreatePrintF16Fn(OpBuilder &b, Operation *moduleOp) { +mlir::LLVM::lookupOrCreatePrintF16Fn(OpBuilder &b, Operation *moduleOp, + SymbolTableCollection *symbolTables) { return lookupOrCreateReservedFn( b, moduleOp, kPrintF16, IntegerType::get(moduleOp->getContext(), 16), // bits! - LLVM::LLVMVoidType::get(moduleOp->getContext())); + LLVM::LLVMVoidType::get(moduleOp->getContext()), symbolTables); } FailureOr -mlir::LLVM::lookupOrCreatePrintBF16Fn(OpBuilder &b, Operation *moduleOp) { +mlir::LLVM::lookupOrCreatePrintBF16Fn(OpBuilder &b, Operation *moduleOp, + SymbolTableCollection *symbolTables) { return lookupOrCreateReservedFn( b, moduleOp, kPrintBF16, IntegerType::get(moduleOp->getContext(), 16), // bits! - LLVM::LLVMVoidType::get(moduleOp->getContext())); + LLVM::LLVMVoidType::get(moduleOp->getContext()), symbolTables); } FailureOr -mlir::LLVM::lookupOrCreatePrintF32Fn(OpBuilder &b, Operation *moduleOp) { +mlir::LLVM::lookupOrCreatePrintF32Fn(OpBuilder &b, Operation *moduleOp, + SymbolTableCollection *symbolTables) { return lookupOrCreateReservedFn( b, moduleOp, kPrintF32, Float32Type::get(moduleOp->getContext()), - LLVM::LLVMVoidType::get(moduleOp->getContext())); + LLVM::LLVMVoidType::get(moduleOp->getContext()), symbolTables); } FailureOr -mlir::LLVM::lookupOrCreatePrintF64Fn(OpBuilder &b, Operation *moduleOp) { +mlir::LLVM::lookupOrCreatePrintF64Fn(OpBuilder &b, Operation *moduleOp, + SymbolTableCollection *symbolTables) { return lookupOrCreateReservedFn( b, moduleOp, kPrintF64, Float64Type::get(moduleOp->getContext()), - LLVM::LLVMVoidType::get(moduleOp->getContext())); + LLVM::LLVMVoidType::get(moduleOp->getContext()), symbolTables); } static LLVM::LLVMPointerType getCharPtr(MLIRContext *context) { @@ -140,90 +171,102 @@ static LLVM::LLVMPointerType getVoidPtr(MLIRContext *context) { FailureOr mlir::LLVM::lookupOrCreatePrintStringFn( OpBuilder &b, Operation *moduleOp, - std::optional runtimeFunctionName) { + std::optional runtimeFunctionName, + SymbolTableCollection *symbolTables) { return lookupOrCreateReservedFn( b, moduleOp, runtimeFunctionName.value_or(kPrintString), getCharPtr(moduleOp->getContext()), - LLVM::LLVMVoidType::get(moduleOp->getContext())); + LLVM::LLVMVoidType::get(moduleOp->getContext()), symbolTables); } FailureOr -mlir::LLVM::lookupOrCreatePrintOpenFn(OpBuilder &b, Operation *moduleOp) { +mlir::LLVM::lookupOrCreatePrintOpenFn(OpBuilder &b, Operation *moduleOp, + SymbolTableCollection *symbolTables) { return lookupOrCreateReservedFn( b, moduleOp, kPrintOpen, {}, - LLVM::LLVMVoidType::get(moduleOp->getContext())); + LLVM::LLVMVoidType::get(moduleOp->getContext()), symbolTables); } FailureOr -mlir::LLVM::lookupOrCreatePrintCloseFn(OpBuilder &b, Operation *moduleOp) { +mlir::LLVM::lookupOrCreatePrintCloseFn(OpBuilder &b, Operation *moduleOp, + SymbolTableCollection *symbolTables) { return lookupOrCreateReservedFn( b, moduleOp, kPrintClose, {}, - LLVM::LLVMVoidType::get(moduleOp->getContext())); + LLVM::LLVMVoidType::get(moduleOp->getContext()), symbolTables); } FailureOr -mlir::LLVM::lookupOrCreatePrintCommaFn(OpBuilder &b, Operation *moduleOp) { +mlir::LLVM::lookupOrCreatePrintCommaFn(OpBuilder &b, Operation *moduleOp, + SymbolTableCollection *symbolTables) { return lookupOrCreateReservedFn( b, moduleOp, kPrintComma, {}, - LLVM::LLVMVoidType::get(moduleOp->getContext())); + LLVM::LLVMVoidType::get(moduleOp->getContext()), symbolTables); } FailureOr -mlir::LLVM::lookupOrCreatePrintNewlineFn(OpBuilder &b, Operation *moduleOp) { +mlir::LLVM::lookupOrCreatePrintNewlineFn(OpBuilder &b, Operation *moduleOp, + SymbolTableCollection *symbolTables) { return lookupOrCreateReservedFn( b, moduleOp, kPrintNewline, {}, - LLVM::LLVMVoidType::get(moduleOp->getContext())); + LLVM::LLVMVoidType::get(moduleOp->getContext()), symbolTables); } FailureOr mlir::LLVM::lookupOrCreateMallocFn(OpBuilder &b, Operation *moduleOp, - Type indexType) { + Type indexType, + SymbolTableCollection *symbolTables) { return lookupOrCreateReservedFn(b, moduleOp, kMalloc, indexType, - getVoidPtr(moduleOp->getContext())); + getVoidPtr(moduleOp->getContext()), + symbolTables); } FailureOr mlir::LLVM::lookupOrCreateAlignedAllocFn(OpBuilder &b, Operation *moduleOp, - Type indexType) { - return lookupOrCreateReservedFn(b, moduleOp, kAlignedAlloc, - {indexType, indexType}, - getVoidPtr(moduleOp->getContext())); + Type indexType, + SymbolTableCollection *symbolTables) { + return lookupOrCreateReservedFn( + b, moduleOp, kAlignedAlloc, {indexType, indexType}, + getVoidPtr(moduleOp->getContext()), symbolTables); } FailureOr -mlir::LLVM::lookupOrCreateFreeFn(OpBuilder &b, Operation *moduleOp) { +mlir::LLVM::lookupOrCreateFreeFn(OpBuilder &b, Operation *moduleOp, + SymbolTableCollection *symbolTables) { return lookupOrCreateReservedFn( b, moduleOp, kFree, getVoidPtr(moduleOp->getContext()), - LLVM::LLVMVoidType::get(moduleOp->getContext())); + LLVM::LLVMVoidType::get(moduleOp->getContext()), symbolTables); } FailureOr mlir::LLVM::lookupOrCreateGenericAllocFn(OpBuilder &b, Operation *moduleOp, - Type indexType) { + Type indexType, + SymbolTableCollection *symbolTables) { return lookupOrCreateReservedFn(b, moduleOp, kGenericAlloc, indexType, - getVoidPtr(moduleOp->getContext())); + getVoidPtr(moduleOp->getContext()), + symbolTables); } FailureOr mlir::LLVM::lookupOrCreateGenericAlignedAllocFn( - OpBuilder &b, Operation *moduleOp, Type indexType) { - return lookupOrCreateReservedFn(b, moduleOp, kGenericAlignedAlloc, - {indexType, indexType}, - getVoidPtr(moduleOp->getContext())); + OpBuilder &b, Operation *moduleOp, Type indexType, + SymbolTableCollection *symbolTables) { + return lookupOrCreateReservedFn( + b, moduleOp, kGenericAlignedAlloc, {indexType, indexType}, + getVoidPtr(moduleOp->getContext()), symbolTables); } FailureOr -mlir::LLVM::lookupOrCreateGenericFreeFn(OpBuilder &b, Operation *moduleOp) { +mlir::LLVM::lookupOrCreateGenericFreeFn(OpBuilder &b, Operation *moduleOp, + SymbolTableCollection *symbolTables) { return lookupOrCreateReservedFn( b, moduleOp, kGenericFree, getVoidPtr(moduleOp->getContext()), - LLVM::LLVMVoidType::get(moduleOp->getContext())); + LLVM::LLVMVoidType::get(moduleOp->getContext()), symbolTables); } -FailureOr -mlir::LLVM::lookupOrCreateMemRefCopyFn(OpBuilder &b, Operation *moduleOp, - Type indexType, - Type unrankedDescriptorType) { +FailureOr mlir::LLVM::lookupOrCreateMemRefCopyFn( + OpBuilder &b, Operation *moduleOp, Type indexType, + Type unrankedDescriptorType, SymbolTableCollection *symbolTables) { return lookupOrCreateReservedFn( b, moduleOp, kMemRefCopy, ArrayRef{indexType, unrankedDescriptorType, unrankedDescriptorType}, - LLVM::LLVMVoidType::get(moduleOp->getContext())); + LLVM::LLVMVoidType::get(moduleOp->getContext()), symbolTables); }