From 3b892e6a0618cd3092a16aa2e56bb80594c9c4ba Mon Sep 17 00:00:00 2001 From: Luohao Wang Date: Sat, 18 Jan 2025 01:43:09 +0800 Subject: [PATCH 1/9] [mlir] Add assertion on reserved function's type --- .../mlir/Dialect/LLVMIR/FunctionCallUtils.h | 3 +- .../Dialect/LLVMIR/IR/FunctionCallUtils.cpp | 56 ++++++++++++------- 2 files changed, 38 insertions(+), 21 deletions(-) diff --git a/mlir/include/mlir/Dialect/LLVMIR/FunctionCallUtils.h b/mlir/include/mlir/Dialect/LLVMIR/FunctionCallUtils.h index 852490cf7428f..3095c83b90db9 100644 --- a/mlir/include/mlir/Dialect/LLVMIR/FunctionCallUtils.h +++ b/mlir/include/mlir/Dialect/LLVMIR/FunctionCallUtils.h @@ -64,7 +64,8 @@ LLVM::LLVMFuncOp lookupOrCreateMemRefCopyFn(Operation *moduleOp, Type indexType, /// Create a FuncOp with signature `resultType`(`paramTypes`)` and name `name`. LLVM::LLVMFuncOp lookupOrCreateFn(Operation *moduleOp, StringRef name, ArrayRef paramTypes = {}, - Type resultType = {}, bool isVarArg = false); + Type resultType = {}, bool isVarArg = false, + bool isReserved = false); } // namespace LLVM } // namespace mlir diff --git a/mlir/lib/Dialect/LLVMIR/IR/FunctionCallUtils.cpp b/mlir/lib/Dialect/LLVMIR/IR/FunctionCallUtils.cpp index 88421a16ccf9f..ecc31df40ea52 100644 --- a/mlir/lib/Dialect/LLVMIR/IR/FunctionCallUtils.cpp +++ b/mlir/lib/Dialect/LLVMIR/IR/FunctionCallUtils.cpp @@ -48,13 +48,29 @@ static constexpr llvm::StringRef kMemRefCopy = "memrefCopy"; LLVM::LLVMFuncOp mlir::LLVM::lookupOrCreateFn(Operation *moduleOp, StringRef name, ArrayRef paramTypes, - Type resultType, bool isVarArg) { + Type resultType, bool isVarArg, bool isReserved) { assert(moduleOp->hasTrait() && "expected SymbolTable operation"); auto func = llvm::dyn_cast_or_null( SymbolTable::lookupSymbolIn(moduleOp, name)); - if (func) + auto funcT = LLVMFunctionType::get(resultType, paramTypes, isVarArg); + // Assert the signature of the found function is same as expected + if (func) { + if (funcT != func.getFunctionType()) { + if (isReserved) { + func.emitError("redefinition of reserved function '" + name + "' of different type ") + .append(func.getFunctionType()) + .append(" is prohibited"); + exit(0); + } else { + func.emitError("redefinition of function '" + name + "' of different type ") + .append(funcT) + .append(" is prohibited"); + exit(0); + } + } return func; + } OpBuilder b(moduleOp->getRegion(0)); return b.create( moduleOp->getLoc(), name, @@ -64,37 +80,37 @@ LLVM::LLVMFuncOp mlir::LLVM::lookupOrCreateFn(Operation *moduleOp, LLVM::LLVMFuncOp mlir::LLVM::lookupOrCreatePrintI64Fn(Operation *moduleOp) { return lookupOrCreateFn(moduleOp, kPrintI64, IntegerType::get(moduleOp->getContext(), 64), - LLVM::LLVMVoidType::get(moduleOp->getContext())); + LLVM::LLVMVoidType::get(moduleOp->getContext()), false, true); } LLVM::LLVMFuncOp mlir::LLVM::lookupOrCreatePrintU64Fn(Operation *moduleOp) { return lookupOrCreateFn(moduleOp, kPrintU64, IntegerType::get(moduleOp->getContext(), 64), - LLVM::LLVMVoidType::get(moduleOp->getContext())); + LLVM::LLVMVoidType::get(moduleOp->getContext()), false, true); } LLVM::LLVMFuncOp mlir::LLVM::lookupOrCreatePrintF16Fn(Operation *moduleOp) { return lookupOrCreateFn(moduleOp, kPrintF16, IntegerType::get(moduleOp->getContext(), 16), // bits! - LLVM::LLVMVoidType::get(moduleOp->getContext())); + LLVM::LLVMVoidType::get(moduleOp->getContext()), false, true); } LLVM::LLVMFuncOp mlir::LLVM::lookupOrCreatePrintBF16Fn(Operation *moduleOp) { return lookupOrCreateFn(moduleOp, kPrintBF16, IntegerType::get(moduleOp->getContext(), 16), // bits! - LLVM::LLVMVoidType::get(moduleOp->getContext())); + LLVM::LLVMVoidType::get(moduleOp->getContext()), false, true); } LLVM::LLVMFuncOp mlir::LLVM::lookupOrCreatePrintF32Fn(Operation *moduleOp) { return lookupOrCreateFn(moduleOp, kPrintF32, Float32Type::get(moduleOp->getContext()), - LLVM::LLVMVoidType::get(moduleOp->getContext())); + LLVM::LLVMVoidType::get(moduleOp->getContext()), false, true); } LLVM::LLVMFuncOp mlir::LLVM::lookupOrCreatePrintF64Fn(Operation *moduleOp) { return lookupOrCreateFn(moduleOp, kPrintF64, Float64Type::get(moduleOp->getContext()), - LLVM::LLVMVoidType::get(moduleOp->getContext())); + LLVM::LLVMVoidType::get(moduleOp->getContext()), false, true); } static LLVM::LLVMPointerType getCharPtr(MLIRContext *context) { @@ -110,51 +126,51 @@ LLVM::LLVMFuncOp mlir::LLVM::lookupOrCreatePrintStringFn( Operation *moduleOp, std::optional runtimeFunctionName) { return lookupOrCreateFn(moduleOp, runtimeFunctionName.value_or(kPrintString), getCharPtr(moduleOp->getContext()), - LLVM::LLVMVoidType::get(moduleOp->getContext())); + LLVM::LLVMVoidType::get(moduleOp->getContext()), false, true); } LLVM::LLVMFuncOp mlir::LLVM::lookupOrCreatePrintOpenFn(Operation *moduleOp) { return lookupOrCreateFn(moduleOp, kPrintOpen, {}, - LLVM::LLVMVoidType::get(moduleOp->getContext())); + LLVM::LLVMVoidType::get(moduleOp->getContext()), false, true); } LLVM::LLVMFuncOp mlir::LLVM::lookupOrCreatePrintCloseFn(Operation *moduleOp) { return lookupOrCreateFn(moduleOp, kPrintClose, {}, - LLVM::LLVMVoidType::get(moduleOp->getContext())); + LLVM::LLVMVoidType::get(moduleOp->getContext()), false, true); } LLVM::LLVMFuncOp mlir::LLVM::lookupOrCreatePrintCommaFn(Operation *moduleOp) { return lookupOrCreateFn(moduleOp, kPrintComma, {}, - LLVM::LLVMVoidType::get(moduleOp->getContext())); + LLVM::LLVMVoidType::get(moduleOp->getContext()), false, true); } LLVM::LLVMFuncOp mlir::LLVM::lookupOrCreatePrintNewlineFn(Operation *moduleOp) { return lookupOrCreateFn(moduleOp, kPrintNewline, {}, - LLVM::LLVMVoidType::get(moduleOp->getContext())); + LLVM::LLVMVoidType::get(moduleOp->getContext()), false, true); } LLVM::LLVMFuncOp mlir::LLVM::lookupOrCreateMallocFn(Operation *moduleOp, Type indexType) { return LLVM::lookupOrCreateFn(moduleOp, kMalloc, indexType, - getVoidPtr(moduleOp->getContext())); + getVoidPtr(moduleOp->getContext()), false, true); } LLVM::LLVMFuncOp mlir::LLVM::lookupOrCreateAlignedAllocFn(Operation *moduleOp, Type indexType) { return LLVM::lookupOrCreateFn(moduleOp, kAlignedAlloc, {indexType, indexType}, - getVoidPtr(moduleOp->getContext())); + getVoidPtr(moduleOp->getContext()), false, true); } LLVM::LLVMFuncOp mlir::LLVM::lookupOrCreateFreeFn(Operation *moduleOp) { return LLVM::lookupOrCreateFn( moduleOp, kFree, getVoidPtr(moduleOp->getContext()), - LLVM::LLVMVoidType::get(moduleOp->getContext())); + LLVM::LLVMVoidType::get(moduleOp->getContext()), false, true); } LLVM::LLVMFuncOp mlir::LLVM::lookupOrCreateGenericAllocFn(Operation *moduleOp, Type indexType) { return LLVM::lookupOrCreateFn(moduleOp, kGenericAlloc, indexType, - getVoidPtr(moduleOp->getContext())); + getVoidPtr(moduleOp->getContext()), false, true); } LLVM::LLVMFuncOp @@ -162,13 +178,13 @@ mlir::LLVM::lookupOrCreateGenericAlignedAllocFn(Operation *moduleOp, Type indexType) { return LLVM::lookupOrCreateFn(moduleOp, kGenericAlignedAlloc, {indexType, indexType}, - getVoidPtr(moduleOp->getContext())); + getVoidPtr(moduleOp->getContext()), false, true); } LLVM::LLVMFuncOp mlir::LLVM::lookupOrCreateGenericFreeFn(Operation *moduleOp) { return LLVM::lookupOrCreateFn( moduleOp, kGenericFree, getVoidPtr(moduleOp->getContext()), - LLVM::LLVMVoidType::get(moduleOp->getContext())); + LLVM::LLVMVoidType::get(moduleOp->getContext()), false, true); } LLVM::LLVMFuncOp @@ -177,5 +193,5 @@ mlir::LLVM::lookupOrCreateMemRefCopyFn(Operation *moduleOp, Type indexType, return LLVM::lookupOrCreateFn( moduleOp, kMemRefCopy, ArrayRef{indexType, unrankedDescriptorType, unrankedDescriptorType}, - LLVM::LLVMVoidType::get(moduleOp->getContext())); + LLVM::LLVMVoidType::get(moduleOp->getContext()), false, true); } From bdb69fc838254c35ea55542c39dbf9392cd6d4b2 Mon Sep 17 00:00:00 2001 From: Luohao Wang Date: Sat, 18 Jan 2025 01:43:28 +0800 Subject: [PATCH 2/9] [mlir] Add test --- mlir/test/Conversion/MemRefToLLVM/issue-120950.mlir | 11 +++++++++++ 1 file changed, 11 insertions(+) create mode 100644 mlir/test/Conversion/MemRefToLLVM/issue-120950.mlir diff --git a/mlir/test/Conversion/MemRefToLLVM/issue-120950.mlir b/mlir/test/Conversion/MemRefToLLVM/issue-120950.mlir new file mode 100644 index 0000000000000..f744e4f7635ea --- /dev/null +++ b/mlir/test/Conversion/MemRefToLLVM/issue-120950.mlir @@ -0,0 +1,11 @@ +// RUN: mlir-opt %s -finalize-memref-to-llvm 2>&1 | FileCheck %s + +#map = affine_map<(d0) -> (d0 + 1)> +module { + // CHECK: redefinition of reserved function 'malloc' of different type '!llvm.func' is prohibited + llvm.func @malloc(i64) + func.func @issue_120950() { + %alloc = memref.alloc() : memref<1024x64xf32, 1> + llvm.return + } +} From d26a77d431b4b18ed5b320185a55f0984c6f3aea Mon Sep 17 00:00:00 2001 From: Luohao Wang Date: Sat, 18 Jan 2025 02:12:39 +0800 Subject: [PATCH 3/9] [mlir] Reformat code --- .../Dialect/LLVMIR/IR/FunctionCallUtils.cpp | 77 +++++++++++-------- 1 file changed, 45 insertions(+), 32 deletions(-) diff --git a/mlir/lib/Dialect/LLVMIR/IR/FunctionCallUtils.cpp b/mlir/lib/Dialect/LLVMIR/IR/FunctionCallUtils.cpp index ecc31df40ea52..757a1acf3626f 100644 --- a/mlir/lib/Dialect/LLVMIR/IR/FunctionCallUtils.cpp +++ b/mlir/lib/Dialect/LLVMIR/IR/FunctionCallUtils.cpp @@ -48,7 +48,8 @@ static constexpr llvm::StringRef kMemRefCopy = "memrefCopy"; LLVM::LLVMFuncOp mlir::LLVM::lookupOrCreateFn(Operation *moduleOp, StringRef name, ArrayRef paramTypes, - Type resultType, bool isVarArg, bool isReserved) { + Type resultType, bool isVarArg, + bool isReserved) { assert(moduleOp->hasTrait() && "expected SymbolTable operation"); auto func = llvm::dyn_cast_or_null( @@ -58,14 +59,16 @@ LLVM::LLVMFuncOp mlir::LLVM::lookupOrCreateFn(Operation *moduleOp, if (func) { if (funcT != func.getFunctionType()) { if (isReserved) { - func.emitError("redefinition of reserved function '" + name + "' of different type ") - .append(func.getFunctionType()) - .append(" is prohibited"); + func.emitError("redefinition of reserved function '" + name + + "' of different type ") + .append(func.getFunctionType()) + .append(" is prohibited"); exit(0); } else { - func.emitError("redefinition of function '" + name + "' of different type ") - .append(funcT) - .append(" is prohibited"); + func.emitError("redefinition of function '" + name + + "' of different type ") + .append(funcT) + .append(" is prohibited"); exit(0); } } @@ -78,39 +81,41 @@ LLVM::LLVMFuncOp mlir::LLVM::lookupOrCreateFn(Operation *moduleOp, } LLVM::LLVMFuncOp mlir::LLVM::lookupOrCreatePrintI64Fn(Operation *moduleOp) { - return lookupOrCreateFn(moduleOp, kPrintI64, - IntegerType::get(moduleOp->getContext(), 64), - LLVM::LLVMVoidType::get(moduleOp->getContext()), false, true); + return lookupOrCreateFn( + moduleOp, kPrintI64, IntegerType::get(moduleOp->getContext(), 64), + LLVM::LLVMVoidType::get(moduleOp->getContext()), false, true); } LLVM::LLVMFuncOp mlir::LLVM::lookupOrCreatePrintU64Fn(Operation *moduleOp) { - return lookupOrCreateFn(moduleOp, kPrintU64, - IntegerType::get(moduleOp->getContext(), 64), - LLVM::LLVMVoidType::get(moduleOp->getContext()), false, true); + return lookupOrCreateFn( + moduleOp, kPrintU64, IntegerType::get(moduleOp->getContext(), 64), + LLVM::LLVMVoidType::get(moduleOp->getContext()), false, true); } LLVM::LLVMFuncOp mlir::LLVM::lookupOrCreatePrintF16Fn(Operation *moduleOp) { return lookupOrCreateFn(moduleOp, kPrintF16, IntegerType::get(moduleOp->getContext(), 16), // bits! - LLVM::LLVMVoidType::get(moduleOp->getContext()), false, true); + LLVM::LLVMVoidType::get(moduleOp->getContext()), + false, true); } LLVM::LLVMFuncOp mlir::LLVM::lookupOrCreatePrintBF16Fn(Operation *moduleOp) { return lookupOrCreateFn(moduleOp, kPrintBF16, IntegerType::get(moduleOp->getContext(), 16), // bits! - LLVM::LLVMVoidType::get(moduleOp->getContext()), false, true); + LLVM::LLVMVoidType::get(moduleOp->getContext()), + false, true); } LLVM::LLVMFuncOp mlir::LLVM::lookupOrCreatePrintF32Fn(Operation *moduleOp) { - return lookupOrCreateFn(moduleOp, kPrintF32, - Float32Type::get(moduleOp->getContext()), - LLVM::LLVMVoidType::get(moduleOp->getContext()), false, true); + return lookupOrCreateFn( + moduleOp, kPrintF32, Float32Type::get(moduleOp->getContext()), + LLVM::LLVMVoidType::get(moduleOp->getContext()), false, true); } LLVM::LLVMFuncOp mlir::LLVM::lookupOrCreatePrintF64Fn(Operation *moduleOp) { - return lookupOrCreateFn(moduleOp, kPrintF64, - Float64Type::get(moduleOp->getContext()), - LLVM::LLVMVoidType::get(moduleOp->getContext()), false, true); + return lookupOrCreateFn( + moduleOp, kPrintF64, Float64Type::get(moduleOp->getContext()), + LLVM::LLVMVoidType::get(moduleOp->getContext()), false, true); } static LLVM::LLVMPointerType getCharPtr(MLIRContext *context) { @@ -126,39 +131,46 @@ LLVM::LLVMFuncOp mlir::LLVM::lookupOrCreatePrintStringFn( Operation *moduleOp, std::optional runtimeFunctionName) { return lookupOrCreateFn(moduleOp, runtimeFunctionName.value_or(kPrintString), getCharPtr(moduleOp->getContext()), - LLVM::LLVMVoidType::get(moduleOp->getContext()), false, true); + LLVM::LLVMVoidType::get(moduleOp->getContext()), + false, true); } LLVM::LLVMFuncOp mlir::LLVM::lookupOrCreatePrintOpenFn(Operation *moduleOp) { return lookupOrCreateFn(moduleOp, kPrintOpen, {}, - LLVM::LLVMVoidType::get(moduleOp->getContext()), false, true); + LLVM::LLVMVoidType::get(moduleOp->getContext()), + false, true); } LLVM::LLVMFuncOp mlir::LLVM::lookupOrCreatePrintCloseFn(Operation *moduleOp) { return lookupOrCreateFn(moduleOp, kPrintClose, {}, - LLVM::LLVMVoidType::get(moduleOp->getContext()), false, true); + LLVM::LLVMVoidType::get(moduleOp->getContext()), + false, true); } LLVM::LLVMFuncOp mlir::LLVM::lookupOrCreatePrintCommaFn(Operation *moduleOp) { return lookupOrCreateFn(moduleOp, kPrintComma, {}, - LLVM::LLVMVoidType::get(moduleOp->getContext()), false, true); + LLVM::LLVMVoidType::get(moduleOp->getContext()), + false, true); } LLVM::LLVMFuncOp mlir::LLVM::lookupOrCreatePrintNewlineFn(Operation *moduleOp) { return lookupOrCreateFn(moduleOp, kPrintNewline, {}, - LLVM::LLVMVoidType::get(moduleOp->getContext()), false, true); + LLVM::LLVMVoidType::get(moduleOp->getContext()), + false, true); } LLVM::LLVMFuncOp mlir::LLVM::lookupOrCreateMallocFn(Operation *moduleOp, Type indexType) { return LLVM::lookupOrCreateFn(moduleOp, kMalloc, indexType, - getVoidPtr(moduleOp->getContext()), false, true); + getVoidPtr(moduleOp->getContext()), false, + true); } LLVM::LLVMFuncOp mlir::LLVM::lookupOrCreateAlignedAllocFn(Operation *moduleOp, Type indexType) { return LLVM::lookupOrCreateFn(moduleOp, kAlignedAlloc, {indexType, indexType}, - getVoidPtr(moduleOp->getContext()), false, true); + getVoidPtr(moduleOp->getContext()), false, + true); } LLVM::LLVMFuncOp mlir::LLVM::lookupOrCreateFreeFn(Operation *moduleOp) { @@ -170,15 +182,16 @@ LLVM::LLVMFuncOp mlir::LLVM::lookupOrCreateFreeFn(Operation *moduleOp) { LLVM::LLVMFuncOp mlir::LLVM::lookupOrCreateGenericAllocFn(Operation *moduleOp, Type indexType) { return LLVM::lookupOrCreateFn(moduleOp, kGenericAlloc, indexType, - getVoidPtr(moduleOp->getContext()), false, true); + getVoidPtr(moduleOp->getContext()), false, + true); } LLVM::LLVMFuncOp mlir::LLVM::lookupOrCreateGenericAlignedAllocFn(Operation *moduleOp, Type indexType) { - return LLVM::lookupOrCreateFn(moduleOp, kGenericAlignedAlloc, - {indexType, indexType}, - getVoidPtr(moduleOp->getContext()), false, true); + return LLVM::lookupOrCreateFn( + moduleOp, kGenericAlignedAlloc, {indexType, indexType}, + getVoidPtr(moduleOp->getContext()), false, true); } LLVM::LLVMFuncOp mlir::LLVM::lookupOrCreateGenericFreeFn(Operation *moduleOp) { From 5ec38d6c09ff700dab72d27e3610c7866a23c7fc Mon Sep 17 00:00:00 2001 From: Luohao Wang Date: Tue, 21 Jan 2025 17:39:13 +0800 Subject: [PATCH 4/9] [mlir] Wrapped return value of function lookup in `FailureOr` for error handling --- .../Conversion/LLVMCommon/PrintCallHelper.h | 2 +- .../mlir/Dialect/LLVMIR/FunctionCallUtils.h | 43 ++--- .../Conversion/AsyncToLLVM/AsyncToLLVM.cpp | 8 +- .../ControlFlowToLLVM/ControlFlowToLLVM.cpp | 6 +- mlir/lib/Conversion/LLVMCommon/Pattern.cpp | 16 +- .../Conversion/LLVMCommon/PrintCallHelper.cpp | 14 +- .../MemRefToLLVM/AllocLikeConversion.cpp | 14 +- .../Conversion/MemRefToLLVM/MemRefToLLVM.cpp | 15 +- .../VectorToLLVM/ConvertVectorToLLVM.cpp | 48 +++-- .../Dialect/LLVMIR/IR/FunctionCallUtils.cpp | 174 ++++++++++-------- 10 files changed, 196 insertions(+), 144 deletions(-) diff --git a/mlir/include/mlir/Conversion/LLVMCommon/PrintCallHelper.h b/mlir/include/mlir/Conversion/LLVMCommon/PrintCallHelper.h index c2742b6fc1d73..5af86956c0ad9 100644 --- a/mlir/include/mlir/Conversion/LLVMCommon/PrintCallHelper.h +++ b/mlir/include/mlir/Conversion/LLVMCommon/PrintCallHelper.h @@ -23,7 +23,7 @@ namespace LLVM { /// Generate IR that prints the given string to stdout. /// If a custom runtime function is defined via `runtimeFunctionName`, it must /// have the signature void(char const*). The default function is `printString`. -void createPrintStrCall(OpBuilder &builder, Location loc, ModuleOp moduleOp, +LogicalResult createPrintStrCall(OpBuilder &builder, Location loc, ModuleOp moduleOp, StringRef symbolName, StringRef string, const LLVMTypeConverter &typeConverter, bool addNewline = true, diff --git a/mlir/include/mlir/Dialect/LLVMIR/FunctionCallUtils.h b/mlir/include/mlir/Dialect/LLVMIR/FunctionCallUtils.h index 3095c83b90db9..473a69019d239 100644 --- a/mlir/include/mlir/Dialect/LLVMIR/FunctionCallUtils.h +++ b/mlir/include/mlir/Dialect/LLVMIR/FunctionCallUtils.h @@ -16,7 +16,6 @@ #include "mlir/IR/Operation.h" #include "mlir/Support/LLVM.h" -#include namespace mlir { class Location; @@ -29,40 +28,42 @@ class ValueRange; namespace LLVM { class LLVMFuncOp; -/// Helper functions to lookup or create the declaration for commonly used +/// Helper functions to look up or create the declaration for commonly used /// external C function calls. The list of functions provided here must be /// implemented separately (e.g. as part of a support runtime library or as part /// of the libc). -LLVM::LLVMFuncOp lookupOrCreatePrintI64Fn(Operation *moduleOp); -LLVM::LLVMFuncOp lookupOrCreatePrintU64Fn(Operation *moduleOp); -LLVM::LLVMFuncOp lookupOrCreatePrintF16Fn(Operation *moduleOp); -LLVM::LLVMFuncOp lookupOrCreatePrintBF16Fn(Operation *moduleOp); -LLVM::LLVMFuncOp lookupOrCreatePrintF32Fn(Operation *moduleOp); -LLVM::LLVMFuncOp lookupOrCreatePrintF64Fn(Operation *moduleOp); +/// Failure if an unexpected version of function is found. +FailureOr lookupOrCreatePrintI64Fn(Operation *moduleOp); +FailureOr lookupOrCreatePrintU64Fn(Operation *moduleOp); +FailureOr lookupOrCreatePrintF16Fn(Operation *moduleOp); +FailureOr lookupOrCreatePrintBF16Fn(Operation *moduleOp); +FailureOr lookupOrCreatePrintF32Fn(Operation *moduleOp); +FailureOr lookupOrCreatePrintF64Fn(Operation *moduleOp); /// Declares a function to print a C-string. /// If a custom runtime function is defined via `runtimeFunctionName`, it must /// have the signature void(char const*). The default function is `printString`. -LLVM::LLVMFuncOp +FailureOr lookupOrCreatePrintStringFn(Operation *moduleOp, std::optional runtimeFunctionName = {}); -LLVM::LLVMFuncOp lookupOrCreatePrintOpenFn(Operation *moduleOp); -LLVM::LLVMFuncOp lookupOrCreatePrintCloseFn(Operation *moduleOp); -LLVM::LLVMFuncOp lookupOrCreatePrintCommaFn(Operation *moduleOp); -LLVM::LLVMFuncOp lookupOrCreatePrintNewlineFn(Operation *moduleOp); -LLVM::LLVMFuncOp lookupOrCreateMallocFn(Operation *moduleOp, Type indexType); -LLVM::LLVMFuncOp lookupOrCreateAlignedAllocFn(Operation *moduleOp, +FailureOr lookupOrCreatePrintOpenFn(Operation *moduleOp); +FailureOr lookupOrCreatePrintCloseFn(Operation *moduleOp); +FailureOr lookupOrCreatePrintCommaFn(Operation *moduleOp); +FailureOr lookupOrCreatePrintNewlineFn(Operation *moduleOp); +FailureOr lookupOrCreateMallocFn(Operation *moduleOp, Type indexType); +FailureOr lookupOrCreateAlignedAllocFn(Operation *moduleOp, Type indexType); -LLVM::LLVMFuncOp lookupOrCreateFreeFn(Operation *moduleOp); -LLVM::LLVMFuncOp lookupOrCreateGenericAllocFn(Operation *moduleOp, +FailureOr lookupOrCreateFreeFn(Operation *moduleOp); +FailureOr lookupOrCreateGenericAllocFn(Operation *moduleOp, Type indexType); -LLVM::LLVMFuncOp lookupOrCreateGenericAlignedAllocFn(Operation *moduleOp, +FailureOr lookupOrCreateGenericAlignedAllocFn(Operation *moduleOp, Type indexType); -LLVM::LLVMFuncOp lookupOrCreateGenericFreeFn(Operation *moduleOp); -LLVM::LLVMFuncOp lookupOrCreateMemRefCopyFn(Operation *moduleOp, Type indexType, +FailureOr lookupOrCreateGenericFreeFn(Operation *moduleOp); +FailureOr lookupOrCreateMemRefCopyFn(Operation *moduleOp, Type indexType, Type unrankedDescriptorType); /// Create a FuncOp with signature `resultType`(`paramTypes`)` and name `name`. -LLVM::LLVMFuncOp lookupOrCreateFn(Operation *moduleOp, StringRef name, +/// Return a failure if the FuncOp found has unexpected signature. +FailureOr lookupOrCreateFn(Operation *moduleOp, StringRef name, ArrayRef paramTypes = {}, Type resultType = {}, bool isVarArg = false, bool isReserved = false); diff --git a/mlir/lib/Conversion/AsyncToLLVM/AsyncToLLVM.cpp b/mlir/lib/Conversion/AsyncToLLVM/AsyncToLLVM.cpp index 9b5aeb3fef30b..47d4474a5c28d 100644 --- a/mlir/lib/Conversion/AsyncToLLVM/AsyncToLLVM.cpp +++ b/mlir/lib/Conversion/AsyncToLLVM/AsyncToLLVM.cpp @@ -396,8 +396,10 @@ class CoroBeginOpConversion : public AsyncOpConversionPattern { // Allocate memory for the coroutine frame. auto allocFuncOp = LLVM::lookupOrCreateAlignedAllocFn( op->getParentOfType(), rewriter.getI64Type()); + if (failed(allocFuncOp)) + return failure(); auto coroAlloc = rewriter.create( - loc, allocFuncOp, ValueRange{coroAlign, coroSize}); + loc, allocFuncOp.value(), ValueRange{coroAlign, coroSize}); // Begin a coroutine: @llvm.coro.begin. auto coroId = CoroBeginOpAdaptor(adaptor.getOperands()).getId(); @@ -431,7 +433,9 @@ class CoroFreeOpConversion : public AsyncOpConversionPattern { // Free the memory. auto freeFuncOp = LLVM::lookupOrCreateFreeFn(op->getParentOfType()); - rewriter.replaceOpWithNewOp(op, freeFuncOp, + if (failed(freeFuncOp)) + return failure(); + rewriter.replaceOpWithNewOp(op, freeFuncOp.value(), ValueRange(coroMem.getResult())); return success(); diff --git a/mlir/lib/Conversion/ControlFlowToLLVM/ControlFlowToLLVM.cpp b/mlir/lib/Conversion/ControlFlowToLLVM/ControlFlowToLLVM.cpp index d0ffb94f3f96a..cdcb613e04ab1 100644 --- a/mlir/lib/Conversion/ControlFlowToLLVM/ControlFlowToLLVM.cpp +++ b/mlir/lib/Conversion/ControlFlowToLLVM/ControlFlowToLLVM.cpp @@ -61,9 +61,11 @@ struct AssertOpLowering : public ConvertOpToLLVMPattern { // Failed block: Generate IR to print the message and call `abort`. Block *failureBlock = rewriter.createBlock(opBlock->getParent()); - LLVM::createPrintStrCall(rewriter, loc, module, "assert_msg", op.getMsg(), + if (LLVM::createPrintStrCall(rewriter, loc, module, "assert_msg", op.getMsg(), *getTypeConverter(), /*addNewLine=*/false, - /*runtimeFunctionName=*/"puts"); + /*runtimeFunctionName=*/"puts").failed()) { + return failure(); + } if (abortOnFailedAssert) { // Insert the `abort` declaration if necessary. auto abortFunc = module.lookupSymbol("abort"); diff --git a/mlir/lib/Conversion/LLVMCommon/Pattern.cpp b/mlir/lib/Conversion/LLVMCommon/Pattern.cpp index a47a2872ceb07..10f72cda7706d 100644 --- a/mlir/lib/Conversion/LLVMCommon/Pattern.cpp +++ b/mlir/lib/Conversion/LLVMCommon/Pattern.cpp @@ -276,11 +276,17 @@ LogicalResult ConvertToLLVMPattern::copyUnrankedDescriptors( // Find the malloc and free, or declare them if necessary. auto module = builder.getInsertionPoint()->getParentOfType(); - LLVM::LLVMFuncOp freeFunc, mallocFunc; - if (toDynamic) + FailureOr freeFunc, mallocFunc; + if (toDynamic) { mallocFunc = LLVM::lookupOrCreateMallocFn(module, indexType); - if (!toDynamic) + if (failed(mallocFunc)) + return failure(); + } + if (!toDynamic) { freeFunc = LLVM::lookupOrCreateFreeFn(module); + if (failed(freeFunc)) + return failure(); + } unsigned unrankedMemrefPos = 0; for (unsigned i = 0, e = operands.size(); i < e; ++i) { @@ -293,7 +299,7 @@ LogicalResult ConvertToLLVMPattern::copyUnrankedDescriptors( // Allocate memory, copy, and free the source if necessary. Value memory = toDynamic - ? builder.create(loc, mallocFunc, allocationSize) + ? builder.create(loc, mallocFunc.value(), allocationSize) .getResult() : builder.create(loc, getVoidPtrType(), IntegerType::get(getContext(), 8), @@ -302,7 +308,7 @@ LogicalResult ConvertToLLVMPattern::copyUnrankedDescriptors( Value source = desc.memRefDescPtr(builder, loc); builder.create(loc, memory, source, allocationSize, false); if (!toDynamic) - builder.create(loc, freeFunc, source); + builder.create(loc, freeFunc.value(), source); // Create a new descriptor. The same descriptor can be returned multiple // times, attempting to modify its pointer can lead to memory leaks diff --git a/mlir/lib/Conversion/LLVMCommon/PrintCallHelper.cpp b/mlir/lib/Conversion/LLVMCommon/PrintCallHelper.cpp index bd7b401efec17..607e1d6504552 100644 --- a/mlir/lib/Conversion/LLVMCommon/PrintCallHelper.cpp +++ b/mlir/lib/Conversion/LLVMCommon/PrintCallHelper.cpp @@ -27,7 +27,7 @@ static std::string ensureSymbolNameIsUnique(ModuleOp moduleOp, return uniqueName; } -void mlir::LLVM::createPrintStrCall( +LogicalResult mlir::LLVM::createPrintStrCall( OpBuilder &builder, Location loc, ModuleOp moduleOp, StringRef symbolName, StringRef string, const LLVMTypeConverter &typeConverter, bool addNewline, std::optional runtimeFunctionName) { @@ -59,8 +59,12 @@ void mlir::LLVM::createPrintStrCall( SmallVector indices(1, 0); Value gep = builder.create(loc, ptrTy, arrayTy, msgAddr, indices); - Operation *printer = - LLVM::lookupOrCreatePrintStringFn(moduleOp, runtimeFunctionName); - builder.create(loc, TypeRange(), SymbolRefAttr::get(printer), - gep); + if (auto printer = + LLVM::lookupOrCreatePrintStringFn(moduleOp, runtimeFunctionName); succeeded(printer)) { + builder.create(loc, TypeRange(), + SymbolRefAttr::get(printer.value()), gep); + } else { + return failure(); + } + return success(); } diff --git a/mlir/lib/Conversion/MemRefToLLVM/AllocLikeConversion.cpp b/mlir/lib/Conversion/MemRefToLLVM/AllocLikeConversion.cpp index a6408391b1330..0ee92722157f3 100644 --- a/mlir/lib/Conversion/MemRefToLLVM/AllocLikeConversion.cpp +++ b/mlir/lib/Conversion/MemRefToLLVM/AllocLikeConversion.cpp @@ -15,7 +15,7 @@ using namespace mlir; namespace { -LLVM::LLVMFuncOp getNotalignedAllocFn(const LLVMTypeConverter *typeConverter, +FailureOr getNotalignedAllocFn(const LLVMTypeConverter *typeConverter, Operation *module, Type indexType) { bool useGenericFn = typeConverter->getOptions().useGenericFunctions; if (useGenericFn) @@ -24,7 +24,7 @@ LLVM::LLVMFuncOp getNotalignedAllocFn(const LLVMTypeConverter *typeConverter, return LLVM::lookupOrCreateMallocFn(module, indexType); } -LLVM::LLVMFuncOp getAlignedAllocFn(const LLVMTypeConverter *typeConverter, +FailureOr getAlignedAllocFn(const LLVMTypeConverter *typeConverter, Operation *module, Type indexType) { bool useGenericFn = typeConverter->getOptions().useGenericFunctions; @@ -80,10 +80,11 @@ std::tuple AllocationOpLLVMLowering::allocateBufferManuallyAlign( << " to integer address space " "failed. Consider adding memory space conversions."; } - LLVM::LLVMFuncOp allocFuncOp = getNotalignedAllocFn( + FailureOr allocFuncOp = getNotalignedAllocFn( getTypeConverter(), op->getParentWithTrait(), getIndexType()); - auto results = rewriter.create(loc, allocFuncOp, sizeBytes); + if (failed(allocFuncOp)) return std::make_tuple(Value(), Value()); + auto results = rewriter.create(loc, allocFuncOp.value(), sizeBytes); Value allocatedPtr = castAllocFuncResult(rewriter, loc, results.getResult(), memRefType, @@ -146,11 +147,12 @@ Value AllocationOpLLVMLowering::allocateBufferAutoAlign( sizeBytes = createAligned(rewriter, loc, sizeBytes, allocAlignment); Type elementPtrType = this->getElementPtrType(memRefType); - LLVM::LLVMFuncOp allocFuncOp = getAlignedAllocFn( + FailureOr allocFuncOp = getAlignedAllocFn( getTypeConverter(), op->getParentWithTrait(), getIndexType()); + if (failed(allocFuncOp)) return Value(); auto results = rewriter.create( - loc, allocFuncOp, ValueRange({allocAlignment, sizeBytes})); + loc, allocFuncOp.value(), ValueRange({allocAlignment, sizeBytes})); return castAllocFuncResult(rewriter, loc, results.getResult(), memRefType, elementPtrType, *getTypeConverter()); diff --git a/mlir/lib/Conversion/MemRefToLLVM/MemRefToLLVM.cpp b/mlir/lib/Conversion/MemRefToLLVM/MemRefToLLVM.cpp index f7542b8b3bc5c..ac27e0dd09bdc 100644 --- a/mlir/lib/Conversion/MemRefToLLVM/MemRefToLLVM.cpp +++ b/mlir/lib/Conversion/MemRefToLLVM/MemRefToLLVM.cpp @@ -42,8 +42,8 @@ bool isStaticStrideOrOffset(int64_t strideOrOffset) { return !ShapedType::isDynamic(strideOrOffset); } -LLVM::LLVMFuncOp getFreeFn(const LLVMTypeConverter *typeConverter, - ModuleOp module) { +FailureOr getFreeFn(const LLVMTypeConverter *typeConverter, + ModuleOp module) { bool useGenericFn = typeConverter->getOptions().useGenericFunctions; if (useGenericFn) @@ -220,8 +220,10 @@ struct DeallocOpLowering : public ConvertOpToLLVMPattern { matchAndRewrite(memref::DeallocOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { // Insert the `free` declaration if it is not already present. - LLVM::LLVMFuncOp freeFunc = + auto freeFunc = getFreeFn(getTypeConverter(), op->getParentOfType()); + if (failed(freeFunc)) + return failure(); Value allocatedPtr; if (auto unrankedTy = llvm::dyn_cast(op.getMemref().getType())) { @@ -236,7 +238,8 @@ struct DeallocOpLowering : public ConvertOpToLLVMPattern { allocatedPtr = MemRefDescriptor(adaptor.getMemref()) .allocatedPtr(rewriter, op.getLoc()); } - rewriter.replaceOpWithNewOp(op, freeFunc, allocatedPtr); + rewriter.replaceOpWithNewOp(op, freeFunc.value(), + allocatedPtr); return success(); } }; @@ -838,7 +841,9 @@ struct MemRefCopyOpLowering : public ConvertOpToLLVMPattern { auto elemSize = getSizeInBytes(loc, srcType.getElementType(), rewriter); auto copyFn = LLVM::lookupOrCreateMemRefCopyFn( op->getParentOfType(), getIndexType(), sourcePtr.getType()); - rewriter.create(loc, copyFn, + if (failed(copyFn)) + return failure(); + rewriter.create(loc, copyFn.value(), ValueRange{elemSize, sourcePtr, targetPtr}); // Restore stack used for descriptors diff --git a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp index a1e21cb524bd9..79617506008fa 100644 --- a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp +++ b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp @@ -1546,24 +1546,32 @@ class VectorPrintOpConversion : public ConvertOpToLLVMPattern { auto punct = printOp.getPunctuation(); if (auto stringLiteral = printOp.getStringLiteral()) { - LLVM::createPrintStrCall(rewriter, loc, parent, "vector_print_str", - *stringLiteral, *getTypeConverter(), - /*addNewline=*/false); + if (LLVM::createPrintStrCall(rewriter, loc, parent, "vector_print_str", + *stringLiteral, *getTypeConverter(), + /*addNewline=*/false) + .failed()) { + return failure(); + } } else if (punct != PrintPunctuation::NoPunctuation) { - emitCall(rewriter, printOp->getLoc(), [&] { - switch (punct) { - case PrintPunctuation::Close: - return LLVM::lookupOrCreatePrintCloseFn(parent); - case PrintPunctuation::Open: - return LLVM::lookupOrCreatePrintOpenFn(parent); - case PrintPunctuation::Comma: - return LLVM::lookupOrCreatePrintCommaFn(parent); - case PrintPunctuation::NewLine: - return LLVM::lookupOrCreatePrintNewlineFn(parent); - default: - llvm_unreachable("unexpected punctuation"); - } - }()); + if (auto op = [&] -> FailureOr { + switch (punct) { + case PrintPunctuation::Close: + return LLVM::lookupOrCreatePrintCloseFn(parent); + case PrintPunctuation::Open: + return LLVM::lookupOrCreatePrintOpenFn(parent); + case PrintPunctuation::Comma: + return LLVM::lookupOrCreatePrintCommaFn(parent); + case PrintPunctuation::NewLine: + return LLVM::lookupOrCreatePrintNewlineFn(parent); + default: + llvm_unreachable("unexpected punctuation"); + } + }(); + succeeded(op)) + emitCall(rewriter, printOp->getLoc(), op.value()); + else { + return failure(); + } } rewriter.eraseOp(printOp); @@ -1588,7 +1596,7 @@ class VectorPrintOpConversion : public ConvertOpToLLVMPattern { // Make sure element type has runtime support. PrintConversion conversion = PrintConversion::None; - Operation *printer; + FailureOr printer; if (printType.isF32()) { printer = LLVM::lookupOrCreatePrintF32Fn(parent); } else if (printType.isF64()) { @@ -1631,6 +1639,8 @@ class VectorPrintOpConversion : public ConvertOpToLLVMPattern { } else { return failure(); } + if (failed(printer)) + return failure(); switch (conversion) { case PrintConversion::ZeroExt64: @@ -1648,7 +1658,7 @@ class VectorPrintOpConversion : public ConvertOpToLLVMPattern { case PrintConversion::None: break; } - emitCall(rewriter, loc, printer, value); + emitCall(rewriter, loc, printer.value(), value); return success(); } diff --git a/mlir/lib/Dialect/LLVMIR/IR/FunctionCallUtils.cpp b/mlir/lib/Dialect/LLVMIR/IR/FunctionCallUtils.cpp index 757a1acf3626f..c2c87bc7544bd 100644 --- a/mlir/lib/Dialect/LLVMIR/IR/FunctionCallUtils.cpp +++ b/mlir/lib/Dialect/LLVMIR/IR/FunctionCallUtils.cpp @@ -45,11 +45,10 @@ static constexpr llvm::StringRef kGenericFree = "_mlir_memref_to_llvm_free"; static constexpr llvm::StringRef kMemRefCopy = "memrefCopy"; /// Generic print function lookupOrCreate helper. -LLVM::LLVMFuncOp mlir::LLVM::lookupOrCreateFn(Operation *moduleOp, - StringRef name, - ArrayRef paramTypes, - Type resultType, bool isVarArg, - bool isReserved) { +FailureOr +mlir::LLVM::lookupOrCreateFn(Operation *moduleOp, StringRef name, + ArrayRef paramTypes, Type resultType, + bool isVarArg, bool isReserved) { assert(moduleOp->hasTrait() && "expected SymbolTable operation"); auto func = llvm::dyn_cast_or_null( @@ -63,14 +62,13 @@ LLVM::LLVMFuncOp mlir::LLVM::lookupOrCreateFn(Operation *moduleOp, "' of different type ") .append(func.getFunctionType()) .append(" is prohibited"); - exit(0); } else { func.emitError("redefinition of function '" + name + "' of different type ") .append(funcT) .append(" is prohibited"); - exit(0); } + return failure(); } return func; } @@ -80,42 +78,58 @@ LLVM::LLVMFuncOp mlir::LLVM::lookupOrCreateFn(Operation *moduleOp, LLVM::LLVMFunctionType::get(resultType, paramTypes, isVarArg)); } -LLVM::LLVMFuncOp mlir::LLVM::lookupOrCreatePrintI64Fn(Operation *moduleOp) { - return lookupOrCreateFn( +namespace { +FailureOr lookupOrCreateReservedFn(Operation *moduleOp, + StringRef name, + ArrayRef paramTypes, + Type resultType) { + return lookupOrCreateFn(moduleOp, name, paramTypes, resultType, + /*isVarArg=*/false, /*isReserved=*/true); +} +} // namespace + +FailureOr +mlir::LLVM::lookupOrCreatePrintI64Fn(Operation *moduleOp) { + return lookupOrCreateReservedFn( moduleOp, kPrintI64, IntegerType::get(moduleOp->getContext(), 64), - LLVM::LLVMVoidType::get(moduleOp->getContext()), false, true); + LLVM::LLVMVoidType::get(moduleOp->getContext())); } -LLVM::LLVMFuncOp mlir::LLVM::lookupOrCreatePrintU64Fn(Operation *moduleOp) { - return lookupOrCreateFn( +FailureOr +mlir::LLVM::lookupOrCreatePrintU64Fn(Operation *moduleOp) { + return lookupOrCreateReservedFn( moduleOp, kPrintU64, IntegerType::get(moduleOp->getContext(), 64), - LLVM::LLVMVoidType::get(moduleOp->getContext()), false, true); + LLVM::LLVMVoidType::get(moduleOp->getContext())); } -LLVM::LLVMFuncOp mlir::LLVM::lookupOrCreatePrintF16Fn(Operation *moduleOp) { - return lookupOrCreateFn(moduleOp, kPrintF16, - IntegerType::get(moduleOp->getContext(), 16), // bits! - LLVM::LLVMVoidType::get(moduleOp->getContext()), - false, true); +FailureOr +mlir::LLVM::lookupOrCreatePrintF16Fn(Operation *moduleOp) { + return lookupOrCreateReservedFn( + moduleOp, kPrintF16, + IntegerType::get(moduleOp->getContext(), 16), // bits! + LLVM::LLVMVoidType::get(moduleOp->getContext())); } -LLVM::LLVMFuncOp mlir::LLVM::lookupOrCreatePrintBF16Fn(Operation *moduleOp) { - return lookupOrCreateFn(moduleOp, kPrintBF16, - IntegerType::get(moduleOp->getContext(), 16), // bits! - LLVM::LLVMVoidType::get(moduleOp->getContext()), - false, true); +FailureOr +mlir::LLVM::lookupOrCreatePrintBF16Fn(Operation *moduleOp) { + return lookupOrCreateReservedFn( + moduleOp, kPrintBF16, + IntegerType::get(moduleOp->getContext(), 16), // bits! + LLVM::LLVMVoidType::get(moduleOp->getContext())); } -LLVM::LLVMFuncOp mlir::LLVM::lookupOrCreatePrintF32Fn(Operation *moduleOp) { - return lookupOrCreateFn( +FailureOr +mlir::LLVM::lookupOrCreatePrintF32Fn(Operation *moduleOp) { + return lookupOrCreateReservedFn( moduleOp, kPrintF32, Float32Type::get(moduleOp->getContext()), - LLVM::LLVMVoidType::get(moduleOp->getContext()), false, true); + LLVM::LLVMVoidType::get(moduleOp->getContext())); } -LLVM::LLVMFuncOp mlir::LLVM::lookupOrCreatePrintF64Fn(Operation *moduleOp) { - return lookupOrCreateFn( +FailureOr +mlir::LLVM::lookupOrCreatePrintF64Fn(Operation *moduleOp) { + return lookupOrCreateReservedFn( moduleOp, kPrintF64, Float64Type::get(moduleOp->getContext()), - LLVM::LLVMVoidType::get(moduleOp->getContext()), false, true); + LLVM::LLVMVoidType::get(moduleOp->getContext())); } static LLVM::LLVMPointerType getCharPtr(MLIRContext *context) { @@ -127,84 +141,88 @@ static LLVM::LLVMPointerType getVoidPtr(MLIRContext *context) { return getCharPtr(context); } -LLVM::LLVMFuncOp mlir::LLVM::lookupOrCreatePrintStringFn( +FailureOr mlir::LLVM::lookupOrCreatePrintStringFn( Operation *moduleOp, std::optional runtimeFunctionName) { - return lookupOrCreateFn(moduleOp, runtimeFunctionName.value_or(kPrintString), - getCharPtr(moduleOp->getContext()), - LLVM::LLVMVoidType::get(moduleOp->getContext()), - false, true); + return lookupOrCreateReservedFn( + moduleOp, runtimeFunctionName.value_or(kPrintString), + getCharPtr(moduleOp->getContext()), + LLVM::LLVMVoidType::get(moduleOp->getContext())); } -LLVM::LLVMFuncOp mlir::LLVM::lookupOrCreatePrintOpenFn(Operation *moduleOp) { - return lookupOrCreateFn(moduleOp, kPrintOpen, {}, - LLVM::LLVMVoidType::get(moduleOp->getContext()), - false, true); +FailureOr +mlir::LLVM::lookupOrCreatePrintOpenFn(Operation *moduleOp) { + return lookupOrCreateReservedFn( + moduleOp, kPrintOpen, {}, + LLVM::LLVMVoidType::get(moduleOp->getContext())); } -LLVM::LLVMFuncOp mlir::LLVM::lookupOrCreatePrintCloseFn(Operation *moduleOp) { - return lookupOrCreateFn(moduleOp, kPrintClose, {}, - LLVM::LLVMVoidType::get(moduleOp->getContext()), - false, true); +FailureOr +mlir::LLVM::lookupOrCreatePrintCloseFn(Operation *moduleOp) { + return lookupOrCreateReservedFn( + moduleOp, kPrintClose, {}, + LLVM::LLVMVoidType::get(moduleOp->getContext())); } -LLVM::LLVMFuncOp mlir::LLVM::lookupOrCreatePrintCommaFn(Operation *moduleOp) { - return lookupOrCreateFn(moduleOp, kPrintComma, {}, - LLVM::LLVMVoidType::get(moduleOp->getContext()), - false, true); +FailureOr +mlir::LLVM::lookupOrCreatePrintCommaFn(Operation *moduleOp) { + return lookupOrCreateReservedFn( + moduleOp, kPrintComma, {}, + LLVM::LLVMVoidType::get(moduleOp->getContext())); } -LLVM::LLVMFuncOp mlir::LLVM::lookupOrCreatePrintNewlineFn(Operation *moduleOp) { - return lookupOrCreateFn(moduleOp, kPrintNewline, {}, - LLVM::LLVMVoidType::get(moduleOp->getContext()), - false, true); +FailureOr +mlir::LLVM::lookupOrCreatePrintNewlineFn(Operation *moduleOp) { + return lookupOrCreateReservedFn( + moduleOp, kPrintNewline, {}, + LLVM::LLVMVoidType::get(moduleOp->getContext())); } -LLVM::LLVMFuncOp mlir::LLVM::lookupOrCreateMallocFn(Operation *moduleOp, - Type indexType) { - return LLVM::lookupOrCreateFn(moduleOp, kMalloc, indexType, - getVoidPtr(moduleOp->getContext()), false, - true); +FailureOr +mlir::LLVM::lookupOrCreateMallocFn(Operation *moduleOp, Type indexType) { + return lookupOrCreateReservedFn(moduleOp, kMalloc, indexType, + getVoidPtr(moduleOp->getContext())); } -LLVM::LLVMFuncOp mlir::LLVM::lookupOrCreateAlignedAllocFn(Operation *moduleOp, - Type indexType) { - return LLVM::lookupOrCreateFn(moduleOp, kAlignedAlloc, {indexType, indexType}, - getVoidPtr(moduleOp->getContext()), false, - true); +FailureOr +mlir::LLVM::lookupOrCreateAlignedAllocFn(Operation *moduleOp, Type indexType) { + return lookupOrCreateReservedFn(moduleOp, kAlignedAlloc, + {indexType, indexType}, + getVoidPtr(moduleOp->getContext())); } -LLVM::LLVMFuncOp mlir::LLVM::lookupOrCreateFreeFn(Operation *moduleOp) { - return LLVM::lookupOrCreateFn( +FailureOr +mlir::LLVM::lookupOrCreateFreeFn(Operation *moduleOp) { + return lookupOrCreateReservedFn( moduleOp, kFree, getVoidPtr(moduleOp->getContext()), - LLVM::LLVMVoidType::get(moduleOp->getContext()), false, true); + LLVM::LLVMVoidType::get(moduleOp->getContext())); } -LLVM::LLVMFuncOp mlir::LLVM::lookupOrCreateGenericAllocFn(Operation *moduleOp, - Type indexType) { - return LLVM::lookupOrCreateFn(moduleOp, kGenericAlloc, indexType, - getVoidPtr(moduleOp->getContext()), false, - true); +FailureOr +mlir::LLVM::lookupOrCreateGenericAllocFn(Operation *moduleOp, Type indexType) { + return lookupOrCreateReservedFn(moduleOp, kGenericAlloc, indexType, + getVoidPtr(moduleOp->getContext())); } -LLVM::LLVMFuncOp +FailureOr mlir::LLVM::lookupOrCreateGenericAlignedAllocFn(Operation *moduleOp, Type indexType) { - return LLVM::lookupOrCreateFn( - moduleOp, kGenericAlignedAlloc, {indexType, indexType}, - getVoidPtr(moduleOp->getContext()), false, true); + return lookupOrCreateReservedFn(moduleOp, kGenericAlignedAlloc, + {indexType, indexType}, + getVoidPtr(moduleOp->getContext())); } -LLVM::LLVMFuncOp mlir::LLVM::lookupOrCreateGenericFreeFn(Operation *moduleOp) { - return LLVM::lookupOrCreateFn( +FailureOr +mlir::LLVM::lookupOrCreateGenericFreeFn(Operation *moduleOp) { + return lookupOrCreateReservedFn( moduleOp, kGenericFree, getVoidPtr(moduleOp->getContext()), - LLVM::LLVMVoidType::get(moduleOp->getContext()), false, true); + LLVM::LLVMVoidType::get(moduleOp->getContext())); } -LLVM::LLVMFuncOp +FailureOr mlir::LLVM::lookupOrCreateMemRefCopyFn(Operation *moduleOp, Type indexType, Type unrankedDescriptorType) { - return LLVM::lookupOrCreateFn( + return lookupOrCreateReservedFn( moduleOp, kMemRefCopy, ArrayRef{indexType, unrankedDescriptorType, unrankedDescriptorType}, - LLVM::LLVMVoidType::get(moduleOp->getContext()), false, true); + LLVM::LLVMVoidType::get(moduleOp->getContext())); } From 9079caff8b7bfa80bef19b2d33bdd7e361a8eb66 Mon Sep 17 00:00:00 2001 From: Luohao Wang Date: Tue, 21 Jan 2025 18:16:52 +0800 Subject: [PATCH 5/9] [mlir] [Test] Moved & renamed test case --- mlir/test/Conversion/MemRefToLLVM/invalid.mlir | 7 +++++++ mlir/test/Conversion/MemRefToLLVM/issue-120950.mlir | 11 ----------- 2 files changed, 7 insertions(+), 11 deletions(-) delete mode 100644 mlir/test/Conversion/MemRefToLLVM/issue-120950.mlir diff --git a/mlir/test/Conversion/MemRefToLLVM/invalid.mlir b/mlir/test/Conversion/MemRefToLLVM/invalid.mlir index 40dd75af1dd77..1e12b83a24b5a 100644 --- a/mlir/test/Conversion/MemRefToLLVM/invalid.mlir +++ b/mlir/test/Conversion/MemRefToLLVM/invalid.mlir @@ -2,6 +2,13 @@ // Since the error is at an unknown location, we use FileCheck instead of // -veri-y-diagnostics here +// CHECK: redefinition of reserved function 'malloc' of different type '!llvm.func' is prohibited +llvm.func @malloc(i64) +func.func @redef_reserved() { + %alloc = memref.alloc() : memref<1024x64xf32, 1> + llvm.return +} + // CHECK: conversion of memref memory space "foo" to integer address space failed. Consider adding memory space conversions. // CHECK-LABEL: @bad_address_space func.func @bad_address_space(%a: memref<2xindex, "foo">) { diff --git a/mlir/test/Conversion/MemRefToLLVM/issue-120950.mlir b/mlir/test/Conversion/MemRefToLLVM/issue-120950.mlir deleted file mode 100644 index f744e4f7635ea..0000000000000 --- a/mlir/test/Conversion/MemRefToLLVM/issue-120950.mlir +++ /dev/null @@ -1,11 +0,0 @@ -// RUN: mlir-opt %s -finalize-memref-to-llvm 2>&1 | FileCheck %s - -#map = affine_map<(d0) -> (d0 + 1)> -module { - // CHECK: redefinition of reserved function 'malloc' of different type '!llvm.func' is prohibited - llvm.func @malloc(i64) - func.func @issue_120950() { - %alloc = memref.alloc() : memref<1024x64xf32, 1> - llvm.return - } -} From 475409c4fac865f26c94b751a9b0dbfcc937b83a Mon Sep 17 00:00:00 2001 From: Luohao Wang Date: Tue, 21 Jan 2025 18:29:22 +0800 Subject: [PATCH 6/9] Reformat code --- .../Conversion/LLVMCommon/PrintCallHelper.h | 9 ++++--- .../mlir/Dialect/LLVMIR/FunctionCallUtils.h | 24 ++++++++++--------- .../ControlFlowToLLVM/ControlFlowToLLVM.cpp | 8 ++++--- mlir/lib/Conversion/LLVMCommon/Pattern.cpp | 3 ++- .../Conversion/LLVMCommon/PrintCallHelper.cpp | 3 ++- .../MemRefToLLVM/AllocLikeConversion.cpp | 19 +++++++++------ .../Dialect/LLVMIR/IR/FunctionCallUtils.cpp | 12 +++++----- 7 files changed, 44 insertions(+), 34 deletions(-) diff --git a/mlir/include/mlir/Conversion/LLVMCommon/PrintCallHelper.h b/mlir/include/mlir/Conversion/LLVMCommon/PrintCallHelper.h index 5af86956c0ad9..33402301115b7 100644 --- a/mlir/include/mlir/Conversion/LLVMCommon/PrintCallHelper.h +++ b/mlir/include/mlir/Conversion/LLVMCommon/PrintCallHelper.h @@ -23,11 +23,10 @@ namespace LLVM { /// Generate IR that prints the given string to stdout. /// If a custom runtime function is defined via `runtimeFunctionName`, it must /// have the signature void(char const*). The default function is `printString`. -LogicalResult createPrintStrCall(OpBuilder &builder, Location loc, ModuleOp moduleOp, - StringRef symbolName, StringRef string, - const LLVMTypeConverter &typeConverter, - bool addNewline = true, - std::optional runtimeFunctionName = {}); +LogicalResult createPrintStrCall( + OpBuilder &builder, Location loc, ModuleOp moduleOp, StringRef symbolName, + StringRef string, const LLVMTypeConverter &typeConverter, + bool addNewline = true, std::optional runtimeFunctionName = {}); } // namespace LLVM } // namespace mlir diff --git a/mlir/include/mlir/Dialect/LLVMIR/FunctionCallUtils.h b/mlir/include/mlir/Dialect/LLVMIR/FunctionCallUtils.h index 473a69019d239..05e9fe9d58859 100644 --- a/mlir/include/mlir/Dialect/LLVMIR/FunctionCallUtils.h +++ b/mlir/include/mlir/Dialect/LLVMIR/FunctionCallUtils.h @@ -49,24 +49,26 @@ FailureOr lookupOrCreatePrintOpenFn(Operation *moduleOp); FailureOr lookupOrCreatePrintCloseFn(Operation *moduleOp); FailureOr lookupOrCreatePrintCommaFn(Operation *moduleOp); FailureOr lookupOrCreatePrintNewlineFn(Operation *moduleOp); -FailureOr lookupOrCreateMallocFn(Operation *moduleOp, Type indexType); +FailureOr lookupOrCreateMallocFn(Operation *moduleOp, + Type indexType); FailureOr lookupOrCreateAlignedAllocFn(Operation *moduleOp, - Type indexType); + Type indexType); FailureOr lookupOrCreateFreeFn(Operation *moduleOp); FailureOr lookupOrCreateGenericAllocFn(Operation *moduleOp, - Type indexType); -FailureOr lookupOrCreateGenericAlignedAllocFn(Operation *moduleOp, - Type indexType); + Type indexType); +FailureOr +lookupOrCreateGenericAlignedAllocFn(Operation *moduleOp, Type indexType); FailureOr lookupOrCreateGenericFreeFn(Operation *moduleOp); -FailureOr lookupOrCreateMemRefCopyFn(Operation *moduleOp, Type indexType, - Type unrankedDescriptorType); +FailureOr +lookupOrCreateMemRefCopyFn(Operation *moduleOp, Type indexType, + Type unrankedDescriptorType); /// Create a FuncOp with signature `resultType`(`paramTypes`)` and name `name`. /// Return a failure if the FuncOp found has unexpected signature. -FailureOr lookupOrCreateFn(Operation *moduleOp, StringRef name, - ArrayRef paramTypes = {}, - Type resultType = {}, bool isVarArg = false, - bool isReserved = false); +FailureOr +lookupOrCreateFn(Operation *moduleOp, StringRef name, + ArrayRef paramTypes = {}, Type resultType = {}, + bool isVarArg = false, bool isReserved = false); } // namespace LLVM } // namespace mlir diff --git a/mlir/lib/Conversion/ControlFlowToLLVM/ControlFlowToLLVM.cpp b/mlir/lib/Conversion/ControlFlowToLLVM/ControlFlowToLLVM.cpp index cdcb613e04ab1..f2fc235fecb28 100644 --- a/mlir/lib/Conversion/ControlFlowToLLVM/ControlFlowToLLVM.cpp +++ b/mlir/lib/Conversion/ControlFlowToLLVM/ControlFlowToLLVM.cpp @@ -61,9 +61,11 @@ struct AssertOpLowering : public ConvertOpToLLVMPattern { // Failed block: Generate IR to print the message and call `abort`. Block *failureBlock = rewriter.createBlock(opBlock->getParent()); - if (LLVM::createPrintStrCall(rewriter, loc, module, "assert_msg", op.getMsg(), - *getTypeConverter(), /*addNewLine=*/false, - /*runtimeFunctionName=*/"puts").failed()) { + if (LLVM::createPrintStrCall(rewriter, loc, module, "assert_msg", + op.getMsg(), *getTypeConverter(), + /*addNewLine=*/false, + /*runtimeFunctionName=*/"puts") + .failed()) { return failure(); } if (abortOnFailedAssert) { diff --git a/mlir/lib/Conversion/LLVMCommon/Pattern.cpp b/mlir/lib/Conversion/LLVMCommon/Pattern.cpp index 10f72cda7706d..840bd3df61a06 100644 --- a/mlir/lib/Conversion/LLVMCommon/Pattern.cpp +++ b/mlir/lib/Conversion/LLVMCommon/Pattern.cpp @@ -299,7 +299,8 @@ LogicalResult ConvertToLLVMPattern::copyUnrankedDescriptors( // Allocate memory, copy, and free the source if necessary. Value memory = toDynamic - ? builder.create(loc, mallocFunc.value(), allocationSize) + ? builder + .create(loc, mallocFunc.value(), allocationSize) .getResult() : builder.create(loc, getVoidPtrType(), IntegerType::get(getContext(), 8), diff --git a/mlir/lib/Conversion/LLVMCommon/PrintCallHelper.cpp b/mlir/lib/Conversion/LLVMCommon/PrintCallHelper.cpp index 607e1d6504552..381e2ffea8eb2 100644 --- a/mlir/lib/Conversion/LLVMCommon/PrintCallHelper.cpp +++ b/mlir/lib/Conversion/LLVMCommon/PrintCallHelper.cpp @@ -60,7 +60,8 @@ LogicalResult mlir::LLVM::createPrintStrCall( Value gep = builder.create(loc, ptrTy, arrayTy, msgAddr, indices); if (auto printer = - LLVM::lookupOrCreatePrintStringFn(moduleOp, runtimeFunctionName); succeeded(printer)) { + LLVM::lookupOrCreatePrintStringFn(moduleOp, runtimeFunctionName); + succeeded(printer)) { builder.create(loc, TypeRange(), SymbolRefAttr::get(printer.value()), gep); } else { diff --git a/mlir/lib/Conversion/MemRefToLLVM/AllocLikeConversion.cpp b/mlir/lib/Conversion/MemRefToLLVM/AllocLikeConversion.cpp index 0ee92722157f3..1712d0b5844b8 100644 --- a/mlir/lib/Conversion/MemRefToLLVM/AllocLikeConversion.cpp +++ b/mlir/lib/Conversion/MemRefToLLVM/AllocLikeConversion.cpp @@ -15,8 +15,9 @@ using namespace mlir; namespace { -FailureOr getNotalignedAllocFn(const LLVMTypeConverter *typeConverter, - Operation *module, Type indexType) { +FailureOr +getNotalignedAllocFn(const LLVMTypeConverter *typeConverter, Operation *module, + Type indexType) { bool useGenericFn = typeConverter->getOptions().useGenericFunctions; if (useGenericFn) return LLVM::lookupOrCreateGenericAllocFn(module, indexType); @@ -24,8 +25,9 @@ FailureOr getNotalignedAllocFn(const LLVMTypeConverter *typeCo return LLVM::lookupOrCreateMallocFn(module, indexType); } -FailureOr getAlignedAllocFn(const LLVMTypeConverter *typeConverter, - Operation *module, Type indexType) { +FailureOr +getAlignedAllocFn(const LLVMTypeConverter *typeConverter, Operation *module, + Type indexType) { bool useGenericFn = typeConverter->getOptions().useGenericFunctions; if (useGenericFn) @@ -83,8 +85,10 @@ std::tuple AllocationOpLLVMLowering::allocateBufferManuallyAlign( FailureOr allocFuncOp = getNotalignedAllocFn( getTypeConverter(), op->getParentWithTrait(), getIndexType()); - if (failed(allocFuncOp)) return std::make_tuple(Value(), Value()); - auto results = rewriter.create(loc, allocFuncOp.value(), sizeBytes); + if (failed(allocFuncOp)) + return std::make_tuple(Value(), Value()); + auto results = + rewriter.create(loc, allocFuncOp.value(), sizeBytes); Value allocatedPtr = castAllocFuncResult(rewriter, loc, results.getResult(), memRefType, @@ -150,7 +154,8 @@ Value AllocationOpLLVMLowering::allocateBufferAutoAlign( FailureOr allocFuncOp = getAlignedAllocFn( getTypeConverter(), op->getParentWithTrait(), getIndexType()); - if (failed(allocFuncOp)) return Value(); + if (failed(allocFuncOp)) + return Value(); auto results = rewriter.create( loc, allocFuncOp.value(), ValueRange({allocAlignment, sizeBytes})); diff --git a/mlir/lib/Dialect/LLVMIR/IR/FunctionCallUtils.cpp b/mlir/lib/Dialect/LLVMIR/IR/FunctionCallUtils.cpp index c2c87bc7544bd..9df5c4554c236 100644 --- a/mlir/lib/Dialect/LLVMIR/IR/FunctionCallUtils.cpp +++ b/mlir/lib/Dialect/LLVMIR/IR/FunctionCallUtils.cpp @@ -180,14 +180,14 @@ mlir::LLVM::lookupOrCreatePrintNewlineFn(Operation *moduleOp) { FailureOr mlir::LLVM::lookupOrCreateMallocFn(Operation *moduleOp, Type indexType) { return lookupOrCreateReservedFn(moduleOp, kMalloc, indexType, - getVoidPtr(moduleOp->getContext())); + getVoidPtr(moduleOp->getContext())); } FailureOr mlir::LLVM::lookupOrCreateAlignedAllocFn(Operation *moduleOp, Type indexType) { return lookupOrCreateReservedFn(moduleOp, kAlignedAlloc, - {indexType, indexType}, - getVoidPtr(moduleOp->getContext())); + {indexType, indexType}, + getVoidPtr(moduleOp->getContext())); } FailureOr @@ -200,15 +200,15 @@ mlir::LLVM::lookupOrCreateFreeFn(Operation *moduleOp) { FailureOr mlir::LLVM::lookupOrCreateGenericAllocFn(Operation *moduleOp, Type indexType) { return lookupOrCreateReservedFn(moduleOp, kGenericAlloc, indexType, - getVoidPtr(moduleOp->getContext())); + getVoidPtr(moduleOp->getContext())); } FailureOr mlir::LLVM::lookupOrCreateGenericAlignedAllocFn(Operation *moduleOp, Type indexType) { return lookupOrCreateReservedFn(moduleOp, kGenericAlignedAlloc, - {indexType, indexType}, - getVoidPtr(moduleOp->getContext())); + {indexType, indexType}, + getVoidPtr(moduleOp->getContext())); } FailureOr From 874a1cd68dd800c0b8a5f5263a24a7a36d41eae5 Mon Sep 17 00:00:00 2001 From: Luohao Wang Date: Sun, 26 Jan 2025 10:47:13 +0800 Subject: [PATCH 7/9] Make stylish fixes --- .../Conversion/LLVMCommon/PrintCallHelper.cpp | 10 +++--- .../MemRefToLLVM/AllocLikeConversion.cpp | 7 ++-- .../Conversion/MemRefToLLVM/MemRefToLLVM.cpp | 8 ++--- .../VectorToLLVM/ConvertVectorToLLVM.cpp | 34 +++++++++---------- .../Dialect/LLVMIR/IR/FunctionCallUtils.cpp | 22 +++++------- 5 files changed, 34 insertions(+), 47 deletions(-) diff --git a/mlir/lib/Conversion/LLVMCommon/PrintCallHelper.cpp b/mlir/lib/Conversion/LLVMCommon/PrintCallHelper.cpp index 381e2ffea8eb2..deabb748b5652 100644 --- a/mlir/lib/Conversion/LLVMCommon/PrintCallHelper.cpp +++ b/mlir/lib/Conversion/LLVMCommon/PrintCallHelper.cpp @@ -59,13 +59,11 @@ LogicalResult mlir::LLVM::createPrintStrCall( SmallVector indices(1, 0); Value gep = builder.create(loc, ptrTy, arrayTy, msgAddr, indices); - if (auto printer = + FailureOr printer = LLVM::lookupOrCreatePrintStringFn(moduleOp, runtimeFunctionName); - succeeded(printer)) { - builder.create(loc, TypeRange(), - SymbolRefAttr::get(printer.value()), gep); - } else { + if(failed(printer)) return failure(); - } + builder.create(loc, TypeRange(), + SymbolRefAttr::get(printer.value()), gep); return success(); } diff --git a/mlir/lib/Conversion/MemRefToLLVM/AllocLikeConversion.cpp b/mlir/lib/Conversion/MemRefToLLVM/AllocLikeConversion.cpp index 1712d0b5844b8..c5b2e83df93dc 100644 --- a/mlir/lib/Conversion/MemRefToLLVM/AllocLikeConversion.cpp +++ b/mlir/lib/Conversion/MemRefToLLVM/AllocLikeConversion.cpp @@ -14,8 +14,7 @@ using namespace mlir; -namespace { -FailureOr +static FailureOr getNotalignedAllocFn(const LLVMTypeConverter *typeConverter, Operation *module, Type indexType) { bool useGenericFn = typeConverter->getOptions().useGenericFunctions; @@ -25,7 +24,7 @@ getNotalignedAllocFn(const LLVMTypeConverter *typeConverter, Operation *module, return LLVM::lookupOrCreateMallocFn(module, indexType); } -FailureOr +static FailureOr getAlignedAllocFn(const LLVMTypeConverter *typeConverter, Operation *module, Type indexType) { bool useGenericFn = typeConverter->getOptions().useGenericFunctions; @@ -36,8 +35,6 @@ getAlignedAllocFn(const LLVMTypeConverter *typeConverter, Operation *module, return LLVM::lookupOrCreateAlignedAllocFn(module, indexType); } -} // end namespace - Value AllocationOpLLVMLowering::createAligned( ConversionPatternRewriter &rewriter, Location loc, Value input, Value alignment) { diff --git a/mlir/lib/Conversion/MemRefToLLVM/MemRefToLLVM.cpp b/mlir/lib/Conversion/MemRefToLLVM/MemRefToLLVM.cpp index ac27e0dd09bdc..af1dba4587dc1 100644 --- a/mlir/lib/Conversion/MemRefToLLVM/MemRefToLLVM.cpp +++ b/mlir/lib/Conversion/MemRefToLLVM/MemRefToLLVM.cpp @@ -38,12 +38,12 @@ using namespace mlir; namespace { -bool isStaticStrideOrOffset(int64_t strideOrOffset) { +static bool isStaticStrideOrOffset(int64_t strideOrOffset) { return !ShapedType::isDynamic(strideOrOffset); } -FailureOr getFreeFn(const LLVMTypeConverter *typeConverter, - ModuleOp module) { +static FailureOr +getFreeFn(const LLVMTypeConverter *typeConverter, ModuleOp module) { bool useGenericFn = typeConverter->getOptions().useGenericFunctions; if (useGenericFn) @@ -220,7 +220,7 @@ struct DeallocOpLowering : public ConvertOpToLLVMPattern { matchAndRewrite(memref::DeallocOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { // Insert the `free` declaration if it is not already present. - auto freeFunc = + FailureOr freeFunc = getFreeFn(getTypeConverter(), op->getParentOfType()); if (failed(freeFunc)) return failure(); diff --git a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp index 79617506008fa..258374f71c7d5 100644 --- a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp +++ b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp @@ -1553,25 +1553,23 @@ class VectorPrintOpConversion : public ConvertOpToLLVMPattern { return failure(); } } else if (punct != PrintPunctuation::NoPunctuation) { - if (auto op = [&] -> FailureOr { - switch (punct) { - case PrintPunctuation::Close: - return LLVM::lookupOrCreatePrintCloseFn(parent); - case PrintPunctuation::Open: - return LLVM::lookupOrCreatePrintOpenFn(parent); - case PrintPunctuation::Comma: - return LLVM::lookupOrCreatePrintCommaFn(parent); - case PrintPunctuation::NewLine: - return LLVM::lookupOrCreatePrintNewlineFn(parent); - default: - llvm_unreachable("unexpected punctuation"); - } - }(); - succeeded(op)) - emitCall(rewriter, printOp->getLoc(), op.value()); - else { + FailureOr op = [&]() { + switch (punct) { + case PrintPunctuation::Close: + return LLVM::lookupOrCreatePrintCloseFn(parent); + case PrintPunctuation::Open: + return LLVM::lookupOrCreatePrintOpenFn(parent); + case PrintPunctuation::Comma: + return LLVM::lookupOrCreatePrintCommaFn(parent); + case PrintPunctuation::NewLine: + return LLVM::lookupOrCreatePrintNewlineFn(parent); + default: + llvm_unreachable("unexpected punctuation"); + } + }(); + if (failed(op)) return failure(); - } + emitCall(rewriter, printOp->getLoc(), op.value()); } rewriter.eraseOp(printOp); diff --git a/mlir/lib/Dialect/LLVMIR/IR/FunctionCallUtils.cpp b/mlir/lib/Dialect/LLVMIR/IR/FunctionCallUtils.cpp index 9df5c4554c236..68d4426e65301 100644 --- a/mlir/lib/Dialect/LLVMIR/IR/FunctionCallUtils.cpp +++ b/mlir/lib/Dialect/LLVMIR/IR/FunctionCallUtils.cpp @@ -58,15 +58,12 @@ mlir::LLVM::lookupOrCreateFn(Operation *moduleOp, StringRef name, if (func) { if (funcT != func.getFunctionType()) { if (isReserved) { - func.emitError("redefinition of reserved function '" + name + - "' of different type ") - .append(func.getFunctionType()) - .append(" is prohibited"); + func.emitError("redefinition of reserved function '") + << name << "' of different type " << func.getFunctionType() + << " is prohibited"; } else { - func.emitError("redefinition of function '" + name + - "' of different type ") - .append(funcT) - .append(" is prohibited"); + func.emitError("redefinition of function '") + << name << "' of different type " << funcT << " is prohibited"; } return failure(); } @@ -78,15 +75,12 @@ mlir::LLVM::lookupOrCreateFn(Operation *moduleOp, StringRef name, LLVM::LLVMFunctionType::get(resultType, paramTypes, isVarArg)); } -namespace { -FailureOr lookupOrCreateReservedFn(Operation *moduleOp, - StringRef name, - ArrayRef paramTypes, - Type resultType) { +static FailureOr +lookupOrCreateReservedFn(Operation *moduleOp, StringRef name, + ArrayRef paramTypes, Type resultType) { return lookupOrCreateFn(moduleOp, name, paramTypes, resultType, /*isVarArg=*/false, /*isReserved=*/true); } -} // namespace FailureOr mlir::LLVM::lookupOrCreatePrintI64Fn(Operation *moduleOp) { From 2d7dc5d888146a6fddb9c717bacd975dcf008553 Mon Sep 17 00:00:00 2001 From: Luohao Wang Date: Sun, 26 Jan 2025 15:22:35 +0800 Subject: [PATCH 8/9] More stylish fixes --- .../ControlFlowToLLVM/ControlFlowToLLVM.cpp | 12 ++++++------ .../Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp | 9 +++++---- 2 files changed, 11 insertions(+), 10 deletions(-) diff --git a/mlir/lib/Conversion/ControlFlowToLLVM/ControlFlowToLLVM.cpp b/mlir/lib/Conversion/ControlFlowToLLVM/ControlFlowToLLVM.cpp index f2fc235fecb28..debfd003bd5b5 100644 --- a/mlir/lib/Conversion/ControlFlowToLLVM/ControlFlowToLLVM.cpp +++ b/mlir/lib/Conversion/ControlFlowToLLVM/ControlFlowToLLVM.cpp @@ -61,13 +61,13 @@ struct AssertOpLowering : public ConvertOpToLLVMPattern { // Failed block: Generate IR to print the message and call `abort`. Block *failureBlock = rewriter.createBlock(opBlock->getParent()); - if (LLVM::createPrintStrCall(rewriter, loc, module, "assert_msg", - op.getMsg(), *getTypeConverter(), - /*addNewLine=*/false, - /*runtimeFunctionName=*/"puts") - .failed()) { + auto createResult = LLVM::createPrintStrCall( + rewriter, loc, module, "assert_msg", op.getMsg(), *getTypeConverter(), + /*addNewLine=*/false, + /*runtimeFunctionName=*/"puts"); + if (createResult.failed()) return failure(); - } + if (abortOnFailedAssert) { // Insert the `abort` declaration if necessary. auto abortFunc = module.lookupSymbol("abort"); diff --git a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp index 258374f71c7d5..baed98c13adc7 100644 --- a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp +++ b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp @@ -1546,12 +1546,13 @@ class VectorPrintOpConversion : public ConvertOpToLLVMPattern { auto punct = printOp.getPunctuation(); if (auto stringLiteral = printOp.getStringLiteral()) { - if (LLVM::createPrintStrCall(rewriter, loc, parent, "vector_print_str", + auto createResult = + LLVM::createPrintStrCall(rewriter, loc, parent, "vector_print_str", *stringLiteral, *getTypeConverter(), - /*addNewline=*/false) - .failed()) { + /*addNewline=*/false); + if (createResult.failed()) return failure(); - } + } else if (punct != PrintPunctuation::NoPunctuation) { FailureOr op = [&]() { switch (punct) { From a7f3308d74e5f4b4ce3bd7e5a4f35cb92f4f633d Mon Sep 17 00:00:00 2001 From: Luohao Wang Date: Sun, 26 Jan 2025 15:25:44 +0800 Subject: [PATCH 9/9] Format code --- mlir/lib/Conversion/LLVMCommon/PrintCallHelper.cpp | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/mlir/lib/Conversion/LLVMCommon/PrintCallHelper.cpp b/mlir/lib/Conversion/LLVMCommon/PrintCallHelper.cpp index deabb748b5652..337c01f01a7cc 100644 --- a/mlir/lib/Conversion/LLVMCommon/PrintCallHelper.cpp +++ b/mlir/lib/Conversion/LLVMCommon/PrintCallHelper.cpp @@ -60,10 +60,10 @@ LogicalResult mlir::LLVM::createPrintStrCall( Value gep = builder.create(loc, ptrTy, arrayTy, msgAddr, indices); FailureOr printer = - LLVM::lookupOrCreatePrintStringFn(moduleOp, runtimeFunctionName); - if(failed(printer)) + LLVM::lookupOrCreatePrintStringFn(moduleOp, runtimeFunctionName); + if (failed(printer)) return failure(); builder.create(loc, TypeRange(), - SymbolRefAttr::get(printer.value()), gep); + SymbolRefAttr::get(printer.value()), gep); return success(); }