Skip to content

Commit ac9e4e9

Browse files
authored
[Flang][OpenMP] Simplify entry block creation for BlockArgOpenMPOpInterface ops, NFC (#132036)
This patch adds the `OpWithBodyGenInfo::blockArgs` field and updates `createBodyOfOp()` to prevent the need for `BlockArgOpenMPOpInterface` operations to pass the same callback, minimizing chances of introducing inconsistent behavior.
1 parent 03adb0e commit ac9e4e9

File tree

1 file changed

+36
-40
lines changed

1 file changed

+36
-40
lines changed

flang/lib/Lower/OpenMP/OpenMP.cpp

Lines changed: 36 additions & 40 deletions
Original file line numberDiff line numberDiff line change
@@ -1047,6 +1047,11 @@ struct OpWithBodyGenInfo {
10471047
return *this;
10481048
}
10491049

1050+
OpWithBodyGenInfo &setEntryBlockArgs(const EntryBlockArgs *value) {
1051+
blockArgs = value;
1052+
return *this;
1053+
}
1054+
10501055
OpWithBodyGenInfo &setGenRegionEntryCb(GenOMPRegionEntryCBFn value) {
10511056
genRegionEntryCB = value;
10521057
return *this;
@@ -1073,8 +1078,12 @@ struct OpWithBodyGenInfo {
10731078
const List<Clause> *clauses = nullptr;
10741079
/// [in] if provided, processes the construct's data-sharing attributes.
10751080
DataSharingProcessor *dsp = nullptr;
1076-
/// [in] if provided, emits the op's region entry. Otherwise, an emtpy block
1077-
/// is created in the region.
1081+
/// [in] if provided, it is used to create the op's region entry block. It is
1082+
/// overriden when a \see genRegionEntryCB is provided. This is only valid for
1083+
/// operations implementing the \see mlir::omp::BlockArgOpenMPOpInterface.
1084+
const EntryBlockArgs *blockArgs = nullptr;
1085+
/// [in] if provided, it overrides the default op's region entry block
1086+
/// creation.
10781087
GenOMPRegionEntryCBFn genRegionEntryCB = nullptr;
10791088
/// [in] if set to `true`, skip generating nested evaluations and dispatching
10801089
/// any further leaf constructs.
@@ -1098,18 +1107,33 @@ static void createBodyOfOp(mlir::Operation &op, const OpWithBodyGenInfo &info,
10981107
return undef.getDefiningOp();
10991108
};
11001109

1101-
// If an argument for the region is provided then create the block with that
1102-
// argument. Also update the symbol's address with the mlir argument value.
1103-
// e.g. For loops the argument is the induction variable. And all further
1104-
// uses of the induction variable should use this mlir value.
1110+
// Create the entry block for the region and collect its arguments for use
1111+
// within the region. The entry block will be created as follows:
1112+
// - By default, it will be empty and have no arguments.
1113+
// - Operations implementing the omp::BlockArgOpenMPOpInterface can set the
1114+
// `info.blockArgs` pointer so that block arguments will be those
1115+
// corresponding to entry block argument-generating clauses. Binding of
1116+
// Fortran symbols to the new MLIR values is done automatically.
1117+
// - If the `info.genRegionEntryCB` callback is set, it takes precedence and
1118+
// allows callers to manually create the entry block with its intended
1119+
// list of arguments and to bind these arguments to their corresponding
1120+
// Fortran symbols. This is used for e.g. loop induction variables.
11051121
auto regionArgs = [&]() -> llvm::SmallVector<const semantics::Symbol *> {
1106-
if (info.genRegionEntryCB != nullptr) {
1122+
if (info.genRegionEntryCB)
11071123
return info.genRegionEntryCB(&op);
1124+
1125+
if (info.blockArgs) {
1126+
genEntryBlock(firOpBuilder, *info.blockArgs, op.getRegion(0));
1127+
bindEntryBlockArgs(info.converter,
1128+
llvm::cast<mlir::omp::BlockArgOpenMPOpInterface>(op),
1129+
*info.blockArgs);
1130+
return llvm::to_vector(info.blockArgs->getSyms());
11081131
}
11091132

11101133
firOpBuilder.createBlock(&op.getRegion(0));
11111134
return {};
11121135
}();
1136+
11131137
// Mark the earliest insertion point.
11141138
mlir::Operation *marker = insertMarker(firOpBuilder);
11151139

@@ -1977,20 +2001,14 @@ genParallelOp(lower::AbstractConverter &converter, lower::SymMap &symTable,
19772001
mlir::omp::ParallelOperands &clauseOps,
19782002
const EntryBlockArgs &args, DataSharingProcessor *dsp,
19792003
bool isComposite = false) {
1980-
auto genRegionEntryCB = [&](mlir::Operation *op) {
1981-
genEntryBlock(converter.getFirOpBuilder(), args, op->getRegion(0));
1982-
bindEntryBlockArgs(
1983-
converter, llvm::cast<mlir::omp::BlockArgOpenMPOpInterface>(op), args);
1984-
return llvm::to_vector(args.getSyms());
1985-
};
1986-
19872004
assert((!enableDelayedPrivatization || dsp) &&
19882005
"expected valid DataSharingProcessor");
2006+
19892007
OpWithBodyGenInfo genInfo =
19902008
OpWithBodyGenInfo(converter, symTable, semaCtx, loc, eval,
19912009
llvm::omp::Directive::OMPD_parallel)
19922010
.setClauses(&item->clauses)
1993-
.setGenRegionEntryCb(genRegionEntryCB)
2011+
.setEntryBlockArgs(&args)
19942012
.setGenSkeletonOnly(isComposite)
19952013
.setDataSharingProcessor(dsp);
19962014

@@ -2066,13 +2084,6 @@ genSectionsOp(lower::AbstractConverter &converter, lower::SymMap &symTable,
20662084
mlir::Operation *terminator =
20672085
lower::genOpenMPTerminator(builder, sectionsOp, loc);
20682086

2069-
auto genRegionEntryCB = [&](mlir::Operation *op) {
2070-
genEntryBlock(builder, args, op->getRegion(0));
2071-
bindEntryBlockArgs(
2072-
converter, llvm::cast<mlir::omp::BlockArgOpenMPOpInterface>(op), args);
2073-
return llvm::to_vector(args.getSyms());
2074-
};
2075-
20762087
// Generate nested SECTION constructs.
20772088
// This is done here rather than in genOMP([...], OpenMPSectionConstruct )
20782089
// because we need to run genReductionVars on each omp.section so that the
@@ -2096,7 +2107,7 @@ genSectionsOp(lower::AbstractConverter &converter, lower::SymMap &symTable,
20962107
OpWithBodyGenInfo(converter, symTable, semaCtx, loc, nestedEval,
20972108
llvm::omp::Directive::OMPD_section)
20982109
.setClauses(&sectionQueue.begin()->clauses)
2099-
.setGenRegionEntryCb(genRegionEntryCB),
2110+
.setEntryBlockArgs(&args),
21002111
sectionQueue, sectionQueue.begin());
21012112
}
21022113

@@ -2435,20 +2446,12 @@ genTaskOp(lower::AbstractConverter &converter, lower::SymMap &symTable,
24352446
taskArgs.priv.syms = dsp.getDelayedPrivSymbols();
24362447
taskArgs.priv.vars = clauseOps.privateVars;
24372448

2438-
auto genRegionEntryCB = [&](mlir::Operation *op) {
2439-
genEntryBlock(converter.getFirOpBuilder(), taskArgs, op->getRegion(0));
2440-
bindEntryBlockArgs(converter,
2441-
llvm::cast<mlir::omp::BlockArgOpenMPOpInterface>(op),
2442-
taskArgs);
2443-
return llvm::to_vector(taskArgs.priv.syms);
2444-
};
2445-
24462449
return genOpWithBody<mlir::omp::TaskOp>(
24472450
OpWithBodyGenInfo(converter, symTable, semaCtx, loc, eval,
24482451
llvm::omp::Directive::OMPD_task)
24492452
.setClauses(&item->clauses)
24502453
.setDataSharingProcessor(&dsp)
2451-
.setGenRegionEntryCb(genRegionEntryCB),
2454+
.setEntryBlockArgs(&taskArgs),
24522455
queue, item, clauseOps);
24532456
}
24542457

@@ -2524,18 +2527,11 @@ genTeamsOp(lower::AbstractConverter &converter, lower::SymMap &symTable,
25242527
args.reduction.syms = reductionSyms;
25252528
args.reduction.vars = clauseOps.reductionVars;
25262529

2527-
auto genRegionEntryCB = [&](mlir::Operation *op) {
2528-
genEntryBlock(converter.getFirOpBuilder(), args, op->getRegion(0));
2529-
bindEntryBlockArgs(
2530-
converter, llvm::cast<mlir::omp::BlockArgOpenMPOpInterface>(op), args);
2531-
return llvm::to_vector(args.getSyms());
2532-
};
2533-
25342530
return genOpWithBody<mlir::omp::TeamsOp>(
25352531
OpWithBodyGenInfo(converter, symTable, semaCtx, loc, eval,
25362532
llvm::omp::Directive::OMPD_teams)
25372533
.setClauses(&item->clauses)
2538-
.setGenRegionEntryCb(genRegionEntryCB),
2534+
.setEntryBlockArgs(&args),
25392535
queue, item, clauseOps);
25402536
}
25412537

0 commit comments

Comments
 (0)