@@ -519,6 +519,11 @@ struct OpWithBodyGenInfo {
519519 return *this ;
520520 }
521521
522+ OpWithBodyGenInfo &setGenSkeletonOnly (bool value) {
523+ genSkeletonOnly = value;
524+ return *this ;
525+ }
526+
522527 // / [inout] converter to use for the clauses.
523528 lower::AbstractConverter &converter;
524529 // / [in] Symbol table
@@ -538,6 +543,9 @@ struct OpWithBodyGenInfo {
538543 // / [in] if provided, emits the op's region entry. Otherwise, an emtpy block
539544 // / is created in the region.
540545 GenOMPRegionEntryCBFn genRegionEntryCB = nullptr ;
546+ // / [in] if set to `true`, skip generating nested evaluations and dispatching
547+ // / any further leaf constructs.
548+ bool genSkeletonOnly = false ;
541549};
542550
543551// / Create the body (block) for an OpenMP Operation.
@@ -600,20 +608,22 @@ static void createBodyOfOp(mlir::Operation &op, const OpWithBodyGenInfo &info,
600608 }
601609 }
602610
603- if (ConstructQueue::const_iterator next = std::next (item);
604- next != queue.end ()) {
605- genOMPDispatch (info.converter , info.symTable , info.semaCtx , info.eval ,
606- info.loc , queue, next);
607- } else {
608- // genFIR(Evaluation&) tries to patch up unterminated blocks, causing
609- // a lot of complications for our approach if the terminator generation
610- // is delayed past this point. Insert a temporary terminator here, then
611- // delete it.
612- firOpBuilder.setInsertionPointToEnd (&op.getRegion (0 ).back ());
613- auto *temp = lower::genOpenMPTerminator (firOpBuilder, &op, info.loc );
614- firOpBuilder.setInsertionPointAfter (marker);
615- genNestedEvaluations (info.converter , info.eval );
616- temp->erase ();
611+ if (!info.genSkeletonOnly ) {
612+ if (ConstructQueue::const_iterator next = std::next (item);
613+ next != queue.end ()) {
614+ genOMPDispatch (info.converter , info.symTable , info.semaCtx , info.eval ,
615+ info.loc , queue, next);
616+ } else {
617+ // genFIR(Evaluation&) tries to patch up unterminated blocks, causing
618+ // a lot of complications for our approach if the terminator generation
619+ // is delayed past this point. Insert a temporary terminator here, then
620+ // delete it.
621+ firOpBuilder.setInsertionPointToEnd (&op.getRegion (0 ).back ());
622+ auto *temp = lower::genOpenMPTerminator (firOpBuilder, &op, info.loc );
623+ firOpBuilder.setInsertionPointAfter (marker);
624+ genNestedEvaluations (info.converter , info.eval );
625+ temp->erase ();
626+ }
617627 }
618628
619629 // Get or create a unique exiting block from the given region, or
@@ -1445,7 +1455,8 @@ genParallelOp(lower::AbstractConverter &converter, lower::SymMap &symTable,
14451455 const ConstructQueue &queue, ConstructQueue::const_iterator item,
14461456 mlir::omp::ParallelOperands &clauseOps,
14471457 llvm::ArrayRef<const semantics::Symbol *> reductionSyms,
1448- llvm::ArrayRef<mlir::Type> reductionTypes) {
1458+ llvm::ArrayRef<mlir::Type> reductionTypes,
1459+ DataSharingProcessor *dsp, bool isComposite = false ) {
14491460 fir::FirOpBuilder &firOpBuilder = converter.getFirOpBuilder ();
14501461
14511462 auto reductionCallback = [&](mlir::Operation *op) {
@@ -1457,17 +1468,17 @@ genParallelOp(lower::AbstractConverter &converter, lower::SymMap &symTable,
14571468 OpWithBodyGenInfo (converter, symTable, semaCtx, loc, eval,
14581469 llvm::omp::Directive::OMPD_parallel)
14591470 .setClauses (&item->clauses )
1460- .setGenRegionEntryCb (reductionCallback);
1461-
1462- if (!enableDelayedPrivatization)
1463- return genOpWithBody<mlir::omp::ParallelOp>(genInfo, queue, item,
1464- clauseOps);
1465-
1466- DataSharingProcessor dsp (converter, semaCtx, item->clauses , eval,
1467- lower::omp::isLastItemInQueue (item, queue),
1468- /* useDelayedPrivatization=*/ true , &symTable);
1469- dsp.processStep1 (&clauseOps);
1471+ .setGenRegionEntryCb (reductionCallback)
1472+ .setGenSkeletonOnly (isComposite);
1473+
1474+ if (!enableDelayedPrivatization) {
1475+ auto parallelOp =
1476+ genOpWithBody<mlir::omp::ParallelOp>(genInfo, queue, item, clauseOps);
1477+ parallelOp.setComposite (isComposite);
1478+ return parallelOp;
1479+ }
14701480
1481+ assert (dsp && " expected valid DataSharingProcessor" );
14711482 auto genRegionEntryCB = [&](mlir::Operation *op) {
14721483 auto parallelOp = llvm::cast<mlir::omp::ParallelOp>(op);
14731484
@@ -1491,8 +1502,8 @@ genParallelOp(lower::AbstractConverter &converter, lower::SymMap &symTable,
14911502 allRegionArgLocs);
14921503
14931504 llvm::SmallVector<const semantics::Symbol *> allSymbols (reductionSyms);
1494- allSymbols.append (dsp. getDelayedPrivSymbols ().begin (),
1495- dsp. getDelayedPrivSymbols ().end ());
1505+ allSymbols.append (dsp-> getDelayedPrivSymbols ().begin (),
1506+ dsp-> getDelayedPrivSymbols ().end ());
14961507
14971508 unsigned argIdx = 0 ;
14981509 for (const semantics::Symbol *arg : allSymbols) {
@@ -1519,8 +1530,11 @@ genParallelOp(lower::AbstractConverter &converter, lower::SymMap &symTable,
15191530 return allSymbols;
15201531 };
15211532
1522- genInfo.setGenRegionEntryCb (genRegionEntryCB).setDataSharingProcessor (&dsp);
1523- return genOpWithBody<mlir::omp::ParallelOp>(genInfo, queue, item, clauseOps);
1533+ genInfo.setGenRegionEntryCb (genRegionEntryCB).setDataSharingProcessor (dsp);
1534+ auto parallelOp =
1535+ genOpWithBody<mlir::omp::ParallelOp>(genInfo, queue, item, clauseOps);
1536+ parallelOp.setComposite (isComposite);
1537+ return parallelOp;
15241538}
15251539
15261540// / This breaks the normal prototype of the gen*Op functions: adding the
@@ -2005,8 +2019,16 @@ static void genStandaloneParallel(lower::AbstractConverter &converter,
20052019 genParallelClauses (converter, semaCtx, stmtCtx, item->clauses , loc, clauseOps,
20062020 reductionTypes, reductionSyms);
20072021
2022+ std::optional<DataSharingProcessor> dsp;
2023+ if (enableDelayedPrivatization) {
2024+ dsp.emplace (converter, semaCtx, item->clauses , eval,
2025+ lower::omp::isLastItemInQueue (item, queue),
2026+ /* useDelayedPrivatization=*/ true , &symTable);
2027+ dsp->processStep1 (&clauseOps);
2028+ }
20082029 genParallelOp (converter, symTable, semaCtx, eval, loc, queue, item, clauseOps,
2009- reductionSyms, reductionTypes);
2030+ reductionSyms, reductionTypes,
2031+ enableDelayedPrivatization ? &dsp.value () : nullptr );
20102032}
20112033
20122034static void genStandaloneSimd (lower::AbstractConverter &converter,
@@ -2058,8 +2080,69 @@ static void genCompositeDistributeParallelDo(
20582080 semantics::SemanticsContext &semaCtx, lower::pft::Evaluation &eval,
20592081 mlir::Location loc, const ConstructQueue &queue,
20602082 ConstructQueue::const_iterator item) {
2083+ lower::StatementContext stmtCtx;
2084+
20612085 assert (std::distance (item, queue.end ()) == 3 && " Invalid leaf constructs" );
2062- TODO (loc, " Composite DISTRIBUTE PARALLEL DO" );
2086+ ConstructQueue::const_iterator distributeItem = item;
2087+ ConstructQueue::const_iterator parallelItem = std::next (distributeItem);
2088+ ConstructQueue::const_iterator doItem = std::next (parallelItem);
2089+
2090+ // Create parent omp.parallel first.
2091+ mlir::omp::ParallelOperands parallelClauseOps;
2092+ llvm::SmallVector<const semantics::Symbol *> parallelReductionSyms;
2093+ llvm::SmallVector<mlir::Type> parallelReductionTypes;
2094+ genParallelClauses (converter, semaCtx, stmtCtx, parallelItem->clauses , loc,
2095+ parallelClauseOps, parallelReductionTypes,
2096+ parallelReductionSyms);
2097+
2098+ DataSharingProcessor dsp (converter, semaCtx, doItem->clauses , eval,
2099+ /* shouldCollectPreDeterminedSymbols=*/ true ,
2100+ /* useDelayedPrivatization=*/ true , &symTable);
2101+ dsp.processStep1 (¶llelClauseOps);
2102+
2103+ genParallelOp (converter, symTable, semaCtx, eval, loc, queue, parallelItem,
2104+ parallelClauseOps, parallelReductionSyms,
2105+ parallelReductionTypes, &dsp, /* isComposite=*/ true );
2106+
2107+ // Clause processing.
2108+ mlir::omp::DistributeOperands distributeClauseOps;
2109+ genDistributeClauses (converter, semaCtx, stmtCtx, distributeItem->clauses ,
2110+ loc, distributeClauseOps);
2111+
2112+ mlir::omp::WsloopOperands wsloopClauseOps;
2113+ llvm::SmallVector<const semantics::Symbol *> wsloopReductionSyms;
2114+ llvm::SmallVector<mlir::Type> wsloopReductionTypes;
2115+ genWsloopClauses (converter, semaCtx, stmtCtx, doItem->clauses , loc,
2116+ wsloopClauseOps, wsloopReductionTypes, wsloopReductionSyms);
2117+
2118+ mlir::omp::LoopNestOperands loopNestClauseOps;
2119+ llvm::SmallVector<const semantics::Symbol *> iv;
2120+ genLoopNestClauses (converter, semaCtx, eval, doItem->clauses , loc,
2121+ loopNestClauseOps, iv);
2122+
2123+ // Operation creation.
2124+ // TODO: Populate entry block arguments with private variables.
2125+ auto distributeOp = genWrapperOp<mlir::omp::DistributeOp>(
2126+ converter, loc, distributeClauseOps, /* blockArgTypes=*/ {});
2127+ distributeOp.setComposite (/* val=*/ true );
2128+
2129+ // TODO: Add private variables to entry block arguments.
2130+ auto wsloopOp = genWrapperOp<mlir::omp::WsloopOp>(
2131+ converter, loc, wsloopClauseOps, wsloopReductionTypes);
2132+ wsloopOp.setComposite (/* val=*/ true );
2133+
2134+ // Construct wrapper entry block list and associated symbols. It is important
2135+ // that the symbol order and the block argument order match, so that the
2136+ // symbol-value bindings created are correct.
2137+ auto &wrapperSyms = wsloopReductionSyms;
2138+
2139+ auto wrapperArgs = llvm::to_vector (
2140+ llvm::concat<mlir::BlockArgument>(distributeOp.getRegion ().getArguments (),
2141+ wsloopOp.getRegion ().getArguments ()));
2142+
2143+ genLoopNestOp (converter, symTable, semaCtx, eval, loc, queue, doItem,
2144+ loopNestClauseOps, iv, wrapperSyms, wrapperArgs,
2145+ llvm::omp::Directive::OMPD_distribute_parallel_do, dsp);
20632146}
20642147
20652148static void genCompositeDistributeParallelDoSimd (
0 commit comments