@@ -1253,13 +1253,13 @@ static void genTaskClauses(
12531253 lower::AbstractConverter &converter, semantics::SemanticsContext &semaCtx,
12541254 lower::StatementContext &stmtCtx, const List<Clause> &clauses,
12551255 mlir::Location loc, mlir::omp::TaskOperands &clauseOps,
1256- llvm::SmallVectorImpl<const semantics::Symbol *> &InReductionSyms ) {
1256+ llvm::SmallVectorImpl<const semantics::Symbol *> &inReductionSyms ) {
12571257 ClauseProcessor cp (converter, semaCtx, clauses);
12581258 cp.processAllocate (clauseOps);
12591259 cp.processDepend (clauseOps);
12601260 cp.processFinal (stmtCtx, clauseOps);
12611261 cp.processIf (llvm::omp::Directive::OMPD_task, clauseOps);
1262- cp.processInReduction (loc, clauseOps, InReductionSyms );
1262+ cp.processInReduction (loc, clauseOps, inReductionSyms );
12631263 cp.processMergeable (clauseOps);
12641264 cp.processPriority (stmtCtx, clauseOps);
12651265 cp.processUntied (clauseOps);
@@ -1888,9 +1888,9 @@ genTaskOp(lower::AbstractConverter &converter, lower::SymMap &symTable,
18881888 ConstructQueue::const_iterator item) {
18891889 lower::StatementContext stmtCtx;
18901890 mlir::omp::TaskOperands clauseOps;
1891- llvm::SmallVector<const semantics::Symbol *> InReductionSyms ;
1891+ llvm::SmallVector<const semantics::Symbol *> inReductionSyms ;
18921892 genTaskClauses (converter, semaCtx, stmtCtx, item->clauses , loc, clauseOps,
1893- InReductionSyms );
1893+ inReductionSyms );
18941894
18951895 if (!enableDelayedPrivatization)
18961896 return genOpWithBody<mlir::omp::TaskOp>(
@@ -1907,7 +1907,7 @@ genTaskOp(lower::AbstractConverter &converter, lower::SymMap &symTable,
19071907 EntryBlockArgs taskArgs;
19081908 taskArgs.priv .syms = dsp.getDelayedPrivSymbols ();
19091909 taskArgs.priv .vars = clauseOps.privateVars ;
1910- taskArgs.inReduction .syms = InReductionSyms ;
1910+ taskArgs.inReduction .syms = inReductionSyms ;
19111911 taskArgs.inReduction .vars = clauseOps.inReductionVars ;
19121912
19131913 auto genRegionEntryCB = [&](mlir::Operation *op) {
@@ -1927,14 +1927,6 @@ genTaskOp(lower::AbstractConverter &converter, lower::SymMap &symTable,
19271927
19281928 auto taskOp =
19291929 genOpWithBody<mlir::omp::TaskOp>(genInfo, queue, item, clauseOps);
1930-
1931- llvm::SmallVector<mlir::Type> inReductionTypes;
1932- for (const auto &inreductionVar : clauseOps.inReductionVars )
1933- inReductionTypes.push_back (inreductionVar.getType ());
1934-
1935- // Add reduction variables as entry block arguments to the task region
1936- llvm::SmallVector<mlir::Location> blockArgLocs (InReductionSyms.size (), loc);
1937- taskOp->getRegion (0 ).addArguments (inReductionTypes, blockArgLocs);
19381930 return taskOp;
19391931}
19401932
@@ -1949,21 +1941,23 @@ genTaskgroupOp(lower::AbstractConverter &converter, lower::SymMap &symTable,
19491941 genTaskgroupClauses (converter, semaCtx, item->clauses , loc, clauseOps,
19501942 taskReductionSyms);
19511943
1944+ EntryBlockArgs taskgroupArgs;
1945+ taskgroupArgs.taskReduction .syms = taskReductionSyms;
1946+ taskgroupArgs.taskReduction .vars = clauseOps.taskReductionVars ;
1947+
1948+ auto genRegionEntryCB = [&](mlir::Operation *op) {
1949+ genEntryBlock (converter.getFirOpBuilder (), taskgroupArgs, op->getRegion (0 ));
1950+ return llvm::to_vector (taskgroupArgs.getSyms ());
1951+ };
1952+
19521953 OpWithBodyGenInfo genInfo =
19531954 OpWithBodyGenInfo (converter, symTable, semaCtx, loc, eval,
19541955 llvm::omp::Directive::OMPD_taskgroup)
1955- .setClauses (&item->clauses );
1956+ .setClauses (&item->clauses )
1957+ .setGenRegionEntryCb (genRegionEntryCB);
19561958
19571959 auto taskgroupOp =
19581960 genOpWithBody<mlir::omp::TaskgroupOp>(genInfo, queue, item, clauseOps);
1959-
1960- llvm::SmallVector<mlir::Type> taskReductionTypes;
1961- for (const auto &taskreductionVar : clauseOps.taskReductionVars )
1962- taskReductionTypes.push_back (taskreductionVar.getType ());
1963-
1964- // Add reduction variables as entry block arguments to the taskgroup region
1965- llvm::SmallVector<mlir::Location> blockArgLocs (taskReductionSyms.size (), loc);
1966- taskgroupOp->getRegion (0 ).addArguments (taskReductionTypes, blockArgLocs);
19671961 return taskgroupOp;
19681962}
19691963
0 commit comments