@@ -1128,9 +1128,10 @@ struct DeferredStore {
11281128} // namespace
11291129
11301130// / Check whether allocations for the given operation might potentially have to
1131- // / be done in device shared memory. That means we're compiling for a offloading
1132- // / target, the operation is an `omp::TargetOp` or nested inside of one and that
1133- // / target region represents a Generic (non-SPMD) kernel.
1131+ // / be done in device shared memory. That means we're compiling for an
1132+ // / offloading target, the operation is neither an `omp::TargetOp` nor nested
1133+ // / inside of one, or it is and that target region represents a Generic
1134+ // / (non-SPMD) kernel.
11341135// /
11351136// / This represents a necessary but not sufficient set of conditions to use
11361137// / device shared memory in place of regular allocas. For some variables, the
@@ -1146,7 +1147,7 @@ mightAllocInDeviceSharedMemory(Operation &op,
11461147 if (!targetOp)
11471148 targetOp = op.getParentOfType <omp::TargetOp>();
11481149
1149- return targetOp &&
1150+ return ! targetOp ||
11501151 targetOp.getKernelExecFlags (targetOp.getInnermostCapturedOmpOp ()) ==
11511152 omp::TargetExecMode::generic;
11521153}
@@ -1160,18 +1161,36 @@ mightAllocInDeviceSharedMemory(Operation &op,
11601161// / operation that owns the specified block argument.
11611162static bool mustAllocPrivateVarInDeviceSharedMemory (BlockArgument value) {
11621163 Operation *parentOp = value.getOwner ()->getParentOp ();
1163- auto targetOp = dyn_cast<omp::TargetOp>(parentOp);
1164- if (!targetOp)
1165- targetOp = parentOp->getParentOfType <omp::TargetOp>();
1166- assert (targetOp && " expected a parent omp.target operation" );
1167-
1164+ auto moduleOp = parentOp->getParentOfType <ModuleOp>();
11681165 for (auto *user : value.getUsers ()) {
11691166 if (auto parallelOp = dyn_cast<omp::ParallelOp>(user)) {
11701167 if (llvm::is_contained (parallelOp.getReductionVars (), value))
11711168 return true ;
11721169 } else if (auto parallelOp = user->getParentOfType <omp::ParallelOp>()) {
1173- if (parentOp->isProperAncestor (parallelOp))
1174- return true ;
1170+ if (parentOp->isProperAncestor (parallelOp)) {
1171+ // If it is used directly inside of a parallel region, skip private
1172+ // clause uses.
1173+ bool isPrivateClauseUse = false ;
1174+ if (auto argIface = dyn_cast<omp::BlockArgOpenMPOpInterface>(user)) {
1175+ if (auto privateSyms = llvm::cast_or_null<ArrayAttr>(
1176+ user->getAttr (" private_syms" ))) {
1177+ for (auto [var, sym] :
1178+ llvm::zip_equal (argIface.getPrivateVars (), privateSyms)) {
1179+ if (var != value)
1180+ continue ;
1181+
1182+ auto privateOp = cast<omp::PrivateClauseOp>(
1183+ moduleOp.lookupSymbol (cast<SymbolRefAttr>(sym)));
1184+ if (privateOp.getCopyRegion ().empty ()) {
1185+ isPrivateClauseUse = true ;
1186+ break ;
1187+ }
1188+ }
1189+ }
1190+ }
1191+ if (!isPrivateClauseUse)
1192+ return true ;
1193+ }
11751194 }
11761195 }
11771196
@@ -1196,8 +1215,8 @@ allocReductionVars(T op, ArrayRef<BlockArgument> reductionArgs,
11961215 builder.SetInsertPoint (allocaIP.getBlock ()->getTerminator ());
11971216
11981217 llvm::OpenMPIRBuilder *ompBuilder = moduleTranslation.getOpenMPBuilder ();
1199- bool useDeviceSharedMem =
1200- isa<omp::TeamsOp>(op) && mightAllocInDeviceSharedMemory (*op, *ompBuilder);
1218+ bool useDeviceSharedMem = isa<omp::TeamsOp>(*op) &&
1219+ mightAllocInDeviceSharedMemory (*op, *ompBuilder);
12011220
12021221 // delay creating stores until after all allocas
12031222 deferredStores.reserve (op.getNumReductionVars ());
@@ -1318,8 +1337,8 @@ initReductionVars(OP op, ArrayRef<BlockArgument> reductionArgs,
13181337 return success ();
13191338
13201339 llvm::OpenMPIRBuilder *ompBuilder = moduleTranslation.getOpenMPBuilder ();
1321- bool useDeviceSharedMem =
1322- isa<omp::TeamsOp>(op) && mightAllocInDeviceSharedMemory (*op, *ompBuilder);
1340+ bool useDeviceSharedMem = isa<omp::TeamsOp>(*op) &&
1341+ mightAllocInDeviceSharedMemory (*op, *ompBuilder);
13231342
13241343 llvm::BasicBlock *initBlock = splitBB (builder, true , " omp.reduction.init" );
13251344 auto allocaIP = llvm::IRBuilderBase::InsertPoint (
@@ -1540,8 +1559,8 @@ static LogicalResult createReductionsAndCleanup(
15401559 reductionRegions, privateReductionVariables, moduleTranslation, builder,
15411560 " omp.reduction.cleanup" );
15421561
1543- bool useDeviceSharedMem =
1544- isa<omp::TeamsOp>(op) && mightAllocInDeviceSharedMemory (*op, *ompBuilder);
1562+ bool useDeviceSharedMem = isa<omp::TeamsOp>(*op) &&
1563+ mightAllocInDeviceSharedMemory (*op, *ompBuilder);
15451564 if (useDeviceSharedMem) {
15461565 for (auto [var, reductionDecl] :
15471566 llvm::zip_equal (privateReductionVariables, reductionDecls))
@@ -1721,7 +1740,7 @@ allocatePrivateVars(T op, llvm::IRBuilderBase &builder,
17211740
17221741 llvm::OpenMPIRBuilder *ompBuilder = moduleTranslation.getOpenMPBuilder ();
17231742 bool mightUseDeviceSharedMem =
1724- isa<omp::TeamsOp, omp::DistributeOp>(*op) &&
1743+ isa<omp::TargetOp, omp:: TeamsOp, omp::DistributeOp>(*op) &&
17251744 mightAllocInDeviceSharedMemory (*op, *ompBuilder);
17261745 unsigned int allocaAS =
17271746 moduleTranslation.getLLVMModule ()->getDataLayout ().getAllocaAddrSpace ();
@@ -1839,7 +1858,7 @@ cleanupPrivateVars(T op, llvm::IRBuilderBase &builder,
18391858
18401859 llvm::OpenMPIRBuilder *ompBuilder = moduleTranslation.getOpenMPBuilder ();
18411860 bool mightUseDeviceSharedMem =
1842- isa<omp::TeamsOp, omp::DistributeOp>(*op) &&
1861+ isa<omp::TargetOp, omp:: TeamsOp, omp::DistributeOp>(*op) &&
18431862 mightAllocInDeviceSharedMemory (*op, *ompBuilder);
18441863 for (auto [privDecl, llvmPrivVar, blockArg] :
18451864 llvm::zip_equal (privateVarsInfo.privatizers , privateVarsInfo.llvmVars ,
@@ -5265,42 +5284,68 @@ handleDeclareTargetMapVar(MapInfoData &mapData,
52655284// a store of the kernel argument into this allocated memory which
52665285// will then be loaded from, ByCopy will use the allocated memory
52675286// directly.
5268- static llvm::IRBuilderBase::InsertPoint
5269- createDeviceArgumentAccessor ( MapInfoData &mapData, llvm::Argument &arg,
5270- llvm::Value *input, llvm::Value *&retVal,
5271- llvm::IRBuilderBase &builder ,
5272- llvm::OpenMPIRBuilder &ompBuilder ,
5273- LLVM::ModuleTranslation &moduleTranslation ,
5274- llvm::IRBuilderBase::InsertPoint allocaIP ,
5275- llvm::IRBuilderBase::InsertPoint codeGenIP ) {
5287+ static llvm::IRBuilderBase::InsertPoint createDeviceArgumentAccessor (
5288+ omp::TargetOp targetOp, MapInfoData &mapData, llvm::Argument &arg,
5289+ llvm::Value *input, llvm::Value *&retVal, llvm::IRBuilderBase &builder ,
5290+ llvm::OpenMPIRBuilder &ompBuilder ,
5291+ LLVM::ModuleTranslation &moduleTranslation ,
5292+ llvm::IRBuilderBase::InsertPoint allocIP ,
5293+ llvm::IRBuilderBase::InsertPoint codeGenIP ,
5294+ llvm::ArrayRef<llvm:: IRBuilderBase::InsertPoint> deallocIPs ) {
52765295 assert (ompBuilder.Config .isTargetDevice () &&
52775296 " function only supported for target device codegen" );
5278- builder.restoreIP (allocaIP );
5297+ builder.restoreIP (allocIP );
52795298
52805299 omp::VariableCaptureKind capture = omp::VariableCaptureKind::ByRef;
52815300 LLVM::TypeToLLVMIRTranslator typeToLLVMIRTranslator (
52825301 ompBuilder.M .getContext ());
52835302 unsigned alignmentValue = 0 ;
5303+ BlockArgument mlirArg;
52845304 // Find the associated MapInfoData entry for the current input
5285- for (size_t i = 0 ; i < mapData.MapClause .size (); ++i)
5305+ for (size_t i = 0 ; i < mapData.MapClause .size (); ++i) {
52865306 if (mapData.OriginalValue [i] == input) {
52875307 auto mapOp = cast<omp::MapInfoOp>(mapData.MapClause [i]);
52885308 capture = mapOp.getMapCaptureType ();
52895309 // Get information of alignment of mapped object
52905310 alignmentValue = typeToLLVMIRTranslator.getPreferredAlignment (
52915311 mapOp.getVarType (), ompBuilder.M .getDataLayout ());
5312+ // Get the corresponding target entry block argument
5313+ mlirArg =
5314+ cast<omp::BlockArgOpenMPOpInterface>(*targetOp).getMapBlockArgs ()[i];
52925315 break ;
52935316 }
5317+ }
52945318
52955319 unsigned int allocaAS = ompBuilder.M .getDataLayout ().getAllocaAddrSpace ();
52965320 unsigned int defaultAS =
52975321 ompBuilder.M .getDataLayout ().getProgramAddressSpace ();
52985322
5299- // Create the alloca for the argument the current point.
5300- llvm::Value *v = builder.CreateAlloca (arg.getType (), allocaAS);
5323+ // Create the allocation for the argument.
5324+ llvm::Value *v = nullptr ;
5325+ if (mightAllocInDeviceSharedMemory (*targetOp, ompBuilder) &&
5326+ mustAllocPrivateVarInDeviceSharedMemory (mlirArg)) {
5327+ // Use the beginning of the codeGenIP rather than the usual allocation point
5328+ // for shared memory allocations because otherwise these would be done prior
5329+ // to the target initialization call. Also, the exit block (where the
5330+ // deallocation is placed) is only executed if the initialization call
5331+ // succeeds.
5332+ builder.SetInsertPoint (codeGenIP.getBlock ()->getFirstInsertionPt ());
5333+ v = ompBuilder.createOMPAllocShared (builder, arg.getType ());
5334+
5335+ // Create deallocations in all provided deallocation points and then restore
5336+ // the insertion point to right after the new allocations.
5337+ llvm::IRBuilderBase::InsertPointGuard guard (builder);
5338+ for (auto deallocIP : deallocIPs) {
5339+ builder.SetInsertPoint (deallocIP.getBlock (), deallocIP.getPoint ());
5340+ ompBuilder.createOMPFreeShared (builder, v, arg.getType ());
5341+ }
5342+ } else {
5343+ // Use the current point, which was previously set to allocIP.
5344+ v = builder.CreateAlloca (arg.getType (), allocaAS);
53015345
5302- if (allocaAS != defaultAS && arg.getType ()->isPointerTy ())
5303- v = builder.CreateAddrSpaceCast (v, builder.getPtrTy (defaultAS));
5346+ if (allocaAS != defaultAS && arg.getType ()->isPointerTy ())
5347+ v = builder.CreateAddrSpaceCast (v, builder.getPtrTy (defaultAS));
5348+ }
53045349
53055350 builder.CreateStore (&arg, v);
53065351
@@ -5890,8 +5935,9 @@ convertOmpTarget(Operation &opInst, llvm::IRBuilderBase &builder,
58905935 };
58915936
58925937 auto argAccessorCB = [&](llvm::Argument &arg, llvm::Value *input,
5893- llvm::Value *&retVal, InsertPointTy allocaIP,
5894- InsertPointTy codeGenIP)
5938+ llvm::Value *&retVal, InsertPointTy allocIP,
5939+ InsertPointTy codeGenIP,
5940+ llvm::ArrayRef<InsertPointTy> deallocIPs)
58955941 -> llvm::OpenMPIRBuilder::InsertPointOrErrorTy {
58965942 llvm::IRBuilderBase::InsertPointGuard guard (builder);
58975943 builder.SetCurrentDebugLocation (llvm::DebugLoc ());
@@ -5905,9 +5951,9 @@ convertOmpTarget(Operation &opInst, llvm::IRBuilderBase &builder,
59055951 return codeGenIP;
59065952 }
59075953
5908- return createDeviceArgumentAccessor (mapData, arg, input, retVal, builder ,
5909- *ompBuilder, moduleTranslation,
5910- allocaIP , codeGenIP);
5954+ return createDeviceArgumentAccessor (targetOp, mapData, arg, input, retVal,
5955+ builder, *ompBuilder, moduleTranslation,
5956+ allocIP , codeGenIP, deallocIPs );
59115957 };
59125958
59135959 llvm::OpenMPIRBuilder::TargetKernelRuntimeAttrs runtimeAttrs;
0 commit comments