@@ -132,16 +132,54 @@ addLLVMOpBundleAttrs(mlir::ConversionPatternRewriter &rewriter,
132132}
133133
134134namespace {
135+
136+ mlir::Value replaceWithAddrOfOrASCast (mlir::ConversionPatternRewriter &rewriter,
137+ mlir::Location loc,
138+ std::uint64_t globalAS,
139+ std::uint64_t programAS,
140+ llvm::StringRef symName, mlir::Type type,
141+ mlir::Operation *replaceOp = nullptr ) {
142+ if (mlir::isa<mlir::LLVM::LLVMPointerType>(type)) {
143+ if (globalAS != programAS) {
144+ auto llvmAddrOp = rewriter.create <mlir::LLVM::AddressOfOp>(
145+ loc, getLlvmPtrType (rewriter.getContext (), globalAS), symName);
146+ if (replaceOp)
147+ return rewriter.replaceOpWithNewOp <mlir::LLVM::AddrSpaceCastOp>(
148+ replaceOp, ::getLlvmPtrType (rewriter.getContext (), programAS),
149+ llvmAddrOp);
150+ return rewriter.create <mlir::LLVM::AddrSpaceCastOp>(
151+ loc, getLlvmPtrType (rewriter.getContext (), programAS), llvmAddrOp);
152+ }
153+
154+ if (replaceOp)
155+ return rewriter.replaceOpWithNewOp <mlir::LLVM::AddressOfOp>(
156+ replaceOp, getLlvmPtrType (rewriter.getContext (), globalAS), symName);
157+ return rewriter.create <mlir::LLVM::AddressOfOp>(
158+ loc, getLlvmPtrType (rewriter.getContext (), globalAS), symName);
159+ }
160+
161+ if (replaceOp)
162+ return rewriter.replaceOpWithNewOp <mlir::LLVM::AddressOfOp>(replaceOp, type,
163+ symName);
164+ return rewriter.create <mlir::LLVM::AddressOfOp>(loc, type, symName);
165+ }
166+
135167// / Lower `fir.address_of` operation to `llvm.address_of` operation.
136168struct AddrOfOpConversion : public fir ::FIROpConversion<fir::AddrOfOp> {
137169 using FIROpConversion::FIROpConversion;
138170
139171 llvm::LogicalResult
140172 matchAndRewrite (fir::AddrOfOp addr, OpAdaptor adaptor,
141173 mlir::ConversionPatternRewriter &rewriter) const override {
142- auto ty = convertType (addr.getType ());
143- rewriter.replaceOpWithNewOp <mlir::LLVM::AddressOfOp>(
144- addr, ty, addr.getSymbol ().getRootReference ().getValue ());
174+ auto global = addr->getParentOfType <mlir::ModuleOp>()
175+ .lookupSymbol <mlir::LLVM::GlobalOp>(addr.getSymbol ());
176+ replaceWithAddrOfOrASCast (
177+ rewriter, addr->getLoc (),
178+ global ? global.getAddrSpace () : getGlobalAddressSpace (rewriter),
179+ getProgramAddressSpace (rewriter),
180+ global ? global.getSymName ()
181+ : addr.getSymbol ().getRootReference ().getValue (),
182+ convertType (addr.getType ()), addr);
145183 return mlir::success ();
146184 }
147185};
@@ -1255,14 +1293,19 @@ struct EmboxCommonConversion : public fir::FIROpConversion<OP> {
12551293 ? fir::NameUniquer::getTypeDescriptorAssemblyName (recType.getName ())
12561294 : fir::NameUniquer::getTypeDescriptorName (recType.getName ());
12571295 mlir::Type llvmPtrTy = ::getLlvmPtrType (mod.getContext ());
1296+ mlir::DataLayout dataLayout (mod);
12581297 if (auto global = mod.template lookupSymbol <fir::GlobalOp>(name)) {
1259- return rewriter.create <mlir::LLVM::AddressOfOp>(loc, llvmPtrTy,
1260- global.getSymName ());
1298+ return replaceWithAddrOfOrASCast (
1299+ rewriter, loc, fir::factory::getGlobalAddressSpace (&dataLayout),
1300+ fir::factory::getProgramAddressSpace (&dataLayout),
1301+ global.getSymName (), llvmPtrTy);
12611302 }
12621303 if (auto global = mod.template lookupSymbol <mlir::LLVM::GlobalOp>(name)) {
12631304 // The global may have already been translated to LLVM.
1264- return rewriter.create <mlir::LLVM::AddressOfOp>(loc, llvmPtrTy,
1265- global.getSymName ());
1305+ return replaceWithAddrOfOrASCast (
1306+ rewriter, loc, global.getAddrSpace (),
1307+ fir::factory::getProgramAddressSpace (&dataLayout),
1308+ global.getSymName (), llvmPtrTy);
12661309 }
12671310 // Type info derived types do not have type descriptors since they are the
12681311 // types defining type descriptors.
@@ -2759,12 +2802,16 @@ struct TypeDescOpConversion : public fir::FIROpConversion<fir::TypeDescOp> {
27592802 : fir::NameUniquer::getTypeDescriptorName (recordType.getName ());
27602803 auto llvmPtrTy = ::getLlvmPtrType (typeDescOp.getContext ());
27612804 if (auto global = module .lookupSymbol <mlir::LLVM::GlobalOp>(typeDescName)) {
2762- rewriter.replaceOpWithNewOp <mlir::LLVM::AddressOfOp>(
2763- typeDescOp, llvmPtrTy, global.getSymName ());
2805+ replaceWithAddrOfOrASCast (rewriter, typeDescOp->getLoc (),
2806+ global.getAddrSpace (),
2807+ getProgramAddressSpace (rewriter),
2808+ global.getSymName (), llvmPtrTy, typeDescOp);
27642809 return mlir::success ();
27652810 } else if (auto global = module .lookupSymbol <fir::GlobalOp>(typeDescName)) {
2766- rewriter.replaceOpWithNewOp <mlir::LLVM::AddressOfOp>(
2767- typeDescOp, llvmPtrTy, global.getSymName ());
2811+ replaceWithAddrOfOrASCast (rewriter, typeDescOp->getLoc (),
2812+ getGlobalAddressSpace (rewriter),
2813+ getProgramAddressSpace (rewriter),
2814+ global.getSymName (), llvmPtrTy, typeDescOp);
27682815 return mlir::success ();
27692816 }
27702817 return mlir::failure ();
@@ -2855,8 +2902,8 @@ struct GlobalOpConversion : public fir::FIROpConversion<fir::GlobalOp> {
28552902 mlir::SymbolRefAttr comdat;
28562903 llvm::ArrayRef<mlir::NamedAttribute> attrs;
28572904 auto g = rewriter.create <mlir::LLVM::GlobalOp>(
2858- loc, tyAttr, isConst, linkage, global.getSymName (), initAttr, 0 , 0 ,
2859- false , false , comdat, attrs, dbgExprs);
2905+ loc, tyAttr, isConst, linkage, global.getSymName (), initAttr, 0 ,
2906+ getGlobalAddressSpace (rewriter), false , false , comdat, attrs, dbgExprs);
28602907
28612908 if (global.getAlignment () && *global.getAlignment () > 0 )
28622909 g.setAlignment (*global.getAlignment ());
0 commit comments