@@ -6137,6 +6137,28 @@ convertOmpGroupprivate(Operation &opInst, llvm::IRBuilderBase &builder,
61376137
61386138 if (failed (checkImplementationStatus (opInst)))
61396139 return failure ();
6140+
6141+ bool isTargetDevice = ompBuilder->Config .isTargetDevice ();
6142+ auto deviceType = groupprivateOp.getDeviceType ();
6143+
6144+ // skip allocation based on device_type
6145+ bool shouldAllocate = true ;
6146+ if (deviceType.has_value ()) {
6147+ switch (*deviceType) {
6148+ case mlir::omp::DeclareTargetDeviceType::host:
6149+ // Only allocate on host
6150+ shouldAllocate = !isTargetDevice;
6151+ break ;
6152+ case mlir::omp::DeclareTargetDeviceType::nohost:
6153+ // Only allocate on device
6154+ shouldAllocate = isTargetDevice;
6155+ break ;
6156+ case mlir::omp::DeclareTargetDeviceType::any:
6157+ // Allocate on both
6158+ shouldAllocate = true ;
6159+ break ;
6160+ }
6161+ }
61406162
61416163 Value symAddr = groupprivateOp.getSymAddr ();
61426164 auto *symOp = symAddr.getDefiningOp ();
@@ -6151,21 +6173,32 @@ convertOmpGroupprivate(Operation &opInst, llvm::IRBuilderBase &builder,
61516173 LLVM::GlobalOp global =
61526174 addressOfOp.getGlobal (moduleTranslation.symbolTable ());
61536175 llvm::GlobalValue *globalValue = moduleTranslation.lookupGlobal (global);
6176+ llvm::Value *resultPtr;
61546177
6155- // Get the size of the variable
6156- llvm::Type *varType = globalValue->getValueType ();
6157- llvm::Module *llvmModule = moduleTranslation.getLLVMModule ();
6158- llvm::DataLayout DL = llvmModule->getDataLayout ();
6159- uint64_t typeSize = DL.getTypeAllocSize (varType);
6160- // Call omp_alloc_shared to allocate memory for groupprivate variable.
6161- llvm::FunctionCallee allocSharedFn = ompBuilder->getOrCreateRuntimeFunction (
6162- *llvmModule, llvm::omp::OMPRTL___kmpc_alloc_shared);
6163- // Call runtime to allocate shared memory for this group
6164- llvm::Value *groupPrivatePtr =
6165- builder.CreateCall (allocSharedFn, {builder.getInt64 (typeSize)});
6166- groupPrivatePtr =
6167- builder.CreateBitCast (groupPrivatePtr, globalValue->getType ());
6168- moduleTranslation.mapValue (opInst.getResult (0 ), groupPrivatePtr);
6178+ if (shouldAllocate) {
6179+ // Get the size of the variable
6180+ llvm::Type *varType = globalValue->getValueType ();
6181+ llvm::Module *llvmModule = moduleTranslation.getLLVMModule ();
6182+ llvm::DataLayout DL = llvmModule->getDataLayout ();
6183+ uint64_t typeSize = DL.getTypeAllocSize (varType);
6184+ // Call omp_alloc_shared to allocate memory for groupprivate variable.
6185+ llvm::FunctionCallee allocSharedFn = ompBuilder->getOrCreateRuntimeFunction (
6186+ *llvmModule, llvm::omp::OMPRTL___kmpc_alloc_shared);
6187+ // Call runtime to allocate shared memory for this group
6188+ llvm::Value *groupPrivatePtr =
6189+ builder.CreateCall (allocSharedFn, {builder.getInt64 (typeSize)});
6190+ resultPtr =
6191+ builder.CreateBitCast (groupPrivatePtr, globalValue->getType ());
6192+ }
6193+ else {
6194+ // Use original global address when not allocating group-private storage
6195+ resultPtr = moduleTranslation.lookupValue (symAddr);
6196+ if (!resultPtr) {
6197+ // Fallback: create address-of for the global
6198+ resultPtr = builder.CreateBitCast (globalValue, globalValue->getType ());
6199+ }
6200+ }
6201+ moduleTranslation.mapValue (opInst.getResult (0 ), resultPtr);
61696202 return success ();
61706203}
61716204
0 commit comments