Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion llvm/lib/Frontend/OpenMP/OMPIRBuilder.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -6775,7 +6775,8 @@ OpenMPIRBuilder::InsertPointTy OpenMPIRBuilder::createTargetInit(
Constant *Ident = getOrCreateIdent(SrcLocStr, SrcLocStrSize);
Constant *IsSPMDVal = ConstantInt::getSigned(Int8, Attrs.ExecFlags);
Constant *UseGenericStateMachineVal = ConstantInt::getSigned(
Int8, Attrs.ExecFlags != omp::OMP_TGT_EXEC_MODE_SPMD);
Int8, Attrs.ExecFlags != omp::OMP_TGT_EXEC_MODE_SPMD &&
Attrs.ExecFlags != omp::OMP_TGT_EXEC_MODE_SPMD_NO_LOOP);
Constant *MayUseNestedParallelismVal = ConstantInt::getSigned(Int8, true);
Constant *DebugIndentionLevelVal = ConstantInt::getSigned(Int16, 0);

Expand Down
36 changes: 13 additions & 23 deletions mlir/include/mlir/Dialect/OpenMP/OpenMPEnums.td
Original file line number Diff line number Diff line change
Expand Up @@ -223,31 +223,21 @@ def ScheduleModifier : OpenMP_I32EnumAttr<
def ScheduleModifierAttr : OpenMP_EnumAttr<ScheduleModifier, "sched_mod">;

//===----------------------------------------------------------------------===//
// target_region_flags enum.
// target_exec_mode enum.
//===----------------------------------------------------------------------===//

def TargetRegionFlagsNone : I32BitEnumAttrCaseNone<"none">;
def TargetRegionFlagsGeneric : I32BitEnumAttrCaseBit<"generic", 0>;
def TargetRegionFlagsSpmd : I32BitEnumAttrCaseBit<"spmd", 1>;
def TargetRegionFlagsTripCount : I32BitEnumAttrCaseBit<"trip_count", 2>;
def TargetRegionFlagsNoLoop : I32BitEnumAttrCaseBit<"no_loop", 3>;

def TargetRegionFlags : OpenMP_BitEnumAttr<
"TargetRegionFlags",
"These flags describe properties of the target kernel. "
"TargetRegionFlagsGeneric - denotes generic kernel. "
"TargetRegionFlagsSpmd - denotes SPMD kernel. "
"TargetRegionFlagsNoLoop - denotes kernel where "
"num_teams * num_threads >= loop_trip_count. It allows the conversion "
"of loops into sequential code by ensuring that each team/thread "
"executes at most one iteration. "
"TargetRegionFlagsTripCount - checks if the loop trip count should be "
"calculated.", [
TargetRegionFlagsNone,
TargetRegionFlagsGeneric,
TargetRegionFlagsSpmd,
TargetRegionFlagsTripCount,
TargetRegionFlagsNoLoop
def TargetExecModeBare : I32EnumAttrCase<"bare", 0>;
def TargetExecModeGeneric : I32EnumAttrCase<"generic", 1>;
def TargetExecModeSpmd : I32EnumAttrCase<"spmd", 2>;
def TargetExecModeSpmdNoLoop : I32EnumAttrCase<"no_loop", 3>;

def TargetExecMode : OpenMP_I32EnumAttr<
"TargetExecMode",
"target execution mode, mirroring the `OMPTgtExecModeFlags` LLVM enum", [
TargetExecModeBare,
TargetExecModeGeneric,
TargetExecModeSpmd,
TargetExecModeSpmdNoLoop,
]>;

//===----------------------------------------------------------------------===//
Expand Down
12 changes: 8 additions & 4 deletions mlir/include/mlir/Dialect/OpenMP/OpenMPOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -1522,13 +1522,17 @@ def TargetOp : OpenMP_Op<"target", traits = [
/// operations, the top level one will be the one captured.
Operation *getInnermostCapturedOmpOp();

/// Infers the kernel type (Generic, SPMD or Generic-SPMD) based on the
/// contents of the target region.
/// Infers the kernel type (Bare, Generic or SPMD) based on the contents of
/// the target region.
///
/// \param capturedOp result of a still valid (no modifications made to any
/// nested operations) previous call to `getInnermostCapturedOmpOp()`.
static ::mlir::omp::TargetRegionFlags
getKernelExecFlags(Operation *capturedOp);
/// \param hostEvalTripCount output argument to store whether this kernel
/// wraps a loop whose bounds must be evaluated on the host prior to
/// launching it.
static ::mlir::omp::TargetExecMode
getKernelExecFlags(Operation *capturedOp,
bool *hostEvalTripCount = nullptr);
}] # clausesExtraClassDeclaration;

let assemblyFormat = clausesAssemblyFormat # [{
Expand Down
77 changes: 30 additions & 47 deletions mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2205,8 +2205,9 @@ LogicalResult TargetOp::verifyRegions() {
return emitError("target containing multiple 'omp.teams' nested ops");

// Check that host_eval values are only used in legal ways.
bool hostEvalTripCount;
Operation *capturedOp = getInnermostCapturedOmpOp();
TargetRegionFlags execFlags = getKernelExecFlags(capturedOp);
TargetExecMode execMode = getKernelExecFlags(capturedOp, &hostEvalTripCount);
for (Value hostEvalArg :
cast<BlockArgOpenMPOpInterface>(getOperation()).getHostEvalBlockArgs()) {
for (Operation *user : hostEvalArg.getUsers()) {
Expand All @@ -2221,7 +2222,7 @@ LogicalResult TargetOp::verifyRegions() {
"and 'thread_limit' in 'omp.teams'";
}
if (auto parallelOp = dyn_cast<ParallelOp>(user)) {
if (bitEnumContainsAny(execFlags, TargetRegionFlags::spmd) &&
if (execMode == TargetExecMode::spmd &&
parallelOp->isAncestor(capturedOp) &&
hostEvalArg == parallelOp.getNumThreads())
continue;
Expand All @@ -2231,8 +2232,7 @@ LogicalResult TargetOp::verifyRegions() {
"'omp.parallel' when representing target SPMD";
}
if (auto loopNestOp = dyn_cast<LoopNestOp>(user)) {
if (bitEnumContainsAny(execFlags, TargetRegionFlags::trip_count) &&
loopNestOp.getOperation() == capturedOp &&
if (hostEvalTripCount && loopNestOp.getOperation() == capturedOp &&
(llvm::is_contained(loopNestOp.getLoopLowerBounds(), hostEvalArg) ||
llvm::is_contained(loopNestOp.getLoopUpperBounds(), hostEvalArg) ||
llvm::is_contained(loopNestOp.getLoopSteps(), hostEvalArg)))
Expand Down Expand Up @@ -2362,7 +2362,9 @@ static bool canPromoteToNoLoop(Operation *capturedOp, TeamsOp teamsOp,
ompFlags.getAssumeThreadsOversubscription();
}

TargetRegionFlags TargetOp::getKernelExecFlags(Operation *capturedOp) {
TargetExecMode TargetOp::getKernelExecFlags(Operation *capturedOp,
bool *hostEvalTripCount) {
// TODO: Support detection of bare kernel mode.
// A non-null captured op is only valid if it resides inside of a TargetOp
// and is the result of calling getInnermostCapturedOmpOp() on it.
TargetOp targetOp =
Expand All @@ -2371,9 +2373,12 @@ TargetRegionFlags TargetOp::getKernelExecFlags(Operation *capturedOp) {
(targetOp && targetOp.getInnermostCapturedOmpOp() == capturedOp)) &&
"unexpected captured op");

if (hostEvalTripCount)
*hostEvalTripCount = false;

// If it's not capturing a loop, it's a default target region.
if (!isa_and_present<LoopNestOp>(capturedOp))
return TargetRegionFlags::generic;
return TargetExecMode::generic;

// Get the innermost non-simd loop wrapper.
SmallVector<LoopWrapperInterface> loopWrappers;
Expand All @@ -2386,85 +2391,63 @@ TargetRegionFlags TargetOp::getKernelExecFlags(Operation *capturedOp) {

auto numWrappers = std::distance(innermostWrapper, loopWrappers.end());
if (numWrappers != 1 && numWrappers != 2)
return TargetRegionFlags::generic;
return TargetExecMode::generic;

// Detect target-teams-distribute-parallel-wsloop[-simd].
if (numWrappers == 2) {
WsloopOp *wsloopOp = dyn_cast<WsloopOp>(innermostWrapper);
if (!wsloopOp)
return TargetRegionFlags::generic;
return TargetExecMode::generic;

innermostWrapper = std::next(innermostWrapper);
if (!isa<DistributeOp>(innermostWrapper))
return TargetRegionFlags::generic;
return TargetExecMode::generic;

Operation *parallelOp = (*innermostWrapper)->getParentOp();
if (!isa_and_present<ParallelOp>(parallelOp))
return TargetRegionFlags::generic;
return TargetExecMode::generic;

TeamsOp teamsOp = dyn_cast<TeamsOp>(parallelOp->getParentOp());
if (!teamsOp)
return TargetRegionFlags::generic;
return TargetExecMode::generic;

if (teamsOp->getParentOp() == targetOp.getOperation()) {
TargetRegionFlags result =
TargetRegionFlags::spmd | TargetRegionFlags::trip_count;
TargetExecMode result = TargetExecMode::spmd;
if (canPromoteToNoLoop(capturedOp, teamsOp, wsloopOp))
result = result | TargetRegionFlags::no_loop;
result = TargetExecMode::no_loop;
if (hostEvalTripCount)
*hostEvalTripCount = true;
return result;
}
}
// Detect target-teams-distribute[-simd] and target-teams-loop.
else if (isa<DistributeOp, LoopOp>(innermostWrapper)) {
Operation *teamsOp = (*innermostWrapper)->getParentOp();
if (!isa_and_present<TeamsOp>(teamsOp))
return TargetRegionFlags::generic;
return TargetExecMode::generic;

if (teamsOp->getParentOp() != targetOp.getOperation())
return TargetRegionFlags::generic;
return TargetExecMode::generic;

if (isa<LoopOp>(innermostWrapper))
return TargetRegionFlags::spmd | TargetRegionFlags::trip_count;

// Find single immediately nested captured omp.parallel and add spmd flag
// (generic-spmd case).
//
// TODO: This shouldn't have to be done here, as it is too easy to break.
// The openmp-opt pass should be updated to be able to promote kernels like
// this from "Generic" to "Generic-SPMD". However, the use of the
// `kmpc_distribute_static_loop` family of functions produced by the
// OMPIRBuilder for these kernels prevents that from working.
Dialect *ompDialect = targetOp->getDialect();
Operation *nestedCapture = findCapturedOmpOp(
capturedOp, /*checkSingleMandatoryExec=*/false,
[&](Operation *sibling) {
return sibling && (ompDialect != sibling->getDialect() ||
sibling->hasTrait<OpTrait::IsTerminator>());
});

TargetRegionFlags result =
TargetRegionFlags::generic | TargetRegionFlags::trip_count;

if (!nestedCapture)
return result;
if (hostEvalTripCount)
*hostEvalTripCount = true;

while (nestedCapture->getParentOp() != capturedOp)
nestedCapture = nestedCapture->getParentOp();
if (isa<LoopOp>(innermostWrapper))
return TargetExecMode::spmd;

return isa<ParallelOp>(nestedCapture) ? result | TargetRegionFlags::spmd
: result;
return TargetExecMode::generic;
}
// Detect target-parallel-wsloop[-simd].
else if (isa<WsloopOp>(innermostWrapper)) {
Operation *parallelOp = (*innermostWrapper)->getParentOp();
if (!isa_and_present<ParallelOp>(parallelOp))
return TargetRegionFlags::generic;
return TargetExecMode::generic;

if (parallelOp->getParentOp() == targetOp.getOperation())
return TargetRegionFlags::spmd;
return TargetExecMode::spmd;
}

return TargetRegionFlags::generic;
return TargetExecMode::generic;
}

//===----------------------------------------------------------------------===//
Expand Down
44 changes: 19 additions & 25 deletions mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2601,13 +2601,8 @@ convertOmpWsloop(Operation &opInst, llvm::IRBuilderBase &builder,
// for every omp.wsloop nested inside a no-loop SPMD target region, even if
// that loop is not the top-level SPMD one.
if (loopOp == targetCapturedOp) {
omp::TargetRegionFlags kernelFlags =
targetOp.getKernelExecFlags(targetCapturedOp);
if (omp::bitEnumContainsAll(kernelFlags,
omp::TargetRegionFlags::spmd |
omp::TargetRegionFlags::no_loop) &&
!omp::bitEnumContainsAny(kernelFlags,
omp::TargetRegionFlags::generic))
if (targetOp.getKernelExecFlags(targetCapturedOp) ==
omp::TargetExecMode::no_loop)
noLoopMode = true;
}
}
Expand Down Expand Up @@ -5437,23 +5432,21 @@ initTargetDefaultAttrs(omp::TargetOp targetOp, Operation *capturedOp,
}

// Update kernel bounds structure for the `OpenMPIRBuilder` to use.
omp::TargetRegionFlags kernelFlags = targetOp.getKernelExecFlags(capturedOp);
assert(
omp::bitEnumContainsAny(kernelFlags, omp::TargetRegionFlags::generic |
omp::TargetRegionFlags::spmd) &&
"invalid kernel flags");
attrs.ExecFlags =
omp::bitEnumContainsAny(kernelFlags, omp::TargetRegionFlags::generic)
? omp::bitEnumContainsAny(kernelFlags, omp::TargetRegionFlags::spmd)
? llvm::omp::OMP_TGT_EXEC_MODE_GENERIC_SPMD
: llvm::omp::OMP_TGT_EXEC_MODE_GENERIC
: llvm::omp::OMP_TGT_EXEC_MODE_SPMD;
if (omp::bitEnumContainsAll(kernelFlags,
omp::TargetRegionFlags::spmd |
omp::TargetRegionFlags::no_loop) &&
!omp::bitEnumContainsAny(kernelFlags, omp::TargetRegionFlags::generic))
omp::TargetExecMode execMode = targetOp.getKernelExecFlags(capturedOp);
switch (execMode) {
case omp::TargetExecMode::bare:
attrs.ExecFlags = llvm::omp::OMP_TGT_EXEC_MODE_BARE;
break;
case omp::TargetExecMode::generic:
attrs.ExecFlags = llvm::omp::OMP_TGT_EXEC_MODE_GENERIC;
break;
case omp::TargetExecMode::spmd:
attrs.ExecFlags = llvm::omp::OMP_TGT_EXEC_MODE_SPMD;
break;
case omp::TargetExecMode::no_loop:
attrs.ExecFlags = llvm::omp::OMP_TGT_EXEC_MODE_SPMD_NO_LOOP;

break;
}
attrs.MinTeams = minTeamsVal;
attrs.MaxTeams.front() = maxTeamsVal;
attrs.MinThreads = 1;
Expand Down Expand Up @@ -5503,8 +5496,9 @@ initTargetRuntimeAttrs(llvm::IRBuilderBase &builder,
if (numThreads)
attrs.MaxThreads = moduleTranslation.lookupValue(numThreads);

if (omp::bitEnumContainsAny(targetOp.getKernelExecFlags(capturedOp),
omp::TargetRegionFlags::trip_count)) {
bool hostEvalTripCount;
targetOp.getKernelExecFlags(capturedOp, &hostEvalTripCount);
if (hostEvalTripCount) {
llvm::OpenMPIRBuilder *ompBuilder = moduleTranslation.getOpenMPBuilder();
attrs.LoopTripCount = nullptr;

Expand Down
2 changes: 1 addition & 1 deletion mlir/test/Target/LLVMIR/openmp-target-generic-spmd.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -84,7 +84,7 @@ module attributes {dlti.dl_spec = #dlti.dl_spec<#dlti.dl_entry<"dlti.alloca_memo
}
}

// DEVICE: @[[KERNEL_NAME:.*]]_exec_mode = weak protected constant i8 [[EXEC_MODE:3]]
// DEVICE: @[[KERNEL_NAME:.*]]_exec_mode = weak protected constant i8 [[EXEC_MODE:1]]
// DEVICE: @llvm.compiler.used = appending global [1 x ptr] [ptr @[[KERNEL_NAME]]_exec_mode], section "llvm.metadata"
// DEVICE: @[[KERNEL_NAME]]_kernel_environment = weak_odr protected constant %struct.KernelEnvironmentTy {
// DEVICE-SAME: %struct.ConfigurationEnvironmentTy { i8 1, i8 1, i8 [[EXEC_MODE]], {{.*}}},
Expand Down