Skip to content

Commit 50619d2

Browse files
committed
[MLIR][OpenMP] Reduce overhead of target compilation
This patch avoids calling `TargetOp::getInnermostCapturedOmpOp` multiple times during initialization of default and runtime target attributes in MLIR to LLVM IR translation of `omp.target` operations. This is a potentially expensive operation, so this change should help keep compile times lower.
1 parent cbeae3e commit 50619d2

File tree

3 files changed

+32
-14
lines changed

3 files changed

+32
-14
lines changed

mlir/include/mlir/Dialect/OpenMP/OpenMPOps.td

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1343,11 +1343,20 @@ def TargetOp : OpenMP_Op<"target", traits = [
13431343
///
13441344
/// If there are omp.loop_nest operations in the sequence of nested
13451345
/// operations, the top level one will be the one captured.
1346+
///
1347+
/// This is a relatively expensive operation, so if it is needed at the same
1348+
/// time as the result of `getKernelExecFlags` it is preferable to cache the
1349+
/// result of this function and pass it along.
13461350
Operation *getInnermostCapturedOmpOp();
13471351

13481352
/// Infers the kernel type (Generic, SPMD or Generic-SPMD) based on the
13491353
/// contents of the target region.
1350-
llvm::omp::OMPTgtExecModeFlags getKernelExecFlags();
1354+
///
1355+
/// \param capturedOp result of a still valid (no modifications made to any
1356+
/// nested operations) previous call to `getInnermostCapturedOmpOp()`. If
1357+
/// not specified, this will call that function itself instead.
1358+
llvm::omp::OMPTgtExecModeFlags
1359+
getKernelExecFlags(std::optional<Operation *> capturedOp = std::nullopt);
13511360
}] # clausesExtraClassDeclaration;
13521361

13531362
let assemblyFormat = clausesAssemblyFormat # [{

mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp

Lines changed: 11 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -2025,17 +2025,24 @@ Operation *TargetOp::getInnermostCapturedOmpOp() {
20252025
return capturedOp;
20262026
}
20272027

2028-
llvm::omp::OMPTgtExecModeFlags TargetOp::getKernelExecFlags() {
2028+
llvm::omp::OMPTgtExecModeFlags
2029+
TargetOp::getKernelExecFlags(std::optional<Operation *> capturedOp) {
20292030
using namespace llvm::omp;
20302031

2032+
// Use a cached operation, if passed in. Otherwise, find the innermost
2033+
// captured operation.
2034+
if (!capturedOp)
2035+
capturedOp = getInnermostCapturedOmpOp();
2036+
assert(*capturedOp == getInnermostCapturedOmpOp() &&
2037+
"unexpected captured op");
2038+
20312039
// Make sure this region is capturing a loop. Otherwise, it's a generic
20322040
// kernel.
2033-
Operation *capturedOp = getInnermostCapturedOmpOp();
2034-
if (!isa_and_present<LoopNestOp>(capturedOp))
2041+
if (!isa_and_present<LoopNestOp>(*capturedOp))
20352042
return OMP_TGT_EXEC_MODE_GENERIC;
20362043

20372044
SmallVector<LoopWrapperInterface> wrappers;
2038-
cast<LoopNestOp>(capturedOp).gatherWrappers(wrappers);
2045+
cast<LoopNestOp>(*capturedOp).gatherWrappers(wrappers);
20392046
assert(!wrappers.empty());
20402047

20412048
// Ignore optional SIMD leaf construct.

mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp

Lines changed: 11 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -4558,11 +4558,10 @@ static std::optional<int64_t> extractConstInteger(Value value) {
45584558
/// function for the target region, so that they can be used to initialize the
45594559
/// corresponding global `ConfigurationEnvironmentTy` structure.
45604560
static void
4561-
initTargetDefaultAttrs(omp::TargetOp targetOp,
4561+
initTargetDefaultAttrs(omp::TargetOp targetOp, Operation *capturedOp,
45624562
llvm::OpenMPIRBuilder::TargetKernelDefaultAttrs &attrs,
45634563
bool isTargetDevice) {
45644564
// TODO: Handle constant 'if' clauses.
4565-
Operation *capturedOp = targetOp.getInnermostCapturedOmpOp();
45664565

45674566
Value numThreads, numTeamsLower, numTeamsUpper, threadLimit;
45684567
if (!isTargetDevice) {
@@ -4644,7 +4643,7 @@ initTargetDefaultAttrs(omp::TargetOp targetOp,
46444643
combinedMaxThreadsVal = maxThreadsVal;
46454644

46464645
// Update kernel bounds structure for the `OpenMPIRBuilder` to use.
4647-
attrs.ExecFlags = targetOp.getKernelExecFlags();
4646+
attrs.ExecFlags = targetOp.getKernelExecFlags(capturedOp);
46484647
attrs.MinTeams = minTeamsVal;
46494648
attrs.MaxTeams.front() = maxTeamsVal;
46504649
attrs.MinThreads = 1;
@@ -4660,10 +4659,9 @@ initTargetDefaultAttrs(omp::TargetOp targetOp,
46604659
static void
46614660
initTargetRuntimeAttrs(llvm::IRBuilderBase &builder,
46624661
LLVM::ModuleTranslation &moduleTranslation,
4663-
omp::TargetOp targetOp,
4662+
omp::TargetOp targetOp, Operation *capturedOp,
46644663
llvm::OpenMPIRBuilder::TargetKernelRuntimeAttrs &attrs) {
4665-
omp::LoopNestOp loopOp = castOrGetParentOfType<omp::LoopNestOp>(
4666-
targetOp.getInnermostCapturedOmpOp());
4664+
omp::LoopNestOp loopOp = castOrGetParentOfType<omp::LoopNestOp>(capturedOp);
46674665
unsigned numLoops = loopOp ? loopOp.getNumLoops() : 0;
46684666

46694667
Value numThreads, numTeamsLower, numTeamsUpper, teamsThreadLimit;
@@ -4690,7 +4688,8 @@ initTargetRuntimeAttrs(llvm::IRBuilderBase &builder,
46904688
if (numThreads)
46914689
attrs.MaxThreads = moduleTranslation.lookupValue(numThreads);
46924690

4693-
if (targetOp.getKernelExecFlags() != llvm::omp::OMP_TGT_EXEC_MODE_GENERIC) {
4691+
if (targetOp.getKernelExecFlags(capturedOp) !=
4692+
llvm::omp::OMP_TGT_EXEC_MODE_GENERIC) {
46944693
llvm::OpenMPIRBuilder *ompBuilder = moduleTranslation.getOpenMPBuilder();
46954694
attrs.LoopTripCount = nullptr;
46964695

@@ -4940,12 +4939,15 @@ convertOmpTarget(Operation &opInst, llvm::IRBuilderBase &builder,
49404939

49414940
llvm::OpenMPIRBuilder::TargetKernelRuntimeAttrs runtimeAttrs;
49424941
llvm::OpenMPIRBuilder::TargetKernelDefaultAttrs defaultAttrs;
4943-
initTargetDefaultAttrs(targetOp, defaultAttrs, isTargetDevice);
4942+
Operation *targetCapturedOp = targetOp.getInnermostCapturedOmpOp();
4943+
initTargetDefaultAttrs(targetOp, targetCapturedOp, defaultAttrs,
4944+
isTargetDevice);
49444945

49454946
// Collect host-evaluated values needed to properly launch the kernel from the
49464947
// host.
49474948
if (!isTargetDevice)
4948-
initTargetRuntimeAttrs(builder, moduleTranslation, targetOp, runtimeAttrs);
4949+
initTargetRuntimeAttrs(builder, moduleTranslation, targetOp,
4950+
targetCapturedOp, runtimeAttrs);
49494951

49504952
// Pass host-evaluated values as parameters to the kernel / host fallback,
49514953
// except if they are constants. In any case, map the MLIR block argument to

0 commit comments

Comments
 (0)