@@ -1047,6 +1047,11 @@ struct OpWithBodyGenInfo {
10471047 return *this ;
10481048 }
10491049
1050+ OpWithBodyGenInfo &setEntryBlockArgs (const EntryBlockArgs *value) {
1051+ blockArgs = value;
1052+ return *this ;
1053+ }
1054+
10501055 OpWithBodyGenInfo &setGenRegionEntryCb (GenOMPRegionEntryCBFn value) {
10511056 genRegionEntryCB = value;
10521057 return *this ;
@@ -1073,8 +1078,12 @@ struct OpWithBodyGenInfo {
10731078 const List<Clause> *clauses = nullptr ;
10741079 // / [in] if provided, processes the construct's data-sharing attributes.
10751080 DataSharingProcessor *dsp = nullptr ;
1076- // / [in] if provided, emits the op's region entry. Otherwise, an emtpy block
1077- // / is created in the region.
1081+ // / [in] if provided, it is used to create the op's region entry block. It is
1082+ // / overriden when a \see genRegionEntryCB is provided. This is only valid for
1083+ // / operations implementing the \see mlir::omp::BlockArgOpenMPOpInterface.
1084+ const EntryBlockArgs *blockArgs = nullptr ;
1085+ // / [in] if provided, it overrides the default op's region entry block
1086+ // / creation.
10781087 GenOMPRegionEntryCBFn genRegionEntryCB = nullptr ;
10791088 // / [in] if set to `true`, skip generating nested evaluations and dispatching
10801089 // / any further leaf constructs.
@@ -1098,18 +1107,33 @@ static void createBodyOfOp(mlir::Operation &op, const OpWithBodyGenInfo &info,
10981107 return undef.getDefiningOp ();
10991108 };
11001109
1101- // If an argument for the region is provided then create the block with that
1102- // argument. Also update the symbol's address with the mlir argument value.
1103- // e.g. For loops the argument is the induction variable. And all further
1104- // uses of the induction variable should use this mlir value.
1110+ // Create the entry block for the region and collect its arguments for use
1111+ // within the region. The entry block will be created as follows:
1112+ // - By default, it will be empty and have no arguments.
1113+ // - Operations implementing the omp::BlockArgOpenMPOpInterface can set the
1114+ // `info.blockArgs` pointer so that block arguments will be those
1115+ // corresponding to entry block argument-generating clauses. Binding of
1116+ // Fortran symbols to the new MLIR values is done automatically.
1117+ // - If the `info.genRegionEntryCB` callback is set, it takes precedence and
1118+ // allows callers to manually create the entry block with its intended
1119+ // list of arguments and to bind these arguments to their corresponding
1120+ // Fortran symbols. This is used for e.g. loop induction variables.
11051121 auto regionArgs = [&]() -> llvm::SmallVector<const semantics::Symbol *> {
1106- if (info.genRegionEntryCB != nullptr ) {
1122+ if (info.genRegionEntryCB )
11071123 return info.genRegionEntryCB (&op);
1124+
1125+ if (info.blockArgs ) {
1126+ genEntryBlock (firOpBuilder, *info.blockArgs , op.getRegion (0 ));
1127+ bindEntryBlockArgs (info.converter ,
1128+ llvm::cast<mlir::omp::BlockArgOpenMPOpInterface>(op),
1129+ *info.blockArgs );
1130+ return llvm::to_vector (info.blockArgs ->getSyms ());
11081131 }
11091132
11101133 firOpBuilder.createBlock (&op.getRegion (0 ));
11111134 return {};
11121135 }();
1136+
11131137 // Mark the earliest insertion point.
11141138 mlir::Operation *marker = insertMarker (firOpBuilder);
11151139
@@ -1977,20 +2001,14 @@ genParallelOp(lower::AbstractConverter &converter, lower::SymMap &symTable,
19772001 mlir::omp::ParallelOperands &clauseOps,
19782002 const EntryBlockArgs &args, DataSharingProcessor *dsp,
19792003 bool isComposite = false ) {
1980- auto genRegionEntryCB = [&](mlir::Operation *op) {
1981- genEntryBlock (converter.getFirOpBuilder (), args, op->getRegion (0 ));
1982- bindEntryBlockArgs (
1983- converter, llvm::cast<mlir::omp::BlockArgOpenMPOpInterface>(op), args);
1984- return llvm::to_vector (args.getSyms ());
1985- };
1986-
19872004 assert ((!enableDelayedPrivatization || dsp) &&
19882005 " expected valid DataSharingProcessor" );
2006+
19892007 OpWithBodyGenInfo genInfo =
19902008 OpWithBodyGenInfo (converter, symTable, semaCtx, loc, eval,
19912009 llvm::omp::Directive::OMPD_parallel)
19922010 .setClauses (&item->clauses )
1993- .setGenRegionEntryCb (genRegionEntryCB )
2011+ .setEntryBlockArgs (&args )
19942012 .setGenSkeletonOnly (isComposite)
19952013 .setDataSharingProcessor (dsp);
19962014
@@ -2066,13 +2084,6 @@ genSectionsOp(lower::AbstractConverter &converter, lower::SymMap &symTable,
20662084 mlir::Operation *terminator =
20672085 lower::genOpenMPTerminator (builder, sectionsOp, loc);
20682086
2069- auto genRegionEntryCB = [&](mlir::Operation *op) {
2070- genEntryBlock (builder, args, op->getRegion (0 ));
2071- bindEntryBlockArgs (
2072- converter, llvm::cast<mlir::omp::BlockArgOpenMPOpInterface>(op), args);
2073- return llvm::to_vector (args.getSyms ());
2074- };
2075-
20762087 // Generate nested SECTION constructs.
20772088 // This is done here rather than in genOMP([...], OpenMPSectionConstruct )
20782089 // because we need to run genReductionVars on each omp.section so that the
@@ -2096,7 +2107,7 @@ genSectionsOp(lower::AbstractConverter &converter, lower::SymMap &symTable,
20962107 OpWithBodyGenInfo (converter, symTable, semaCtx, loc, nestedEval,
20972108 llvm::omp::Directive::OMPD_section)
20982109 .setClauses (§ionQueue.begin ()->clauses )
2099- .setGenRegionEntryCb (genRegionEntryCB ),
2110+ .setEntryBlockArgs (&args ),
21002111 sectionQueue, sectionQueue.begin ());
21012112 }
21022113
@@ -2435,20 +2446,12 @@ genTaskOp(lower::AbstractConverter &converter, lower::SymMap &symTable,
24352446 taskArgs.priv .syms = dsp.getDelayedPrivSymbols ();
24362447 taskArgs.priv .vars = clauseOps.privateVars ;
24372448
2438- auto genRegionEntryCB = [&](mlir::Operation *op) {
2439- genEntryBlock (converter.getFirOpBuilder (), taskArgs, op->getRegion (0 ));
2440- bindEntryBlockArgs (converter,
2441- llvm::cast<mlir::omp::BlockArgOpenMPOpInterface>(op),
2442- taskArgs);
2443- return llvm::to_vector (taskArgs.priv .syms );
2444- };
2445-
24462449 return genOpWithBody<mlir::omp::TaskOp>(
24472450 OpWithBodyGenInfo (converter, symTable, semaCtx, loc, eval,
24482451 llvm::omp::Directive::OMPD_task)
24492452 .setClauses (&item->clauses )
24502453 .setDataSharingProcessor (&dsp)
2451- .setGenRegionEntryCb (genRegionEntryCB ),
2454+ .setEntryBlockArgs (&taskArgs ),
24522455 queue, item, clauseOps);
24532456}
24542457
@@ -2524,18 +2527,11 @@ genTeamsOp(lower::AbstractConverter &converter, lower::SymMap &symTable,
25242527 args.reduction .syms = reductionSyms;
25252528 args.reduction .vars = clauseOps.reductionVars ;
25262529
2527- auto genRegionEntryCB = [&](mlir::Operation *op) {
2528- genEntryBlock (converter.getFirOpBuilder (), args, op->getRegion (0 ));
2529- bindEntryBlockArgs (
2530- converter, llvm::cast<mlir::omp::BlockArgOpenMPOpInterface>(op), args);
2531- return llvm::to_vector (args.getSyms ());
2532- };
2533-
25342530 return genOpWithBody<mlir::omp::TeamsOp>(
25352531 OpWithBodyGenInfo (converter, symTable, semaCtx, loc, eval,
25362532 llvm::omp::Directive::OMPD_teams)
25372533 .setClauses (&item->clauses )
2538- .setGenRegionEntryCb (genRegionEntryCB ),
2534+ .setEntryBlockArgs (&args ),
25392535 queue, item, clauseOps);
25402536}
25412537
0 commit comments