@@ -137,16 +137,54 @@ addLLVMOpBundleAttrs(mlir::ConversionPatternRewriter &rewriter,
137137}
138138
139139namespace {
140+
141+ mlir::Value replaceWithAddrOfOrASCast (mlir::ConversionPatternRewriter &rewriter,
142+ mlir::Location loc,
143+ std::uint64_t globalAS,
144+ std::uint64_t programAS,
145+ llvm::StringRef symName, mlir::Type type,
146+ mlir::Operation *replaceOp = nullptr ) {
147+ if (mlir::isa<mlir::LLVM::LLVMPointerType>(type)) {
148+ if (globalAS != programAS) {
149+ auto llvmAddrOp = rewriter.create <mlir::LLVM::AddressOfOp>(
150+ loc, getLlvmPtrType (rewriter.getContext (), globalAS), symName);
151+ if (replaceOp)
152+ return rewriter.replaceOpWithNewOp <mlir::LLVM::AddrSpaceCastOp>(
153+ replaceOp, ::getLlvmPtrType (rewriter.getContext (), programAS),
154+ llvmAddrOp);
155+ return rewriter.create <mlir::LLVM::AddrSpaceCastOp>(
156+ loc, getLlvmPtrType (rewriter.getContext (), programAS), llvmAddrOp);
157+ }
158+
159+ if (replaceOp)
160+ return rewriter.replaceOpWithNewOp <mlir::LLVM::AddressOfOp>(
161+ replaceOp, getLlvmPtrType (rewriter.getContext (), globalAS), symName);
162+ return rewriter.create <mlir::LLVM::AddressOfOp>(
163+ loc, getLlvmPtrType (rewriter.getContext (), globalAS), symName);
164+ }
165+
166+ if (replaceOp)
167+ return rewriter.replaceOpWithNewOp <mlir::LLVM::AddressOfOp>(replaceOp, type,
168+ symName);
169+ return rewriter.create <mlir::LLVM::AddressOfOp>(loc, type, symName);
170+ }
171+
140172// / Lower `fir.address_of` operation to `llvm.address_of` operation.
141173struct AddrOfOpConversion : public fir ::FIROpConversion<fir::AddrOfOp> {
142174 using FIROpConversion::FIROpConversion;
143175
144176 llvm::LogicalResult
145177 matchAndRewrite (fir::AddrOfOp addr, OpAdaptor adaptor,
146178 mlir::ConversionPatternRewriter &rewriter) const override {
147- auto ty = convertType (addr.getType ());
148- rewriter.replaceOpWithNewOp <mlir::LLVM::AddressOfOp>(
149- addr, ty, addr.getSymbol ().getRootReference ().getValue ());
179+ auto global = addr->getParentOfType <mlir::ModuleOp>()
180+ .lookupSymbol <mlir::LLVM::GlobalOp>(addr.getSymbol ());
181+ replaceWithAddrOfOrASCast (
182+ rewriter, addr->getLoc (),
183+ global ? global.getAddrSpace () : getGlobalAddressSpace (rewriter),
184+ getProgramAddressSpace (rewriter),
185+ global ? global.getSymName ()
186+ : addr.getSymbol ().getRootReference ().getValue (),
187+ convertType (addr.getType ()), addr);
150188 return mlir::success ();
151189 }
152190};
@@ -1306,13 +1344,18 @@ getTypeDescriptor(ModOpTy mod, mlir::ConversionPatternRewriter &rewriter,
13061344 ? fir::NameUniquer::getTypeDescriptorAssemblyName (recType.getName ())
13071345 : fir::NameUniquer::getTypeDescriptorName (recType.getName ());
13081346 mlir::Type llvmPtrTy = ::getLlvmPtrType (mod.getContext ());
1347+ mlir::DataLayout dataLayout (mod);
13091348 if (auto global = mod.template lookupSymbol <fir::GlobalOp>(name))
1310- return rewriter.create <mlir::LLVM::AddressOfOp>(loc, llvmPtrTy,
1311- global.getSymName ());
1349+ return replaceWithAddrOfOrASCast (
1350+ rewriter, loc, fir::factory::getGlobalAddressSpace (&dataLayout),
1351+ fir::factory::getProgramAddressSpace (&dataLayout), global.getSymName (),
1352+ llvmPtrTy);
13121353 // The global may have already been translated to LLVM.
13131354 if (auto global = mod.template lookupSymbol <mlir::LLVM::GlobalOp>(name))
1314- return rewriter.create <mlir::LLVM::AddressOfOp>(loc, llvmPtrTy,
1315- global.getSymName ());
1355+ return replaceWithAddrOfOrASCast (
1356+ rewriter, loc, global.getAddrSpace (),
1357+ fir::factory::getProgramAddressSpace (&dataLayout), global.getSymName (),
1358+ llvmPtrTy);
13161359 // Type info derived types do not have type descriptors since they are the
13171360 // types defining type descriptors.
13181361 if (options.ignoreMissingTypeDescriptors ||
@@ -3130,8 +3173,8 @@ struct GlobalOpConversion : public fir::FIROpConversion<fir::GlobalOp> {
31303173 mlir::SymbolRefAttr comdat;
31313174 llvm::ArrayRef<mlir::NamedAttribute> attrs;
31323175 auto g = rewriter.create <mlir::LLVM::GlobalOp>(
3133- loc, tyAttr, isConst, linkage, global.getSymName (), initAttr, 0 , 0 ,
3134- false , false , comdat, attrs, dbgExprs);
3176+ loc, tyAttr, isConst, linkage, global.getSymName (), initAttr, 0 ,
3177+ getGlobalAddressSpace (rewriter), false , false , comdat, attrs, dbgExprs);
31353178
31363179 if (global.getAlignment () && *global.getAlignment () > 0 )
31373180 g.setAlignment (*global.getAlignment ());
0 commit comments