diff --git a/llvm/lib/Frontend/OpenMP/OMPIRBuilder.cpp b/llvm/lib/Frontend/OpenMP/OMPIRBuilder.cpp index 5980ee35a5cd2..fb501983ca938 100644 --- a/llvm/lib/Frontend/OpenMP/OMPIRBuilder.cpp +++ b/llvm/lib/Frontend/OpenMP/OMPIRBuilder.cpp @@ -6775,7 +6775,8 @@ OpenMPIRBuilder::InsertPointTy OpenMPIRBuilder::createTargetInit( Constant *Ident = getOrCreateIdent(SrcLocStr, SrcLocStrSize); Constant *IsSPMDVal = ConstantInt::getSigned(Int8, Attrs.ExecFlags); Constant *UseGenericStateMachineVal = ConstantInt::getSigned( - Int8, Attrs.ExecFlags != omp::OMP_TGT_EXEC_MODE_SPMD); + Int8, Attrs.ExecFlags != omp::OMP_TGT_EXEC_MODE_SPMD && + Attrs.ExecFlags != omp::OMP_TGT_EXEC_MODE_SPMD_NO_LOOP); Constant *MayUseNestedParallelismVal = ConstantInt::getSigned(Int8, true); Constant *DebugIndentionLevelVal = ConstantInt::getSigned(Int16, 0); diff --git a/mlir/include/mlir/Dialect/OpenMP/OpenMPEnums.td b/mlir/include/mlir/Dialect/OpenMP/OpenMPEnums.td index f693a0737e0fc..95f4e8b0d5673 100644 --- a/mlir/include/mlir/Dialect/OpenMP/OpenMPEnums.td +++ b/mlir/include/mlir/Dialect/OpenMP/OpenMPEnums.td @@ -223,31 +223,21 @@ def ScheduleModifier : OpenMP_I32EnumAttr< def ScheduleModifierAttr : OpenMP_EnumAttr; //===----------------------------------------------------------------------===// -// target_region_flags enum. +// target_exec_mode enum. //===----------------------------------------------------------------------===// -def TargetRegionFlagsNone : I32BitEnumAttrCaseNone<"none">; -def TargetRegionFlagsGeneric : I32BitEnumAttrCaseBit<"generic", 0>; -def TargetRegionFlagsSpmd : I32BitEnumAttrCaseBit<"spmd", 1>; -def TargetRegionFlagsTripCount : I32BitEnumAttrCaseBit<"trip_count", 2>; -def TargetRegionFlagsNoLoop : I32BitEnumAttrCaseBit<"no_loop", 3>; - -def TargetRegionFlags : OpenMP_BitEnumAttr< - "TargetRegionFlags", - "These flags describe properties of the target kernel. " - "TargetRegionFlagsGeneric - denotes generic kernel. " - "TargetRegionFlagsSpmd - denotes SPMD kernel. " - "TargetRegionFlagsNoLoop - denotes kernel where " - "num_teams * num_threads >= loop_trip_count. It allows the conversion " - "of loops into sequential code by ensuring that each team/thread " - "executes at most one iteration. " - "TargetRegionFlagsTripCount - checks if the loop trip count should be " - "calculated.", [ - TargetRegionFlagsNone, - TargetRegionFlagsGeneric, - TargetRegionFlagsSpmd, - TargetRegionFlagsTripCount, - TargetRegionFlagsNoLoop +def TargetExecModeBare : I32EnumAttrCase<"bare", 0>; +def TargetExecModeGeneric : I32EnumAttrCase<"generic", 1>; +def TargetExecModeSpmd : I32EnumAttrCase<"spmd", 2>; +def TargetExecModeSpmdNoLoop : I32EnumAttrCase<"no_loop", 3>; + +def TargetExecMode : OpenMP_I32EnumAttr< + "TargetExecMode", + "target execution mode, mirroring the `OMPTgtExecModeFlags` LLVM enum", [ + TargetExecModeBare, + TargetExecModeGeneric, + TargetExecModeSpmd, + TargetExecModeSpmdNoLoop, ]>; //===----------------------------------------------------------------------===// diff --git a/mlir/include/mlir/Dialect/OpenMP/OpenMPOps.td b/mlir/include/mlir/Dialect/OpenMP/OpenMPOps.td index 5c77e215467e4..9003fb2ef7959 100644 --- a/mlir/include/mlir/Dialect/OpenMP/OpenMPOps.td +++ b/mlir/include/mlir/Dialect/OpenMP/OpenMPOps.td @@ -1522,13 +1522,17 @@ def TargetOp : OpenMP_Op<"target", traits = [ /// operations, the top level one will be the one captured. Operation *getInnermostCapturedOmpOp(); - /// Infers the kernel type (Generic, SPMD or Generic-SPMD) based on the - /// contents of the target region. + /// Infers the kernel type (Bare, Generic or SPMD) based on the contents of + /// the target region. /// /// \param capturedOp result of a still valid (no modifications made to any /// nested operations) previous call to `getInnermostCapturedOmpOp()`. - static ::mlir::omp::TargetRegionFlags - getKernelExecFlags(Operation *capturedOp); + /// \param hostEvalTripCount output argument to store whether this kernel + /// wraps a loop whose bounds must be evaluated on the host prior to + /// launching it. + static ::mlir::omp::TargetExecMode + getKernelExecFlags(Operation *capturedOp, + bool *hostEvalTripCount = nullptr); }] # clausesExtraClassDeclaration; let assemblyFormat = clausesAssemblyFormat # [{ diff --git a/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp b/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp index 32ebe06e240db..8640c4ba0b757 100644 --- a/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp +++ b/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp @@ -2205,8 +2205,9 @@ LogicalResult TargetOp::verifyRegions() { return emitError("target containing multiple 'omp.teams' nested ops"); // Check that host_eval values are only used in legal ways. + bool hostEvalTripCount; Operation *capturedOp = getInnermostCapturedOmpOp(); - TargetRegionFlags execFlags = getKernelExecFlags(capturedOp); + TargetExecMode execMode = getKernelExecFlags(capturedOp, &hostEvalTripCount); for (Value hostEvalArg : cast(getOperation()).getHostEvalBlockArgs()) { for (Operation *user : hostEvalArg.getUsers()) { @@ -2221,7 +2222,7 @@ LogicalResult TargetOp::verifyRegions() { "and 'thread_limit' in 'omp.teams'"; } if (auto parallelOp = dyn_cast(user)) { - if (bitEnumContainsAny(execFlags, TargetRegionFlags::spmd) && + if (execMode == TargetExecMode::spmd && parallelOp->isAncestor(capturedOp) && hostEvalArg == parallelOp.getNumThreads()) continue; @@ -2231,8 +2232,7 @@ LogicalResult TargetOp::verifyRegions() { "'omp.parallel' when representing target SPMD"; } if (auto loopNestOp = dyn_cast(user)) { - if (bitEnumContainsAny(execFlags, TargetRegionFlags::trip_count) && - loopNestOp.getOperation() == capturedOp && + if (hostEvalTripCount && loopNestOp.getOperation() == capturedOp && (llvm::is_contained(loopNestOp.getLoopLowerBounds(), hostEvalArg) || llvm::is_contained(loopNestOp.getLoopUpperBounds(), hostEvalArg) || llvm::is_contained(loopNestOp.getLoopSteps(), hostEvalArg))) @@ -2362,7 +2362,9 @@ static bool canPromoteToNoLoop(Operation *capturedOp, TeamsOp teamsOp, ompFlags.getAssumeThreadsOversubscription(); } -TargetRegionFlags TargetOp::getKernelExecFlags(Operation *capturedOp) { +TargetExecMode TargetOp::getKernelExecFlags(Operation *capturedOp, + bool *hostEvalTripCount) { + // TODO: Support detection of bare kernel mode. // A non-null captured op is only valid if it resides inside of a TargetOp // and is the result of calling getInnermostCapturedOmpOp() on it. TargetOp targetOp = @@ -2371,9 +2373,12 @@ TargetRegionFlags TargetOp::getKernelExecFlags(Operation *capturedOp) { (targetOp && targetOp.getInnermostCapturedOmpOp() == capturedOp)) && "unexpected captured op"); + if (hostEvalTripCount) + *hostEvalTripCount = false; + // If it's not capturing a loop, it's a default target region. if (!isa_and_present(capturedOp)) - return TargetRegionFlags::generic; + return TargetExecMode::generic; // Get the innermost non-simd loop wrapper. SmallVector loopWrappers; @@ -2386,31 +2391,32 @@ TargetRegionFlags TargetOp::getKernelExecFlags(Operation *capturedOp) { auto numWrappers = std::distance(innermostWrapper, loopWrappers.end()); if (numWrappers != 1 && numWrappers != 2) - return TargetRegionFlags::generic; + return TargetExecMode::generic; // Detect target-teams-distribute-parallel-wsloop[-simd]. if (numWrappers == 2) { WsloopOp *wsloopOp = dyn_cast(innermostWrapper); if (!wsloopOp) - return TargetRegionFlags::generic; + return TargetExecMode::generic; innermostWrapper = std::next(innermostWrapper); if (!isa(innermostWrapper)) - return TargetRegionFlags::generic; + return TargetExecMode::generic; Operation *parallelOp = (*innermostWrapper)->getParentOp(); if (!isa_and_present(parallelOp)) - return TargetRegionFlags::generic; + return TargetExecMode::generic; TeamsOp teamsOp = dyn_cast(parallelOp->getParentOp()); if (!teamsOp) - return TargetRegionFlags::generic; + return TargetExecMode::generic; if (teamsOp->getParentOp() == targetOp.getOperation()) { - TargetRegionFlags result = - TargetRegionFlags::spmd | TargetRegionFlags::trip_count; + TargetExecMode result = TargetExecMode::spmd; if (canPromoteToNoLoop(capturedOp, teamsOp, wsloopOp)) - result = result | TargetRegionFlags::no_loop; + result = TargetExecMode::no_loop; + if (hostEvalTripCount) + *hostEvalTripCount = true; return result; } } @@ -2418,53 +2424,30 @@ TargetRegionFlags TargetOp::getKernelExecFlags(Operation *capturedOp) { else if (isa(innermostWrapper)) { Operation *teamsOp = (*innermostWrapper)->getParentOp(); if (!isa_and_present(teamsOp)) - return TargetRegionFlags::generic; + return TargetExecMode::generic; if (teamsOp->getParentOp() != targetOp.getOperation()) - return TargetRegionFlags::generic; + return TargetExecMode::generic; - if (isa(innermostWrapper)) - return TargetRegionFlags::spmd | TargetRegionFlags::trip_count; - - // Find single immediately nested captured omp.parallel and add spmd flag - // (generic-spmd case). - // - // TODO: This shouldn't have to be done here, as it is too easy to break. - // The openmp-opt pass should be updated to be able to promote kernels like - // this from "Generic" to "Generic-SPMD". However, the use of the - // `kmpc_distribute_static_loop` family of functions produced by the - // OMPIRBuilder for these kernels prevents that from working. - Dialect *ompDialect = targetOp->getDialect(); - Operation *nestedCapture = findCapturedOmpOp( - capturedOp, /*checkSingleMandatoryExec=*/false, - [&](Operation *sibling) { - return sibling && (ompDialect != sibling->getDialect() || - sibling->hasTrait()); - }); - - TargetRegionFlags result = - TargetRegionFlags::generic | TargetRegionFlags::trip_count; - - if (!nestedCapture) - return result; + if (hostEvalTripCount) + *hostEvalTripCount = true; - while (nestedCapture->getParentOp() != capturedOp) - nestedCapture = nestedCapture->getParentOp(); + if (isa(innermostWrapper)) + return TargetExecMode::spmd; - return isa(nestedCapture) ? result | TargetRegionFlags::spmd - : result; + return TargetExecMode::generic; } // Detect target-parallel-wsloop[-simd]. else if (isa(innermostWrapper)) { Operation *parallelOp = (*innermostWrapper)->getParentOp(); if (!isa_and_present(parallelOp)) - return TargetRegionFlags::generic; + return TargetExecMode::generic; if (parallelOp->getParentOp() == targetOp.getOperation()) - return TargetRegionFlags::spmd; + return TargetExecMode::spmd; } - return TargetRegionFlags::generic; + return TargetExecMode::generic; } //===----------------------------------------------------------------------===// diff --git a/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp b/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp index 53209a40665ae..172029196905d 100644 --- a/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp +++ b/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp @@ -2601,13 +2601,8 @@ convertOmpWsloop(Operation &opInst, llvm::IRBuilderBase &builder, // for every omp.wsloop nested inside a no-loop SPMD target region, even if // that loop is not the top-level SPMD one. if (loopOp == targetCapturedOp) { - omp::TargetRegionFlags kernelFlags = - targetOp.getKernelExecFlags(targetCapturedOp); - if (omp::bitEnumContainsAll(kernelFlags, - omp::TargetRegionFlags::spmd | - omp::TargetRegionFlags::no_loop) && - !omp::bitEnumContainsAny(kernelFlags, - omp::TargetRegionFlags::generic)) + if (targetOp.getKernelExecFlags(targetCapturedOp) == + omp::TargetExecMode::no_loop) noLoopMode = true; } } @@ -5437,23 +5432,21 @@ initTargetDefaultAttrs(omp::TargetOp targetOp, Operation *capturedOp, } // Update kernel bounds structure for the `OpenMPIRBuilder` to use. - omp::TargetRegionFlags kernelFlags = targetOp.getKernelExecFlags(capturedOp); - assert( - omp::bitEnumContainsAny(kernelFlags, omp::TargetRegionFlags::generic | - omp::TargetRegionFlags::spmd) && - "invalid kernel flags"); - attrs.ExecFlags = - omp::bitEnumContainsAny(kernelFlags, omp::TargetRegionFlags::generic) - ? omp::bitEnumContainsAny(kernelFlags, omp::TargetRegionFlags::spmd) - ? llvm::omp::OMP_TGT_EXEC_MODE_GENERIC_SPMD - : llvm::omp::OMP_TGT_EXEC_MODE_GENERIC - : llvm::omp::OMP_TGT_EXEC_MODE_SPMD; - if (omp::bitEnumContainsAll(kernelFlags, - omp::TargetRegionFlags::spmd | - omp::TargetRegionFlags::no_loop) && - !omp::bitEnumContainsAny(kernelFlags, omp::TargetRegionFlags::generic)) + omp::TargetExecMode execMode = targetOp.getKernelExecFlags(capturedOp); + switch (execMode) { + case omp::TargetExecMode::bare: + attrs.ExecFlags = llvm::omp::OMP_TGT_EXEC_MODE_BARE; + break; + case omp::TargetExecMode::generic: + attrs.ExecFlags = llvm::omp::OMP_TGT_EXEC_MODE_GENERIC; + break; + case omp::TargetExecMode::spmd: + attrs.ExecFlags = llvm::omp::OMP_TGT_EXEC_MODE_SPMD; + break; + case omp::TargetExecMode::no_loop: attrs.ExecFlags = llvm::omp::OMP_TGT_EXEC_MODE_SPMD_NO_LOOP; - + break; + } attrs.MinTeams = minTeamsVal; attrs.MaxTeams.front() = maxTeamsVal; attrs.MinThreads = 1; @@ -5503,8 +5496,9 @@ initTargetRuntimeAttrs(llvm::IRBuilderBase &builder, if (numThreads) attrs.MaxThreads = moduleTranslation.lookupValue(numThreads); - if (omp::bitEnumContainsAny(targetOp.getKernelExecFlags(capturedOp), - omp::TargetRegionFlags::trip_count)) { + bool hostEvalTripCount; + targetOp.getKernelExecFlags(capturedOp, &hostEvalTripCount); + if (hostEvalTripCount) { llvm::OpenMPIRBuilder *ompBuilder = moduleTranslation.getOpenMPBuilder(); attrs.LoopTripCount = nullptr; diff --git a/mlir/test/Target/LLVMIR/openmp-target-generic-spmd.mlir b/mlir/test/Target/LLVMIR/openmp-target-generic-spmd.mlir index 504d91b1f6198..6084a33fac8aa 100644 --- a/mlir/test/Target/LLVMIR/openmp-target-generic-spmd.mlir +++ b/mlir/test/Target/LLVMIR/openmp-target-generic-spmd.mlir @@ -84,7 +84,7 @@ module attributes {dlti.dl_spec = #dlti.dl_spec<#dlti.dl_entry<"dlti.alloca_memo } } -// DEVICE: @[[KERNEL_NAME:.*]]_exec_mode = weak protected constant i8 [[EXEC_MODE:3]] +// DEVICE: @[[KERNEL_NAME:.*]]_exec_mode = weak protected constant i8 [[EXEC_MODE:1]] // DEVICE: @llvm.compiler.used = appending global [1 x ptr] [ptr @[[KERNEL_NAME]]_exec_mode], section "llvm.metadata" // DEVICE: @[[KERNEL_NAME]]_kernel_environment = weak_odr protected constant %struct.KernelEnvironmentTy { // DEVICE-SAME: %struct.ConfigurationEnvironmentTy { i8 1, i8 1, i8 [[EXEC_MODE]], {{.*}}},