@@ -173,15 +173,6 @@ static LogicalResult checkImplementationStatus(Operation &op) {
173173 if (op.getHint ())
174174 op.emitWarning (" hint clause discarded" );
175175 };
176- auto checkHostEval = [](auto op, LogicalResult &result) {
177- // Host evaluated clauses are supported, except for loop bounds.
178- for (BlockArgument arg :
179- cast<omp::BlockArgOpenMPOpInterface>(*op).getHostEvalBlockArgs ())
180- for (Operation *user : arg.getUsers ())
181- if (isa<omp::LoopNestOp>(user))
182- result = op.emitError (" not yet implemented: host evaluation of loop "
183- " bounds in omp.target operation" );
184- };
185176 auto checkInReduction = [&todo](auto op, LogicalResult &result) {
186177 if (!op.getInReductionVars ().empty () || op.getInReductionByref () ||
187178 op.getInReductionSyms ())
@@ -318,7 +309,6 @@ static LogicalResult checkImplementationStatus(Operation &op) {
318309 checkBare (op, result);
319310 checkDevice (op, result);
320311 checkHasDeviceAddr (op, result);
321- checkHostEval (op, result);
322312 checkInReduction (op, result);
323313 checkIsDevicePtr (op, result);
324314 checkPrivate (op, result);
@@ -4058,9 +4048,13 @@ createDeviceArgumentAccessor(MapInfoData &mapData, llvm::Argument &arg,
40584048// /
40594049// / Loop bounds and steps are only optionally populated, if output vectors are
40604050// / provided.
4061- static void extractHostEvalClauses (omp::TargetOp targetOp, Value &numThreads,
4062- Value &numTeamsLower, Value &numTeamsUpper,
4063- Value &threadLimit) {
4051+ static void
4052+ extractHostEvalClauses (omp::TargetOp targetOp, Value &numThreads,
4053+ Value &numTeamsLower, Value &numTeamsUpper,
4054+ Value &threadLimit,
4055+ llvm::SmallVectorImpl<Value> *lowerBounds = nullptr ,
4056+ llvm::SmallVectorImpl<Value> *upperBounds = nullptr ,
4057+ llvm::SmallVectorImpl<Value> *steps = nullptr ) {
40644058 auto blockArgIface = llvm::cast<omp::BlockArgOpenMPOpInterface>(*targetOp);
40654059 for (auto item : llvm::zip_equal (targetOp.getHostEvalVars (),
40664060 blockArgIface.getHostEvalBlockArgs ())) {
@@ -4085,11 +4079,26 @@ static void extractHostEvalClauses(omp::TargetOp targetOp, Value &numThreads,
40854079 llvm_unreachable (" unsupported host_eval use" );
40864080 })
40874081 .Case ([&](omp::LoopNestOp loopOp) {
4088- // TODO: Extract bounds and step values. Currently, this cannot be
4089- // reached because translation would have been stopped earlier as a
4090- // result of `checkImplementationStatus` detecting and reporting
4091- // this situation.
4092- llvm_unreachable (" unsupported host_eval use" );
4082+ auto processBounds =
4083+ [&](OperandRange opBounds,
4084+ llvm::SmallVectorImpl<Value> *outBounds) -> bool {
4085+ bool found = false ;
4086+ for (auto [i, lb] : llvm::enumerate (opBounds)) {
4087+ if (lb == blockArg) {
4088+ found = true ;
4089+ if (outBounds)
4090+ (*outBounds)[i] = hostEvalVar;
4091+ }
4092+ }
4093+ return found;
4094+ };
4095+ bool found =
4096+ processBounds (loopOp.getLoopLowerBounds (), lowerBounds);
4097+ found = processBounds (loopOp.getLoopUpperBounds (), upperBounds) ||
4098+ found;
4099+ found = processBounds (loopOp.getLoopSteps (), steps) || found;
4100+ if (!found)
4101+ llvm_unreachable (" unsupported host_eval use" );
40934102 })
40944103 .Default ([](Operation *) {
40954104 llvm_unreachable (" unsupported host_eval use" );
@@ -4226,6 +4235,7 @@ initTargetDefaultAttrs(omp::TargetOp targetOp,
42264235 combinedMaxThreadsVal = maxThreadsVal;
42274236
42284237 // Update kernel bounds structure for the `OpenMPIRBuilder` to use.
4238+ attrs.ExecFlags = targetOp.getKernelExecFlags ();
42294239 attrs.MinTeams = minTeamsVal;
42304240 attrs.MaxTeams .front () = maxTeamsVal;
42314241 attrs.MinThreads = 1 ;
@@ -4243,9 +4253,15 @@ initTargetRuntimeAttrs(llvm::IRBuilderBase &builder,
42434253 LLVM::ModuleTranslation &moduleTranslation,
42444254 omp::TargetOp targetOp,
42454255 llvm::OpenMPIRBuilder::TargetKernelRuntimeAttrs &attrs) {
4256+ omp::LoopNestOp loopOp = castOrGetParentOfType<omp::LoopNestOp>(
4257+ targetOp.getInnermostCapturedOmpOp ());
4258+ unsigned numLoops = loopOp ? loopOp.getNumLoops () : 0 ;
4259+
42464260 Value numThreads, numTeamsLower, numTeamsUpper, teamsThreadLimit;
4261+ llvm::SmallVector<Value> lowerBounds (numLoops), upperBounds (numLoops),
4262+ steps (numLoops);
42474263 extractHostEvalClauses (targetOp, numThreads, numTeamsLower, numTeamsUpper,
4248- teamsThreadLimit);
4264+ teamsThreadLimit, &lowerBounds, &upperBounds, &steps );
42494265
42504266 // TODO: Handle constant 'if' clauses.
42514267 if (Value targetThreadLimit = targetOp.getThreadLimit ())
@@ -4265,7 +4281,34 @@ initTargetRuntimeAttrs(llvm::IRBuilderBase &builder,
42654281 if (numThreads)
42664282 attrs.MaxThreads = moduleTranslation.lookupValue (numThreads);
42674283
4268- // TODO: Populate attrs.LoopTripCount if it is target SPMD.
4284+ if (targetOp.getKernelExecFlags () != llvm::omp::OMP_TGT_EXEC_MODE_GENERIC) {
4285+ llvm::OpenMPIRBuilder *ompBuilder = moduleTranslation.getOpenMPBuilder ();
4286+ attrs.LoopTripCount = nullptr ;
4287+
4288+ // To calculate the trip count, we multiply together the trip counts of
4289+ // every collapsed canonical loop. We don't need to create the loop nests
4290+ // here, since we're only interested in the trip count.
4291+ for (auto [loopLower, loopUpper, loopStep] :
4292+ llvm::zip_equal (lowerBounds, upperBounds, steps)) {
4293+ llvm::Value *lowerBound = moduleTranslation.lookupValue (loopLower);
4294+ llvm::Value *upperBound = moduleTranslation.lookupValue (loopUpper);
4295+ llvm::Value *step = moduleTranslation.lookupValue (loopStep);
4296+
4297+ llvm::OpenMPIRBuilder::LocationDescription loc (builder);
4298+ llvm::Value *tripCount = ompBuilder->calculateCanonicalLoopTripCount (
4299+ loc, lowerBound, upperBound, step, /* IsSigned=*/ true ,
4300+ loopOp.getLoopInclusive ());
4301+
4302+ if (!attrs.LoopTripCount ) {
4303+ attrs.LoopTripCount = tripCount;
4304+ continue ;
4305+ }
4306+
4307+ // TODO: Enable UndefinedSanitizer to diagnose an overflow here.
4308+ attrs.LoopTripCount = builder.CreateMul (attrs.LoopTripCount , tripCount,
4309+ {}, /* HasNUW=*/ true );
4310+ }
4311+ }
42694312}
42704313
42714314static LogicalResult
0 commit comments