1313#include " PassDetail.h"
1414#include " mlir/Dialect/Func/IR/FuncOps.h"
1515#include " mlir/IR/PatternMatch.h"
16+ #include " mlir/IR/ValueRange.h"
1617#include " mlir/Support/LogicalResult.h"
1718#include " mlir/Transforms/DialectConversion.h"
1819#include " mlir/Transforms/GreedyPatternRewriteDriver.h"
@@ -868,14 +869,6 @@ class CIRTernaryOpFlattening : public mlir::OpRewritePattern<cir::TernaryOp> {
868869 auto *condBlock = rewriter.getInsertionBlock ();
869870 auto opPosition = rewriter.getInsertionPoint ();
870871 auto *remainingOpsBlock = rewriter.splitBlock (condBlock, opPosition);
871- llvm::SmallVector<mlir::Location, 2 > locs;
872- // Ternary result is optional, make sure to populate the location only
873- // when relevant.
874- if (op->getResultTypes ().size ())
875- locs.push_back (loc);
876- auto *continueBlock =
877- rewriter.createBlock (remainingOpsBlock, op->getResultTypes (), locs);
878- rewriter.create <cir::BrOp>(loc, remainingOpsBlock);
879872
880873 auto &trueRegion = op.getTrueRegion ();
881874 auto *trueBlock = &trueRegion.front ();
@@ -884,24 +877,29 @@ class CIRTernaryOpFlattening : public mlir::OpRewritePattern<cir::TernaryOp> {
884877 auto trueYieldOp = dyn_cast<cir::YieldOp>(trueTerminator);
885878
886879 rewriter.replaceOpWithNewOp <cir::BrOp>(trueYieldOp, trueYieldOp.getArgs (),
887- continueBlock );
888- rewriter.inlineRegionBefore (trueRegion, continueBlock );
880+ remainingOpsBlock );
881+ rewriter.inlineRegionBefore (trueRegion, remainingOpsBlock );
889882
890- auto *falseBlock = continueBlock;
891883 auto &falseRegion = op.getFalseRegion ();
884+ auto *falseBlock = &falseRegion.front ();
892885
893- falseBlock = &falseRegion.front ();
894886 mlir::Operation *falseTerminator = falseRegion.back ().getTerminator ();
895887 rewriter.setInsertionPointToEnd (&falseRegion.back ());
896888 auto falseYieldOp = dyn_cast<cir::YieldOp>(falseTerminator);
897889 rewriter.replaceOpWithNewOp <cir::BrOp>(falseYieldOp, falseYieldOp.getArgs (),
898- continueBlock );
899- rewriter.inlineRegionBefore (falseRegion, continueBlock );
890+ remainingOpsBlock );
891+ rewriter.inlineRegionBefore (falseRegion, remainingOpsBlock );
900892
901893 rewriter.setInsertionPointToEnd (condBlock);
902894 rewriter.create <cir::BrCondOp>(loc, op.getCond (), trueBlock, falseBlock);
903895
904- rewriter.replaceOp (op, continueBlock->getArguments ());
896+ if (auto rt = op.getResultTypes (); rt.size ()) {
897+ auto args = remainingOpsBlock->addArguments (rt, op.getLoc ());
898+ SmallVector<mlir::Value, 2 > values;
899+ llvm::copy (args, std::back_inserter (values));
900+ rewriter.replaceOpUsesWithinBlock (op, values, remainingOpsBlock);
901+ }
902+ rewriter.eraseOp (op);
905903
906904 // Ok, we're done!
907905 return mlir::success ();
0 commit comments