@@ -176,15 +176,6 @@ static LogicalResult checkImplementationStatus(Operation &op) {
176176 if (op.getHint ())
177177 op.emitWarning (" hint clause discarded" );
178178 };
179- auto checkHostEval = [](auto op, LogicalResult &result) {
180- // Host evaluated clauses are supported, except for loop bounds.
181- for (BlockArgument arg :
182- cast<omp::BlockArgOpenMPOpInterface>(*op).getHostEvalBlockArgs ())
183- for (Operation *user : arg.getUsers ())
184- if (isa<omp::LoopNestOp>(user))
185- result = op.emitError (" not yet implemented: host evaluation of loop "
186- " bounds in omp.target operation" );
187- };
188179 auto checkInReduction = [&todo](auto op, LogicalResult &result) {
189180 if (!op.getInReductionVars ().empty () || op.getInReductionByref () ||
190181 op.getInReductionSyms ())
@@ -321,7 +312,6 @@ static LogicalResult checkImplementationStatus(Operation &op) {
321312 checkBare (op, result);
322313 checkDevice (op, result);
323314 checkHasDeviceAddr (op, result);
324- checkHostEval (op, result);
325315 checkInReduction (op, result);
326316 checkIsDevicePtr (op, result);
327317 checkPrivate (op, result);
@@ -4054,9 +4044,13 @@ createDeviceArgumentAccessor(MapInfoData &mapData, llvm::Argument &arg,
40544044// /
40554045// / Loop bounds and steps are only optionally populated, if output vectors are
40564046// / provided.
4057- static void extractHostEvalClauses (omp::TargetOp targetOp, Value &numThreads,
4058- Value &numTeamsLower, Value &numTeamsUpper,
4059- Value &threadLimit) {
4047+ static void
4048+ extractHostEvalClauses (omp::TargetOp targetOp, Value &numThreads,
4049+ Value &numTeamsLower, Value &numTeamsUpper,
4050+ Value &threadLimit,
4051+ llvm::SmallVectorImpl<Value> *lowerBounds = nullptr ,
4052+ llvm::SmallVectorImpl<Value> *upperBounds = nullptr ,
4053+ llvm::SmallVectorImpl<Value> *steps = nullptr ) {
40604054 auto blockArgIface = llvm::cast<omp::BlockArgOpenMPOpInterface>(*targetOp);
40614055 for (auto item : llvm::zip_equal (targetOp.getHostEvalVars (),
40624056 blockArgIface.getHostEvalBlockArgs ())) {
@@ -4081,11 +4075,26 @@ static void extractHostEvalClauses(omp::TargetOp targetOp, Value &numThreads,
40814075 llvm_unreachable (" unsupported host_eval use" );
40824076 })
40834077 .Case ([&](omp::LoopNestOp loopOp) {
4084- // TODO: Extract bounds and step values. Currently, this cannot be
4085- // reached because translation would have been stopped earlier as a
4086- // result of `checkImplementationStatus` detecting and reporting
4087- // this situation.
4088- llvm_unreachable (" unsupported host_eval use" );
4078+ auto processBounds =
4079+ [&](OperandRange opBounds,
4080+ llvm::SmallVectorImpl<Value> *outBounds) -> bool {
4081+ bool found = false ;
4082+ for (auto [i, lb] : llvm::enumerate (opBounds)) {
4083+ if (lb == blockArg) {
4084+ found = true ;
4085+ if (outBounds)
4086+ (*outBounds)[i] = hostEvalVar;
4087+ }
4088+ }
4089+ return found;
4090+ };
4091+ bool found =
4092+ processBounds (loopOp.getLoopLowerBounds (), lowerBounds);
4093+ found = processBounds (loopOp.getLoopUpperBounds (), upperBounds) ||
4094+ found;
4095+ found = processBounds (loopOp.getLoopSteps (), steps) || found;
4096+ if (!found)
4097+ llvm_unreachable (" unsupported host_eval use" );
40894098 })
40904099 .Default ([](Operation *) {
40914100 llvm_unreachable (" unsupported host_eval use" );
@@ -4222,6 +4231,7 @@ initTargetDefaultAttrs(omp::TargetOp targetOp,
42224231 combinedMaxThreadsVal = maxThreadsVal;
42234232
42244233 // Update kernel bounds structure for the `OpenMPIRBuilder` to use.
4234+ attrs.ExecFlags = targetOp.getKernelExecFlags ();
42254235 attrs.MinTeams = minTeamsVal;
42264236 attrs.MaxTeams .front () = maxTeamsVal;
42274237 attrs.MinThreads = 1 ;
@@ -4239,9 +4249,15 @@ initTargetRuntimeAttrs(llvm::IRBuilderBase &builder,
42394249 LLVM::ModuleTranslation &moduleTranslation,
42404250 omp::TargetOp targetOp,
42414251 llvm::OpenMPIRBuilder::TargetKernelRuntimeAttrs &attrs) {
4252+ omp::LoopNestOp loopOp = castOrGetParentOfType<omp::LoopNestOp>(
4253+ targetOp.getInnermostCapturedOmpOp ());
4254+ unsigned numLoops = loopOp ? loopOp.getNumLoops () : 0 ;
4255+
42424256 Value numThreads, numTeamsLower, numTeamsUpper, teamsThreadLimit;
4257+ llvm::SmallVector<Value> lowerBounds (numLoops), upperBounds (numLoops),
4258+ steps (numLoops);
42434259 extractHostEvalClauses (targetOp, numThreads, numTeamsLower, numTeamsUpper,
4244- teamsThreadLimit);
4260+ teamsThreadLimit, &lowerBounds, &upperBounds, &steps );
42454261
42464262 // TODO: Handle constant 'if' clauses.
42474263 if (Value targetThreadLimit = targetOp.getThreadLimit ())
@@ -4261,7 +4277,34 @@ initTargetRuntimeAttrs(llvm::IRBuilderBase &builder,
42614277 if (numThreads)
42624278 attrs.MaxThreads = moduleTranslation.lookupValue (numThreads);
42634279
4264- // TODO: Populate attrs.LoopTripCount if it is target SPMD.
4280+ if (targetOp.getKernelExecFlags () != llvm::omp::OMP_TGT_EXEC_MODE_GENERIC) {
4281+ llvm::OpenMPIRBuilder *ompBuilder = moduleTranslation.getOpenMPBuilder ();
4282+ attrs.LoopTripCount = nullptr ;
4283+
4284+ // To calculate the trip count, we multiply together the trip counts of
4285+ // every collapsed canonical loop. We don't need to create the loop nests
4286+ // here, since we're only interested in the trip count.
4287+ for (auto [loopLower, loopUpper, loopStep] :
4288+ llvm::zip_equal (lowerBounds, upperBounds, steps)) {
4289+ llvm::Value *lowerBound = moduleTranslation.lookupValue (loopLower);
4290+ llvm::Value *upperBound = moduleTranslation.lookupValue (loopUpper);
4291+ llvm::Value *step = moduleTranslation.lookupValue (loopStep);
4292+
4293+ llvm::OpenMPIRBuilder::LocationDescription loc (builder);
4294+ llvm::Value *tripCount = ompBuilder->calculateCanonicalLoopTripCount (
4295+ loc, lowerBound, upperBound, step, /* IsSigned=*/ true ,
4296+ loopOp.getLoopInclusive ());
4297+
4298+ if (!attrs.LoopTripCount ) {
4299+ attrs.LoopTripCount = tripCount;
4300+ continue ;
4301+ }
4302+
4303+ // TODO: Enable UndefinedSanitizer to diagnose an overflow here.
4304+ attrs.LoopTripCount = builder.CreateMul (attrs.LoopTripCount , tripCount,
4305+ {}, /* HasNUW=*/ true );
4306+ }
4307+ }
42654308}
42664309
42674310static LogicalResult
0 commit comments