@@ -1253,21 +1253,20 @@ static void genTaskClauses(
12531253 lower::AbstractConverter &converter, semantics::SemanticsContext &semaCtx,
12541254 lower::StatementContext &stmtCtx, const List<Clause> &clauses,
12551255 mlir::Location loc, mlir::omp::TaskOperands &clauseOps,
1256- llvm::SmallVectorImpl<const semantics::Symbol *> &InReductionSyms ) {
1256+ llvm::SmallVectorImpl<const semantics::Symbol *> &inReductionSyms ) {
12571257 ClauseProcessor cp (converter, semaCtx, clauses);
12581258 cp.processAllocate (clauseOps);
12591259 cp.processDepend (clauseOps);
12601260 cp.processFinal (stmtCtx, clauseOps);
12611261 cp.processIf (llvm::omp::Directive::OMPD_task, clauseOps);
1262- cp.processInReduction (loc, clauseOps, InReductionSyms );
1262+ cp.processInReduction (loc, clauseOps, inReductionSyms );
12631263 cp.processMergeable (clauseOps);
12641264 cp.processPriority (stmtCtx, clauseOps);
12651265 cp.processUntied (clauseOps);
12661266 cp.processDetach (clauseOps);
12671267 // TODO Support delayed privatization.
12681268
1269- cp.processTODO <clause::Affinity>(
1270- loc, llvm::omp::Directive::OMPD_task);
1269+ cp.processTODO <clause::Affinity>(loc, llvm::omp::Directive::OMPD_task);
12711270}
12721271
12731272static void genTaskgroupClauses (
@@ -1888,9 +1887,9 @@ genTaskOp(lower::AbstractConverter &converter, lower::SymMap &symTable,
18881887 ConstructQueue::const_iterator item) {
18891888 lower::StatementContext stmtCtx;
18901889 mlir::omp::TaskOperands clauseOps;
1891- llvm::SmallVector<const semantics::Symbol *> InReductionSyms ;
1890+ llvm::SmallVector<const semantics::Symbol *> inReductionSyms ;
18921891 genTaskClauses (converter, semaCtx, stmtCtx, item->clauses , loc, clauseOps,
1893- InReductionSyms );
1892+ inReductionSyms );
18941893
18951894 if (!enableDelayedPrivatization)
18961895 return genOpWithBody<mlir::omp::TaskOp>(
@@ -1907,7 +1906,7 @@ genTaskOp(lower::AbstractConverter &converter, lower::SymMap &symTable,
19071906 EntryBlockArgs taskArgs;
19081907 taskArgs.priv .syms = dsp.getDelayedPrivSymbols ();
19091908 taskArgs.priv .vars = clauseOps.privateVars ;
1910- taskArgs.inReduction .syms = InReductionSyms ;
1909+ taskArgs.inReduction .syms = inReductionSyms ;
19111910 taskArgs.inReduction .vars = clauseOps.inReductionVars ;
19121911
19131912 auto genRegionEntryCB = [&](mlir::Operation *op) {
@@ -1927,14 +1926,6 @@ genTaskOp(lower::AbstractConverter &converter, lower::SymMap &symTable,
19271926
19281927 auto taskOp =
19291928 genOpWithBody<mlir::omp::TaskOp>(genInfo, queue, item, clauseOps);
1930-
1931- llvm::SmallVector<mlir::Type> inReductionTypes;
1932- for (const auto &inreductionVar : clauseOps.inReductionVars )
1933- inReductionTypes.push_back (inreductionVar.getType ());
1934-
1935- // Add reduction variables as entry block arguments to the task region
1936- llvm::SmallVector<mlir::Location> blockArgLocs (InReductionSyms.size (), loc);
1937- taskOp->getRegion (0 ).addArguments (inReductionTypes, blockArgLocs);
19381929 return taskOp;
19391930}
19401931
@@ -1949,21 +1940,23 @@ genTaskgroupOp(lower::AbstractConverter &converter, lower::SymMap &symTable,
19491940 genTaskgroupClauses (converter, semaCtx, item->clauses , loc, clauseOps,
19501941 taskReductionSyms);
19511942
1943+ EntryBlockArgs taskgroupArgs;
1944+ taskgroupArgs.taskReduction .syms = taskReductionSyms;
1945+ taskgroupArgs.taskReduction .vars = clauseOps.taskReductionVars ;
1946+
1947+ auto genRegionEntryCB = [&](mlir::Operation *op) {
1948+ genEntryBlock (converter.getFirOpBuilder (), taskgroupArgs, op->getRegion (0 ));
1949+ return llvm::to_vector (taskgroupArgs.getSyms ());
1950+ };
1951+
19521952 OpWithBodyGenInfo genInfo =
19531953 OpWithBodyGenInfo (converter, symTable, semaCtx, loc, eval,
19541954 llvm::omp::Directive::OMPD_taskgroup)
1955- .setClauses (&item->clauses );
1955+ .setClauses (&item->clauses )
1956+ .setGenRegionEntryCb (genRegionEntryCB);
19561957
19571958 auto taskgroupOp =
19581959 genOpWithBody<mlir::omp::TaskgroupOp>(genInfo, queue, item, clauseOps);
1959-
1960- llvm::SmallVector<mlir::Type> taskReductionTypes;
1961- for (const auto &taskreductionVar : clauseOps.taskReductionVars )
1962- taskReductionTypes.push_back (taskreductionVar.getType ());
1963-
1964- // Add reduction variables as entry block arguments to the taskgroup region
1965- llvm::SmallVector<mlir::Location> blockArgLocs (taskReductionSyms.size (), loc);
1966- taskgroupOp->getRegion (0 ).addArguments (taskReductionTypes, blockArgLocs);
19671960 return taskgroupOp;
19681961}
19691962
0 commit comments