@@ -134,16 +134,65 @@ addLLVMOpBundleAttrs(mlir::ConversionPatternRewriter &rewriter,
134134}
135135
136136namespace {
137+
138+ // Creates an existing operation with an AddressOfOp or an AddrSpaceCastOp
139+ // depending on the existing address spaces of the type.
140+ mlir::Value createAddrOfOrASCast (mlir::ConversionPatternRewriter &rewriter,
141+ mlir::Location loc, std::uint64_t globalAS,
142+ std::uint64_t programAS,
143+ llvm::StringRef symName, mlir::Type type) {
144+ if (mlir::isa<mlir::LLVM::LLVMPointerType>(type)) {
145+ if (globalAS != programAS) {
146+ auto llvmAddrOp = rewriter.create <mlir::LLVM::AddressOfOp>(
147+ loc, getLlvmPtrType (rewriter.getContext (), globalAS), symName);
148+ return rewriter.create <mlir::LLVM::AddrSpaceCastOp>(
149+ loc, getLlvmPtrType (rewriter.getContext (), programAS), llvmAddrOp);
150+ }
151+ return rewriter.create <mlir::LLVM::AddressOfOp>(
152+ loc, getLlvmPtrType (rewriter.getContext (), globalAS), symName);
153+ }
154+ return rewriter.create <mlir::LLVM::AddressOfOp>(loc, type, symName);
155+ }
156+
157+ // Replaces an existing operation with an AddressOfOp or an AddrSpaceCastOp
158+ // depending on the existing address spaces of the type.
159+ mlir::Value replaceWithAddrOfOrASCast (mlir::ConversionPatternRewriter &rewriter,
160+ mlir::Location loc,
161+ std::uint64_t globalAS,
162+ std::uint64_t programAS,
163+ llvm::StringRef symName, mlir::Type type,
164+ mlir::Operation *replaceOp) {
165+ if (mlir::isa<mlir::LLVM::LLVMPointerType>(type)) {
166+ if (globalAS != programAS) {
167+ auto llvmAddrOp = rewriter.create <mlir::LLVM::AddressOfOp>(
168+ loc, getLlvmPtrType (rewriter.getContext (), globalAS), symName);
169+ return rewriter.replaceOpWithNewOp <mlir::LLVM::AddrSpaceCastOp>(
170+ replaceOp, ::getLlvmPtrType (rewriter.getContext (), programAS),
171+ llvmAddrOp);
172+ }
173+ return rewriter.replaceOpWithNewOp <mlir::LLVM::AddressOfOp>(
174+ replaceOp, getLlvmPtrType (rewriter.getContext (), globalAS), symName);
175+ }
176+ return rewriter.replaceOpWithNewOp <mlir::LLVM::AddressOfOp>(replaceOp, type,
177+ symName);
178+ }
179+
137180// / Lower `fir.address_of` operation to `llvm.address_of` operation.
138181struct AddrOfOpConversion : public fir ::FIROpConversion<fir::AddrOfOp> {
139182 using FIROpConversion::FIROpConversion;
140183
141184 llvm::LogicalResult
142185 matchAndRewrite (fir::AddrOfOp addr, OpAdaptor adaptor,
143186 mlir::ConversionPatternRewriter &rewriter) const override {
144- auto ty = convertType (addr.getType ());
145- rewriter.replaceOpWithNewOp <mlir::LLVM::AddressOfOp>(
146- addr, ty, addr.getSymbol ().getRootReference ().getValue ());
187+ auto global = addr->getParentOfType <mlir::ModuleOp>()
188+ .lookupSymbol <mlir::LLVM::GlobalOp>(addr.getSymbol ());
189+ replaceWithAddrOfOrASCast (
190+ rewriter, addr->getLoc (),
191+ global ? global.getAddrSpace () : getGlobalAddressSpace (rewriter),
192+ getProgramAddressSpace (rewriter),
193+ global ? global.getSymName ()
194+ : addr.getSymbol ().getRootReference ().getValue (),
195+ convertType (addr.getType ()), addr);
147196 return mlir::success ();
148197 }
149198};
@@ -1350,14 +1399,34 @@ struct EmboxCommonConversion : public fir::FIROpConversion<OP> {
13501399 ? fir::NameUniquer::getTypeDescriptorAssemblyName (recType.getName ())
13511400 : fir::NameUniquer::getTypeDescriptorName (recType.getName ());
13521401 mlir::Type llvmPtrTy = ::getLlvmPtrType (mod.getContext ());
1402+
1403+ // We allow the module to be set to a default layout if it's a regular module
1404+ // however, we prevent this if it's a GPU module, as the datalayout in these
1405+ // cases will currently be the union of the GPU Module and the parent builtin
1406+ // module, with the GPU module overriding the parent where there are duplicates.
1407+ // However, if we force the default layout onto a GPU module, with no datalayout
1408+ // it'll result in issues as the infrastructure does not support the union of
1409+ // two layouts with builtin data layout entries currently (and it doesn't look
1410+ // like it was intended to).
1411+ std::optional<mlir::DataLayout> dataLayout =
1412+ fir::support::getOrSetMLIRDataLayout (
1413+ mod, /* allowDefaultLayout*/ mlir::isa<mlir::gpu::GPUModuleOp>(mod)
1414+ ? false
1415+ : true );
1416+ assert (dataLayout.has_value () && " Module missing DataLayout information" );
1417+
13531418 if (auto global = mod.template lookupSymbol <fir::GlobalOp>(name)) {
1354- return rewriter.create <mlir::LLVM::AddressOfOp>(loc, llvmPtrTy,
1355- global.getSymName ());
1419+ return createAddrOfOrASCast (
1420+ rewriter, loc, fir::factory::getGlobalAddressSpace (&*dataLayout),
1421+ fir::factory::getProgramAddressSpace (&*dataLayout),
1422+ global.getSymName (), llvmPtrTy);
13561423 }
13571424 if (auto global = mod.template lookupSymbol <mlir::LLVM::GlobalOp>(name)) {
13581425 // The global may have already been translated to LLVM.
1359- return rewriter.create <mlir::LLVM::AddressOfOp>(loc, llvmPtrTy,
1360- global.getSymName ());
1426+ return createAddrOfOrASCast (
1427+ rewriter, loc, global.getAddrSpace (),
1428+ fir::factory::getProgramAddressSpace (&*dataLayout),
1429+ global.getSymName (), llvmPtrTy);
13611430 }
13621431 // Type info derived types do not have type descriptors since they are the
13631432 // types defining type descriptors.
@@ -2896,12 +2965,16 @@ struct TypeDescOpConversion : public fir::FIROpConversion<fir::TypeDescOp> {
28962965 : fir::NameUniquer::getTypeDescriptorName (recordType.getName ());
28972966 auto llvmPtrTy = ::getLlvmPtrType (typeDescOp.getContext ());
28982967 if (auto global = module .lookupSymbol <mlir::LLVM::GlobalOp>(typeDescName)) {
2899- rewriter.replaceOpWithNewOp <mlir::LLVM::AddressOfOp>(
2900- typeDescOp, llvmPtrTy, global.getSymName ());
2968+ replaceWithAddrOfOrASCast (rewriter, typeDescOp->getLoc (),
2969+ global.getAddrSpace (),
2970+ getProgramAddressSpace (rewriter),
2971+ global.getSymName (), llvmPtrTy, typeDescOp);
29012972 return mlir::success ();
29022973 } else if (auto global = module .lookupSymbol <fir::GlobalOp>(typeDescName)) {
2903- rewriter.replaceOpWithNewOp <mlir::LLVM::AddressOfOp>(
2904- typeDescOp, llvmPtrTy, global.getSymName ());
2974+ replaceWithAddrOfOrASCast (rewriter, typeDescOp->getLoc (),
2975+ getGlobalAddressSpace (rewriter),
2976+ getProgramAddressSpace (rewriter),
2977+ global.getSymName (), llvmPtrTy, typeDescOp);
29052978 return mlir::success ();
29062979 }
29072980 return mlir::failure ();
@@ -2992,8 +3065,8 @@ struct GlobalOpConversion : public fir::FIROpConversion<fir::GlobalOp> {
29923065 mlir::SymbolRefAttr comdat;
29933066 llvm::ArrayRef<mlir::NamedAttribute> attrs;
29943067 auto g = rewriter.create <mlir::LLVM::GlobalOp>(
2995- loc, tyAttr, isConst, linkage, global.getSymName (), initAttr, 0 , 0 ,
2996- false , false , comdat, attrs, dbgExprs);
3068+ loc, tyAttr, isConst, linkage, global.getSymName (), initAttr, 0 ,
3069+ getGlobalAddressSpace (rewriter), false , false , comdat, attrs, dbgExprs);
29973070
29983071 if (global.getAlignment () && *global.getAlignment () > 0 )
29993072 g.setAlignment (*global.getAlignment ());
0 commit comments