@@ -7628,62 +7628,64 @@ struct StepCompareFolder : public OpRewritePattern<StepOp> {
76287628 const int64_t stepSize = stepOp.getResult ().getType ().getNumElements ();
76297629
76307630 for (auto &use : stepOp.getResult ().getUses ()) {
7631- if (auto cmpiOp = dyn_cast<arith::CmpIOp>(use.getOwner ())) {
7632- const unsigned stepOperandNumber = use.getOperandNumber ();
7633-
7634- // arith.cmpi canonicalizer makes constants final operands.
7635- if (stepOperandNumber != 0 )
7636- continue ;
7637-
7638- // Check that operand 1 is a constant.
7639- unsigned constOperandNumber = 1 ;
7640- Value otherOperand = cmpiOp.getOperand (constOperandNumber);
7641- auto maybeConstValue = getConstantIntValue (otherOperand);
7642- if (!maybeConstValue.has_value ())
7643- continue ;
7631+ auto cmpiOp = dyn_cast<arith::CmpIOp>(use.getOwner ());
7632+ if (!cmpiOp)
7633+ continue ;
76447634
7645- int64_t constValue = maybeConstValue.value ();
7646- arith::CmpIPredicate pred = cmpiOp.getPredicate ();
7635+ // arith.cmpi canonicalizer makes constants final operands.
7636+ const unsigned stepOperandNumber = use.getOperandNumber ();
7637+ if (stepOperandNumber != 0 )
7638+ continue ;
76477639
7648- auto maybeSplat = [&]() -> std::optional< bool > {
7649- // Handle ult ( unsigned less than) and uge (unsigned greater equal).
7650- if ((pred == arith::CmpIPredicate::ult ||
7651- pred == arith::CmpIPredicate::uge) &&
7652- stepSize <= constValue )
7653- return pred == arith::CmpIPredicate::ult ;
7640+ // Check that operand 1 is a constant.
7641+ unsigned constOperandNumber = 1 ;
7642+ Value otherOperand = cmpiOp. getOperand (constOperandNumber);
7643+ auto maybeConstValue = getConstantIntValue (otherOperand);
7644+ if (!maybeConstValue. has_value () )
7645+ continue ;
76547646
7655- // Handle ule and ugt.
7656- if ((pred == arith::CmpIPredicate::ule ||
7657- pred == arith::CmpIPredicate::ugt) &&
7658- stepSize <= constValue + 1 )
7659- return pred == arith::CmpIPredicate::ule;
7647+ int64_t constValue = maybeConstValue.value ();
7648+ arith::CmpIPredicate pred = cmpiOp.getPredicate ();
7649+
7650+ auto maybeSplat = [&]() -> std::optional<bool > {
7651+ // Handle ult (unsigned less than) and uge (unsigned greater equal).
7652+ if ((pred == arith::CmpIPredicate::ult ||
7653+ pred == arith::CmpIPredicate::uge) &&
7654+ stepSize <= constValue)
7655+ return pred == arith::CmpIPredicate::ult;
7656+
7657+ // Handle ule and ugt.
7658+ if ((pred == arith::CmpIPredicate::ule ||
7659+ pred == arith::CmpIPredicate::ugt) &&
7660+ stepSize - 1 <= constValue) {
7661+ return pred == arith::CmpIPredicate::ule;
7662+ }
76607663
7661- // Handle eq and ne.
7662- if ((pred == arith::CmpIPredicate::eq ||
7663- pred == arith::CmpIPredicate::ne) &&
7664- stepSize <= constValue)
7665- return pred == arith::CmpIPredicate::ne;
7664+ // Handle eq and ne.
7665+ if ((pred == arith::CmpIPredicate::eq ||
7666+ pred == arith::CmpIPredicate::ne) &&
7667+ stepSize <= constValue)
7668+ return pred == arith::CmpIPredicate::ne;
76667669
7667- return std::optional< bool >() ;
7668- }();
7670+ return std::nullopt ;
7671+ }();
76697672
7670- if (!maybeSplat.has_value ())
7671- continue ;
7673+ if (!maybeSplat.has_value ())
7674+ continue ;
76727675
7673- rewriter.setInsertionPointAfter (cmpiOp);
7676+ rewriter.setInsertionPointAfter (cmpiOp);
76747677
7675- auto type = dyn_cast<VectorType>(cmpiOp.getResult ().getType ());
7676- if (!type)
7677- continue ;
7678+ auto type = dyn_cast<VectorType>(cmpiOp.getResult ().getType ());
7679+ if (!type)
7680+ continue ;
76787681
7679- DenseElementsAttr boolAttr =
7680- DenseElementsAttr::get (type, maybeSplat.value ());
7681- Value splat = mlir::arith::ConstantOp::create (rewriter, cmpiOp.getLoc (),
7682- type, boolAttr);
7682+ DenseElementsAttr boolAttr =
7683+ DenseElementsAttr::get (type, maybeSplat.value ());
7684+ Value splat = mlir::arith::ConstantOp::create (rewriter, cmpiOp.getLoc (),
7685+ type, boolAttr);
76837686
7684- rewriter.replaceOp (cmpiOp, splat);
7685- return success ();
7686- }
7687+ rewriter.replaceOp (cmpiOp, splat);
7688+ return success ();
76877689 }
76887690
76897691 return failure ();
0 commit comments