@@ -1286,6 +1286,44 @@ static LogicalResult allocAndInitializeReductionVars(
12861286 isByRef, deferredStores);
12871287}
12881288
1289+ // / Return the llvm::Value * corresponding to the `privateVar` that
1290+ // / is being privatized. It isn't always as simple as looking up
1291+ // / moduleTranslation with privateVar. For instance, in case of
1292+ // / an allocatable, the descriptor for the allocatable is privatized.
1293+ // / This descriptor is mapped using an MapInfoOp. So, this function
1294+ // / will return a pointer to the llvm::Value corresponding to the
1295+ // / block argument for the mapped descriptor.
1296+ static llvm::Value *
1297+ findAssociatedValue (Value privateVar, llvm::IRBuilderBase &builder,
1298+ LLVM::ModuleTranslation &moduleTranslation,
1299+ omp::TargetOp targetOp = nullptr ,
1300+ llvm::DenseMap<Value, int > *mappedPrivateVars = nullptr ) {
1301+ if (mappedPrivateVars != nullptr && mappedPrivateVars->contains (privateVar)) {
1302+ int blockArgIndex = (*mappedPrivateVars)[privateVar];
1303+ Value blockArg = targetOp.getRegion ().getArgument (blockArgIndex);
1304+ Type privVarType = privateVar.getType ();
1305+ Type blockArgType = blockArg.getType ();
1306+ assert (isa<LLVM::LLVMPointerType>(blockArgType) &&
1307+ " A block argument corresponding to a mapped var should have "
1308+ " !llvm.ptr type" );
1309+
1310+ if (privVarType == blockArg.getType ())
1311+ return moduleTranslation.lookupValue (blockArg);
1312+
1313+ if (!isa<LLVM::LLVMPointerType>(privVarType)) {
1314+ // This typically happens when the privatized type is lowered from
1315+ // boxchar<KIND> and gets lowered to !llvm.struct<(ptr, i64)>. That is the
1316+ // struct/pair is passed by value. But, mapped values are passed only as
1317+ // pointers, so before we privatize, we must load the pointer.
1318+ llvm::Value *load =
1319+ builder.CreateLoad (moduleTranslation.convertType (privVarType),
1320+ moduleTranslation.lookupValue (blockArg));
1321+ return load;
1322+ }
1323+ }
1324+ return moduleTranslation.lookupValue (privateVar);
1325+ }
1326+
12891327// / Allocate delayed private variables. Returns the basic block which comes
12901328// / after all of these allocations. llvm::Value * for each of these private
12911329// / variables are populated in llvmPrivateVars.
@@ -1296,7 +1334,9 @@ allocatePrivateVars(llvm::IRBuilderBase &builder,
12961334 MutableArrayRef<omp::PrivateClauseOp> privateDecls,
12971335 MutableArrayRef<mlir::Value> mlirPrivateVars,
12981336 llvm::SmallVectorImpl<llvm::Value *> &llvmPrivateVars,
1299- const llvm::OpenMPIRBuilder::InsertPointTy &allocaIP) {
1337+ const llvm::OpenMPIRBuilder::InsertPointTy &allocaIP,
1338+ omp::TargetOp targetOp = nullptr ,
1339+ llvm::DenseMap<Value, int > *mappedPrivateVars = nullptr ) {
13001340 llvm::IRBuilderBase::InsertPointGuard guard (builder);
13011341 // Allocate private vars
13021342 llvm::BranchInst *allocaTerminator =
@@ -1326,7 +1366,8 @@ allocatePrivateVars(llvm::IRBuilderBase &builder,
13261366 Region &allocRegion = privDecl.getAllocRegion ();
13271367
13281368 // map allocation region block argument
1329- llvm::Value *nonPrivateVar = moduleTranslation.lookupValue (mlirPrivVar);
1369+ llvm::Value *nonPrivateVar = findAssociatedValue (
1370+ mlirPrivVar, builder, moduleTranslation, targetOp, mappedPrivateVars);
13301371 assert (nonPrivateVar);
13311372 moduleTranslation.mapValue (privDecl.getAllocMoldArg (), nonPrivateVar);
13321373
@@ -1341,6 +1382,7 @@ allocatePrivateVars(llvm::IRBuilderBase &builder,
13411382 } else {
13421383 builder.SetInsertPoint (privAllocBlock->getTerminator ());
13431384 }
1385+
13441386 if (failed (inlineConvertOmpRegions (allocRegion, " omp.private.alloc" ,
13451387 builder, moduleTranslation, &phis)))
13461388 return llvm::createStringError (
@@ -3809,43 +3851,6 @@ createDeviceArgumentAccessor(MapInfoData &mapData, llvm::Argument &arg,
38093851 return builder.saveIP ();
38103852}
38113853
3812- // / Return the llvm::Value * corresponding to the `privateVar` that
3813- // / is being privatized. It isn't always as simple as looking up
3814- // / moduleTranslation with privateVar. For instance, in case of
3815- // / an allocatable, the descriptor for the allocatable is privatized.
3816- // / This descriptor is mapped using an MapInfoOp. So, this function
3817- // / will return a pointer to the llvm::Value corresponding to the
3818- // / block argument for the mapped descriptor.
3819- static llvm::Value *
3820- findHostAssociatedValue (Value privateVar, omp::TargetOp targetOp,
3821- llvm::DenseMap<Value, int > &mappedPrivateVars,
3822- llvm::IRBuilderBase &builder,
3823- LLVM::ModuleTranslation &moduleTranslation) {
3824- if (mappedPrivateVars.contains (privateVar)) {
3825- int blockArgIndex = mappedPrivateVars[privateVar];
3826- Value blockArg = targetOp.getRegion ().getArgument (blockArgIndex);
3827- Type privVarType = privateVar.getType ();
3828- Type blockArgType = blockArg.getType ();
3829- assert (isa<LLVM::LLVMPointerType>(blockArgType) &&
3830- " A block argument corresponding to a mapped var should have "
3831- " !llvm.ptr type" );
3832-
3833- if (privVarType == blockArg.getType ())
3834- return moduleTranslation.lookupValue (blockArg);
3835-
3836- if (!isa<LLVM::LLVMPointerType>(privVarType)) {
3837- // This typically happens when the privatized type is lowered from
3838- // boxchar<KIND> and gets lowered to !llvm.struct<(ptr, i64)>. That is the
3839- // struct/pair is passed by value. But, mapped values are passed only as
3840- // pointers, so before we privatize, we must load the pointer.
3841- llvm::Value *load =
3842- builder.CreateLoad (moduleTranslation.convertType (privVarType),
3843- moduleTranslation.lookupValue (blockArg));
3844- return load;
3845- }
3846- }
3847- return moduleTranslation.lookupValue (privateVar);
3848- }
38493854static LogicalResult
38503855convertOmpTarget (Operation &opInst, llvm::IRBuilderBase &builder,
38513856 LLVM::ModuleTranslation &moduleTranslation) {
@@ -3949,68 +3954,49 @@ convertOmpTarget(Operation &opInst, llvm::IRBuilderBase &builder,
39493954 attr.isStringAttribute ())
39503955 llvmOutlinedFn->addFnAttr (attr);
39513956
3952- builder.restoreIP (codeGenIP);
39533957 for (auto [arg, mapOp] : llvm::zip_equal (mapBlockArgs, mapVars)) {
39543958 auto mapInfoOp = cast<omp::MapInfoOp>(mapOp.getDefiningOp ());
39553959 llvm::Value *mapOpValue =
39563960 moduleTranslation.lookupValue (mapInfoOp.getVarPtr ());
39573961 moduleTranslation.mapValue (arg, mapOpValue);
39583962 }
3963+
39593964 // Do privatization after moduleTranslation has already recorded
39603965 // mapped values.
3966+ MutableArrayRef<BlockArgument> privateBlockArgs =
3967+ cast<omp::BlockArgOpenMPOpInterface>(opInst).getPrivateBlockArgs ();
3968+ SmallVector<mlir::Value> mlirPrivateVars;
39613969 SmallVector<llvm::Value *> llvmPrivateVars;
3970+ SmallVector<omp::PrivateClauseOp> privateDecls;
3971+ mlirPrivateVars.reserve (privateBlockArgs.size ());
3972+ llvmPrivateVars.reserve (privateBlockArgs.size ());
3973+ collectPrivatizationDecls (targetOp, privateDecls);
3974+ for (mlir::Value privateVar : targetOp.getPrivateVars ())
3975+ mlirPrivateVars.push_back (privateVar);
3976+
3977+ llvm::Expected<llvm::BasicBlock *> afterAllocas =
3978+ allocatePrivateVars (builder, moduleTranslation, privateBlockArgs,
3979+ privateDecls, mlirPrivateVars, llvmPrivateVars,
3980+ allocaIP, targetOp, &mappedPrivateVars);
3981+
3982+ if (handleError (afterAllocas, *targetOp).failed ())
3983+ return llvm::make_error<PreviouslyReportedError>();
3984+
39623985 SmallVector<Region *> privateCleanupRegions;
3963- if (!targetOp.getPrivateVars ().empty ()) {
3964- builder.restoreIP (allocaIP);
3965-
3966- OperandRange privateVars = targetOp.getPrivateVars ();
3967- std::optional<ArrayAttr> privateSyms = targetOp.getPrivateSyms ();
3968- MutableArrayRef<BlockArgument> privateBlockArgs =
3969- cast<omp::BlockArgOpenMPOpInterface>(opInst).getPrivateBlockArgs ();
3970-
3971- for (auto [privVar, privatizerNameAttr, privBlockArg] :
3972- llvm::zip_equal (privateVars, *privateSyms, privateBlockArgs)) {
3973-
3974- SymbolRefAttr privSym = cast<SymbolRefAttr>(privatizerNameAttr);
3975- omp::PrivateClauseOp privatizer = findPrivatizer (&opInst, privSym);
3976- assert (privatizer.getDataSharingType () !=
3977- omp::DataSharingClauseType::FirstPrivate &&
3978- " unsupported privatizer" );
3979- Region &allocRegion = privatizer.getAllocRegion ();
3980- BlockArgument allocRegionArg = allocRegion.getArgument (0 );
3981- moduleTranslation.mapValue (
3982- allocRegionArg,
3983- findHostAssociatedValue (privVar, targetOp, mappedPrivateVars,
3984- builder, moduleTranslation));
3985- SmallVector<llvm::Value *, 1 > yieldedValues;
3986- if (failed (inlineConvertOmpRegions (
3987- allocRegion, " omp.targetop.privatizer" , builder,
3988- moduleTranslation, &yieldedValues))) {
3989- return llvm::createStringError (
3990- " failed to inline `alloc` region of `omp.private`" );
3991- }
3992- assert (yieldedValues.size () == 1 );
3993- llvm::Value *llvmReplacementValue = yieldedValues.front ();
3994- moduleTranslation.mapValue (privBlockArg, llvmReplacementValue);
3995- if (!privatizer.getDeallocRegion ().empty ()) {
3996- llvmPrivateVars.push_back (llvmReplacementValue);
3997- privateCleanupRegions.push_back (&privatizer.getDeallocRegion ());
3998- }
3999- moduleTranslation.forgetMapping (allocRegion);
4000- builder.restoreIP (builder.saveIP ());
4001- }
4002- }
3986+ llvm::transform (privateDecls, std::back_inserter (privateCleanupRegions),
3987+ [](omp::PrivateClauseOp privatizer) {
3988+ return &privatizer.getDeallocRegion ();
3989+ });
40033990
3991+ builder.restoreIP (codeGenIP);
40043992 llvm::Expected<llvm::BasicBlock *> exitBlock = convertOmpOpRegions (
40053993 targetRegion, " omp.target" , builder, moduleTranslation);
3994+
40063995 if (!exitBlock)
40073996 return exitBlock.takeError ();
40083997
40093998 builder.SetInsertPoint (*exitBlock);
4010- if (!llvmPrivateVars.empty ()) {
4011- assert (llvmPrivateVars.size () == privateCleanupRegions.size () &&
4012- " Number of private variables needing cleanup not equal to number"
4013- " of privatizers with dealloc regions" );
3999+ if (!privateCleanupRegions.empty ()) {
40144000 if (failed (inlineOmpRegionCleanup (
40154001 privateCleanupRegions, llvmPrivateVars, moduleTranslation,
40164002 builder, " omp.targetop.private.cleanup" ,
@@ -4020,10 +4006,10 @@ convertOmpTarget(Operation &opInst, llvm::IRBuilderBase &builder,
40204006 " op in the target region" );
40214007 }
40224008 }
4023- return builder.saveIP ();
4009+
4010+ return InsertPointTy (exitBlock.get (), exitBlock.get ()->end ());
40244011 };
40254012
4026- llvm::OpenMPIRBuilder::LocationDescription ompLoc (builder);
40274013 StringRef parentName = parentFn.getName ();
40284014
40294015 llvm::TargetRegionEntryInfo entryInfo;
@@ -4034,9 +4020,6 @@ convertOmpTarget(Operation &opInst, llvm::IRBuilderBase &builder,
40344020 int32_t defaultValTeams = -1 ;
40354021 int32_t defaultValThreads = 0 ;
40364022
4037- llvm::OpenMPIRBuilder::InsertPointTy allocaIP =
4038- findAllocaInsertPoint (builder, moduleTranslation);
4039-
40404023 MapInfoData mapData;
40414024 collectMapDataFromMapOperands (mapData, mapVars, moduleTranslation, dl,
40424025 builder);
@@ -4084,6 +4067,10 @@ convertOmpTarget(Operation &opInst, llvm::IRBuilderBase &builder,
40844067 buildDependData (targetOp.getDependKinds (), targetOp.getDependVars (),
40854068 moduleTranslation, dds);
40864069
4070+ llvm::OpenMPIRBuilder::InsertPointTy allocaIP =
4071+ findAllocaInsertPoint (builder, moduleTranslation);
4072+ llvm::OpenMPIRBuilder::LocationDescription ompLoc (builder);
4073+
40874074 llvm::OpenMPIRBuilder::InsertPointOrErrorTy afterIP =
40884075 moduleTranslation.getOpenMPBuilder ()->createTarget (
40894076 ompLoc, isOffloadEntry, allocaIP, builder.saveIP (), entryInfo,
0 commit comments