diff --git a/flang/lib/Lower/OpenMP/OpenMP.cpp b/flang/lib/Lower/OpenMP/OpenMP.cpp index ac688a69d7fb6..812359ff8052e 100644 --- a/flang/lib/Lower/OpenMP/OpenMP.cpp +++ b/flang/lib/Lower/OpenMP/OpenMP.cpp @@ -1048,6 +1048,11 @@ struct OpWithBodyGenInfo { return *this; } + OpWithBodyGenInfo &setEntryBlockArgs(const EntryBlockArgs *value) { + blockArgs = value; + return *this; + } + OpWithBodyGenInfo &setGenRegionEntryCb(GenOMPRegionEntryCBFn value) { genRegionEntryCB = value; return *this; @@ -1074,8 +1079,12 @@ struct OpWithBodyGenInfo { const List *clauses = nullptr; /// [in] if provided, processes the construct's data-sharing attributes. DataSharingProcessor *dsp = nullptr; - /// [in] if provided, emits the op's region entry. Otherwise, an emtpy block - /// is created in the region. + /// [in] if provided, it is used to create the op's region entry block. It is + /// overriden when a \see genRegionEntryCB is provided. This is only valid for + /// operations implementing the \see mlir::omp::BlockArgOpenMPOpInterface. + const EntryBlockArgs *blockArgs = nullptr; + /// [in] if provided, it overrides the default op's region entry block + /// creation. GenOMPRegionEntryCBFn genRegionEntryCB = nullptr; /// [in] if set to `true`, skip generating nested evaluations and dispatching /// any further leaf constructs. @@ -1099,18 +1108,33 @@ static void createBodyOfOp(mlir::Operation &op, const OpWithBodyGenInfo &info, return undef.getDefiningOp(); }; - // If an argument for the region is provided then create the block with that - // argument. Also update the symbol's address with the mlir argument value. - // e.g. For loops the argument is the induction variable. And all further - // uses of the induction variable should use this mlir value. + // Create the entry block for the region and collect its arguments for use + // within the region. The entry block will be created as follows: + // - By default, it will be empty and have no arguments. + // - Operations implementing the omp::BlockArgOpenMPOpInterface can set the + // `info.blockArgs` pointer so that block arguments will be those + // corresponding to entry block argument-generating clauses. Binding of + // Fortran symbols to the new MLIR values is done automatically. + // - If the `info.genRegionEntryCB` callback is set, it takes precedence and + // allows callers to manually create the entry block with its intended + // list of arguments and to bind these arguments to their corresponding + // Fortran symbols. This is used for e.g. loop induction variables. auto regionArgs = [&]() -> llvm::SmallVector { - if (info.genRegionEntryCB != nullptr) { + if (info.genRegionEntryCB) return info.genRegionEntryCB(&op); + + if (info.blockArgs) { + genEntryBlock(firOpBuilder, *info.blockArgs, op.getRegion(0)); + bindEntryBlockArgs(info.converter, + llvm::cast(op), + *info.blockArgs); + return llvm::to_vector(info.blockArgs->getSyms()); } firOpBuilder.createBlock(&op.getRegion(0)); return {}; }(); + // Mark the earliest insertion point. mlir::Operation *marker = insertMarker(firOpBuilder); @@ -1978,20 +2002,14 @@ genParallelOp(lower::AbstractConverter &converter, lower::SymMap &symTable, mlir::omp::ParallelOperands &clauseOps, const EntryBlockArgs &args, DataSharingProcessor *dsp, bool isComposite = false) { - auto genRegionEntryCB = [&](mlir::Operation *op) { - genEntryBlock(converter.getFirOpBuilder(), args, op->getRegion(0)); - bindEntryBlockArgs( - converter, llvm::cast(op), args); - return llvm::to_vector(args.getSyms()); - }; - assert((!enableDelayedPrivatization || dsp) && "expected valid DataSharingProcessor"); + OpWithBodyGenInfo genInfo = OpWithBodyGenInfo(converter, symTable, semaCtx, loc, eval, llvm::omp::Directive::OMPD_parallel) .setClauses(&item->clauses) - .setGenRegionEntryCb(genRegionEntryCB) + .setEntryBlockArgs(&args) .setGenSkeletonOnly(isComposite) .setDataSharingProcessor(dsp); @@ -2067,13 +2085,6 @@ genSectionsOp(lower::AbstractConverter &converter, lower::SymMap &symTable, mlir::Operation *terminator = lower::genOpenMPTerminator(builder, sectionsOp, loc); - auto genRegionEntryCB = [&](mlir::Operation *op) { - genEntryBlock(builder, args, op->getRegion(0)); - bindEntryBlockArgs( - converter, llvm::cast(op), args); - return llvm::to_vector(args.getSyms()); - }; - // Generate nested SECTION constructs. // This is done here rather than in genOMP([...], OpenMPSectionConstruct ) // because we need to run genReductionVars on each omp.section so that the @@ -2097,7 +2108,7 @@ genSectionsOp(lower::AbstractConverter &converter, lower::SymMap &symTable, OpWithBodyGenInfo(converter, symTable, semaCtx, loc, nestedEval, llvm::omp::Directive::OMPD_section) .setClauses(§ionQueue.begin()->clauses) - .setGenRegionEntryCb(genRegionEntryCB), + .setEntryBlockArgs(&args), sectionQueue, sectionQueue.begin()); } @@ -2436,20 +2447,12 @@ genTaskOp(lower::AbstractConverter &converter, lower::SymMap &symTable, taskArgs.priv.syms = dsp.getDelayedPrivSymbols(); taskArgs.priv.vars = clauseOps.privateVars; - auto genRegionEntryCB = [&](mlir::Operation *op) { - genEntryBlock(converter.getFirOpBuilder(), taskArgs, op->getRegion(0)); - bindEntryBlockArgs(converter, - llvm::cast(op), - taskArgs); - return llvm::to_vector(taskArgs.priv.syms); - }; - return genOpWithBody( OpWithBodyGenInfo(converter, symTable, semaCtx, loc, eval, llvm::omp::Directive::OMPD_task) .setClauses(&item->clauses) .setDataSharingProcessor(&dsp) - .setGenRegionEntryCb(genRegionEntryCB), + .setEntryBlockArgs(&taskArgs), queue, item, clauseOps); } @@ -2525,18 +2528,11 @@ genTeamsOp(lower::AbstractConverter &converter, lower::SymMap &symTable, args.reduction.syms = reductionSyms; args.reduction.vars = clauseOps.reductionVars; - auto genRegionEntryCB = [&](mlir::Operation *op) { - genEntryBlock(converter.getFirOpBuilder(), args, op->getRegion(0)); - bindEntryBlockArgs( - converter, llvm::cast(op), args); - return llvm::to_vector(args.getSyms()); - }; - return genOpWithBody( OpWithBodyGenInfo(converter, symTable, semaCtx, loc, eval, llvm::omp::Directive::OMPD_teams) .setClauses(&item->clauses) - .setGenRegionEntryCb(genRegionEntryCB), + .setEntryBlockArgs(&args), queue, item, clauseOps); }