diff --git a/flang/include/flang/Lower/AbstractConverter.h b/flang/include/flang/Lower/AbstractConverter.h index 8ae68e143cd2f..de3e833f60699 100644 --- a/flang/include/flang/Lower/AbstractConverter.h +++ b/flang/include/flang/Lower/AbstractConverter.h @@ -26,6 +26,7 @@ namespace mlir { class SymbolTable; +class StateStack; } namespace fir { @@ -361,6 +362,8 @@ class AbstractConverter { /// functions in order to be in sync). virtual mlir::SymbolTable *getMLIRSymbolTable() = 0; + virtual mlir::StateStack &getStateStack() = 0; + private: /// Options controlling lowering behavior. const Fortran::lower::LoweringOptions &loweringOptions; diff --git a/flang/lib/Lower/Bridge.cpp b/flang/lib/Lower/Bridge.cpp index 64b16b3abe991..8506b9a984e58 100644 --- a/flang/lib/Lower/Bridge.cpp +++ b/flang/lib/Lower/Bridge.cpp @@ -69,6 +69,7 @@ #include "mlir/IR/Matchers.h" #include "mlir/IR/PatternMatch.h" #include "mlir/Parser/Parser.h" +#include "mlir/Support/StateStack.h" #include "mlir/Transforms/RegionUtils.h" #include "llvm/ADT/SmallVector.h" #include "llvm/ADT/StringSet.h" @@ -1237,6 +1238,8 @@ class FirConverter : public Fortran::lower::AbstractConverter { mlir::SymbolTable *getMLIRSymbolTable() override { return &mlirSymbolTable; } + mlir::StateStack &getStateStack() override { return stateStack; } + /// Add the symbol to the local map and return `true`. If the symbol is /// already in the map and \p forced is `false`, the map is not updated. /// Instead the value `false` is returned. @@ -6552,6 +6555,9 @@ class FirConverter : public Fortran::lower::AbstractConverter { /// attribute since mlirSymbolTable must pro-actively be maintained when /// new Symbol operations are created. mlir::SymbolTable mlirSymbolTable; + + /// Used to store context while recursing into regions during lowering. + mlir::StateStack stateStack; }; } // namespace diff --git a/flang/lib/Lower/OpenMP/OpenMP.cpp b/flang/lib/Lower/OpenMP/OpenMP.cpp index ebd1d038716e4..60b6366c184d4 100644 --- a/flang/lib/Lower/OpenMP/OpenMP.cpp +++ b/flang/lib/Lower/OpenMP/OpenMP.cpp @@ -39,6 +39,7 @@ #include "flang/Support/OpenMP-utils.h" #include "mlir/Dialect/ControlFlow/IR/ControlFlowOps.h" #include "mlir/Dialect/OpenMP/OpenMPDialect.h" +#include "mlir/Support/StateStack.h" #include "mlir/Transforms/RegionUtils.h" #include "llvm/ADT/STLExtras.h" #include "llvm/Frontend/OpenMP/OMPConstants.h" @@ -198,9 +199,41 @@ class HostEvalInfo { /// the handling of the outer region by keeping a stack of information /// structures, but it will probably still require some further work to support /// reverse offloading. -static llvm::SmallVector hostEvalInfo; -static llvm::SmallVector - sectionsStack; +class HostEvalInfoStackFrame + : public mlir::StateStackFrameBase { +public: + MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(HostEvalInfoStackFrame) + + HostEvalInfo info; +}; + +static HostEvalInfo * +getHostEvalInfoStackTop(lower::AbstractConverter &converter) { + HostEvalInfoStackFrame *frame = + converter.getStateStack().getStackTop(); + return frame ? &frame->info : nullptr; +} + +/// Stack frame for storing the OpenMPSectionsConstruct currently being +/// processed so that it can be referred to when lowering the construct. +class SectionsConstructStackFrame + : public mlir::StateStackFrameBase { +public: + MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(SectionsConstructStackFrame) + + explicit SectionsConstructStackFrame( + const parser::OpenMPSectionsConstruct §ionsConstruct) + : sectionsConstruct{sectionsConstruct} {} + + const parser::OpenMPSectionsConstruct §ionsConstruct; +}; + +static const parser::OpenMPSectionsConstruct * +getSectionsConstructStackTop(lower::AbstractConverter &converter) { + SectionsConstructStackFrame *frame = + converter.getStateStack().getStackTop(); + return frame ? &frame->sectionsConstruct : nullptr; +} /// Bind symbols to their corresponding entry block arguments. /// @@ -535,31 +568,32 @@ static void processHostEvalClauses(lower::AbstractConverter &converter, if (!ompEval) return; - HostEvalInfo &hostInfo = hostEvalInfo.back(); + HostEvalInfo *hostInfo = getHostEvalInfoStackTop(converter); + assert(hostInfo && "expected HOST_EVAL info structure"); switch (extractOmpDirective(*ompEval)) { case OMPD_teams_distribute_parallel_do: case OMPD_teams_distribute_parallel_do_simd: - cp.processThreadLimit(stmtCtx, hostInfo.ops); + cp.processThreadLimit(stmtCtx, hostInfo->ops); [[fallthrough]]; case OMPD_target_teams_distribute_parallel_do: case OMPD_target_teams_distribute_parallel_do_simd: - cp.processNumTeams(stmtCtx, hostInfo.ops); + cp.processNumTeams(stmtCtx, hostInfo->ops); [[fallthrough]]; case OMPD_distribute_parallel_do: case OMPD_distribute_parallel_do_simd: - cp.processNumThreads(stmtCtx, hostInfo.ops); + cp.processNumThreads(stmtCtx, hostInfo->ops); [[fallthrough]]; case OMPD_distribute: case OMPD_distribute_simd: - cp.processCollapse(loc, eval, hostInfo.ops, hostInfo.iv); + cp.processCollapse(loc, eval, hostInfo->ops, hostInfo->iv); break; case OMPD_teams: - cp.processThreadLimit(stmtCtx, hostInfo.ops); + cp.processThreadLimit(stmtCtx, hostInfo->ops); [[fallthrough]]; case OMPD_target_teams: - cp.processNumTeams(stmtCtx, hostInfo.ops); + cp.processNumTeams(stmtCtx, hostInfo->ops); processSingleNestedIf([](Directive nestedDir) { return topDistributeSet.test(nestedDir) || topLoopSet.test(nestedDir); }); @@ -567,22 +601,22 @@ static void processHostEvalClauses(lower::AbstractConverter &converter, case OMPD_teams_distribute: case OMPD_teams_distribute_simd: - cp.processThreadLimit(stmtCtx, hostInfo.ops); + cp.processThreadLimit(stmtCtx, hostInfo->ops); [[fallthrough]]; case OMPD_target_teams_distribute: case OMPD_target_teams_distribute_simd: - cp.processCollapse(loc, eval, hostInfo.ops, hostInfo.iv); - cp.processNumTeams(stmtCtx, hostInfo.ops); + cp.processCollapse(loc, eval, hostInfo->ops, hostInfo->iv); + cp.processNumTeams(stmtCtx, hostInfo->ops); break; case OMPD_teams_loop: - cp.processThreadLimit(stmtCtx, hostInfo.ops); + cp.processThreadLimit(stmtCtx, hostInfo->ops); [[fallthrough]]; case OMPD_target_teams_loop: - cp.processNumTeams(stmtCtx, hostInfo.ops); + cp.processNumTeams(stmtCtx, hostInfo->ops); [[fallthrough]]; case OMPD_loop: - cp.processCollapse(loc, eval, hostInfo.ops, hostInfo.iv); + cp.processCollapse(loc, eval, hostInfo->ops, hostInfo->iv); break; // Standalone 'target' case. @@ -596,8 +630,6 @@ static void processHostEvalClauses(lower::AbstractConverter &converter, } }; - assert(!hostEvalInfo.empty() && "expected HOST_EVAL info structure"); - const auto *ompEval = eval.getIf(); assert(ompEval && llvm::omp::allTargetSet.test(extractOmpDirective(*ompEval)) && @@ -1456,8 +1488,8 @@ static void genBodyOfTargetOp( mlir::Region ®ion = targetOp.getRegion(); mlir::Block *entryBlock = genEntryBlock(firOpBuilder, args, region); bindEntryBlockArgs(converter, targetOp, args); - if (!hostEvalInfo.empty()) - hostEvalInfo.back().bindOperands(argIface.getHostEvalBlockArgs()); + if (HostEvalInfo *hostEvalInfo = getHostEvalInfoStackTop(converter)) + hostEvalInfo->bindOperands(argIface.getHostEvalBlockArgs()); // Check if cloning the bounds introduced any dependency on the outer region. // If so, then either clone them as well if they are MemoryEffectFree, or else @@ -1696,7 +1728,8 @@ genLoopNestClauses(lower::AbstractConverter &converter, llvm::SmallVectorImpl &iv) { ClauseProcessor cp(converter, semaCtx, clauses); - if (hostEvalInfo.empty() || !hostEvalInfo.back().apply(clauseOps, iv)) + HostEvalInfo *hostEvalInfo = getHostEvalInfoStackTop(converter); + if (!hostEvalInfo || !hostEvalInfo->apply(clauseOps, iv)) cp.processCollapse(loc, eval, clauseOps, iv); clauseOps.loopInclusive = converter.getFirOpBuilder().getUnitAttr(); @@ -1741,7 +1774,8 @@ static void genParallelClauses( cp.processAllocate(clauseOps); cp.processIf(llvm::omp::Directive::OMPD_parallel, clauseOps); - if (hostEvalInfo.empty() || !hostEvalInfo.back().apply(clauseOps)) + HostEvalInfo *hostEvalInfo = getHostEvalInfoStackTop(converter); + if (!hostEvalInfo || !hostEvalInfo->apply(clauseOps)) cp.processNumThreads(stmtCtx, clauseOps); cp.processProcBind(clauseOps); @@ -1812,10 +1846,10 @@ static void genTargetClauses( cp.processDepend(symTable, stmtCtx, clauseOps); cp.processDevice(stmtCtx, clauseOps); cp.processHasDeviceAddr(stmtCtx, clauseOps, hasDeviceAddrSyms); - if (!hostEvalInfo.empty()) { + if (HostEvalInfo *hostEvalInfo = getHostEvalInfoStackTop(converter)) { // Only process host_eval if compiling for the host device. processHostEvalClauses(converter, semaCtx, stmtCtx, eval, loc); - hostEvalInfo.back().collectValues(clauseOps.hostEvalVars); + hostEvalInfo->collectValues(clauseOps.hostEvalVars); } cp.processIf(llvm::omp::Directive::OMPD_target, clauseOps); cp.processIsDevicePtr(clauseOps, isDevicePtrSyms); @@ -1952,7 +1986,8 @@ static void genTeamsClauses( cp.processAllocate(clauseOps); cp.processIf(llvm::omp::Directive::OMPD_teams, clauseOps); - if (hostEvalInfo.empty() || !hostEvalInfo.back().apply(clauseOps)) { + HostEvalInfo *hostEvalInfo = getHostEvalInfoStackTop(converter); + if (!hostEvalInfo || !hostEvalInfo->apply(clauseOps)) { cp.processNumTeams(stmtCtx, clauseOps); cp.processThreadLimit(stmtCtx, clauseOps); } @@ -2204,19 +2239,18 @@ genScanOp(lower::AbstractConverter &converter, lower::SymMap &symTable, converter.getCurrentLocation(), clauseOps); } -/// This breaks the normal prototype of the gen*Op functions: adding the -/// sectionBlocks argument so that the enclosed section constructs can be -/// lowered here with correct reduction symbol remapping. static mlir::omp::SectionsOp genSectionsOp(lower::AbstractConverter &converter, lower::SymMap &symTable, semantics::SemanticsContext &semaCtx, lower::pft::Evaluation &eval, mlir::Location loc, const ConstructQueue &queue, ConstructQueue::const_iterator item) { - assert(!sectionsStack.empty()); + const parser::OpenMPSectionsConstruct *sectionsConstruct = + getSectionsConstructStackTop(converter); + assert(sectionsConstruct && "Missing additional parsing information"); + const auto §ionBlocks = - std::get(sectionsStack.back()->t); - sectionsStack.pop_back(); + std::get(sectionsConstruct->t); mlir::omp::SectionsOperands clauseOps; llvm::SmallVector reductionSyms; genSectionsClauses(converter, semaCtx, item->clauses, loc, clauseOps, @@ -2370,7 +2404,7 @@ genTargetOp(lower::AbstractConverter &converter, lower::SymMap &symTable, // Introduce a new host_eval information structure for this target region. if (!isTargetDevice) - hostEvalInfo.emplace_back(); + converter.getStateStack().stackPush(); mlir::omp::TargetOperands clauseOps; DefaultMapsTy defaultMaps; @@ -2497,7 +2531,7 @@ genTargetOp(lower::AbstractConverter &converter, lower::SymMap &symTable, // Remove the host_eval information structure created for this target region. if (!isTargetDevice) - hostEvalInfo.pop_back(); + converter.getStateStack().stackPop(); return targetOp; } @@ -3771,7 +3805,8 @@ static void genOMP(lower::AbstractConverter &converter, lower::SymMap &symTable, buildConstructQueue(converter.getFirOpBuilder().getModule(), semaCtx, eval, source, directive, clauses)}; - sectionsStack.push_back(§ionsConstruct); + mlir::SaveStateStack saveStateStack{ + converter.getStateStack(), sectionsConstruct}; genOMPDispatch(converter, symTable, semaCtx, eval, currentLocation, queue, queue.begin()); } diff --git a/mlir/include/mlir/Support/StateStack.h b/mlir/include/mlir/Support/StateStack.h index ac70d05a3020a..44972fafe7fed 100644 --- a/mlir/include/mlir/Support/StateStack.h +++ b/mlir/include/mlir/Support/StateStack.h @@ -84,6 +84,17 @@ class StateStack { return WalkResult::advance(); } + /// Get the top instance of frame type `T` or nullptr if none are found + template + T *getStackTop() { + T *top = nullptr; + stackWalk([&](T &frame) -> mlir::WalkResult { + top = &frame; + return mlir::WalkResult::interrupt(); + }); + return top; + } + private: SmallVector> stack; };