@@ -254,10 +254,61 @@ class CIRLoopOpInterfaceFlattening
254254 }
255255};
256256
257+ class CIRTernaryOpFlattening : public mlir ::OpRewritePattern<cir::TernaryOp> {
258+ public:
259+ using OpRewritePattern<cir::TernaryOp>::OpRewritePattern;
260+
261+ mlir::LogicalResult
262+ matchAndRewrite (cir::TernaryOp op,
263+ mlir::PatternRewriter &rewriter) const override {
264+ Location loc = op->getLoc ();
265+ Block *condBlock = rewriter.getInsertionBlock ();
266+ Block::iterator opPosition = rewriter.getInsertionPoint ();
267+ Block *remainingOpsBlock = rewriter.splitBlock (condBlock, opPosition);
268+ llvm::SmallVector<mlir::Location, 2 > locs;
269+ // Ternary result is optional, make sure to populate the location only
270+ // when relevant.
271+ if (op->getResultTypes ().size ())
272+ locs.push_back (loc);
273+ auto *continueBlock =
274+ rewriter.createBlock (remainingOpsBlock, op->getResultTypes (), locs);
275+ rewriter.create <cir::BrOp>(loc, remainingOpsBlock);
276+
277+ Region &trueRegion = op.getTrueRegion ();
278+ Block *trueBlock = &trueRegion.front ();
279+ mlir::Operation *trueTerminator = trueRegion.back ().getTerminator ();
280+ rewriter.setInsertionPointToEnd (&trueRegion.back ());
281+ auto trueYieldOp = dyn_cast<cir::YieldOp>(trueTerminator);
282+
283+ rewriter.replaceOpWithNewOp <cir::BrOp>(trueYieldOp, trueYieldOp.getArgs (),
284+ continueBlock);
285+ rewriter.inlineRegionBefore (trueRegion, continueBlock);
286+
287+ Block *falseBlock = continueBlock;
288+ Region &falseRegion = op.getFalseRegion ();
289+
290+ falseBlock = &falseRegion.front ();
291+ mlir::Operation *falseTerminator = falseRegion.back ().getTerminator ();
292+ rewriter.setInsertionPointToEnd (&falseRegion.back ());
293+ cir::YieldOp falseYieldOp = dyn_cast<cir::YieldOp>(falseTerminator);
294+ rewriter.replaceOpWithNewOp <cir::BrOp>(falseYieldOp, falseYieldOp.getArgs (),
295+ continueBlock);
296+ rewriter.inlineRegionBefore (falseRegion, continueBlock);
297+
298+ rewriter.setInsertionPointToEnd (condBlock);
299+ rewriter.create <cir::BrCondOp>(loc, op.getCond (), trueBlock, falseBlock);
300+
301+ rewriter.replaceOp (op, continueBlock->getArguments ());
302+
303+ // Ok, we're done!
304+ return mlir::success ();
305+ }
306+ };
307+
257308void populateFlattenCFGPatterns (RewritePatternSet &patterns) {
258- patterns
259- . add <CIRIfFlattening, CIRLoopOpInterfaceFlattening, CIRScopeOpFlattening>(
260- patterns.getContext ());
309+ patterns. add <CIRIfFlattening, CIRLoopOpInterfaceFlattening,
310+ CIRScopeOpFlattening, CIRTernaryOpFlattening >(
311+ patterns.getContext ());
261312}
262313
263314void CIRFlattenCFGPass::runOnOperation () {
@@ -269,9 +320,8 @@ void CIRFlattenCFGPass::runOnOperation() {
269320 getOperation ()->walk <mlir::WalkOrder::PostOrder>([&](Operation *op) {
270321 assert (!cir::MissingFeatures::ifOp ());
271322 assert (!cir::MissingFeatures::switchOp ());
272- assert (!cir::MissingFeatures::ternaryOp ());
273323 assert (!cir::MissingFeatures::tryOp ());
274- if (isa<IfOp, ScopeOp, LoopOpInterface>(op))
324+ if (isa<IfOp, ScopeOp, LoopOpInterface, TernaryOp >(op))
275325 ops.push_back (op);
276326 });
277327
0 commit comments