diff --git a/mlir/lib/Conversion/FuncToLLVM/FuncToLLVM.cpp b/mlir/lib/Conversion/FuncToLLVM/FuncToLLVM.cpp index 27c43e0daad07..c046ea1b824fc 100644 --- a/mlir/lib/Conversion/FuncToLLVM/FuncToLLVM.cpp +++ b/mlir/lib/Conversion/FuncToLLVM/FuncToLLVM.cpp @@ -273,7 +273,7 @@ static void wrapExternalFunction(OpBuilder &builder, Location loc, static void restoreByValRefArgumentType( ConversionPatternRewriter &rewriter, const LLVMTypeConverter &typeConverter, ArrayRef> byValRefNonPtrAttrs, - LLVM::LLVMFuncOp funcOp) { + ArrayRef oldBlockArgs, LLVM::LLVMFuncOp funcOp) { // Nothing to do for function declarations. if (funcOp.isExternal()) return; @@ -281,8 +281,8 @@ static void restoreByValRefArgumentType( ConversionPatternRewriter::InsertionGuard guard(rewriter); rewriter.setInsertionPointToStart(&funcOp.getFunctionBody().front()); - for (const auto &[arg, byValRefAttr] : - llvm::zip(funcOp.getArguments(), byValRefNonPtrAttrs)) { + for (const auto &[arg, oldArg, byValRefAttr] : + llvm::zip(funcOp.getArguments(), oldBlockArgs, byValRefNonPtrAttrs)) { // Skip argument if no `llvm.byval` or `llvm.byref` attribute. if (!byValRefAttr) continue; @@ -295,7 +295,7 @@ static void restoreByValRefArgumentType( cast(byValRefAttr->getValue()).getValue()); auto valueArg = rewriter.create(arg.getLoc(), resTy, arg); - rewriter.replaceAllUsesExcept(arg, valueArg, valueArg); + rewriter.replaceUsesOfBlockArgument(oldArg, valueArg); } } @@ -309,6 +309,10 @@ mlir::convertFuncOpToLLVMFuncOp(FunctionOpInterface funcOp, return rewriter.notifyMatchFailure( funcOp, "Only support FunctionOpInterface with FunctionType"); + // Keep track of the entry block arguments. They will be needed later. + SmallVector oldBlockArgs = + llvm::to_vector(funcOp.getArguments()); + // Convert the original function arguments. They are converted using the // LLVMTypeConverter provided to this legalization pattern. auto varargsAttr = funcOp->getAttrOfType(varargsAttrName); @@ -438,7 +442,7 @@ mlir::convertFuncOpToLLVMFuncOp(FunctionOpInterface funcOp, // pointee type in the function body when converting `llvm.byval`/`llvm.byref` // function arguments. restoreByValRefArgumentType(rewriter, converter, byValRefNonPtrAttrs, - newFuncOp); + oldBlockArgs, newFuncOp); if (!shouldUseBarePtrCallConv(funcOp, &converter)) { if (funcOp->getAttrOfType(