From 4033b128d6388c8c3badee3def88478a48e356a2 Mon Sep 17 00:00:00 2001 From: Sergio Afonso Date: Tue, 25 Mar 2025 10:40:47 +0000 Subject: [PATCH] [MLIR][OpenMP] Fix standalone distribute on the device This patch updates the handling of target regions to set trip counts and kernel execution modes properly, based on clang's behavior. This fixes a race condition on `target teams distribute` constructs with no `parallel do` loop inside. This is how kernels are classified, after changes introduced in this patch: ```f90 ! Exec mode: SPMD. ! Trip count: Set. !$omp target teams distribute parallel do do i=... end do ! Exec mode: Generic-SPMD. ! Trip count: Set (outer loop). !$omp target teams distribute do i=... !$omp parallel do private(idx, y) do j=... end do end do ! Exec mode: Generic-SPMD. ! Trip count: Set (outer loop). !$omp target teams distribute do i=... !$omp parallel ... !$omp end parallel end do ! Exec mode: Generic. ! Trip count: Set. !$omp target teams distribute do i=... end do ! Exec mode: SPMD. ! Trip count: Not set. !$omp target parallel do do i=... end do ! Exec mode: Generic. ! Trip count: Not set. !$omp target ... !$omp end target ``` For the split `target teams distribute + parallel do` case, clang produces a Generic kernel which gets promoted to Generic-SPMD by the openmp-opt pass. We can't currently replicate that behavior in flang because our codegen for these constructs results in the introduction of calls to the `kmpc_distribute_static_loop` family of functions, instead of `kmpc_distribute_static_init`, which currently prevent promotion of the kernel to Generic-SPMD. For the time being, instead of relying on the openmp-opt pass, we look at the MLIR representation to find the Generic-SPMD pattern and directly tag the kernel as such during codegen. This is what we were already doing, but incorrectly matching other kinds of kernels as such in the process. --- .../mlir/Dialect/OpenMP/OpenMPEnums.td | 18 ++ mlir/include/mlir/Dialect/OpenMP/OpenMPOps.td | 2 +- mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp | 204 +++++++++++------- .../OpenMP/OpenMPToLLVMIRTranslation.cpp | 16 +- mlir/test/Dialect/OpenMP/invalid.mlir | 4 +- mlir/test/Dialect/OpenMP/ops.mlir | 17 ++ .../LLVMIR/openmp-target-generic-spmd.mlir | 111 ++++++++++ 7 files changed, 285 insertions(+), 87 deletions(-) create mode 100644 mlir/test/Target/LLVMIR/openmp-target-generic-spmd.mlir diff --git a/mlir/include/mlir/Dialect/OpenMP/OpenMPEnums.td b/mlir/include/mlir/Dialect/OpenMP/OpenMPEnums.td index 690e3df1f685e..9dbe6897a3304 100644 --- a/mlir/include/mlir/Dialect/OpenMP/OpenMPEnums.td +++ b/mlir/include/mlir/Dialect/OpenMP/OpenMPEnums.td @@ -222,6 +222,24 @@ def ScheduleModifier : OpenMP_I32EnumAttr< def ScheduleModifierAttr : OpenMP_EnumAttr; +//===----------------------------------------------------------------------===// +// target_region_flags enum. +//===----------------------------------------------------------------------===// + +def TargetRegionFlagsNone : I32BitEnumAttrCaseNone<"none">; +def TargetRegionFlagsGeneric : I32BitEnumAttrCaseBit<"generic", 0>; +def TargetRegionFlagsSpmd : I32BitEnumAttrCaseBit<"spmd", 1>; +def TargetRegionFlagsTripCount : I32BitEnumAttrCaseBit<"trip_count", 2>; + +def TargetRegionFlags : OpenMP_BitEnumAttr< + "TargetRegionFlags", + "target region property flags", [ + TargetRegionFlagsNone, + TargetRegionFlagsGeneric, + TargetRegionFlagsSpmd, + TargetRegionFlagsTripCount + ]>; + //===----------------------------------------------------------------------===// // variable_capture_kind enum. //===----------------------------------------------------------------------===// diff --git a/mlir/include/mlir/Dialect/OpenMP/OpenMPOps.td b/mlir/include/mlir/Dialect/OpenMP/OpenMPOps.td index 65095932be627..11530c0fa3620 100644 --- a/mlir/include/mlir/Dialect/OpenMP/OpenMPOps.td +++ b/mlir/include/mlir/Dialect/OpenMP/OpenMPOps.td @@ -1312,7 +1312,7 @@ def TargetOp : OpenMP_Op<"target", traits = [ /// /// \param capturedOp result of a still valid (no modifications made to any /// nested operations) previous call to `getInnermostCapturedOmpOp()`. - static llvm::omp::OMPTgtExecModeFlags + static ::mlir::omp::TargetRegionFlags getKernelExecFlags(Operation *capturedOp); }] # clausesExtraClassDeclaration; diff --git a/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp b/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp index 4ac9f49f12161..ecadf16e1e9f6 100644 --- a/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp +++ b/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp @@ -1908,8 +1908,8 @@ LogicalResult TargetOp::verifyRegions() { return emitError("target containing multiple 'omp.teams' nested ops"); // Check that host_eval values are only used in legal ways. - llvm::omp::OMPTgtExecModeFlags execFlags = - getKernelExecFlags(getInnermostCapturedOmpOp()); + Operation *capturedOp = getInnermostCapturedOmpOp(); + TargetRegionFlags execFlags = getKernelExecFlags(capturedOp); for (Value hostEvalArg : cast(getOperation()).getHostEvalBlockArgs()) { for (Operation *user : hostEvalArg.getUsers()) { @@ -1924,7 +1924,8 @@ LogicalResult TargetOp::verifyRegions() { "and 'thread_limit' in 'omp.teams'"; } if (auto parallelOp = dyn_cast(user)) { - if (execFlags == llvm::omp::OMP_TGT_EXEC_MODE_SPMD && + if (bitEnumContainsAny(execFlags, TargetRegionFlags::spmd) && + parallelOp->isAncestor(capturedOp) && hostEvalArg == parallelOp.getNumThreads()) continue; @@ -1933,15 +1934,16 @@ LogicalResult TargetOp::verifyRegions() { "'omp.parallel' when representing target SPMD"; } if (auto loopNestOp = dyn_cast(user)) { - if (execFlags != llvm::omp::OMP_TGT_EXEC_MODE_GENERIC && + if (bitEnumContainsAny(execFlags, TargetRegionFlags::trip_count) && + loopNestOp.getOperation() == capturedOp && (llvm::is_contained(loopNestOp.getLoopLowerBounds(), hostEvalArg) || llvm::is_contained(loopNestOp.getLoopUpperBounds(), hostEvalArg) || llvm::is_contained(loopNestOp.getLoopSteps(), hostEvalArg))) continue; return emitOpError() << "host_eval argument only legal as loop bounds " - "and steps in 'omp.loop_nest' when " - "representing target SPMD or Generic-SPMD"; + "and steps in 'omp.loop_nest' when trip count " + "must be evaluated in the host"; } return emitOpError() << "host_eval argument illegal use in '" @@ -1951,33 +1953,12 @@ LogicalResult TargetOp::verifyRegions() { return success(); } -/// Only allow OpenMP terminators and non-OpenMP ops that have known memory -/// effects, but don't include a memory write effect. -static bool siblingAllowedInCapture(Operation *op) { - if (!op) - return false; +static Operation * +findCapturedOmpOp(Operation *rootOp, bool checkSingleMandatoryExec, + llvm::function_ref siblingAllowedFn) { + assert(rootOp && "expected valid operation"); - bool isOmpDialect = - op->getContext()->getLoadedDialect() == - op->getDialect(); - - if (isOmpDialect) - return op->hasTrait(); - - if (auto memOp = dyn_cast(op)) { - SmallVector, 4> effects; - memOp.getEffects(effects); - return !llvm::any_of(effects, [&](MemoryEffects::EffectInstance &effect) { - return isa(effect.getEffect()) && - isa( - effect.getResource()); - }); - } - return true; -} - -Operation *TargetOp::getInnermostCapturedOmpOp() { - Dialect *ompDialect = (*this)->getDialect(); + Dialect *ompDialect = rootOp->getDialect(); Operation *capturedOp = nullptr; DominanceInfo domInfo; @@ -1985,8 +1966,8 @@ Operation *TargetOp::getInnermostCapturedOmpOp() { // ensuring we only enter the region of an operation if it meets the criteria // for being captured. We stop the exploration of nested operations as soon as // we process a region holding no operations to be captured. - walk([&](Operation *op) { - if (op == *this) + rootOp->walk([&](Operation *op) { + if (op == rootOp) return WalkResult::advance(); // Ignore operations of other dialects or omp operations with no regions, @@ -2001,22 +1982,24 @@ Operation *TargetOp::getInnermostCapturedOmpOp() { // (i.e. its block's successors can reach it) or if it's not guaranteed to // be executed before all exits of the region (i.e. it doesn't dominate all // blocks with no successors reachable from the entry block). - Region *parentRegion = op->getParentRegion(); - Block *parentBlock = op->getBlock(); - - for (Block *successor : parentBlock->getSuccessors()) - if (successor->isReachable(parentBlock)) - return WalkResult::interrupt(); - - for (Block &block : *parentRegion) - if (domInfo.isReachableFromEntry(&block) && block.hasNoSuccessors() && - !domInfo.dominates(parentBlock, &block)) - return WalkResult::interrupt(); + if (checkSingleMandatoryExec) { + Region *parentRegion = op->getParentRegion(); + Block *parentBlock = op->getBlock(); + + for (Block *successor : parentBlock->getSuccessors()) + if (successor->isReachable(parentBlock)) + return WalkResult::interrupt(); + + for (Block &block : *parentRegion) + if (domInfo.isReachableFromEntry(&block) && block.hasNoSuccessors() && + !domInfo.dominates(parentBlock, &block)) + return WalkResult::interrupt(); + } // Don't capture this op if it has a not-allowed sibling, and stop recursing // into nested operations. for (Operation &sibling : op->getParentRegion()->getOps()) - if (&sibling != op && !siblingAllowedInCapture(&sibling)) + if (&sibling != op && !siblingAllowedFn(&sibling)) return WalkResult::interrupt(); // Don't continue capturing nested operations if we reach an omp.loop_nest. @@ -2029,10 +2012,35 @@ Operation *TargetOp::getInnermostCapturedOmpOp() { return capturedOp; } -llvm::omp::OMPTgtExecModeFlags -TargetOp::getKernelExecFlags(Operation *capturedOp) { - using namespace llvm::omp; +Operation *TargetOp::getInnermostCapturedOmpOp() { + auto *ompDialect = getContext()->getLoadedDialect(); + + // Only allow OpenMP terminators and non-OpenMP ops that have known memory + // effects, but don't include a memory write effect. + return findCapturedOmpOp( + *this, /*checkSingleMandatoryExec=*/true, [&](Operation *sibling) { + if (!sibling) + return false; + + if (ompDialect == sibling->getDialect()) + return sibling->hasTrait(); + + if (auto memOp = dyn_cast(sibling)) { + SmallVector, 4> + effects; + memOp.getEffects(effects); + return !llvm::any_of( + effects, [&](MemoryEffects::EffectInstance &effect) { + return isa(effect.getEffect()) && + isa( + effect.getResource()); + }); + } + return true; + }); +} +TargetRegionFlags TargetOp::getKernelExecFlags(Operation *capturedOp) { // A non-null captured op is only valid if it resides inside of a TargetOp // and is the result of calling getInnermostCapturedOmpOp() on it. TargetOp targetOp = @@ -2041,60 +2049,94 @@ TargetOp::getKernelExecFlags(Operation *capturedOp) { (targetOp && targetOp.getInnermostCapturedOmpOp() == capturedOp)) && "unexpected captured op"); - // Make sure this region is capturing a loop. Otherwise, it's a generic - // kernel. + // If it's not capturing a loop, it's a default target region. if (!isa_and_present(capturedOp)) - return OMP_TGT_EXEC_MODE_GENERIC; + return TargetRegionFlags::generic; - SmallVector wrappers; - cast(capturedOp).gatherWrappers(wrappers); - assert(!wrappers.empty()); + // Get the innermost non-simd loop wrapper. + SmallVector loopWrappers; + cast(capturedOp).gatherWrappers(loopWrappers); + assert(!loopWrappers.empty()); - // Ignore optional SIMD leaf construct. - auto *innermostWrapper = wrappers.begin(); + LoopWrapperInterface *innermostWrapper = loopWrappers.begin(); if (isa(innermostWrapper)) innermostWrapper = std::next(innermostWrapper); - long numWrappers = std::distance(innermostWrapper, wrappers.end()); - - // Detect Generic-SPMD: target-teams-distribute[-simd]. - // Detect SPMD: target-teams-loop. - if (numWrappers == 1) { - if (!isa(innermostWrapper)) - return OMP_TGT_EXEC_MODE_GENERIC; - - Operation *teamsOp = (*innermostWrapper)->getParentOp(); - if (!isa_and_present(teamsOp)) - return OMP_TGT_EXEC_MODE_GENERIC; + auto numWrappers = std::distance(innermostWrapper, loopWrappers.end()); + if (numWrappers != 1 && numWrappers != 2) + return TargetRegionFlags::generic; - if (teamsOp->getParentOp() == targetOp.getOperation()) - return isa(innermostWrapper) - ? OMP_TGT_EXEC_MODE_GENERIC_SPMD - : OMP_TGT_EXEC_MODE_SPMD; - } - - // Detect SPMD: target-teams-distribute-parallel-wsloop[-simd]. + // Detect target-teams-distribute-parallel-wsloop[-simd]. if (numWrappers == 2) { if (!isa(innermostWrapper)) - return OMP_TGT_EXEC_MODE_GENERIC; + return TargetRegionFlags::generic; innermostWrapper = std::next(innermostWrapper); if (!isa(innermostWrapper)) - return OMP_TGT_EXEC_MODE_GENERIC; + return TargetRegionFlags::generic; Operation *parallelOp = (*innermostWrapper)->getParentOp(); if (!isa_and_present(parallelOp)) - return OMP_TGT_EXEC_MODE_GENERIC; + return TargetRegionFlags::generic; Operation *teamsOp = parallelOp->getParentOp(); if (!isa_and_present(teamsOp)) - return OMP_TGT_EXEC_MODE_GENERIC; + return TargetRegionFlags::generic; if (teamsOp->getParentOp() == targetOp.getOperation()) - return OMP_TGT_EXEC_MODE_SPMD; + return TargetRegionFlags::spmd | TargetRegionFlags::trip_count; + } + // Detect target-teams-distribute[-simd] and target-teams-loop. + else if (isa(innermostWrapper)) { + Operation *teamsOp = (*innermostWrapper)->getParentOp(); + if (!isa_and_present(teamsOp)) + return TargetRegionFlags::generic; + + if (teamsOp->getParentOp() != targetOp.getOperation()) + return TargetRegionFlags::generic; + + if (isa(innermostWrapper)) + return TargetRegionFlags::spmd | TargetRegionFlags::trip_count; + + // Find single immediately nested captured omp.parallel and add spmd flag + // (generic-spmd case). + // + // TODO: This shouldn't have to be done here, as it is too easy to break. + // The openmp-opt pass should be updated to be able to promote kernels like + // this from "Generic" to "Generic-SPMD". However, the use of the + // `kmpc_distribute_static_loop` family of functions produced by the + // OMPIRBuilder for these kernels prevents that from working. + Dialect *ompDialect = targetOp->getDialect(); + Operation *nestedCapture = findCapturedOmpOp( + capturedOp, /*checkSingleMandatoryExec=*/false, + [&](Operation *sibling) { + return sibling && (ompDialect != sibling->getDialect() || + sibling->hasTrait()); + }); + + TargetRegionFlags result = + TargetRegionFlags::generic | TargetRegionFlags::trip_count; + + if (!nestedCapture) + return result; + + while (nestedCapture->getParentOp() != capturedOp) + nestedCapture = nestedCapture->getParentOp(); + + return isa(nestedCapture) ? result | TargetRegionFlags::spmd + : result; + } + // Detect target-parallel-wsloop[-simd]. + else if (isa(innermostWrapper)) { + Operation *parallelOp = (*innermostWrapper)->getParentOp(); + if (!isa_and_present(parallelOp)) + return TargetRegionFlags::generic; + + if (parallelOp->getParentOp() == targetOp.getOperation()) + return TargetRegionFlags::spmd; } - return OMP_TGT_EXEC_MODE_GENERIC; + return TargetRegionFlags::generic; } //===----------------------------------------------------------------------===// diff --git a/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp b/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp index d41489921bd13..4d610d6e2656d 100644 --- a/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp +++ b/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp @@ -4646,7 +4646,17 @@ initTargetDefaultAttrs(omp::TargetOp targetOp, Operation *capturedOp, combinedMaxThreadsVal = maxThreadsVal; // Update kernel bounds structure for the `OpenMPIRBuilder` to use. - attrs.ExecFlags = targetOp.getKernelExecFlags(capturedOp); + omp::TargetRegionFlags kernelFlags = targetOp.getKernelExecFlags(capturedOp); + assert( + omp::bitEnumContainsAny(kernelFlags, omp::TargetRegionFlags::generic | + omp::TargetRegionFlags::spmd) && + "invalid kernel flags"); + attrs.ExecFlags = + omp::bitEnumContainsAny(kernelFlags, omp::TargetRegionFlags::generic) + ? omp::bitEnumContainsAny(kernelFlags, omp::TargetRegionFlags::spmd) + ? llvm::omp::OMP_TGT_EXEC_MODE_GENERIC_SPMD + : llvm::omp::OMP_TGT_EXEC_MODE_GENERIC + : llvm::omp::OMP_TGT_EXEC_MODE_SPMD; attrs.MinTeams = minTeamsVal; attrs.MaxTeams.front() = maxTeamsVal; attrs.MinThreads = 1; @@ -4691,8 +4701,8 @@ initTargetRuntimeAttrs(llvm::IRBuilderBase &builder, if (numThreads) attrs.MaxThreads = moduleTranslation.lookupValue(numThreads); - if (targetOp.getKernelExecFlags(capturedOp) != - llvm::omp::OMP_TGT_EXEC_MODE_GENERIC) { + if (omp::bitEnumContainsAny(targetOp.getKernelExecFlags(capturedOp), + omp::TargetRegionFlags::trip_count)) { llvm::OpenMPIRBuilder *ompBuilder = moduleTranslation.getOpenMPBuilder(); attrs.LoopTripCount = nullptr; diff --git a/mlir/test/Dialect/OpenMP/invalid.mlir b/mlir/test/Dialect/OpenMP/invalid.mlir index 403128bb2300e..bd0541987339a 100644 --- a/mlir/test/Dialect/OpenMP/invalid.mlir +++ b/mlir/test/Dialect/OpenMP/invalid.mlir @@ -2320,7 +2320,7 @@ func.func @omp_target_host_eval_parallel(%x : i32) { // ----- func.func @omp_target_host_eval_loop1(%x : i32) { - // expected-error @below {{op host_eval argument only legal as loop bounds and steps in 'omp.loop_nest' when representing target SPMD or Generic-SPMD}} + // expected-error @below {{op host_eval argument only legal as loop bounds and steps in 'omp.loop_nest' when trip count must be evaluated in the host}} omp.target host_eval(%x -> %arg0 : i32) { omp.wsloop { omp.loop_nest (%iv) : i32 = (%arg0) to (%arg0) step (%arg0) { @@ -2335,7 +2335,7 @@ func.func @omp_target_host_eval_loop1(%x : i32) { // ----- func.func @omp_target_host_eval_loop2(%x : i32) { - // expected-error @below {{op host_eval argument only legal as loop bounds and steps in 'omp.loop_nest' when representing target SPMD or Generic-SPMD}} + // expected-error @below {{op host_eval argument only legal as loop bounds and steps in 'omp.loop_nest' when trip count must be evaluated in the host}} omp.target host_eval(%x -> %arg0 : i32) { omp.teams { ^bb0: diff --git a/mlir/test/Dialect/OpenMP/ops.mlir b/mlir/test/Dialect/OpenMP/ops.mlir index 0a10626cd4877..6bc2500471997 100644 --- a/mlir/test/Dialect/OpenMP/ops.mlir +++ b/mlir/test/Dialect/OpenMP/ops.mlir @@ -2864,6 +2864,23 @@ func.func @omp_target_host_eval(%x : i32) { omp.terminator } + // CHECK: omp.target host_eval(%{{.*}} -> %[[HOST_ARG:.*]] : i32) { + // CHECK: omp.parallel num_threads(%[[HOST_ARG]] : i32) { + // CHECK: omp.wsloop { + // CHECK: omp.loop_nest + omp.target host_eval(%x -> %arg0 : i32) { + %y = arith.constant 2 : i32 + omp.parallel num_threads(%arg0 : i32) { + omp.wsloop { + omp.loop_nest (%iv) : i32 = (%y) to (%y) step (%y) { + omp.yield + } + } + omp.terminator + } + omp.terminator + } + // CHECK: omp.target host_eval(%{{.*}} -> %[[HOST_ARG:.*]] : i32) { // CHECK: omp.teams { // CHECK: omp.distribute { diff --git a/mlir/test/Target/LLVMIR/openmp-target-generic-spmd.mlir b/mlir/test/Target/LLVMIR/openmp-target-generic-spmd.mlir new file mode 100644 index 0000000000000..8101660e571e4 --- /dev/null +++ b/mlir/test/Target/LLVMIR/openmp-target-generic-spmd.mlir @@ -0,0 +1,111 @@ +// RUN: split-file %s %t +// RUN: mlir-translate -mlir-to-llvmir %t/host.mlir | FileCheck %s --check-prefix=HOST +// RUN: mlir-translate -mlir-to-llvmir %t/device.mlir | FileCheck %s --check-prefix=DEVICE + +//--- host.mlir + +module attributes {omp.is_target_device = false, omp.target_triples = ["amdgcn-amd-amdhsa"]} { + llvm.func @main(%arg0 : !llvm.ptr) { + %x = llvm.load %arg0 : !llvm.ptr -> i32 + %0 = omp.map.info var_ptr(%arg0 : !llvm.ptr, i32) map_clauses(to) capture(ByCopy) -> !llvm.ptr + omp.target host_eval(%x -> %lb, %x -> %ub, %x -> %step : i32, i32, i32) map_entries(%0 -> %ptr : !llvm.ptr) { + %x.map = llvm.load %ptr : !llvm.ptr -> i32 + omp.teams { + omp.distribute { + omp.loop_nest (%iv1) : i32 = (%lb) to (%ub) step (%step) { + omp.parallel { + omp.wsloop { + omp.loop_nest (%iv2) : i32 = (%x.map) to (%x.map) step (%x.map) { + omp.yield + } + } + omp.terminator + } + omp.yield + } + } + omp.terminator + } + omp.terminator + } + llvm.return + } +} + +// HOST-LABEL: define void @main +// HOST: %omp_loop.tripcount = {{.*}} +// HOST-NEXT: br label %[[ENTRY:.*]] +// HOST: [[ENTRY]]: +// HOST: %[[TRIPCOUNT:.*]] = zext i32 %omp_loop.tripcount to i64 +// HOST: %[[TRIPCOUNT_KARG:.*]] = getelementptr inbounds nuw %struct.__tgt_kernel_arguments, ptr %[[KARGS:.*]], i32 0, i32 8 +// HOST-NEXT: store i64 %[[TRIPCOUNT]], ptr %[[TRIPCOUNT_KARG]] +// HOST: %[[RESULT:.*]] = call i32 @__tgt_target_kernel({{.*}}, ptr %[[KARGS]]) +// HOST-NEXT: %[[CMP:.*]] = icmp ne i32 %[[RESULT]], 0 +// HOST-NEXT: br i1 %[[CMP]], label %[[OFFLOAD_FAILED:.*]], label %{{.*}} +// HOST: [[OFFLOAD_FAILED]]: +// HOST: call void @[[TARGET_OUTLINE:.*]]({{.*}}) + +// HOST: define internal void @[[TARGET_OUTLINE]] +// HOST: call void{{.*}}@__kmpc_fork_teams({{.*}}, ptr @[[TEAMS_OUTLINE:.*]], {{.*}}) + +// HOST: define internal void @[[TEAMS_OUTLINE]] +// HOST: call void @[[DISTRIBUTE_OUTLINE:.*]]({{.*}}) + +// HOST: define internal void @[[DISTRIBUTE_OUTLINE]] +// HOST: call void @__kmpc_for_static_init{{.*}}(ptr {{.*}}, i32 {{.*}}, i32 92, ptr {{.*}}, ptr {{.*}}, ptr {{.*}}, ptr {{.*}}, i32 {{.*}}, i32 {{.*}}) +// HOST: call void (ptr, i32, ptr, ...) @__kmpc_fork_call({{.*}}, ptr @[[PARALLEL_OUTLINE:.*]], {{.*}}) + +// HOST: define internal void @[[PARALLEL_OUTLINE]] +// HOST: call void @__kmpc_for_static_init{{.*}}(ptr {{.*}}, i32 {{.*}}, i32 34, ptr {{.*}}, ptr {{.*}}, ptr {{.*}}, ptr {{.*}}, i32 {{.*}}, i32 {{.*}}) + +//--- device.mlir + +module attributes {llvm.target_triple = "amdgcn-amd-amdhsa", omp.is_target_device = true, omp.is_gpu = true} { + llvm.func @main(%arg0 : !llvm.ptr) { + %0 = omp.map.info var_ptr(%arg0 : !llvm.ptr, i32) map_clauses(to) capture(ByCopy) -> !llvm.ptr + omp.target map_entries(%0 -> %ptr : !llvm.ptr) { + %x = llvm.load %ptr : !llvm.ptr -> i32 + omp.teams { + omp.distribute { + omp.loop_nest (%iv1) : i32 = (%x) to (%x) step (%x) { + omp.parallel { + omp.wsloop { + omp.loop_nest (%iv2) : i32 = (%x) to (%x) step (%x) { + omp.yield + } + } + omp.terminator + } + omp.yield + } + } + omp.terminator + } + omp.terminator + } + llvm.return + } +} + +// DEVICE: @[[KERNEL_NAME:.*]]_exec_mode = weak protected constant i8 [[EXEC_MODE:3]] +// DEVICE: @llvm.compiler.used = appending global [1 x ptr] [ptr @[[KERNEL_NAME]]_exec_mode], section "llvm.metadata" +// DEVICE: @[[KERNEL_NAME]]_kernel_environment = weak_odr protected constant %struct.KernelEnvironmentTy { +// DEVICE-SAME: %struct.ConfigurationEnvironmentTy { i8 1, i8 1, i8 [[EXEC_MODE]], {{.*}}}, +// DEVICE-SAME: ptr @{{.*}}, ptr @{{.*}} } + +// DEVICE: define weak_odr protected amdgpu_kernel void @[[KERNEL_NAME]]({{.*}}) +// DEVICE: %{{.*}} = call i32 @__kmpc_target_init(ptr @[[KERNEL_NAME]]_kernel_environment, {{.*}}) +// DEVICE: call void @[[TARGET_OUTLINE:.*]]({{.*}}) +// DEVICE: call void @__kmpc_target_deinit() + +// DEVICE: define internal void @[[TARGET_OUTLINE]]({{.*}}) +// DEVICE: call void @[[TEAMS_OUTLINE:.*]]({{.*}}) + +// DEVICE: define internal void @[[TEAMS_OUTLINE]]({{.*}}) +// DEVICE: call void @__kmpc_distribute_static_loop{{.*}}({{.*}}, ptr @[[DISTRIBUTE_OUTLINE:[^,]*]], {{.*}}) + +// DEVICE: define internal void @[[DISTRIBUTE_OUTLINE]]({{.*}}) +// DEVICE: call void @__kmpc_parallel_51(ptr {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, ptr @[[PARALLEL_OUTLINE:.*]], ptr {{.*}}, ptr {{.*}}, i64 {{.*}}) + +// DEVICE: define internal void @[[PARALLEL_OUTLINE]]({{.*}}) +// DEVICE: call void @__kmpc_for_static_loop{{.*}}({{.*}})