@@ -982,7 +982,7 @@ struct EmboxCharOpConversion : public fir::FIROpConversion<fir::EmboxCharOp> {
982982template <typename ModuleOp>
983983static mlir::SymbolRefAttr
984984getMallocInModule (ModuleOp mod, fir::AllocMemOp op,
985- mlir::ConversionPatternRewriter &rewriter) {
985+ mlir::ConversionPatternRewriter &rewriter, bool addr32 ) {
986986 static constexpr char mallocName[] = " malloc" ;
987987 if (auto mallocFunc =
988988 mod.template lookupSymbol <mlir::LLVM::LLVMFuncOp>(mallocName))
@@ -992,7 +992,7 @@ getMallocInModule(ModuleOp mod, fir::AllocMemOp op,
992992 return mlir::SymbolRefAttr::get (userMalloc);
993993
994994 mlir::OpBuilder moduleBuilder (mod.getBodyRegion ());
995- auto indexType = mlir::IntegerType::get (op.getContext (), 64 );
995+ auto indexType = mlir::IntegerType::get (op.getContext (), addr32 ? 32 : 64 );
996996 auto mallocDecl = moduleBuilder.create <mlir::LLVM::LLVMFuncOp>(
997997 op.getLoc (), mallocName,
998998 mlir::LLVM::LLVMFunctionType::get (getLlvmPtrType (op.getContext ()),
@@ -1002,12 +1002,13 @@ getMallocInModule(ModuleOp mod, fir::AllocMemOp op,
10021002}
10031003
10041004// / Return the LLVMFuncOp corresponding to the standard malloc call.
1005- static mlir::SymbolRefAttr
1006- getMalloc (fir::AllocMemOp op, mlir::ConversionPatternRewriter &rewriter) {
1005+ static mlir::SymbolRefAttr getMalloc (fir::AllocMemOp op,
1006+ mlir::ConversionPatternRewriter &rewriter,
1007+ bool addr32) {
10071008 if (auto mod = op->getParentOfType <mlir::gpu::GPUModuleOp>())
1008- return getMallocInModule (mod, op, rewriter);
1009+ return getMallocInModule (mod, op, rewriter, addr32 );
10091010 auto mod = op->getParentOfType <mlir::ModuleOp>();
1010- return getMallocInModule (mod, op, rewriter);
1011+ return getMallocInModule (mod, op, rewriter, addr32 );
10111012}
10121013
10131014// / Helper function for generating the LLVM IR that computes the distance
@@ -1057,6 +1058,7 @@ struct AllocMemOpConversion : public fir::FIROpConversion<fir::AllocMemOp> {
10571058 mlir::Type heapTy = heap.getType ();
10581059 mlir::Location loc = heap.getLoc ();
10591060 auto ity = lowerTy ().indexType ();
1061+ auto addr32 = lowerTy ().getPointerBitwidth (0 ) == 32 ;
10601062 mlir::Type dataTy = fir::unwrapRefType (heapTy);
10611063 mlir::Type llvmObjectTy = convertObjectType (dataTy);
10621064 if (fir::isRecordWithTypeParameters (fir::unwrapSequenceType (dataTy)))
@@ -1067,7 +1069,11 @@ struct AllocMemOpConversion : public fir::FIROpConversion<fir::AllocMemOp> {
10671069 for (mlir::Value opnd : adaptor.getOperands ())
10681070 size = rewriter.create <mlir::LLVM::MulOp>(
10691071 loc, ity, size, integerCast (loc, rewriter, ity, opnd));
1070- heap->setAttr (" callee" , getMalloc (heap, rewriter));
1072+ if (addr32) {
1073+ auto i32ty = mlir::IntegerType::get (rewriter.getContext (), 32 );
1074+ size = integerCast (loc, rewriter, i32ty, size);
1075+ }
1076+ heap->setAttr (" callee" , getMalloc (heap, rewriter, addr32));
10711077 rewriter.replaceOpWithNewOp <mlir::LLVM::CallOp>(
10721078 heap, ::getLlvmPtrType (heap.getContext ()), size,
10731079 addLLVMOpBundleAttrs (rewriter, heap->getAttrs (), 1 ));
0 commit comments