Skip to content

Commit 0258a52

Browse files
authored
[MLIR][OpenMP] Fix handling of constant num_teams/threads (#204)
The PR stack transitioning to the `host_eval` representation for introduced an issue causing the `num_teams`, `thread_limit` and `num_threads` clauses to not properly initialize kernel attributes during target device compilation when these are constant. This patch fixes that issue by using host-evaluated values instead of their corresponding block argument when extracting these constant values.
1 parent 27e3c3a commit 0258a52

File tree

1 file changed

+13
-29
lines changed

1 file changed

+13
-29
lines changed

mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp

Lines changed: 13 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -4103,8 +4103,7 @@ extractHostEvalClauses(omp::TargetOp targetOp, Value &numThreads,
41034103
/// corresponding global `ConfigurationEnvironmentTy` structure.
41044104
static void initTargetDefaultBounds(
41054105
omp::TargetOp targetOp,
4106-
llvm::OpenMPIRBuilder::TargetKernelDefaultBounds &bounds,
4107-
bool isTargetDevice, bool isGPU) {
4106+
llvm::OpenMPIRBuilder::TargetKernelDefaultBounds &bounds, bool isGPU) {
41084107
Value hostNumThreads, hostNumTeamsLower, hostNumTeamsUpper, hostThreadLimit;
41094108
extractHostEvalClauses(targetOp, hostNumThreads, hostNumTeamsLower,
41104109
hostNumTeamsUpper, hostThreadLimit);
@@ -4114,15 +4113,12 @@ static void initTargetDefaultBounds(
41144113

41154114
// Handle clauses impacting the number of teams.
41164115
int32_t minTeamsVal = 1, maxTeamsVal = -1;
4117-
if (auto teamsOp =
4118-
castOrGetParentOfType<omp::TeamsOp>(innermostCapturedOmpOp)) {
4119-
// TODO Use teamsOp.getNumTeamsLower() to initialize `minTeamsVal`. For now,
4120-
// just match clang and set min and max to the same value.
4121-
Value numTeamsClause =
4122-
isTargetDevice ? teamsOp.getNumTeamsUpper() : hostNumTeamsUpper;
4123-
if (numTeamsClause) {
4116+
if (castOrGetParentOfType<omp::TeamsOp>(innermostCapturedOmpOp)) {
4117+
// TODO: Use `hostNumTeamsLower` to initialize `minTeamsVal`. For now, match
4118+
// clang and set min and max to the same value.
4119+
if (hostNumTeamsUpper) {
41244120
if (auto constOp = dyn_cast_if_present<LLVM::ConstantOp>(
4125-
numTeamsClause.getDefiningOp())) {
4121+
hostNumTeamsUpper.getDefiningOp())) {
41264122
if (auto constAttr = dyn_cast<IntegerAttr>(constOp.getValue()))
41274123
minTeamsVal = maxTeamsVal = constAttr.getInt();
41284124
}
@@ -4159,26 +4155,14 @@ static void initTargetDefaultBounds(
41594155

41604156
// Extract THREAD_LIMIT clause from TARGET and TEAMS directives.
41614157
setMaxValueFromClause(targetOp.getThreadLimit(), targetThreadLimitVal);
4162-
4163-
if (auto teamsOp =
4164-
castOrGetParentOfType<omp::TeamsOp>(innermostCapturedOmpOp)) {
4165-
Value threadLimitClause =
4166-
isTargetDevice ? teamsOp.getThreadLimit() : hostThreadLimit;
4167-
setMaxValueFromClause(threadLimitClause, teamsThreadLimitVal);
4168-
}
4158+
setMaxValueFromClause(hostThreadLimit, teamsThreadLimitVal);
41694159

41704160
// Extract MAX_THREADS clause from PARALLEL or set to 1 if it's SIMD.
4171-
if (innermostCapturedOmpOp) {
4172-
if (auto parallelOp =
4173-
castOrGetParentOfType<omp::ParallelOp>(innermostCapturedOmpOp)) {
4174-
Value numThreadsClause =
4175-
isTargetDevice ? parallelOp.getNumThreads() : hostNumThreads;
4176-
setMaxValueFromClause(numThreadsClause, maxThreadsVal);
4177-
} else if (castOrGetParentOfType<omp::SimdOp>(innermostCapturedOmpOp,
4178-
/*immediateParent=*/true)) {
4179-
maxThreadsVal = 1;
4180-
}
4181-
}
4161+
if (castOrGetParentOfType<omp::ParallelOp>(innermostCapturedOmpOp))
4162+
setMaxValueFromClause(hostNumThreads, maxThreadsVal);
4163+
else if (castOrGetParentOfType<omp::SimdOp>(innermostCapturedOmpOp,
4164+
/*immediateParent=*/true))
4165+
maxThreadsVal = 1;
41824166

41834167
// For max values, < 0 means unset, == 0 means set but unknown. Select the
41844168
// minimum value between MAX_THREADS and THREAD_LIMIT clauses that were set.
@@ -4423,7 +4407,7 @@ convertOmpTarget(Operation &opInst, llvm::IRBuilderBase &builder,
44234407

44244408
llvm::SmallVector<llvm::Value *, 4> kernelInput;
44254409
llvm::OpenMPIRBuilder::TargetKernelDefaultBounds defaultBounds;
4426-
initTargetDefaultBounds(targetOp, defaultBounds, isTargetDevice, isGPU);
4410+
initTargetDefaultBounds(targetOp, defaultBounds, isGPU);
44274411

44284412
// Collect host-evaluated values needed to properly launch the kernel from the
44294413
// host.

0 commit comments

Comments
 (0)