@@ -134,16 +134,65 @@ addLLVMOpBundleAttrs(mlir::ConversionPatternRewriter &rewriter,
134134}
135135
136136namespace {
137+
138+ // Creates an existing operation with an AddressOfOp or an AddrSpaceCastOp
139+ // depending on the existing address spaces of the type.
140+ mlir::Value createAddrOfOrASCast (mlir::ConversionPatternRewriter &rewriter,
141+ mlir::Location loc, std::uint64_t globalAS,
142+ std::uint64_t programAS,
143+ llvm::StringRef symName, mlir::Type type) {
144+ if (mlir::isa<mlir::LLVM::LLVMPointerType>(type)) {
145+ if (globalAS != programAS) {
146+ auto llvmAddrOp = rewriter.create <mlir::LLVM::AddressOfOp>(
147+ loc, getLlvmPtrType (rewriter.getContext (), globalAS), symName);
148+ return rewriter.create <mlir::LLVM::AddrSpaceCastOp>(
149+ loc, getLlvmPtrType (rewriter.getContext (), programAS), llvmAddrOp);
150+ }
151+ return rewriter.create <mlir::LLVM::AddressOfOp>(
152+ loc, getLlvmPtrType (rewriter.getContext (), globalAS), symName);
153+ }
154+ return rewriter.create <mlir::LLVM::AddressOfOp>(loc, type, symName);
155+ }
156+
157+ // Replaces an existing operation with an AddressOfOp or an AddrSpaceCastOp
158+ // depending on the existing address spaces of the type.
159+ mlir::Value replaceWithAddrOfOrASCast (mlir::ConversionPatternRewriter &rewriter,
160+ mlir::Location loc,
161+ std::uint64_t globalAS,
162+ std::uint64_t programAS,
163+ llvm::StringRef symName, mlir::Type type,
164+ mlir::Operation *replaceOp) {
165+ if (mlir::isa<mlir::LLVM::LLVMPointerType>(type)) {
166+ if (globalAS != programAS) {
167+ auto llvmAddrOp = rewriter.create <mlir::LLVM::AddressOfOp>(
168+ loc, getLlvmPtrType (rewriter.getContext (), globalAS), symName);
169+ return rewriter.replaceOpWithNewOp <mlir::LLVM::AddrSpaceCastOp>(
170+ replaceOp, ::getLlvmPtrType (rewriter.getContext (), programAS),
171+ llvmAddrOp);
172+ }
173+ return rewriter.replaceOpWithNewOp <mlir::LLVM::AddressOfOp>(
174+ replaceOp, getLlvmPtrType (rewriter.getContext (), globalAS), symName);
175+ }
176+ return rewriter.replaceOpWithNewOp <mlir::LLVM::AddressOfOp>(replaceOp, type,
177+ symName);
178+ }
179+
137180// / Lower `fir.address_of` operation to `llvm.address_of` operation.
138181struct AddrOfOpConversion : public fir ::FIROpConversion<fir::AddrOfOp> {
139182 using FIROpConversion::FIROpConversion;
140183
141184 llvm::LogicalResult
142185 matchAndRewrite (fir::AddrOfOp addr, OpAdaptor adaptor,
143186 mlir::ConversionPatternRewriter &rewriter) const override {
144- auto ty = convertType (addr.getType ());
145- rewriter.replaceOpWithNewOp <mlir::LLVM::AddressOfOp>(
146- addr, ty, addr.getSymbol ().getRootReference ().getValue ());
187+ auto global = addr->getParentOfType <mlir::ModuleOp>()
188+ .lookupSymbol <mlir::LLVM::GlobalOp>(addr.getSymbol ());
189+ replaceWithAddrOfOrASCast (
190+ rewriter, addr->getLoc (),
191+ global ? global.getAddrSpace () : getGlobalAddressSpace (rewriter),
192+ getProgramAddressSpace (rewriter),
193+ global ? global.getSymName ()
194+ : addr.getSymbol ().getRootReference ().getValue (),
195+ convertType (addr.getType ()), addr);
147196 return mlir::success ();
148197 }
149198};
@@ -1350,14 +1399,26 @@ struct EmboxCommonConversion : public fir::FIROpConversion<OP> {
13501399 ? fir::NameUniquer::getTypeDescriptorAssemblyName (recType.getName ())
13511400 : fir::NameUniquer::getTypeDescriptorName (recType.getName ());
13521401 mlir::Type llvmPtrTy = ::getLlvmPtrType (mod.getContext ());
1402+
1403+ // As we set allowDefaultLayout to true, there should be no chance the
1404+ // optional returns null even if the module has no layout information,
1405+ // however, assert just incase.
1406+ std::optional<mlir::DataLayout> dataLayout =
1407+ fir::support::getOrSetDataLayout (mod, /* allowDefaultLayout=*/ true );
1408+ assert (!dataLayout.has_value ());
1409+
13531410 if (auto global = mod.template lookupSymbol <fir::GlobalOp>(name)) {
1354- return rewriter.create <mlir::LLVM::AddressOfOp>(loc, llvmPtrTy,
1355- global.getSymName ());
1411+ return createAddrOfOrASCast (
1412+ rewriter, loc, fir::factory::getGlobalAddressSpace (&*dataLayout),
1413+ fir::factory::getProgramAddressSpace (&*dataLayout),
1414+ global.getSymName (), llvmPtrTy);
13561415 }
13571416 if (auto global = mod.template lookupSymbol <mlir::LLVM::GlobalOp>(name)) {
13581417 // The global may have already been translated to LLVM.
1359- return rewriter.create <mlir::LLVM::AddressOfOp>(loc, llvmPtrTy,
1360- global.getSymName ());
1418+ return createAddrOfOrASCast (
1419+ rewriter, loc, global.getAddrSpace (),
1420+ fir::factory::getProgramAddressSpace (&*dataLayout),
1421+ global.getSymName (), llvmPtrTy);
13611422 }
13621423 // Type info derived types do not have type descriptors since they are the
13631424 // types defining type descriptors.
@@ -2896,12 +2957,16 @@ struct TypeDescOpConversion : public fir::FIROpConversion<fir::TypeDescOp> {
28962957 : fir::NameUniquer::getTypeDescriptorName (recordType.getName ());
28972958 auto llvmPtrTy = ::getLlvmPtrType (typeDescOp.getContext ());
28982959 if (auto global = module .lookupSymbol <mlir::LLVM::GlobalOp>(typeDescName)) {
2899- rewriter.replaceOpWithNewOp <mlir::LLVM::AddressOfOp>(
2900- typeDescOp, llvmPtrTy, global.getSymName ());
2960+ replaceWithAddrOfOrASCast (rewriter, typeDescOp->getLoc (),
2961+ global.getAddrSpace (),
2962+ getProgramAddressSpace (rewriter),
2963+ global.getSymName (), llvmPtrTy, typeDescOp);
29012964 return mlir::success ();
29022965 } else if (auto global = module .lookupSymbol <fir::GlobalOp>(typeDescName)) {
2903- rewriter.replaceOpWithNewOp <mlir::LLVM::AddressOfOp>(
2904- typeDescOp, llvmPtrTy, global.getSymName ());
2966+ replaceWithAddrOfOrASCast (rewriter, typeDescOp->getLoc (),
2967+ getGlobalAddressSpace (rewriter),
2968+ getProgramAddressSpace (rewriter),
2969+ global.getSymName (), llvmPtrTy, typeDescOp);
29052970 return mlir::success ();
29062971 }
29072972 return mlir::failure ();
@@ -2992,8 +3057,8 @@ struct GlobalOpConversion : public fir::FIROpConversion<fir::GlobalOp> {
29923057 mlir::SymbolRefAttr comdat;
29933058 llvm::ArrayRef<mlir::NamedAttribute> attrs;
29943059 auto g = rewriter.create <mlir::LLVM::GlobalOp>(
2995- loc, tyAttr, isConst, linkage, global.getSymName (), initAttr, 0 , 0 ,
2996- false , false , comdat, attrs, dbgExprs);
3060+ loc, tyAttr, isConst, linkage, global.getSymName (), initAttr, 0 ,
3061+ getGlobalAddressSpace (rewriter), false , false , comdat, attrs, dbgExprs);
29973062
29983063 if (global.getAlignment () && *global.getAlignment () > 0 )
29993064 g.setAlignment (*global.getAlignment ());
0 commit comments