@@ -1048,6 +1048,11 @@ struct OpWithBodyGenInfo {
10481048 return *this ;
10491049 }
10501050
1051+ OpWithBodyGenInfo &setEntryBlockArgs (const EntryBlockArgs *value) {
1052+ blockArgs = value;
1053+ return *this ;
1054+ }
1055+
10511056 OpWithBodyGenInfo &setGenRegionEntryCb (GenOMPRegionEntryCBFn value) {
10521057 genRegionEntryCB = value;
10531058 return *this ;
@@ -1074,8 +1079,12 @@ struct OpWithBodyGenInfo {
10741079 const List<Clause> *clauses = nullptr ;
10751080 // / [in] if provided, processes the construct's data-sharing attributes.
10761081 DataSharingProcessor *dsp = nullptr ;
1077- // / [in] if provided, emits the op's region entry. Otherwise, an emtpy block
1078- // / is created in the region.
1082+ // / [in] if provided, it is used to create the op's region entry block. It is
1083+ // / overriden when a \see genRegionEntryCB is provided. This is only valid for
1084+ // / operations implementing the \see mlir::omp::BlockArgOpenMPOpInterface.
1085+ const EntryBlockArgs *blockArgs = nullptr ;
1086+ // / [in] if provided, it overrides the default op's region entry block
1087+ // / creation.
10791088 GenOMPRegionEntryCBFn genRegionEntryCB = nullptr ;
10801089 // / [in] if set to `true`, skip generating nested evaluations and dispatching
10811090 // / any further leaf constructs.
@@ -1099,18 +1108,33 @@ static void createBodyOfOp(mlir::Operation &op, const OpWithBodyGenInfo &info,
10991108 return undef.getDefiningOp ();
11001109 };
11011110
1102- // If an argument for the region is provided then create the block with that
1103- // argument. Also update the symbol's address with the mlir argument value.
1104- // e.g. For loops the argument is the induction variable. And all further
1105- // uses of the induction variable should use this mlir value.
1111+ // Create the entry block for the region and collect its arguments for use
1112+ // within the region. The entry block will be created as follows:
1113+ // - By default, it will be empty and have no arguments.
1114+ // - Operations implementing the omp::BlockArgOpenMPOpInterface can set the
1115+ // `info.blockArgs` pointer so that block arguments will be those
1116+ // corresponding to entry block argument-generating clauses. Binding of
1117+ // Fortran symbols to the new MLIR values is done automatically.
1118+ // - If the `info.genRegionEntryCB` callback is set, it takes precedence and
1119+ // allows callers to manually create the entry block with its intended
1120+ // list of arguments and to bind these arguments to their corresponding
1121+ // Fortran symbols. This is used for e.g. loop induction variables.
11061122 auto regionArgs = [&]() -> llvm::SmallVector<const semantics::Symbol *> {
1107- if (info.genRegionEntryCB != nullptr ) {
1123+ if (info.genRegionEntryCB )
11081124 return info.genRegionEntryCB (&op);
1125+
1126+ if (info.blockArgs ) {
1127+ genEntryBlock (firOpBuilder, *info.blockArgs , op.getRegion (0 ));
1128+ bindEntryBlockArgs (info.converter ,
1129+ llvm::cast<mlir::omp::BlockArgOpenMPOpInterface>(op),
1130+ *info.blockArgs );
1131+ return llvm::to_vector (info.blockArgs ->getSyms ());
11091132 }
11101133
11111134 firOpBuilder.createBlock (&op.getRegion (0 ));
11121135 return {};
11131136 }();
1137+
11141138 // Mark the earliest insertion point.
11151139 mlir::Operation *marker = insertMarker (firOpBuilder);
11161140
@@ -1978,20 +2002,14 @@ genParallelOp(lower::AbstractConverter &converter, lower::SymMap &symTable,
19782002 mlir::omp::ParallelOperands &clauseOps,
19792003 const EntryBlockArgs &args, DataSharingProcessor *dsp,
19802004 bool isComposite = false ) {
1981- auto genRegionEntryCB = [&](mlir::Operation *op) {
1982- genEntryBlock (converter.getFirOpBuilder (), args, op->getRegion (0 ));
1983- bindEntryBlockArgs (
1984- converter, llvm::cast<mlir::omp::BlockArgOpenMPOpInterface>(op), args);
1985- return llvm::to_vector (args.getSyms ());
1986- };
1987-
19882005 assert ((!enableDelayedPrivatization || dsp) &&
19892006 " expected valid DataSharingProcessor" );
2007+
19902008 OpWithBodyGenInfo genInfo =
19912009 OpWithBodyGenInfo (converter, symTable, semaCtx, loc, eval,
19922010 llvm::omp::Directive::OMPD_parallel)
19932011 .setClauses (&item->clauses )
1994- .setGenRegionEntryCb (genRegionEntryCB )
2012+ .setEntryBlockArgs (&args )
19952013 .setGenSkeletonOnly (isComposite)
19962014 .setDataSharingProcessor (dsp);
19972015
@@ -2067,13 +2085,6 @@ genSectionsOp(lower::AbstractConverter &converter, lower::SymMap &symTable,
20672085 mlir::Operation *terminator =
20682086 lower::genOpenMPTerminator (builder, sectionsOp, loc);
20692087
2070- auto genRegionEntryCB = [&](mlir::Operation *op) {
2071- genEntryBlock (builder, args, op->getRegion (0 ));
2072- bindEntryBlockArgs (
2073- converter, llvm::cast<mlir::omp::BlockArgOpenMPOpInterface>(op), args);
2074- return llvm::to_vector (args.getSyms ());
2075- };
2076-
20772088 // Generate nested SECTION constructs.
20782089 // This is done here rather than in genOMP([...], OpenMPSectionConstruct )
20792090 // because we need to run genReductionVars on each omp.section so that the
@@ -2097,7 +2108,7 @@ genSectionsOp(lower::AbstractConverter &converter, lower::SymMap &symTable,
20972108 OpWithBodyGenInfo (converter, symTable, semaCtx, loc, nestedEval,
20982109 llvm::omp::Directive::OMPD_section)
20992110 .setClauses (§ionQueue.begin ()->clauses )
2100- .setGenRegionEntryCb (genRegionEntryCB ),
2111+ .setEntryBlockArgs (&args ),
21012112 sectionQueue, sectionQueue.begin ());
21022113 }
21032114
@@ -2436,20 +2447,12 @@ genTaskOp(lower::AbstractConverter &converter, lower::SymMap &symTable,
24362447 taskArgs.priv .syms = dsp.getDelayedPrivSymbols ();
24372448 taskArgs.priv .vars = clauseOps.privateVars ;
24382449
2439- auto genRegionEntryCB = [&](mlir::Operation *op) {
2440- genEntryBlock (converter.getFirOpBuilder (), taskArgs, op->getRegion (0 ));
2441- bindEntryBlockArgs (converter,
2442- llvm::cast<mlir::omp::BlockArgOpenMPOpInterface>(op),
2443- taskArgs);
2444- return llvm::to_vector (taskArgs.priv .syms );
2445- };
2446-
24472450 return genOpWithBody<mlir::omp::TaskOp>(
24482451 OpWithBodyGenInfo (converter, symTable, semaCtx, loc, eval,
24492452 llvm::omp::Directive::OMPD_task)
24502453 .setClauses (&item->clauses )
24512454 .setDataSharingProcessor (&dsp)
2452- .setGenRegionEntryCb (genRegionEntryCB ),
2455+ .setEntryBlockArgs (&taskArgs ),
24532456 queue, item, clauseOps);
24542457}
24552458
@@ -2525,18 +2528,11 @@ genTeamsOp(lower::AbstractConverter &converter, lower::SymMap &symTable,
25252528 args.reduction .syms = reductionSyms;
25262529 args.reduction .vars = clauseOps.reductionVars ;
25272530
2528- auto genRegionEntryCB = [&](mlir::Operation *op) {
2529- genEntryBlock (converter.getFirOpBuilder (), args, op->getRegion (0 ));
2530- bindEntryBlockArgs (
2531- converter, llvm::cast<mlir::omp::BlockArgOpenMPOpInterface>(op), args);
2532- return llvm::to_vector (args.getSyms ());
2533- };
2534-
25352531 return genOpWithBody<mlir::omp::TeamsOp>(
25362532 OpWithBodyGenInfo (converter, symTable, semaCtx, loc, eval,
25372533 llvm::omp::Directive::OMPD_teams)
25382534 .setClauses (&item->clauses )
2539- .setGenRegionEntryCb (genRegionEntryCB ),
2535+ .setEntryBlockArgs (&args ),
25402536 queue, item, clauseOps);
25412537}
25422538
0 commit comments