3939#include " flang/Support/OpenMP-utils.h"
4040#include " mlir/Dialect/ControlFlow/IR/ControlFlowOps.h"
4141#include " mlir/Dialect/OpenMP/OpenMPDialect.h"
42+ #include " mlir/Support/StateStack.h"
4243#include " mlir/Transforms/RegionUtils.h"
4344#include " llvm/ADT/STLExtras.h"
4445#include " llvm/Frontend/OpenMP/OMPConstants.h"
@@ -198,9 +199,41 @@ class HostEvalInfo {
198199// / the handling of the outer region by keeping a stack of information
199200// / structures, but it will probably still require some further work to support
200201// / reverse offloading.
201- static llvm::SmallVector<HostEvalInfo, 0 > hostEvalInfo;
202- static llvm::SmallVector<const parser::OpenMPSectionsConstruct *, 0 >
203- sectionsStack;
202+ class HostEvalInfoStackFrame
203+ : public mlir::StateStackFrameBase<HostEvalInfoStackFrame> {
204+ public:
205+ MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID (HostEvalInfoStackFrame)
206+
207+ HostEvalInfo info;
208+ };
209+
210+ static HostEvalInfo *
211+ getHostEvalInfoStackTop (lower::AbstractConverter &converter) {
212+ HostEvalInfoStackFrame *frame =
213+ converter.getStateStack ().getStackTop <HostEvalInfoStackFrame>();
214+ return frame ? &frame->info : nullptr ;
215+ }
216+
217+ // / Stack frame for storing the OpenMPSectionsConstruct currently being
218+ // / processed so that it can be refered to when lowering the construct.
219+ class SectionsConstructStackFrame
220+ : public mlir::StateStackFrameBase<SectionsConstructStackFrame> {
221+ public:
222+ MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID (SectionsConstructStackFrame)
223+
224+ explicit SectionsConstructStackFrame (
225+ const parser::OpenMPSectionsConstruct §ionsConstruct)
226+ : sectionsConstruct{sectionsConstruct} {}
227+
228+ const parser::OpenMPSectionsConstruct §ionsConstruct;
229+ };
230+
231+ static const parser::OpenMPSectionsConstruct *
232+ getSectionsConstructStackTop (lower::AbstractConverter &converter) {
233+ SectionsConstructStackFrame *frame =
234+ converter.getStateStack ().getStackTop <SectionsConstructStackFrame>();
235+ return frame ? &frame->sectionsConstruct : nullptr ;
236+ }
204237
205238// / Bind symbols to their corresponding entry block arguments.
206239// /
@@ -535,54 +568,55 @@ static void processHostEvalClauses(lower::AbstractConverter &converter,
535568 if (!ompEval)
536569 return ;
537570
538- HostEvalInfo &hostInfo = hostEvalInfo.back ();
571+ HostEvalInfo *hostInfo = getHostEvalInfoStackTop (converter);
572+ assert (hostInfo && " expected HOST_EVAL info structure" );
539573
540574 switch (extractOmpDirective (*ompEval)) {
541575 case OMPD_teams_distribute_parallel_do:
542576 case OMPD_teams_distribute_parallel_do_simd:
543- cp.processThreadLimit (stmtCtx, hostInfo. ops );
577+ cp.processThreadLimit (stmtCtx, hostInfo-> ops );
544578 [[fallthrough]];
545579 case OMPD_target_teams_distribute_parallel_do:
546580 case OMPD_target_teams_distribute_parallel_do_simd:
547- cp.processNumTeams (stmtCtx, hostInfo. ops );
581+ cp.processNumTeams (stmtCtx, hostInfo-> ops );
548582 [[fallthrough]];
549583 case OMPD_distribute_parallel_do:
550584 case OMPD_distribute_parallel_do_simd:
551- cp.processNumThreads (stmtCtx, hostInfo. ops );
585+ cp.processNumThreads (stmtCtx, hostInfo-> ops );
552586 [[fallthrough]];
553587 case OMPD_distribute:
554588 case OMPD_distribute_simd:
555- cp.processCollapse (loc, eval, hostInfo. ops , hostInfo. iv );
589+ cp.processCollapse (loc, eval, hostInfo-> ops , hostInfo-> iv );
556590 break ;
557591
558592 case OMPD_teams:
559- cp.processThreadLimit (stmtCtx, hostInfo. ops );
593+ cp.processThreadLimit (stmtCtx, hostInfo-> ops );
560594 [[fallthrough]];
561595 case OMPD_target_teams:
562- cp.processNumTeams (stmtCtx, hostInfo. ops );
596+ cp.processNumTeams (stmtCtx, hostInfo-> ops );
563597 processSingleNestedIf ([](Directive nestedDir) {
564598 return topDistributeSet.test (nestedDir) || topLoopSet.test (nestedDir);
565599 });
566600 break ;
567601
568602 case OMPD_teams_distribute:
569603 case OMPD_teams_distribute_simd:
570- cp.processThreadLimit (stmtCtx, hostInfo. ops );
604+ cp.processThreadLimit (stmtCtx, hostInfo-> ops );
571605 [[fallthrough]];
572606 case OMPD_target_teams_distribute:
573607 case OMPD_target_teams_distribute_simd:
574- cp.processCollapse (loc, eval, hostInfo. ops , hostInfo. iv );
575- cp.processNumTeams (stmtCtx, hostInfo. ops );
608+ cp.processCollapse (loc, eval, hostInfo-> ops , hostInfo-> iv );
609+ cp.processNumTeams (stmtCtx, hostInfo-> ops );
576610 break ;
577611
578612 case OMPD_teams_loop:
579- cp.processThreadLimit (stmtCtx, hostInfo. ops );
613+ cp.processThreadLimit (stmtCtx, hostInfo-> ops );
580614 [[fallthrough]];
581615 case OMPD_target_teams_loop:
582- cp.processNumTeams (stmtCtx, hostInfo. ops );
616+ cp.processNumTeams (stmtCtx, hostInfo-> ops );
583617 [[fallthrough]];
584618 case OMPD_loop:
585- cp.processCollapse (loc, eval, hostInfo. ops , hostInfo. iv );
619+ cp.processCollapse (loc, eval, hostInfo-> ops , hostInfo-> iv );
586620 break ;
587621
588622 // Standalone 'target' case.
@@ -596,8 +630,6 @@ static void processHostEvalClauses(lower::AbstractConverter &converter,
596630 }
597631 };
598632
599- assert (!hostEvalInfo.empty () && " expected HOST_EVAL info structure" );
600-
601633 const auto *ompEval = eval.getIf <parser::OpenMPConstruct>();
602634 assert (ompEval &&
603635 llvm::omp::allTargetSet.test (extractOmpDirective (*ompEval)) &&
@@ -1456,8 +1488,8 @@ static void genBodyOfTargetOp(
14561488 mlir::Region ®ion = targetOp.getRegion ();
14571489 mlir::Block *entryBlock = genEntryBlock (firOpBuilder, args, region);
14581490 bindEntryBlockArgs (converter, targetOp, args);
1459- if (! hostEvalInfo. empty ( ))
1460- hostEvalInfo. back (). bindOperands (argIface.getHostEvalBlockArgs ());
1491+ if (HostEvalInfo * hostEvalInfo = getHostEvalInfoStackTop (converter ))
1492+ hostEvalInfo-> bindOperands (argIface.getHostEvalBlockArgs ());
14611493
14621494 // Check if cloning the bounds introduced any dependency on the outer region.
14631495 // If so, then either clone them as well if they are MemoryEffectFree, or else
@@ -1696,7 +1728,8 @@ genLoopNestClauses(lower::AbstractConverter &converter,
16961728 llvm::SmallVectorImpl<const semantics::Symbol *> &iv) {
16971729 ClauseProcessor cp (converter, semaCtx, clauses);
16981730
1699- if (hostEvalInfo.empty () || !hostEvalInfo.back ().apply (clauseOps, iv))
1731+ HostEvalInfo *hostEvalInfo = getHostEvalInfoStackTop (converter);
1732+ if (!hostEvalInfo || !hostEvalInfo->apply (clauseOps, iv))
17001733 cp.processCollapse (loc, eval, clauseOps, iv);
17011734
17021735 clauseOps.loopInclusive = converter.getFirOpBuilder ().getUnitAttr ();
@@ -1741,7 +1774,8 @@ static void genParallelClauses(
17411774 cp.processAllocate (clauseOps);
17421775 cp.processIf (llvm::omp::Directive::OMPD_parallel, clauseOps);
17431776
1744- if (hostEvalInfo.empty () || !hostEvalInfo.back ().apply (clauseOps))
1777+ HostEvalInfo *hostEvalInfo = getHostEvalInfoStackTop (converter);
1778+ if (!hostEvalInfo || !hostEvalInfo->apply (clauseOps))
17451779 cp.processNumThreads (stmtCtx, clauseOps);
17461780
17471781 cp.processProcBind (clauseOps);
@@ -1806,16 +1840,17 @@ static void genTargetClauses(
18061840 llvm::SmallVectorImpl<const semantics::Symbol *> &hasDeviceAddrSyms,
18071841 llvm::SmallVectorImpl<const semantics::Symbol *> &isDevicePtrSyms,
18081842 llvm::SmallVectorImpl<const semantics::Symbol *> &mapSyms) {
1843+ HostEvalInfo *hostEvalInfo = getHostEvalInfoStackTop (converter);
18091844 ClauseProcessor cp (converter, semaCtx, clauses);
18101845 cp.processBare (clauseOps);
18111846 cp.processDefaultMap (stmtCtx, defaultMaps);
18121847 cp.processDepend (symTable, stmtCtx, clauseOps);
18131848 cp.processDevice (stmtCtx, clauseOps);
18141849 cp.processHasDeviceAddr (stmtCtx, clauseOps, hasDeviceAddrSyms);
1815- if (! hostEvalInfo. empty () ) {
1850+ if (hostEvalInfo) {
18161851 // Only process host_eval if compiling for the host device.
18171852 processHostEvalClauses (converter, semaCtx, stmtCtx, eval, loc);
1818- hostEvalInfo. back (). collectValues (clauseOps.hostEvalVars );
1853+ hostEvalInfo-> collectValues (clauseOps.hostEvalVars );
18191854 }
18201855 cp.processIf (llvm::omp::Directive::OMPD_target, clauseOps);
18211856 cp.processIsDevicePtr (clauseOps, isDevicePtrSyms);
@@ -1952,7 +1987,8 @@ static void genTeamsClauses(
19521987 cp.processAllocate (clauseOps);
19531988 cp.processIf (llvm::omp::Directive::OMPD_teams, clauseOps);
19541989
1955- if (hostEvalInfo.empty () || !hostEvalInfo.back ().apply (clauseOps)) {
1990+ HostEvalInfo *hostEvalInfo = getHostEvalInfoStackTop (converter);
1991+ if (!hostEvalInfo || !hostEvalInfo->apply (clauseOps)) {
19561992 cp.processNumTeams (stmtCtx, clauseOps);
19571993 cp.processThreadLimit (stmtCtx, clauseOps);
19581994 }
@@ -2213,10 +2249,13 @@ genSectionsOp(lower::AbstractConverter &converter, lower::SymMap &symTable,
22132249 lower::pft::Evaluation &eval, mlir::Location loc,
22142250 const ConstructQueue &queue,
22152251 ConstructQueue::const_iterator item) {
2216- assert (!sectionsStack.empty ());
2252+ const parser::OpenMPSectionsConstruct *sectionsConstruct =
2253+ getSectionsConstructStackTop (converter);
2254+ assert (sectionsConstruct);
2255+
22172256 const auto §ionBlocks =
2218- std::get<parser::OmpSectionBlocks>(sectionsStack. back () ->t );
2219- sectionsStack. pop_back ();
2257+ std::get<parser::OmpSectionBlocks>(sectionsConstruct ->t );
2258+ converter. getStateStack (). stackPop ();
22202259 mlir::omp::SectionsOperands clauseOps;
22212260 llvm::SmallVector<const semantics::Symbol *> reductionSyms;
22222261 genSectionsClauses (converter, semaCtx, item->clauses , loc, clauseOps,
@@ -2370,7 +2409,7 @@ genTargetOp(lower::AbstractConverter &converter, lower::SymMap &symTable,
23702409
23712410 // Introduce a new host_eval information structure for this target region.
23722411 if (!isTargetDevice)
2373- hostEvalInfo. emplace_back ();
2412+ converter. getStateStack (). stackPush <HostEvalInfoStackFrame> ();
23742413
23752414 mlir::omp::TargetOperands clauseOps;
23762415 DefaultMapsTy defaultMaps;
@@ -2497,7 +2536,7 @@ genTargetOp(lower::AbstractConverter &converter, lower::SymMap &symTable,
24972536
24982537 // Remove the host_eval information structure created for this target region.
24992538 if (!isTargetDevice)
2500- hostEvalInfo. pop_back ();
2539+ converter. getStateStack (). stackPop ();
25012540 return targetOp;
25022541}
25032542
@@ -3771,7 +3810,8 @@ static void genOMP(lower::AbstractConverter &converter, lower::SymMap &symTable,
37713810 buildConstructQueue (converter.getFirOpBuilder ().getModule (), semaCtx,
37723811 eval, source, directive, clauses)};
37733812
3774- sectionsStack.push_back (§ionsConstruct);
3813+ converter.getStateStack ().stackPush <SectionsConstructStackFrame>(
3814+ sectionsConstruct);
37753815 genOMPDispatch (converter, symTable, semaCtx, eval, currentLocation, queue,
37763816 queue.begin ());
37773817}
0 commit comments