Skip to content

Commit a18ac65

Browse files
committed
[MLIR][OpenMP] Prevent loop wrapper translation crashes
This patch updates the `convertOmpOpRegions` translation function to prevent calling it for a loop wrapper region from causing a compiler crash due to a lack of terminator operations. This problem is currently not triggered because there are no cases for which the region of a loop wrapper is passed to that function. This might have to change in order to support composite construct translation to LLVM IR.
1 parent 85eec89 commit a18ac65

File tree

1 file changed

+33
-20
lines changed

1 file changed

+33
-20
lines changed

mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp

Lines changed: 33 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -391,6 +391,8 @@ static llvm::Expected<llvm::BasicBlock *> convertOmpOpRegions(
391391
Region &region, StringRef blockName, llvm::IRBuilderBase &builder,
392392
LLVM::ModuleTranslation &moduleTranslation,
393393
SmallVectorImpl<llvm::PHINode *> *continuationBlockPHIs = nullptr) {
394+
bool isLoopWrapper = isa<omp::LoopWrapperInterface>(region.getParentOp());
395+
394396
llvm::BasicBlock *continuationBlock =
395397
splitBB(builder, true, "omp.region.cont");
396398
llvm::BasicBlock *sourceBlock = builder.GetInsertBlock();
@@ -407,30 +409,34 @@ static llvm::Expected<llvm::BasicBlock *> convertOmpOpRegions(
407409

408410
// Terminators (namely YieldOp) may be forwarding values to the region that
409411
// need to be available in the continuation block. Collect the types of these
410-
// operands in preparation of creating PHI nodes.
412+
// operands in preparation of creating PHI nodes. This is skipped for loop
413+
// wrapper operations, for which we know in advance they have no terminators.
411414
SmallVector<llvm::Type *> continuationBlockPHITypes;
412-
bool operandsProcessed = false;
413415
unsigned numYields = 0;
414-
for (Block &bb : region.getBlocks()) {
415-
if (omp::YieldOp yield = dyn_cast<omp::YieldOp>(bb.getTerminator())) {
416-
if (!operandsProcessed) {
417-
for (unsigned i = 0, e = yield->getNumOperands(); i < e; ++i) {
418-
continuationBlockPHITypes.push_back(
419-
moduleTranslation.convertType(yield->getOperand(i).getType()));
420-
}
421-
operandsProcessed = true;
422-
} else {
423-
assert(continuationBlockPHITypes.size() == yield->getNumOperands() &&
424-
"mismatching number of values yielded from the region");
425-
for (unsigned i = 0, e = yield->getNumOperands(); i < e; ++i) {
426-
llvm::Type *operandType =
427-
moduleTranslation.convertType(yield->getOperand(i).getType());
428-
(void)operandType;
429-
assert(continuationBlockPHITypes[i] == operandType &&
430-
"values of mismatching types yielded from the region");
416+
417+
if (!isLoopWrapper) {
418+
bool operandsProcessed = false;
419+
for (Block &bb : region.getBlocks()) {
420+
if (omp::YieldOp yield = dyn_cast<omp::YieldOp>(bb.getTerminator())) {
421+
if (!operandsProcessed) {
422+
for (unsigned i = 0, e = yield->getNumOperands(); i < e; ++i) {
423+
continuationBlockPHITypes.push_back(
424+
moduleTranslation.convertType(yield->getOperand(i).getType()));
425+
}
426+
operandsProcessed = true;
427+
} else {
428+
assert(continuationBlockPHITypes.size() == yield->getNumOperands() &&
429+
"mismatching number of values yielded from the region");
430+
for (unsigned i = 0, e = yield->getNumOperands(); i < e; ++i) {
431+
llvm::Type *operandType =
432+
moduleTranslation.convertType(yield->getOperand(i).getType());
433+
(void)operandType;
434+
assert(continuationBlockPHITypes[i] == operandType &&
435+
"values of mismatching types yielded from the region");
436+
}
431437
}
438+
numYields++;
432439
}
433-
numYields++;
434440
}
435441
}
436442

@@ -468,6 +474,13 @@ static llvm::Expected<llvm::BasicBlock *> convertOmpOpRegions(
468474
moduleTranslation.convertBlock(*bb, bb->isEntryBlock(), builder)))
469475
return llvm::make_error<PreviouslyReportedError>();
470476

477+
// Create a direct branch here for loop wrappers to prevent their lack of a
478+
// terminator from causing a crash below.
479+
if (isLoopWrapper) {
480+
builder.CreateBr(continuationBlock);
481+
continue;
482+
}
483+
471484
// Special handling for `omp.yield` and `omp.terminator` (we may have more
472485
// than one): they return the control to the parent OpenMP dialect operation
473486
// so replace them with the branch to the continuation block. We handle this

0 commit comments

Comments
 (0)