Skip to content

Commit d32e05d

Browse files
committed
[Flang][OpenMP] Simplify entry block creation for BlockArgOpenMPOpInterface ops, NFC
This patch adds the `OpWithBodyGenInfo::blockArgs` field and updates `createBodyOfOp()` to prevent the need for `BlockArgOpenMPOpInterface` operations to pass the same callback, minimizing chances of introducing inconsistent behavior.
1 parent 3adf2b0 commit d32e05d

File tree

1 file changed

+36
-40
lines changed

1 file changed

+36
-40
lines changed

flang/lib/Lower/OpenMP/OpenMP.cpp

Lines changed: 36 additions & 40 deletions
Original file line numberDiff line numberDiff line change
@@ -1048,6 +1048,11 @@ struct OpWithBodyGenInfo {
10481048
return *this;
10491049
}
10501050

1051+
OpWithBodyGenInfo &setEntryBlockArgs(const EntryBlockArgs *value) {
1052+
blockArgs = value;
1053+
return *this;
1054+
}
1055+
10511056
OpWithBodyGenInfo &setGenRegionEntryCb(GenOMPRegionEntryCBFn value) {
10521057
genRegionEntryCB = value;
10531058
return *this;
@@ -1074,8 +1079,12 @@ struct OpWithBodyGenInfo {
10741079
const List<Clause> *clauses = nullptr;
10751080
/// [in] if provided, processes the construct's data-sharing attributes.
10761081
DataSharingProcessor *dsp = nullptr;
1077-
/// [in] if provided, emits the op's region entry. Otherwise, an emtpy block
1078-
/// is created in the region.
1082+
/// [in] if provided, it is used to create the op's region entry block. It is
1083+
/// overriden when a \see genRegionEntryCB is provided. This is only valid for
1084+
/// operations implementing the \see mlir::omp::BlockArgOpenMPOpInterface.
1085+
const EntryBlockArgs *blockArgs = nullptr;
1086+
/// [in] if provided, it overrides the default op's region entry block
1087+
/// creation.
10791088
GenOMPRegionEntryCBFn genRegionEntryCB = nullptr;
10801089
/// [in] if set to `true`, skip generating nested evaluations and dispatching
10811090
/// any further leaf constructs.
@@ -1099,18 +1108,33 @@ static void createBodyOfOp(mlir::Operation &op, const OpWithBodyGenInfo &info,
10991108
return undef.getDefiningOp();
11001109
};
11011110

1102-
// If an argument for the region is provided then create the block with that
1103-
// argument. Also update the symbol's address with the mlir argument value.
1104-
// e.g. For loops the argument is the induction variable. And all further
1105-
// uses of the induction variable should use this mlir value.
1111+
// Create the entry block for the region and collect its arguments for use
1112+
// within the region. The entry block will be created as follows:
1113+
// - By default, it will be empty and have no arguments.
1114+
// - Operations implementing the omp::BlockArgOpenMPOpInterface can set the
1115+
// `info.blockArgs` pointer so that block arguments will be those
1116+
// corresponding to entry block argument-generating clauses. Binding of
1117+
// Fortran symbols to the new MLIR values is done automatically.
1118+
// - If the `info.genRegionEntryCB` callback is set, it takes precedence and
1119+
// allows callers to manually create the entry block with its intended
1120+
// list of arguments and to bind these arguments to their corresponding
1121+
// Fortran symbols. This is used for e.g. loop induction variables.
11061122
auto regionArgs = [&]() -> llvm::SmallVector<const semantics::Symbol *> {
1107-
if (info.genRegionEntryCB != nullptr) {
1123+
if (info.genRegionEntryCB)
11081124
return info.genRegionEntryCB(&op);
1125+
1126+
if (info.blockArgs) {
1127+
genEntryBlock(firOpBuilder, *info.blockArgs, op.getRegion(0));
1128+
bindEntryBlockArgs(info.converter,
1129+
llvm::cast<mlir::omp::BlockArgOpenMPOpInterface>(op),
1130+
*info.blockArgs);
1131+
return llvm::to_vector(info.blockArgs->getSyms());
11091132
}
11101133

11111134
firOpBuilder.createBlock(&op.getRegion(0));
11121135
return {};
11131136
}();
1137+
11141138
// Mark the earliest insertion point.
11151139
mlir::Operation *marker = insertMarker(firOpBuilder);
11161140

@@ -1978,20 +2002,14 @@ genParallelOp(lower::AbstractConverter &converter, lower::SymMap &symTable,
19782002
mlir::omp::ParallelOperands &clauseOps,
19792003
const EntryBlockArgs &args, DataSharingProcessor *dsp,
19802004
bool isComposite = false) {
1981-
auto genRegionEntryCB = [&](mlir::Operation *op) {
1982-
genEntryBlock(converter.getFirOpBuilder(), args, op->getRegion(0));
1983-
bindEntryBlockArgs(
1984-
converter, llvm::cast<mlir::omp::BlockArgOpenMPOpInterface>(op), args);
1985-
return llvm::to_vector(args.getSyms());
1986-
};
1987-
19882005
assert((!enableDelayedPrivatization || dsp) &&
19892006
"expected valid DataSharingProcessor");
2007+
19902008
OpWithBodyGenInfo genInfo =
19912009
OpWithBodyGenInfo(converter, symTable, semaCtx, loc, eval,
19922010
llvm::omp::Directive::OMPD_parallel)
19932011
.setClauses(&item->clauses)
1994-
.setGenRegionEntryCb(genRegionEntryCB)
2012+
.setEntryBlockArgs(&args)
19952013
.setGenSkeletonOnly(isComposite)
19962014
.setDataSharingProcessor(dsp);
19972015

@@ -2067,13 +2085,6 @@ genSectionsOp(lower::AbstractConverter &converter, lower::SymMap &symTable,
20672085
mlir::Operation *terminator =
20682086
lower::genOpenMPTerminator(builder, sectionsOp, loc);
20692087

2070-
auto genRegionEntryCB = [&](mlir::Operation *op) {
2071-
genEntryBlock(builder, args, op->getRegion(0));
2072-
bindEntryBlockArgs(
2073-
converter, llvm::cast<mlir::omp::BlockArgOpenMPOpInterface>(op), args);
2074-
return llvm::to_vector(args.getSyms());
2075-
};
2076-
20772088
// Generate nested SECTION constructs.
20782089
// This is done here rather than in genOMP([...], OpenMPSectionConstruct )
20792090
// because we need to run genReductionVars on each omp.section so that the
@@ -2097,7 +2108,7 @@ genSectionsOp(lower::AbstractConverter &converter, lower::SymMap &symTable,
20972108
OpWithBodyGenInfo(converter, symTable, semaCtx, loc, nestedEval,
20982109
llvm::omp::Directive::OMPD_section)
20992110
.setClauses(&sectionQueue.begin()->clauses)
2100-
.setGenRegionEntryCb(genRegionEntryCB),
2111+
.setEntryBlockArgs(&args),
21012112
sectionQueue, sectionQueue.begin());
21022113
}
21032114

@@ -2436,20 +2447,12 @@ genTaskOp(lower::AbstractConverter &converter, lower::SymMap &symTable,
24362447
taskArgs.priv.syms = dsp.getDelayedPrivSymbols();
24372448
taskArgs.priv.vars = clauseOps.privateVars;
24382449

2439-
auto genRegionEntryCB = [&](mlir::Operation *op) {
2440-
genEntryBlock(converter.getFirOpBuilder(), taskArgs, op->getRegion(0));
2441-
bindEntryBlockArgs(converter,
2442-
llvm::cast<mlir::omp::BlockArgOpenMPOpInterface>(op),
2443-
taskArgs);
2444-
return llvm::to_vector(taskArgs.priv.syms);
2445-
};
2446-
24472450
return genOpWithBody<mlir::omp::TaskOp>(
24482451
OpWithBodyGenInfo(converter, symTable, semaCtx, loc, eval,
24492452
llvm::omp::Directive::OMPD_task)
24502453
.setClauses(&item->clauses)
24512454
.setDataSharingProcessor(&dsp)
2452-
.setGenRegionEntryCb(genRegionEntryCB),
2455+
.setEntryBlockArgs(&taskArgs),
24532456
queue, item, clauseOps);
24542457
}
24552458

@@ -2525,18 +2528,11 @@ genTeamsOp(lower::AbstractConverter &converter, lower::SymMap &symTable,
25252528
args.reduction.syms = reductionSyms;
25262529
args.reduction.vars = clauseOps.reductionVars;
25272530

2528-
auto genRegionEntryCB = [&](mlir::Operation *op) {
2529-
genEntryBlock(converter.getFirOpBuilder(), args, op->getRegion(0));
2530-
bindEntryBlockArgs(
2531-
converter, llvm::cast<mlir::omp::BlockArgOpenMPOpInterface>(op), args);
2532-
return llvm::to_vector(args.getSyms());
2533-
};
2534-
25352531
return genOpWithBody<mlir::omp::TeamsOp>(
25362532
OpWithBodyGenInfo(converter, symTable, semaCtx, loc, eval,
25372533
llvm::omp::Directive::OMPD_teams)
25382534
.setClauses(&item->clauses)
2539-
.setGenRegionEntryCb(genRegionEntryCB),
2535+
.setEntryBlockArgs(&args),
25402536
queue, item, clauseOps);
25412537
}
25422538

0 commit comments

Comments
 (0)