@@ -4297,33 +4297,42 @@ void IndexSwitchOp::getRegionInvocationBounds(
42974297 bounds.emplace_back (/* lb=*/ 0 , /* ub=*/ i == liveIndex);
42984298}
42994299
4300- LogicalResult IndexSwitchOp::fold (FoldAdaptor adaptor,
4301- SmallVectorImpl<OpFoldResult> &results) {
4302- std::optional<int64_t > maybeCst = getConstantIntValue (getArg ());
4303- if (!maybeCst.has_value ())
4304- return failure ();
4305- int64_t cst = *maybeCst;
4306- int64_t caseIdx, e = getNumCases ();
4307- for (caseIdx = 0 ; caseIdx < e; ++caseIdx) {
4308- if (cst == getCases ()[caseIdx])
4309- break ;
4310- }
4300+ struct FoldConstantCase : OpRewritePattern<scf::IndexSwitchOp> {
4301+ using OpRewritePattern<scf::IndexSwitchOp>::OpRewritePattern;
43114302
4312- Region &r = (caseIdx < getNumCases ()) ? getCaseRegions ()[caseIdx]
4313- : getDefaultRegion ();
4314- Block &source = r.front ();
4315- results.assign (source.getTerminator ()->getOperands ().begin (),
4316- source.getTerminator ()->getOperands ().end ());
4303+ LogicalResult matchAndRewrite (scf::IndexSwitchOp op,
4304+ PatternRewriter &rewriter) const override {
4305+ // If `op.getArg()` is a constant, select the region that matches with
4306+ // the constant value. Use the default region if no matche is found.
4307+ std::optional<int64_t > maybeCst = getConstantIntValue (op.getArg ());
4308+ if (!maybeCst.has_value ())
4309+ return failure ();
4310+ int64_t cst = *maybeCst;
4311+ int64_t caseIdx, e = op.getNumCases ();
4312+ for (caseIdx = 0 ; caseIdx < e; ++caseIdx) {
4313+ if (cst == op.getCases ()[caseIdx])
4314+ break ;
4315+ }
43174316
4318- Block *pDestination = (*this )->getBlock ();
4319- if (!pDestination)
4320- return failure ();
4321- Block::iterator insertionPoint = (*this )->getIterator ();
4322- pDestination->getOperations ().splice (insertionPoint, source.getOperations (),
4323- source.getOperations ().begin (),
4324- std::prev (source.getOperations ().end ()));
4317+ Region &r = (caseIdx < op.getNumCases ()) ? op.getCaseRegions ()[caseIdx]
4318+ : op.getDefaultRegion ();
4319+ Block &source = r.front ();
4320+ Operation *terminator = source.getTerminator ();
4321+ SmallVector<Value> results = terminator->getOperands ();
43254322
4326- return success ();
4323+ rewriter.inlineBlockBefore (&source, op);
4324+ rewriter.eraseOp (terminator);
4325+ // Repalce the operation with a potentially empty list of results.
4326+ // Fold mechanism doesn't support the case where the result list is empty.
4327+ rewriter.replaceOp (op, results);
4328+
4329+ return success ();
4330+ }
4331+ };
4332+
4333+ void IndexSwitchOp::getCanonicalizationPatterns (RewritePatternSet &results,
4334+ MLIRContext *context) {
4335+ results.add <FoldConstantCase>(context);
43274336}
43284337
43294338// ===----------------------------------------------------------------------===//
0 commit comments