@@ -3564,6 +3564,26 @@ convertOmpCancellationPoint(omp::CancellationPointOp op,
35643564 return success ();
35653565}
35663566
3567+ static LLVM::GlobalOp
3568+ getGlobalFromSymbol (Operation *symOp,
3569+ LLVM::ModuleTranslation &moduleTranslation,
3570+ Operation *opInst) {
3571+
3572+ // Handle potential address space cast
3573+ if (auto asCast = dyn_cast<LLVM::AddrSpaceCastOp>(symOp))
3574+ symOp = asCast.getOperand ().getDefiningOp ();
3575+
3576+ // Check if we have an AddressOfOp
3577+ if (!isa<LLVM::AddressOfOp>(symOp)) {
3578+ if (opInst)
3579+ opInst->emitError (" Addressing symbol not found" );
3580+ return nullptr ;
3581+ }
3582+
3583+ LLVM::AddressOfOp addressOfOp = cast<LLVM::AddressOfOp>(symOp);
3584+ return addressOfOp.getGlobal (moduleTranslation.symbolTable ());
3585+ }
3586+
35673587// / Converts an OpenMP Threadprivate operation into LLVM IR using
35683588// / OpenMPIRBuilder.
35693589static LogicalResult
@@ -3579,15 +3599,10 @@ convertOmpThreadprivate(Operation &opInst, llvm::IRBuilderBase &builder,
35793599 Value symAddr = threadprivateOp.getSymAddr ();
35803600 auto *symOp = symAddr.getDefiningOp ();
35813601
3582- if (auto asCast = dyn_cast<LLVM::AddrSpaceCastOp>(symOp))
3583- symOp = asCast.getOperand ().getDefiningOp ();
3584-
3585- if (!isa<LLVM::AddressOfOp>(symOp))
3586- return opInst.emitError (" Addressing symbol not found" );
3587- LLVM::AddressOfOp addressOfOp = dyn_cast<LLVM::AddressOfOp>(symOp);
3588-
35893602 LLVM::GlobalOp global =
3590- addressOfOp.getGlobal (moduleTranslation.symbolTable ());
3603+ getGlobalFromSymbol (symOp, moduleTranslation, &opInst);
3604+ if (!global)
3605+ return failure ();
35913606 llvm::GlobalValue *globalValue = moduleTranslation.lookupGlobal (global);
35923607
35933608 if (!ompBuilder->Config .isTargetDevice ()) {
@@ -6161,17 +6176,13 @@ convertOmpGroupprivate(Operation &opInst, llvm::IRBuilderBase &builder,
61616176 }
61626177
61636178 Value symAddr = groupprivateOp.getSymAddr ();
6164- auto *symOp = symAddr.getDefiningOp ();
6165-
6166- if (auto asCast = dyn_cast<LLVM::AddrSpaceCastOp>(symOp))
6167- symOp = asCast.getOperand ().getDefiningOp ();
6168-
6169- if (!isa<LLVM::AddressOfOp>(symOp))
6170- return opInst.emitError (" Addressing symbol not found" );
6171- LLVM::AddressOfOp addressOfOp = dyn_cast<LLVM::AddressOfOp>(symOp);
6179+ Operation *symOp = symAddr.getDefiningOp ();
61726180
61736181 LLVM::GlobalOp global =
6174- addressOfOp.getGlobal (moduleTranslation.symbolTable ());
6182+ getGlobalFromSymbol (symOp, moduleTranslation, &opInst);
6183+ if (!global)
6184+ return failure ();
6185+
61756186 llvm::GlobalValue *globalValue = moduleTranslation.lookupGlobal (global);
61766187 llvm::Value *resultPtr;
61776188
@@ -6192,6 +6203,11 @@ convertOmpGroupprivate(Operation &opInst, llvm::IRBuilderBase &builder,
61926203 if (!resultPtr)
61936204 resultPtr = globalValue;
61946205 }
6206+
6207+ llvm::Type *ptrTy = builder.getPtrTy ();
6208+ if (resultPtr->getType () != ptrTy)
6209+ resultPtr = builder.CreateBitCast (resultPtr, ptrTy);
6210+
61956211 moduleTranslation.mapValue (opInst.getResult (0 ), resultPtr);
61966212 return success ();
61976213}
0 commit comments