@@ -5994,7 +5994,7 @@ static bool isTargetDeviceOp(Operation *op) {
59945994 // by taking it in as an operand, so we must always lower these in
59955995 // some manner or result in an ICE (whether they end up in a no-op
59965996 // or otherwise).
5997- if (mlir::isa<omp::ThreadprivateOp>(op))
5997+ if (mlir::isa<omp::ThreadprivateOp, omp::GroupprivateOp >(op))
59985998 return true ;
59995999
60006000 if (mlir::isa<omp::TargetAllocMemOp>(op) ||
@@ -6095,8 +6095,7 @@ convertTargetFreeMemOp(Operation &opInst, llvm::IRBuilderBase &builder,
60956095// / Converts an OpenMP Groupprivate operation into LLVM IR.
60966096static LogicalResult
60976097convertOmpGroupprivate (Operation &opInst, llvm::IRBuilderBase &builder,
6098- LLVM::ModuleTranslation &moduleTranslation) {
6099- llvm::OpenMPIRBuilder::LocationDescription ompLoc (builder);
6098+ LLVM::ModuleTranslation &moduleTranslation) {
61006099 llvm::OpenMPIRBuilder *ompBuilder = moduleTranslation.getOpenMPBuilder ();
61016100 auto groupprivateOp = cast<omp::GroupprivateOp>(opInst);
61026101
@@ -6117,12 +6116,20 @@ convertOmpGroupprivate(Operation &opInst, llvm::IRBuilderBase &builder,
61176116 addressOfOp.getGlobal (moduleTranslation.symbolTable ());
61186117 llvm::GlobalValue *globalValue = moduleTranslation.lookupGlobal (global);
61196118
6120- if (!ompBuilder->Config .isTargetDevice ()) {
6121- llvm_unreachable (" NYI" );
6122- } else {
6123- moduleTranslation.mapValue (opInst.getResult (0 ), globalValue);
6124- }
6125-
6119+ // Get the size of the variable
6120+ llvm::Type *varType = globalValue->getValueType ();
6121+ llvm::Module *llvmModule = moduleTranslation.getLLVMModule ();
6122+ llvm::DataLayout DL = llvmModule->getDataLayout ();
6123+ uint64_t typeSize = DL.getTypeAllocSize (varType);
6124+ // Call omp_alloc_shared to allocate memory for groupprivate variable.
6125+ llvm::FunctionCallee allocSharedFn = ompBuilder->getOrCreateRuntimeFunction (
6126+ *llvmModule, llvm::omp::OMPRTL___kmpc_alloc_shared);
6127+ // Call runtime to allocate shared memory for this group
6128+ llvm::Value *groupPrivatePtr =
6129+ builder.CreateCall (allocSharedFn, {builder.getInt64 (typeSize)});
6130+ groupPrivatePtr =
6131+ builder.CreateBitCast (groupPrivatePtr, globalValue->getType ());
6132+ moduleTranslation.mapValue (opInst.getResult (0 ), groupPrivatePtr);
61266133 return success ();
61276134}
61286135
0 commit comments