@@ -3564,6 +3564,26 @@ convertOmpCancellationPoint(omp::CancellationPointOp op,
35643564 return success ();
35653565}
35663566
3567+ static LLVM::GlobalOp
3568+ getGlobalFromSymbol (Operation *symOp,
3569+ LLVM::ModuleTranslation &moduleTranslation,
3570+ Operation *opInst) {
3571+
3572+ // Handle potential address space cast
3573+ if (auto asCast = dyn_cast<LLVM::AddrSpaceCastOp>(symOp))
3574+ symOp = asCast.getOperand ().getDefiningOp ();
3575+
3576+ // Check if we have an AddressOfOp
3577+ if (!isa<LLVM::AddressOfOp>(symOp)) {
3578+ if (opInst)
3579+ opInst->emitError (" Addressing symbol not found" );
3580+ return nullptr ;
3581+ }
3582+
3583+ LLVM::AddressOfOp addressOfOp = cast<LLVM::AddressOfOp>(symOp);
3584+ return addressOfOp.getGlobal (moduleTranslation.symbolTable ());
3585+ }
3586+
35673587// / Converts an OpenMP Threadprivate operation into LLVM IR using
35683588// / OpenMPIRBuilder.
35693589static LogicalResult
@@ -3579,15 +3599,10 @@ convertOmpThreadprivate(Operation &opInst, llvm::IRBuilderBase &builder,
35793599 Value symAddr = threadprivateOp.getSymAddr ();
35803600 auto *symOp = symAddr.getDefiningOp ();
35813601
3582- if (auto asCast = dyn_cast<LLVM::AddrSpaceCastOp>(symOp))
3583- symOp = asCast.getOperand ().getDefiningOp ();
3584-
3585- if (!isa<LLVM::AddressOfOp>(symOp))
3586- return opInst.emitError (" Addressing symbol not found" );
3587- LLVM::AddressOfOp addressOfOp = dyn_cast<LLVM::AddressOfOp>(symOp);
3588-
35893602 LLVM::GlobalOp global =
3590- addressOfOp.getGlobal (moduleTranslation.symbolTable ());
3603+ getGlobalFromSymbol (symOp, moduleTranslation, &opInst);
3604+ if (!global)
3605+ return failure ();
35913606 llvm::GlobalValue *globalValue = moduleTranslation.lookupGlobal (global);
35923607
35933608 if (!ompBuilder->Config .isTargetDevice ()) {
@@ -6144,17 +6159,13 @@ convertOmpGroupprivate(Operation &opInst, llvm::IRBuilderBase &builder,
61446159 }
61456160
61466161 Value symAddr = groupprivateOp.getSymAddr ();
6147- auto *symOp = symAddr.getDefiningOp ();
6148-
6149- if (auto asCast = dyn_cast<LLVM::AddrSpaceCastOp>(symOp))
6150- symOp = asCast.getOperand ().getDefiningOp ();
6151-
6152- if (!isa<LLVM::AddressOfOp>(symOp))
6153- return opInst.emitError (" Addressing symbol not found" );
6154- LLVM::AddressOfOp addressOfOp = dyn_cast<LLVM::AddressOfOp>(symOp);
6162+ Operation *symOp = symAddr.getDefiningOp ();
61556163
61566164 LLVM::GlobalOp global =
6157- addressOfOp.getGlobal (moduleTranslation.symbolTable ());
6165+ getGlobalFromSymbol (symOp, moduleTranslation, &opInst);
6166+ if (!global)
6167+ return failure ();
6168+
61586169 llvm::GlobalValue *globalValue = moduleTranslation.lookupGlobal (global);
61596170 llvm::Value *resultPtr;
61606171
@@ -6175,6 +6186,11 @@ convertOmpGroupprivate(Operation &opInst, llvm::IRBuilderBase &builder,
61756186 if (!resultPtr)
61766187 resultPtr = globalValue;
61776188 }
6189+
6190+ llvm::Type *ptrTy = builder.getPtrTy ();
6191+ if (resultPtr->getType () != ptrTy)
6192+ resultPtr = builder.CreateBitCast (resultPtr, ptrTy);
6193+
61786194 moduleTranslation.mapValue (opInst.getResult (0 ), resultPtr);
61796195 return success ();
61806196}
0 commit comments