3838#include " flang/Support/OpenMP-utils.h"
3939#include " mlir/Dialect/ControlFlow/IR/ControlFlowOps.h"
4040#include " mlir/Dialect/OpenMP/OpenMPDialect.h"
41+ #include " mlir/Support/StateStack.h"
4142#include " mlir/Transforms/RegionUtils.h"
4243#include " llvm/ADT/STLExtras.h"
4344#include " llvm/Frontend/OpenMP/OMPConstants.h"
@@ -200,9 +201,41 @@ class HostEvalInfo {
200201// / the handling of the outer region by keeping a stack of information
201202// / structures, but it will probably still require some further work to support
202203// / reverse offloading.
203- static llvm::SmallVector<HostEvalInfo, 0 > hostEvalInfo;
204- static llvm::SmallVector<const parser::OpenMPSectionsConstruct *, 0 >
205- sectionsStack;
204+ class HostEvalInfoStackFrame
205+ : public mlir::StateStackFrameBase<HostEvalInfoStackFrame> {
206+ public:
207+ MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID (HostEvalInfoStackFrame)
208+
209+ HostEvalInfo info;
210+ };
211+
212+ static HostEvalInfo *
213+ getHostEvalInfoStackTop (lower::AbstractConverter &converter) {
214+ HostEvalInfoStackFrame *frame =
215+ converter.getStateStack ().getStackTop <HostEvalInfoStackFrame>();
216+ return frame ? &frame->info : nullptr ;
217+ }
218+
219+ // / Stack frame for storing the OpenMPSectionsConstruct currently being
220+ // / processed so that it can be refered to when lowering the construct.
221+ class SectionsConstructStackFrame
222+ : public mlir::StateStackFrameBase<SectionsConstructStackFrame> {
223+ public:
224+ MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID (SectionsConstructStackFrame)
225+
226+ explicit SectionsConstructStackFrame (
227+ const parser::OpenMPSectionsConstruct §ionsConstruct)
228+ : sectionsConstruct{sectionsConstruct} {}
229+
230+ const parser::OpenMPSectionsConstruct §ionsConstruct;
231+ };
232+
233+ static const parser::OpenMPSectionsConstruct *
234+ getSectionsConstructStackTop (lower::AbstractConverter &converter) {
235+ SectionsConstructStackFrame *frame =
236+ converter.getStateStack ().getStackTop <SectionsConstructStackFrame>();
237+ return frame ? &frame->sectionsConstruct : nullptr ;
238+ }
206239
207240// / Bind symbols to their corresponding entry block arguments.
208241// /
@@ -537,54 +570,55 @@ static void processHostEvalClauses(lower::AbstractConverter &converter,
537570 if (!ompEval)
538571 return ;
539572
540- HostEvalInfo &hostInfo = hostEvalInfo.back ();
573+ HostEvalInfo *hostInfo = getHostEvalInfoStackTop (converter);
574+ assert (hostInfo && " expected HOST_EVAL info structure" );
541575
542576 switch (extractOmpDirective (*ompEval)) {
543577 case OMPD_teams_distribute_parallel_do:
544578 case OMPD_teams_distribute_parallel_do_simd:
545- cp.processThreadLimit (stmtCtx, hostInfo. ops );
579+ cp.processThreadLimit (stmtCtx, hostInfo-> ops );
546580 [[fallthrough]];
547581 case OMPD_target_teams_distribute_parallel_do:
548582 case OMPD_target_teams_distribute_parallel_do_simd:
549- cp.processNumTeams (stmtCtx, hostInfo. ops );
583+ cp.processNumTeams (stmtCtx, hostInfo-> ops );
550584 [[fallthrough]];
551585 case OMPD_distribute_parallel_do:
552586 case OMPD_distribute_parallel_do_simd:
553- cp.processNumThreads (stmtCtx, hostInfo. ops );
587+ cp.processNumThreads (stmtCtx, hostInfo-> ops );
554588 [[fallthrough]];
555589 case OMPD_distribute:
556590 case OMPD_distribute_simd:
557- cp.processCollapse (loc, eval, hostInfo. ops , hostInfo. iv );
591+ cp.processCollapse (loc, eval, hostInfo-> ops , hostInfo-> iv );
558592 break ;
559593
560594 case OMPD_teams:
561- cp.processThreadLimit (stmtCtx, hostInfo. ops );
595+ cp.processThreadLimit (stmtCtx, hostInfo-> ops );
562596 [[fallthrough]];
563597 case OMPD_target_teams:
564- cp.processNumTeams (stmtCtx, hostInfo. ops );
598+ cp.processNumTeams (stmtCtx, hostInfo-> ops );
565599 processSingleNestedIf ([](Directive nestedDir) {
566600 return topDistributeSet.test (nestedDir) || topLoopSet.test (nestedDir);
567601 });
568602 break ;
569603
570604 case OMPD_teams_distribute:
571605 case OMPD_teams_distribute_simd:
572- cp.processThreadLimit (stmtCtx, hostInfo. ops );
606+ cp.processThreadLimit (stmtCtx, hostInfo-> ops );
573607 [[fallthrough]];
574608 case OMPD_target_teams_distribute:
575609 case OMPD_target_teams_distribute_simd:
576- cp.processCollapse (loc, eval, hostInfo. ops , hostInfo. iv );
577- cp.processNumTeams (stmtCtx, hostInfo. ops );
610+ cp.processCollapse (loc, eval, hostInfo-> ops , hostInfo-> iv );
611+ cp.processNumTeams (stmtCtx, hostInfo-> ops );
578612 break ;
579613
580614 case OMPD_teams_loop:
581- cp.processThreadLimit (stmtCtx, hostInfo. ops );
615+ cp.processThreadLimit (stmtCtx, hostInfo-> ops );
582616 [[fallthrough]];
583617 case OMPD_target_teams_loop:
584- cp.processNumTeams (stmtCtx, hostInfo. ops );
618+ cp.processNumTeams (stmtCtx, hostInfo-> ops );
585619 [[fallthrough]];
586620 case OMPD_loop:
587- cp.processCollapse (loc, eval, hostInfo. ops , hostInfo. iv );
621+ cp.processCollapse (loc, eval, hostInfo-> ops , hostInfo-> iv );
588622 break ;
589623
590624 // Standalone 'target' case.
@@ -598,8 +632,6 @@ static void processHostEvalClauses(lower::AbstractConverter &converter,
598632 }
599633 };
600634
601- assert (!hostEvalInfo.empty () && " expected HOST_EVAL info structure" );
602-
603635 const auto *ompEval = eval.getIf <parser::OpenMPConstruct>();
604636 assert (ompEval &&
605637 llvm::omp::allTargetSet.test (extractOmpDirective (*ompEval)) &&
@@ -1468,8 +1500,8 @@ static void genBodyOfTargetOp(
14681500 mlir::Region ®ion = targetOp.getRegion ();
14691501 mlir::Block *entryBlock = genEntryBlock (firOpBuilder, args, region);
14701502 bindEntryBlockArgs (converter, targetOp, args);
1471- if (! hostEvalInfo. empty ( ))
1472- hostEvalInfo. back (). bindOperands (argIface.getHostEvalBlockArgs ());
1503+ if (HostEvalInfo * hostEvalInfo = getHostEvalInfoStackTop (converter ))
1504+ hostEvalInfo-> bindOperands (argIface.getHostEvalBlockArgs ());
14731505
14741506 // Check if cloning the bounds introduced any dependency on the outer region.
14751507 // If so, then either clone them as well if they are MemoryEffectFree, or else
@@ -1708,7 +1740,8 @@ genLoopNestClauses(lower::AbstractConverter &converter,
17081740 llvm::SmallVectorImpl<const semantics::Symbol *> &iv) {
17091741 ClauseProcessor cp (converter, semaCtx, clauses);
17101742
1711- if (hostEvalInfo.empty () || !hostEvalInfo.back ().apply (clauseOps, iv))
1743+ HostEvalInfo *hostEvalInfo = getHostEvalInfoStackTop (converter);
1744+ if (!hostEvalInfo || !hostEvalInfo->apply (clauseOps, iv))
17121745 cp.processCollapse (loc, eval, clauseOps, iv);
17131746
17141747 clauseOps.loopInclusive = converter.getFirOpBuilder ().getUnitAttr ();
@@ -1753,7 +1786,8 @@ static void genParallelClauses(
17531786 cp.processAllocate (clauseOps);
17541787 cp.processIf (llvm::omp::Directive::OMPD_parallel, clauseOps);
17551788
1756- if (hostEvalInfo.empty () || !hostEvalInfo.back ().apply (clauseOps))
1789+ HostEvalInfo *hostEvalInfo = getHostEvalInfoStackTop (converter);
1790+ if (!hostEvalInfo || !hostEvalInfo->apply (clauseOps))
17571791 cp.processNumThreads (stmtCtx, clauseOps);
17581792
17591793 cp.processProcBind (clauseOps);
@@ -1818,16 +1852,17 @@ static void genTargetClauses(
18181852 llvm::SmallVectorImpl<const semantics::Symbol *> &hasDeviceAddrSyms,
18191853 llvm::SmallVectorImpl<const semantics::Symbol *> &isDevicePtrSyms,
18201854 llvm::SmallVectorImpl<const semantics::Symbol *> &mapSyms) {
1855+ HostEvalInfo *hostEvalInfo = getHostEvalInfoStackTop (converter);
18211856 ClauseProcessor cp (converter, semaCtx, clauses);
18221857 cp.processBare (clauseOps);
18231858 cp.processDefaultMap (stmtCtx, defaultMaps);
18241859 cp.processDepend (symTable, stmtCtx, clauseOps);
18251860 cp.processDevice (stmtCtx, clauseOps);
18261861 cp.processHasDeviceAddr (stmtCtx, clauseOps, hasDeviceAddrSyms);
1827- if (! hostEvalInfo. empty () ) {
1862+ if (hostEvalInfo) {
18281863 // Only process host_eval if compiling for the host device.
18291864 processHostEvalClauses (converter, semaCtx, stmtCtx, eval, loc);
1830- hostEvalInfo. back (). collectValues (clauseOps.hostEvalVars );
1865+ hostEvalInfo-> collectValues (clauseOps.hostEvalVars );
18311866 }
18321867 cp.processIf (llvm::omp::Directive::OMPD_target, clauseOps);
18331868 cp.processIsDevicePtr (clauseOps, isDevicePtrSyms);
@@ -1963,7 +1998,8 @@ static void genTeamsClauses(
19631998 cp.processAllocate (clauseOps);
19641999 cp.processIf (llvm::omp::Directive::OMPD_teams, clauseOps);
19652000
1966- if (hostEvalInfo.empty () || !hostEvalInfo.back ().apply (clauseOps)) {
2001+ HostEvalInfo *hostEvalInfo = getHostEvalInfoStackTop (converter);
2002+ if (!hostEvalInfo || !hostEvalInfo->apply (clauseOps)) {
19672003 cp.processNumTeams (stmtCtx, clauseOps);
19682004 cp.processThreadLimit (stmtCtx, clauseOps);
19692005 }
@@ -2224,10 +2260,13 @@ genSectionsOp(lower::AbstractConverter &converter, lower::SymMap &symTable,
22242260 lower::pft::Evaluation &eval, mlir::Location loc,
22252261 const ConstructQueue &queue,
22262262 ConstructQueue::const_iterator item) {
2227- assert (!sectionsStack.empty ());
2263+ const parser::OpenMPSectionsConstruct *sectionsConstruct =
2264+ getSectionsConstructStackTop (converter);
2265+ assert (sectionsConstruct);
2266+
22282267 const auto §ionBlocks =
2229- std::get<parser::OmpSectionBlocks>(sectionsStack. back () ->t );
2230- sectionsStack. pop_back ();
2268+ std::get<parser::OmpSectionBlocks>(sectionsConstruct ->t );
2269+ converter. getStateStack (). stackPop ();
22312270 mlir::omp::SectionsOperands clauseOps;
22322271 llvm::SmallVector<const semantics::Symbol *> reductionSyms;
22332272 genSectionsClauses (converter, semaCtx, item->clauses , loc, clauseOps,
@@ -2381,7 +2420,7 @@ genTargetOp(lower::AbstractConverter &converter, lower::SymMap &symTable,
23812420
23822421 // Introduce a new host_eval information structure for this target region.
23832422 if (!isTargetDevice)
2384- hostEvalInfo. emplace_back ();
2423+ converter. getStateStack (). stackPush <HostEvalInfoStackFrame> ();
23852424
23862425 mlir::omp::TargetOperands clauseOps;
23872426 DefaultMapsTy defaultMaps;
@@ -2508,7 +2547,7 @@ genTargetOp(lower::AbstractConverter &converter, lower::SymMap &symTable,
25082547
25092548 // Remove the host_eval information structure created for this target region.
25102549 if (!isTargetDevice)
2511- hostEvalInfo. pop_back ();
2550+ converter. getStateStack (). stackPop ();
25122551 return targetOp;
25132552}
25142553
@@ -4235,7 +4274,8 @@ static void genOMP(lower::AbstractConverter &converter, lower::SymMap &symTable,
42354274 buildConstructQueue (converter.getFirOpBuilder ().getModule (), semaCtx,
42364275 eval, source, directive, clauses)};
42374276
4238- sectionsStack.push_back (§ionsConstruct);
4277+ converter.getStateStack ().stackPush <SectionsConstructStackFrame>(
4278+ sectionsConstruct);
42394279 genOMPDispatch (converter, symTable, semaCtx, eval, currentLocation, queue,
42404280 queue.begin ());
42414281}
0 commit comments