@@ -161,8 +161,18 @@ static LogicalResult checkImplementationStatus(Operation &op) {
161161 auto checkCancelDirective = [&todo](auto op, LogicalResult &result) {
162162 omp::ClauseCancellationConstructType cancelledDirective =
163163 op.getCancelDirective ();
164- if (cancelledDirective == omp::ClauseCancellationConstructType::Taskgroup)
165- result = todo (" cancel directive construct type not yet supported" );
164+ // Cancelling a taskloop is not yet supported because we don't yet have LLVM
165+ // IR conversion for taskloop
166+ if (cancelledDirective == omp::ClauseCancellationConstructType::Taskgroup) {
167+ Operation *parent = op->getParentOp ();
168+ while (parent) {
169+ if (parent->getDialect () == op->getDialect ())
170+ break ;
171+ parent = parent->getParentOp ();
172+ }
173+ if (isa_and_nonnull<omp::TaskloopOp>(parent))
174+ result = todo (" cancel directive inside of taskloop" );
175+ }
166176 };
167177 auto checkDepend = [&todo](auto op, LogicalResult &result) {
168178 if (!op.getDependVars ().empty () || op.getDependKinds ())
@@ -1889,6 +1899,55 @@ buildDependData(std::optional<ArrayAttr> dependKinds, OperandRange dependVars,
18891899 }
18901900}
18911901
1902+ // / Shared implementation of a callback which adds a termiator for the new block
1903+ // / created for the branch taken when an openmp construct is cancelled. The
1904+ // / terminator is saved in \p cancelTerminators. This callback is invoked only
1905+ // / if there is cancellation inside of the taskgroup body.
1906+ // / The terminator will need to be fixed to branch to the correct block to
1907+ // / cleanup the construct.
1908+ static void
1909+ pushCancelFinalizationCB (SmallVectorImpl<llvm::BranchInst *> &cancelTerminators,
1910+ llvm::IRBuilderBase &llvmBuilder,
1911+ llvm::OpenMPIRBuilder &ompBuilder, mlir::Operation *op,
1912+ llvm::omp::Directive cancelDirective) {
1913+ auto finiCB = [&](llvm::OpenMPIRBuilder::InsertPointTy ip) -> llvm::Error {
1914+ llvm::IRBuilderBase::InsertPointGuard guard (llvmBuilder);
1915+
1916+ // ip is currently in the block branched to if cancellation occured.
1917+ // We need to create a branch to terminate that block.
1918+ llvmBuilder.restoreIP (ip);
1919+
1920+ // We must still clean up the construct after cancelling it, so we need to
1921+ // branch to the block that finalizes the taskgroup.
1922+ // That block has not been created yet so use this block as a dummy for now
1923+ // and fix this after creating the operation.
1924+ cancelTerminators.push_back (llvmBuilder.CreateBr (ip.getBlock ()));
1925+ return llvm::Error::success ();
1926+ };
1927+ // We have to add the cleanup to the OpenMPIRBuilder before the body gets
1928+ // created in case the body contains omp.cancel (which will then expect to be
1929+ // able to find this cleanup callback).
1930+ ompBuilder.pushFinalizationCB (
1931+ {finiCB, cancelDirective, constructIsCancellable (op)});
1932+ }
1933+
1934+ // / If we cancelled the construct, we should branch to the finalization block of
1935+ // / that construct. OMPIRBuilder structures the CFG such that the cleanup block
1936+ // / is immediately before the continuation block. Now this finalization has
1937+ // / been created we can fix the branch.
1938+ static void
1939+ popCancelFinalizationCB (const ArrayRef<llvm::BranchInst *> cancelTerminators,
1940+ llvm::OpenMPIRBuilder &ompBuilder,
1941+ const llvm::OpenMPIRBuilder::InsertPointTy &afterIP) {
1942+ ompBuilder.popFinalizationCB ();
1943+ llvm::BasicBlock *constructFini = afterIP.getBlock ()->getSinglePredecessor ();
1944+ for (llvm::BranchInst *cancelBranch : cancelTerminators) {
1945+ assert (cancelBranch->getNumSuccessors () == 1 &&
1946+ " cancel branch should have one target" );
1947+ cancelBranch->setSuccessor (0 , constructFini);
1948+ }
1949+ }
1950+
18921951namespace {
18931952// / TaskContextStructManager takes care of creating and freeing a structure
18941953// / containing information needed by the task body to execute.
@@ -2202,6 +2261,14 @@ convertOmpTaskOp(omp::TaskOp taskOp, llvm::IRBuilderBase &builder,
22022261 return llvm::Error::success ();
22032262 };
22042263
2264+ llvm::OpenMPIRBuilder &ompBuilder = *moduleTranslation.getOpenMPBuilder ();
2265+ SmallVector<llvm::BranchInst *> cancelTerminators;
2266+ // The directive to match here is OMPD_taskgroup because it is the taskgroup
2267+ // which is canceled. This is handled here because it is the task's cleanup
2268+ // block which should be branched to.
2269+ pushCancelFinalizationCB (cancelTerminators, builder, ompBuilder, taskOp,
2270+ llvm::omp::Directive::OMPD_taskgroup);
2271+
22052272 SmallVector<llvm::OpenMPIRBuilder::DependData> dds;
22062273 buildDependData (taskOp.getDependKinds (), taskOp.getDependVars (),
22072274 moduleTranslation, dds);
@@ -2219,6 +2286,9 @@ convertOmpTaskOp(omp::TaskOp taskOp, llvm::IRBuilderBase &builder,
22192286 if (failed (handleError (afterIP, *taskOp)))
22202287 return failure ();
22212288
2289+ // Set the correct branch target for task cancellation
2290+ popCancelFinalizationCB (cancelTerminators, ompBuilder, afterIP.get ());
2291+
22222292 builder.restoreIP (*afterIP);
22232293 return success ();
22242294}
@@ -2349,28 +2419,8 @@ convertOmpWsloop(Operation &opInst, llvm::IRBuilderBase &builder,
23492419 : llvm::omp::WorksharingLoopType::ForStaticLoop;
23502420
23512421 SmallVector<llvm::BranchInst *> cancelTerminators;
2352- // This callback is invoked only if there is cancellation inside of the wsloop
2353- // body.
2354- auto finiCB = [&](llvm::OpenMPIRBuilder::InsertPointTy ip) -> llvm::Error {
2355- llvm::IRBuilderBase &llvmBuilder = ompBuilder->Builder ;
2356- llvm::IRBuilderBase::InsertPointGuard guard (llvmBuilder);
2357-
2358- // ip is currently in the block branched to if cancellation occured.
2359- // We need to create a branch to terminate that block.
2360- llvmBuilder.restoreIP (ip);
2361-
2362- // We must still clean up the wsloop after cancelling it, so we need to
2363- // branch to the block that finalizes the wsloop.
2364- // That block has not been created yet so use this block as a dummy for now
2365- // and fix this after creating the wsloop.
2366- cancelTerminators.push_back (llvmBuilder.CreateBr (ip.getBlock ()));
2367- return llvm::Error::success ();
2368- };
2369- // We have to add the cleanup to the OpenMPIRBuilder before the body gets
2370- // created in case the body contains omp.cancel (which will then expect to be
2371- // able to find this cleanup callback).
2372- ompBuilder->pushFinalizationCB ({finiCB, llvm::omp::Directive::OMPD_for,
2373- constructIsCancellable (wsloopOp)});
2422+ pushCancelFinalizationCB (cancelTerminators, builder, *ompBuilder, wsloopOp,
2423+ llvm::omp::Directive::OMPD_for);
23742424
23752425 llvm::OpenMPIRBuilder::LocationDescription ompLoc (builder);
23762426 llvm::Expected<llvm::BasicBlock *> regionBlock = convertOmpOpRegions (
@@ -2393,18 +2443,8 @@ convertOmpWsloop(Operation &opInst, llvm::IRBuilderBase &builder,
23932443 if (failed (handleError (wsloopIP, opInst)))
23942444 return failure ();
23952445
2396- ompBuilder->popFinalizationCB ();
2397- if (!cancelTerminators.empty ()) {
2398- // If we cancelled the loop, we should branch to the finalization block of
2399- // the wsloop (which is always immediately before the loop continuation
2400- // block). Now the finalization has been created, we can fix the branch.
2401- llvm::BasicBlock *wsloopFini = wsloopIP->getBlock ()->getSinglePredecessor ();
2402- for (llvm::BranchInst *cancelBranch : cancelTerminators) {
2403- assert (cancelBranch->getNumSuccessors () == 1 &&
2404- " cancel branch should have one target" );
2405- cancelBranch->setSuccessor (0 , wsloopFini);
2406- }
2407- }
2446+ // Set the correct branch target for task cancellation
2447+ popCancelFinalizationCB (cancelTerminators, *ompBuilder, wsloopIP.get ());
24082448
24092449 // Process the reductions if required.
24102450 if (failed (createReductionsAndCleanup (
@@ -3060,12 +3100,12 @@ static llvm::omp::Directive convertCancellationConstructType(
30603100static LogicalResult
30613101convertOmpCancel (omp::CancelOp op, llvm::IRBuilderBase &builder,
30623102 LLVM::ModuleTranslation &moduleTranslation) {
3063- llvm::OpenMPIRBuilder::LocationDescription ompLoc (builder);
3064- llvm::OpenMPIRBuilder *ompBuilder = moduleTranslation.getOpenMPBuilder ();
3065-
30663103 if (failed (checkImplementationStatus (*op.getOperation ())))
30673104 return failure ();
30683105
3106+ llvm::OpenMPIRBuilder::LocationDescription ompLoc (builder);
3107+ llvm::OpenMPIRBuilder *ompBuilder = moduleTranslation.getOpenMPBuilder ();
3108+
30693109 llvm::Value *ifCond = nullptr ;
30703110 if (Value ifVar = op.getIfExpr ())
30713111 ifCond = moduleTranslation.lookupValue (ifVar);
@@ -3088,12 +3128,12 @@ static LogicalResult
30883128convertOmpCancellationPoint (omp::CancellationPointOp op,
30893129 llvm::IRBuilderBase &builder,
30903130 LLVM::ModuleTranslation &moduleTranslation) {
3091- llvm::OpenMPIRBuilder::LocationDescription ompLoc (builder);
3092- llvm::OpenMPIRBuilder *ompBuilder = moduleTranslation.getOpenMPBuilder ();
3093-
30943131 if (failed (checkImplementationStatus (*op.getOperation ())))
30953132 return failure ();
30963133
3134+ llvm::OpenMPIRBuilder::LocationDescription ompLoc (builder);
3135+ llvm::OpenMPIRBuilder *ompBuilder = moduleTranslation.getOpenMPBuilder ();
3136+
30973137 llvm::omp::Directive cancelledDirective =
30983138 convertCancellationConstructType (op.getCancelDirective ());
30993139
0 commit comments