Skip to content

Commit 201a1fa

Browse files
committed
[MLIR][OpenMP][OMPIRBuilder] Error propagation across callbacks
This is a small proof of concept showing an approach to communicate errors between MLIR to LLVM IR translation of the OpenMP dialect and the OMPIRBuilder. It only implements the approach for a single case, so it doesn't compile or run, since it's only intended to show how it could look like and discuss it before investing too much effort on a full implementation. The main idea is to use `llvm::Error` objects returned by callbacks passed to `OMPIRBuilder` codegen functions that they can then check and forward back to the caller to avoid continuing after an error has been hit. The caller then emits an MLIR error diagnostic based on that and stops the translation process. This should prevent encountering any unsupported operations or arguments, or any other unexpected error from resulting in a compiler crash. Instead, a descriptive error message is presented to users.
1 parent 15d8576 commit 201a1fa

File tree

3 files changed

+36
-25
lines changed

3 files changed

+36
-25
lines changed

llvm/include/llvm/Frontend/OpenMP/OMPIRBuilder.h

Lines changed: 15 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -589,15 +589,19 @@ class OpenMPIRBuilder {
589589
/// not be split.
590590
/// \param CodeGenIP is the insertion point at which the body code should be
591591
/// placed.
592+
///
593+
/// \return an error, if any were triggered during execution.
592594
using BodyGenCallbackTy =
593-
function_ref<void(InsertPointTy AllocaIP, InsertPointTy CodeGenIP)>;
595+
function_ref<Error(InsertPointTy AllocaIP, InsertPointTy CodeGenIP)>;
594596

595597
// This is created primarily for sections construct as llvm::function_ref
596598
// (BodyGenCallbackTy) is not storable (as described in the comments of
597599
// function_ref class - function_ref contains non-ownable reference
598600
// to the callable.
601+
///
602+
/// \return an error, if any were triggered during execution.
599603
using StorableBodyGenCallbackTy =
600-
std::function<void(InsertPointTy AllocaIP, InsertPointTy CodeGenIP)>;
604+
std::function<Error(InsertPointTy AllocaIP, InsertPointTy CodeGenIP)>;
601605

602606
/// Callback type for loop body code generation.
603607
///
@@ -607,8 +611,10 @@ class OpenMPIRBuilder {
607611
/// terminated with an unconditional branch to the loop
608612
/// latch.
609613
/// \param IndVar is the induction variable usable at the insertion point.
614+
///
615+
/// \return an error, if any were triggered during execution.
610616
using LoopBodyGenCallbackTy =
611-
function_ref<void(InsertPointTy CodeGenIP, Value *IndVar)>;
617+
function_ref<Error(InsertPointTy CodeGenIP, Value *IndVar)>;
612618

613619
/// Callback type for variable privatization (think copy & default
614620
/// constructor).
@@ -626,9 +632,9 @@ class OpenMPIRBuilder {
626632
/// \param ReplVal The replacement value, thus a copy or new created version
627633
/// of \p Inner.
628634
///
629-
/// \returns The new insertion point where code generation continues and
630-
/// \p ReplVal the replacement value.
631-
using PrivatizeCallbackTy = function_ref<InsertPointTy(
635+
/// \returns The new insertion point where code generation continues or an
636+
/// error, and \p ReplVal the replacement value.
637+
using PrivatizeCallbackTy = function_ref<Expected<InsertPointTy>(
632638
InsertPointTy AllocaIP, InsertPointTy CodeGenIP, Value &Original,
633639
Value &Inner, Value *&ReplVal)>;
634640

@@ -1262,9 +1268,9 @@ class OpenMPIRBuilder {
12621268
/// \param Loc The location where the taskgroup construct was encountered.
12631269
/// \param AllocaIP The insertion point to be used for alloca instructions.
12641270
/// \param BodyGenCB Callback that will generate the region code.
1265-
InsertPointTy createTaskgroup(const LocationDescription &Loc,
1266-
InsertPointTy AllocaIP,
1267-
BodyGenCallbackTy BodyGenCB);
1271+
Expected<InsertPointTy> createTaskgroup(const LocationDescription &Loc,
1272+
InsertPointTy AllocaIP,
1273+
BodyGenCallbackTy BodyGenCB);
12681274

12691275
using FileIdentifierInfoCallbackTy =
12701276
std::function<std::tuple<std::string, uint64_t>()>;

llvm/lib/Frontend/OpenMP/OMPIRBuilder.cpp

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2048,7 +2048,7 @@ OpenMPIRBuilder::createTask(const LocationDescription &Loc,
20482048
return Builder.saveIP();
20492049
}
20502050

2051-
OpenMPIRBuilder::InsertPointTy
2051+
Expected<OpenMPIRBuilder::InsertPointTy>
20522052
OpenMPIRBuilder::createTaskgroup(const LocationDescription &Loc,
20532053
InsertPointTy AllocaIP,
20542054
BodyGenCallbackTy BodyGenCB) {
@@ -2066,7 +2066,8 @@ OpenMPIRBuilder::createTaskgroup(const LocationDescription &Loc,
20662066
Builder.CreateCall(TaskgroupFn, {Ident, ThreadID});
20672067

20682068
BasicBlock *TaskgroupExitBB = splitBB(Builder, true, "taskgroup.exit");
2069-
BodyGenCB(AllocaIP, Builder.saveIP());
2069+
if (auto Err = BodyGenCB(AllocaIP, Builder.saveIP()))
2070+
return std::move(Err);
20702071

20712072
Builder.SetInsertPoint(TaskgroupExitBB);
20722073
// Emit the @__kmpc_end_taskgroup runtime call to end the taskgroup

mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp

Lines changed: 18 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -137,9 +137,9 @@ findAllocaInsertPoint(llvm::IRBuilderBase &builder,
137137
/// region, and a branch from any block with an successor-less OpenMP terminator
138138
/// to `continuationBlock`. Populates `continuationBlockPHIs` with the PHI nodes
139139
/// of the continuation block if provided.
140-
static llvm::BasicBlock *convertOmpOpRegions(
140+
static llvm::Expected<llvm::BasicBlock *> convertOmpOpRegions(
141141
Region &region, StringRef blockName, llvm::IRBuilderBase &builder,
142-
LLVM::ModuleTranslation &moduleTranslation, LogicalResult &bodyGenStatus,
142+
LLVM::ModuleTranslation &moduleTranslation,
143143
SmallVectorImpl<llvm::PHINode *> *continuationBlockPHIs = nullptr) {
144144
llvm::BasicBlock *continuationBlock =
145145
splitBB(builder, true, "omp.region.cont");
@@ -215,10 +215,8 @@ static llvm::BasicBlock *convertOmpOpRegions(
215215

216216
llvm::IRBuilderBase::InsertPointGuard guard(builder);
217217
if (failed(
218-
moduleTranslation.convertBlock(*bb, bb->isEntryBlock(), builder))) {
219-
bodyGenStatus = failure();
220-
return continuationBlock;
221-
}
218+
moduleTranslation.convertBlock(*bb, bb->isEntryBlock(), builder)))
219+
return llvm::createStringError("failed region translation");
222220

223221
// Special handling for `omp.yield` and `omp.terminator` (we may have more
224222
// than one): they return the control to the parent OpenMP dialect operation
@@ -1145,20 +1143,26 @@ static LogicalResult
11451143
convertOmpTaskgroupOp(omp::TaskgroupOp tgOp, llvm::IRBuilderBase &builder,
11461144
LLVM::ModuleTranslation &moduleTranslation) {
11471145
using InsertPointTy = llvm::OpenMPIRBuilder::InsertPointTy;
1148-
LogicalResult bodyGenStatus = success();
1149-
if (!tgOp.getTaskReductionVars().empty() || !tgOp.getAllocateVars().empty()) {
1146+
if (!tgOp.getTaskReductionVars().empty() || !tgOp.getAllocateVars().empty())
11501147
return tgOp.emitError("unhandled clauses for translation to LLVM IR");
1151-
}
1148+
11521149
auto bodyCB = [&](InsertPointTy allocaIP, InsertPointTy codegenIP) {
11531150
builder.restoreIP(codegenIP);
1154-
convertOmpOpRegions(tgOp.getRegion(), "omp.taskgroup.region", builder,
1155-
moduleTranslation, bodyGenStatus);
1151+
return convertOmpOpRegions(tgOp.getRegion(), "omp.taskgroup.region",
1152+
builder, moduleTranslation)
1153+
.takeError();
11561154
};
1155+
11571156
InsertPointTy allocaIP = findAllocaInsertPoint(builder, moduleTranslation);
11581157
llvm::OpenMPIRBuilder::LocationDescription ompLoc(builder);
1159-
builder.restoreIP(moduleTranslation.getOpenMPBuilder()->createTaskgroup(
1160-
ompLoc, allocaIP, bodyCB));
1161-
return bodyGenStatus;
1158+
auto result = moduleTranslation.getOpenMPBuilder()->createTaskgroup(
1159+
ompLoc, allocaIP, bodyCB);
1160+
1161+
if (!result)
1162+
return tgOp.emitError(llvm::toString(result.takeError()));
1163+
1164+
builder.restoreIP(*result);
1165+
return success();
11621166
}
11631167

11641168
static LogicalResult

0 commit comments

Comments
 (0)