@@ -173,15 +173,6 @@ static LogicalResult checkImplementationStatus(Operation &op) {
173173 if (op.getHint ())
174174 op.emitWarning (" hint clause discarded" );
175175 };
176- auto checkHostEval = [](auto op, LogicalResult &result) {
177- // Host evaluated clauses are supported, except for loop bounds.
178- for (BlockArgument arg :
179- cast<omp::BlockArgOpenMPOpInterface>(*op).getHostEvalBlockArgs ())
180- for (Operation *user : arg.getUsers ())
181- if (isa<omp::LoopNestOp>(user))
182- result = op.emitError (" not yet implemented: host evaluation of loop "
183- " bounds in omp.target operation" );
184- };
185176 auto checkInReduction = [&todo](auto op, LogicalResult &result) {
186177 if (!op.getInReductionVars ().empty () || op.getInReductionByref () ||
187178 op.getInReductionSyms ())
@@ -318,7 +309,6 @@ static LogicalResult checkImplementationStatus(Operation &op) {
318309 checkBare (op, result);
319310 checkDevice (op, result);
320311 checkHasDeviceAddr (op, result);
321- checkHostEval (op, result);
322312 checkInReduction (op, result);
323313 checkIsDevicePtr (op, result);
324314 checkPrivate (op, result);
@@ -4158,9 +4148,13 @@ createDeviceArgumentAccessor(MapInfoData &mapData, llvm::Argument &arg,
41584148// /
41594149// / Loop bounds and steps are only optionally populated, if output vectors are
41604150// / provided.
4161- static void extractHostEvalClauses (omp::TargetOp targetOp, Value &numThreads,
4162- Value &numTeamsLower, Value &numTeamsUpper,
4163- Value &threadLimit) {
4151+ static void
4152+ extractHostEvalClauses (omp::TargetOp targetOp, Value &numThreads,
4153+ Value &numTeamsLower, Value &numTeamsUpper,
4154+ Value &threadLimit,
4155+ llvm::SmallVectorImpl<Value> *lowerBounds = nullptr ,
4156+ llvm::SmallVectorImpl<Value> *upperBounds = nullptr ,
4157+ llvm::SmallVectorImpl<Value> *steps = nullptr ) {
41644158 auto blockArgIface = llvm::cast<omp::BlockArgOpenMPOpInterface>(*targetOp);
41654159 for (auto item : llvm::zip_equal (targetOp.getHostEvalVars (),
41664160 blockArgIface.getHostEvalBlockArgs ())) {
@@ -4185,11 +4179,26 @@ static void extractHostEvalClauses(omp::TargetOp targetOp, Value &numThreads,
41854179 llvm_unreachable (" unsupported host_eval use" );
41864180 })
41874181 .Case ([&](omp::LoopNestOp loopOp) {
4188- // TODO: Extract bounds and step values. Currently, this cannot be
4189- // reached because translation would have been stopped earlier as a
4190- // result of `checkImplementationStatus` detecting and reporting
4191- // this situation.
4192- llvm_unreachable (" unsupported host_eval use" );
4182+ auto processBounds =
4183+ [&](OperandRange opBounds,
4184+ llvm::SmallVectorImpl<Value> *outBounds) -> bool {
4185+ bool found = false ;
4186+ for (auto [i, lb] : llvm::enumerate (opBounds)) {
4187+ if (lb == blockArg) {
4188+ found = true ;
4189+ if (outBounds)
4190+ (*outBounds)[i] = hostEvalVar;
4191+ }
4192+ }
4193+ return found;
4194+ };
4195+ bool found =
4196+ processBounds (loopOp.getLoopLowerBounds (), lowerBounds);
4197+ found = processBounds (loopOp.getLoopUpperBounds (), upperBounds) ||
4198+ found;
4199+ found = processBounds (loopOp.getLoopSteps (), steps) || found;
4200+ (void )found;
4201+ assert (found && " unsupported host_eval use" );
41934202 })
41944203 .Default ([](Operation *) {
41954204 llvm_unreachable (" unsupported host_eval use" );
@@ -4326,6 +4335,7 @@ initTargetDefaultAttrs(omp::TargetOp targetOp,
43264335 combinedMaxThreadsVal = maxThreadsVal;
43274336
43284337 // Update kernel bounds structure for the `OpenMPIRBuilder` to use.
4338+ attrs.ExecFlags = targetOp.getKernelExecFlags ();
43294339 attrs.MinTeams = minTeamsVal;
43304340 attrs.MaxTeams .front () = maxTeamsVal;
43314341 attrs.MinThreads = 1 ;
@@ -4343,9 +4353,15 @@ initTargetRuntimeAttrs(llvm::IRBuilderBase &builder,
43434353 LLVM::ModuleTranslation &moduleTranslation,
43444354 omp::TargetOp targetOp,
43454355 llvm::OpenMPIRBuilder::TargetKernelRuntimeAttrs &attrs) {
4356+ omp::LoopNestOp loopOp = castOrGetParentOfType<omp::LoopNestOp>(
4357+ targetOp.getInnermostCapturedOmpOp ());
4358+ unsigned numLoops = loopOp ? loopOp.getNumLoops () : 0 ;
4359+
43464360 Value numThreads, numTeamsLower, numTeamsUpper, teamsThreadLimit;
4361+ llvm::SmallVector<Value> lowerBounds (numLoops), upperBounds (numLoops),
4362+ steps (numLoops);
43474363 extractHostEvalClauses (targetOp, numThreads, numTeamsLower, numTeamsUpper,
4348- teamsThreadLimit);
4364+ teamsThreadLimit, &lowerBounds, &upperBounds, &steps );
43494365
43504366 // TODO: Handle constant 'if' clauses.
43514367 if (Value targetThreadLimit = targetOp.getThreadLimit ())
@@ -4365,7 +4381,34 @@ initTargetRuntimeAttrs(llvm::IRBuilderBase &builder,
43654381 if (numThreads)
43664382 attrs.MaxThreads = moduleTranslation.lookupValue (numThreads);
43674383
4368- // TODO: Populate attrs.LoopTripCount if it is target SPMD.
4384+ if (targetOp.getKernelExecFlags () != llvm::omp::OMP_TGT_EXEC_MODE_GENERIC) {
4385+ llvm::OpenMPIRBuilder *ompBuilder = moduleTranslation.getOpenMPBuilder ();
4386+ attrs.LoopTripCount = nullptr ;
4387+
4388+ // To calculate the trip count, we multiply together the trip counts of
4389+ // every collapsed canonical loop. We don't need to create the loop nests
4390+ // here, since we're only interested in the trip count.
4391+ for (auto [loopLower, loopUpper, loopStep] :
4392+ llvm::zip_equal (lowerBounds, upperBounds, steps)) {
4393+ llvm::Value *lowerBound = moduleTranslation.lookupValue (loopLower);
4394+ llvm::Value *upperBound = moduleTranslation.lookupValue (loopUpper);
4395+ llvm::Value *step = moduleTranslation.lookupValue (loopStep);
4396+
4397+ llvm::OpenMPIRBuilder::LocationDescription loc (builder);
4398+ llvm::Value *tripCount = ompBuilder->calculateCanonicalLoopTripCount (
4399+ loc, lowerBound, upperBound, step, /* IsSigned=*/ true ,
4400+ loopOp.getLoopInclusive ());
4401+
4402+ if (!attrs.LoopTripCount ) {
4403+ attrs.LoopTripCount = tripCount;
4404+ continue ;
4405+ }
4406+
4407+ // TODO: Enable UndefinedSanitizer to diagnose an overflow here.
4408+ attrs.LoopTripCount = builder.CreateMul (attrs.LoopTripCount , tripCount,
4409+ {}, /* HasNUW=*/ true );
4410+ }
4411+ }
43694412}
43704413
43714414static LogicalResult
0 commit comments