@@ -1242,6 +1242,44 @@ static LogicalResult allocAndInitializeReductionVars(
12421242 return success ();
12431243}
12441244
1245+ // / Return the llvm::Value * corresponding to the `privateVar` that
1246+ // / is being privatized. It isn't always as simple as looking up
1247+ // / moduleTranslation with privateVar. For instance, in case of
1248+ // / an allocatable, the descriptor for the allocatable is privatized.
1249+ // / This descriptor is mapped using an MapInfoOp. So, this function
1250+ // / will return a pointer to the llvm::Value corresponding to the
1251+ // / block argument for the mapped descriptor.
1252+ static llvm::Value *
1253+ findAssociatedValue (Value privateVar, llvm::IRBuilderBase &builder,
1254+ LLVM::ModuleTranslation &moduleTranslation,
1255+ omp::TargetOp targetOp = nullptr ,
1256+ llvm::DenseMap<Value, int > *mappedPrivateVars = nullptr ) {
1257+ if (mappedPrivateVars != nullptr && mappedPrivateVars->contains (privateVar)) {
1258+ int blockArgIndex = (*mappedPrivateVars)[privateVar];
1259+ Value blockArg = targetOp.getRegion ().getArgument (blockArgIndex);
1260+ Type privVarType = privateVar.getType ();
1261+ Type blockArgType = blockArg.getType ();
1262+ assert (isa<LLVM::LLVMPointerType>(blockArgType) &&
1263+ " A block argument corresponding to a mapped var should have "
1264+ " !llvm.ptr type" );
1265+
1266+ if (privVarType == blockArg.getType ())
1267+ return moduleTranslation.lookupValue (blockArg);
1268+
1269+ if (!isa<LLVM::LLVMPointerType>(privVarType)) {
1270+ // This typically happens when the privatized type is lowered from
1271+ // boxchar<KIND> and gets lowered to !llvm.struct<(ptr, i64)>. That is the
1272+ // struct/pair is passed by value. But, mapped values are passed only as
1273+ // pointers, so before we privatize, we must load the pointer.
1274+ llvm::Value *load =
1275+ builder.CreateLoad (moduleTranslation.convertType (privVarType),
1276+ moduleTranslation.lookupValue (blockArg));
1277+ return load;
1278+ }
1279+ }
1280+ return moduleTranslation.lookupValue (privateVar);
1281+ }
1282+
12451283// / Allocate delayed private variables. Returns the basic block which comes
12461284// / after all of these allocations. llvm::Value * for each of these private
12471285// / variables are populated in llvmPrivateVars.
@@ -1252,7 +1290,9 @@ allocatePrivateVars(llvm::IRBuilderBase &builder,
12521290 MutableArrayRef<omp::PrivateClauseOp> privateDecls,
12531291 MutableArrayRef<mlir::Value> mlirPrivateVars,
12541292 llvm::SmallVectorImpl<llvm::Value *> &llvmPrivateVars,
1255- const llvm::OpenMPIRBuilder::InsertPointTy &allocaIP) {
1293+ const llvm::OpenMPIRBuilder::InsertPointTy &allocaIP,
1294+ omp::TargetOp targetOp = nullptr ,
1295+ llvm::DenseMap<Value, int > *mappedPrivateVars = nullptr ) {
12561296 // Allocate private vars
12571297 llvm::BranchInst *allocaTerminator =
12581298 llvm::cast<llvm::BranchInst>(allocaIP.getBlock ()->getTerminator ());
@@ -1281,7 +1321,8 @@ allocatePrivateVars(llvm::IRBuilderBase &builder,
12811321 Region &allocRegion = privDecl.getAllocRegion ();
12821322
12831323 // map allocation region block argument
1284- llvm::Value *nonPrivateVar = moduleTranslation.lookupValue (mlirPrivVar);
1324+ llvm::Value *nonPrivateVar = findAssociatedValue (
1325+ mlirPrivVar, builder, moduleTranslation, targetOp, mappedPrivateVars);
12851326 assert (nonPrivateVar);
12861327 moduleTranslation.mapValue (privDecl.getAllocMoldArg (), nonPrivateVar);
12871328
@@ -1296,6 +1337,7 @@ allocatePrivateVars(llvm::IRBuilderBase &builder,
12961337 } else {
12971338 builder.SetInsertPoint (privAllocBlock->getTerminator ());
12981339 }
1340+
12991341 if (failed (inlineConvertOmpRegions (allocRegion, " omp.private.alloc" ,
13001342 builder, moduleTranslation, &phis)))
13011343 return llvm::createStringError (
@@ -3806,43 +3848,6 @@ createDeviceArgumentAccessor(MapInfoData &mapData, llvm::Argument &arg,
38063848 return builder.saveIP ();
38073849}
38083850
3809- // / Return the llvm::Value * corresponding to the `privateVar` that
3810- // / is being privatized. It isn't always as simple as looking up
3811- // / moduleTranslation with privateVar. For instance, in case of
3812- // / an allocatable, the descriptor for the allocatable is privatized.
3813- // / This descriptor is mapped using an MapInfoOp. So, this function
3814- // / will return a pointer to the llvm::Value corresponding to the
3815- // / block argument for the mapped descriptor.
3816- static llvm::Value *
3817- findHostAssociatedValue (Value privateVar, omp::TargetOp targetOp,
3818- llvm::DenseMap<Value, int > &mappedPrivateVars,
3819- llvm::IRBuilderBase &builder,
3820- LLVM::ModuleTranslation &moduleTranslation) {
3821- if (mappedPrivateVars.contains (privateVar)) {
3822- int blockArgIndex = mappedPrivateVars[privateVar];
3823- Value blockArg = targetOp.getRegion ().getArgument (blockArgIndex);
3824- Type privVarType = privateVar.getType ();
3825- Type blockArgType = blockArg.getType ();
3826- assert (isa<LLVM::LLVMPointerType>(blockArgType) &&
3827- " A block argument corresponding to a mapped var should have "
3828- " !llvm.ptr type" );
3829-
3830- if (privVarType == blockArg.getType ())
3831- return moduleTranslation.lookupValue (blockArg);
3832-
3833- if (!isa<LLVM::LLVMPointerType>(privVarType)) {
3834- // This typically happens when the privatized type is lowered from
3835- // boxchar<KIND> and gets lowered to !llvm.struct<(ptr, i64)>. That is the
3836- // struct/pair is passed by value. But, mapped values are passed only as
3837- // pointers, so before we privatize, we must load the pointer.
3838- llvm::Value *load =
3839- builder.CreateLoad (moduleTranslation.convertType (privVarType),
3840- moduleTranslation.lookupValue (blockArg));
3841- return load;
3842- }
3843- }
3844- return moduleTranslation.lookupValue (privateVar);
3845- }
38463851static LogicalResult
38473852convertOmpTarget (Operation &opInst, llvm::IRBuilderBase &builder,
38483853 LLVM::ModuleTranslation &moduleTranslation) {
@@ -3946,68 +3951,49 @@ convertOmpTarget(Operation &opInst, llvm::IRBuilderBase &builder,
39463951 attr.isStringAttribute ())
39473952 llvmOutlinedFn->addFnAttr (attr);
39483953
3949- builder.restoreIP (codeGenIP);
39503954 for (auto [arg, mapOp] : llvm::zip_equal (mapBlockArgs, mapVars)) {
39513955 auto mapInfoOp = cast<omp::MapInfoOp>(mapOp.getDefiningOp ());
39523956 llvm::Value *mapOpValue =
39533957 moduleTranslation.lookupValue (mapInfoOp.getVarPtr ());
39543958 moduleTranslation.mapValue (arg, mapOpValue);
39553959 }
3960+
39563961 // Do privatization after moduleTranslation has already recorded
39573962 // mapped values.
3963+ MutableArrayRef<BlockArgument> privateBlockArgs =
3964+ cast<omp::BlockArgOpenMPOpInterface>(opInst).getPrivateBlockArgs ();
3965+ SmallVector<mlir::Value> mlirPrivateVars;
39583966 SmallVector<llvm::Value *> llvmPrivateVars;
3967+ SmallVector<omp::PrivateClauseOp> privateDecls;
3968+ mlirPrivateVars.reserve (privateBlockArgs.size ());
3969+ llvmPrivateVars.reserve (privateBlockArgs.size ());
3970+ collectPrivatizationDecls (targetOp, privateDecls);
3971+ for (mlir::Value privateVar : targetOp.getPrivateVars ())
3972+ mlirPrivateVars.push_back (privateVar);
3973+
3974+ llvm::Expected<llvm::BasicBlock *> afterAllocas =
3975+ allocatePrivateVars (builder, moduleTranslation, privateBlockArgs,
3976+ privateDecls, mlirPrivateVars, llvmPrivateVars,
3977+ allocaIP, targetOp, &mappedPrivateVars);
3978+
3979+ if (handleError (afterAllocas, *targetOp).failed ())
3980+ return llvm::make_error<PreviouslyReportedError>();
3981+
39593982 SmallVector<Region *> privateCleanupRegions;
3960- if (!targetOp.getPrivateVars ().empty ()) {
3961- builder.restoreIP (allocaIP);
3962-
3963- OperandRange privateVars = targetOp.getPrivateVars ();
3964- std::optional<ArrayAttr> privateSyms = targetOp.getPrivateSyms ();
3965- MutableArrayRef<BlockArgument> privateBlockArgs =
3966- cast<omp::BlockArgOpenMPOpInterface>(opInst).getPrivateBlockArgs ();
3967-
3968- for (auto [privVar, privatizerNameAttr, privBlockArg] :
3969- llvm::zip_equal (privateVars, *privateSyms, privateBlockArgs)) {
3970-
3971- SymbolRefAttr privSym = cast<SymbolRefAttr>(privatizerNameAttr);
3972- omp::PrivateClauseOp privatizer = findPrivatizer (&opInst, privSym);
3973- assert (privatizer.getDataSharingType () !=
3974- omp::DataSharingClauseType::FirstPrivate &&
3975- " unsupported privatizer" );
3976- Region &allocRegion = privatizer.getAllocRegion ();
3977- BlockArgument allocRegionArg = allocRegion.getArgument (0 );
3978- moduleTranslation.mapValue (
3979- allocRegionArg,
3980- findHostAssociatedValue (privVar, targetOp, mappedPrivateVars,
3981- builder, moduleTranslation));
3982- SmallVector<llvm::Value *, 1 > yieldedValues;
3983- if (failed (inlineConvertOmpRegions (
3984- allocRegion, " omp.targetop.privatizer" , builder,
3985- moduleTranslation, &yieldedValues))) {
3986- return llvm::createStringError (
3987- " failed to inline `alloc` region of `omp.private`" );
3988- }
3989- assert (yieldedValues.size () == 1 );
3990- llvm::Value *llvmReplacementValue = yieldedValues.front ();
3991- moduleTranslation.mapValue (privBlockArg, llvmReplacementValue);
3992- if (!privatizer.getDeallocRegion ().empty ()) {
3993- llvmPrivateVars.push_back (llvmReplacementValue);
3994- privateCleanupRegions.push_back (&privatizer.getDeallocRegion ());
3995- }
3996- moduleTranslation.forgetMapping (allocRegion);
3997- builder.restoreIP (builder.saveIP ());
3998- }
3999- }
3983+ llvm::transform (privateDecls, std::back_inserter (privateCleanupRegions),
3984+ [](omp::PrivateClauseOp privatizer) {
3985+ return &privatizer.getDeallocRegion ();
3986+ });
40003987
3988+ builder.restoreIP (codeGenIP);
40013989 llvm::Expected<llvm::BasicBlock *> exitBlock = convertOmpOpRegions (
40023990 targetRegion, " omp.target" , builder, moduleTranslation);
3991+
40033992 if (!exitBlock)
40043993 return exitBlock.takeError ();
40053994
40063995 builder.SetInsertPoint (*exitBlock);
4007- if (!llvmPrivateVars.empty ()) {
4008- assert (llvmPrivateVars.size () == privateCleanupRegions.size () &&
4009- " Number of private variables needing cleanup not equal to number"
4010- " of privatizers with dealloc regions" );
3996+ if (!privateCleanupRegions.empty ()) {
40113997 if (failed (inlineOmpRegionCleanup (
40123998 privateCleanupRegions, llvmPrivateVars, moduleTranslation,
40133999 builder, " omp.targetop.private.cleanup" ,
@@ -4017,10 +4003,10 @@ convertOmpTarget(Operation &opInst, llvm::IRBuilderBase &builder,
40174003 " op in the target region" );
40184004 }
40194005 }
4020- return builder.saveIP ();
4006+
4007+ return InsertPointTy (exitBlock.get (), exitBlock.get ()->end ());
40214008 };
40224009
4023- llvm::OpenMPIRBuilder::LocationDescription ompLoc (builder);
40244010 StringRef parentName = parentFn.getName ();
40254011
40264012 llvm::TargetRegionEntryInfo entryInfo;
@@ -4031,9 +4017,6 @@ convertOmpTarget(Operation &opInst, llvm::IRBuilderBase &builder,
40314017 int32_t defaultValTeams = -1 ;
40324018 int32_t defaultValThreads = 0 ;
40334019
4034- llvm::OpenMPIRBuilder::InsertPointTy allocaIP =
4035- findAllocaInsertPoint (builder, moduleTranslation);
4036-
40374020 MapInfoData mapData;
40384021 collectMapDataFromMapOperands (mapData, mapVars, moduleTranslation, dl,
40394022 builder);
@@ -4081,6 +4064,10 @@ convertOmpTarget(Operation &opInst, llvm::IRBuilderBase &builder,
40814064 buildDependData (targetOp.getDependKinds (), targetOp.getDependVars (),
40824065 moduleTranslation, dds);
40834066
4067+ llvm::OpenMPIRBuilder::InsertPointTy allocaIP =
4068+ findAllocaInsertPoint (builder, moduleTranslation);
4069+ llvm::OpenMPIRBuilder::LocationDescription ompLoc (builder);
4070+
40844071 llvm::OpenMPIRBuilder::InsertPointOrErrorTy afterIP =
40854072 moduleTranslation.getOpenMPBuilder ()->createTarget (
40864073 ompLoc, isOffloadEntry, allocaIP, builder.saveIP (), entryInfo,
0 commit comments