@@ -7629,62 +7629,64 @@ struct StepCompareFolder : public OpRewritePattern<StepOp> {
76297629 const int64_t stepSize = stepOp.getResult ().getType ().getNumElements ();
76307630
76317631 for (auto &use : stepOp.getResult ().getUses ()) {
7632- if (auto cmpiOp = dyn_cast<arith::CmpIOp>(use.getOwner ())) {
7633- const unsigned stepOperandNumber = use.getOperandNumber ();
7634-
7635- // arith.cmpi canonicalizer makes constants final operands.
7636- if (stepOperandNumber != 0 )
7637- continue ;
7638-
7639- // Check that operand 1 is a constant.
7640- unsigned constOperandNumber = 1 ;
7641- Value otherOperand = cmpiOp.getOperand (constOperandNumber);
7642- auto maybeConstValue = getConstantIntValue (otherOperand);
7643- if (!maybeConstValue.has_value ())
7644- continue ;
7632+ auto cmpiOp = dyn_cast<arith::CmpIOp>(use.getOwner ());
7633+ if (!cmpiOp)
7634+ continue ;
76457635
7646- int64_t constValue = maybeConstValue.value ();
7647- arith::CmpIPredicate pred = cmpiOp.getPredicate ();
7636+ // arith.cmpi canonicalizer makes constants final operands.
7637+ const unsigned stepOperandNumber = use.getOperandNumber ();
7638+ if (stepOperandNumber != 0 )
7639+ continue ;
76487640
7649- auto maybeSplat = [&]() -> std::optional< bool > {
7650- // Handle ult ( unsigned less than) and uge (unsigned greater equal).
7651- if ((pred == arith::CmpIPredicate::ult ||
7652- pred == arith::CmpIPredicate::uge) &&
7653- stepSize <= constValue )
7654- return pred == arith::CmpIPredicate::ult ;
7641+ // Check that operand 1 is a constant.
7642+ unsigned constOperandNumber = 1 ;
7643+ Value otherOperand = cmpiOp. getOperand (constOperandNumber);
7644+ auto maybeConstValue = getConstantIntValue (otherOperand);
7645+ if (!maybeConstValue. has_value () )
7646+ continue ;
76557647
7656- // Handle ule and ugt.
7657- if ((pred == arith::CmpIPredicate::ule ||
7658- pred == arith::CmpIPredicate::ugt) &&
7659- stepSize <= constValue + 1 )
7660- return pred == arith::CmpIPredicate::ule;
7648+ int64_t constValue = maybeConstValue.value ();
7649+ arith::CmpIPredicate pred = cmpiOp.getPredicate ();
7650+
7651+ auto maybeSplat = [&]() -> std::optional<bool > {
7652+ // Handle ult (unsigned less than) and uge (unsigned greater equal).
7653+ if ((pred == arith::CmpIPredicate::ult ||
7654+ pred == arith::CmpIPredicate::uge) &&
7655+ stepSize <= constValue)
7656+ return pred == arith::CmpIPredicate::ult;
7657+
7658+ // Handle ule and ugt.
7659+ if ((pred == arith::CmpIPredicate::ule ||
7660+ pred == arith::CmpIPredicate::ugt) &&
7661+ stepSize - 1 <= constValue) {
7662+ return pred == arith::CmpIPredicate::ule;
7663+ }
76617664
7662- // Handle eq and ne.
7663- if ((pred == arith::CmpIPredicate::eq ||
7664- pred == arith::CmpIPredicate::ne) &&
7665- stepSize <= constValue)
7666- return pred == arith::CmpIPredicate::ne;
7665+ // Handle eq and ne.
7666+ if ((pred == arith::CmpIPredicate::eq ||
7667+ pred == arith::CmpIPredicate::ne) &&
7668+ stepSize <= constValue)
7669+ return pred == arith::CmpIPredicate::ne;
76677670
7668- return std::optional< bool >() ;
7669- }();
7671+ return std::nullopt ;
7672+ }();
76707673
7671- if (!maybeSplat.has_value ())
7672- continue ;
7674+ if (!maybeSplat.has_value ())
7675+ continue ;
76737676
7674- rewriter.setInsertionPointAfter (cmpiOp);
7677+ rewriter.setInsertionPointAfter (cmpiOp);
76757678
7676- auto type = dyn_cast<VectorType>(cmpiOp.getResult ().getType ());
7677- if (!type)
7678- continue ;
7679+ auto type = dyn_cast<VectorType>(cmpiOp.getResult ().getType ());
7680+ if (!type)
7681+ continue ;
76797682
7680- DenseElementsAttr boolAttr =
7681- DenseElementsAttr::get (type, maybeSplat.value ());
7682- Value splat = mlir::arith::ConstantOp::create (rewriter, cmpiOp.getLoc (),
7683- type, boolAttr);
7683+ DenseElementsAttr boolAttr =
7684+ DenseElementsAttr::get (type, maybeSplat.value ());
7685+ Value splat = mlir::arith::ConstantOp::create (rewriter, cmpiOp.getLoc (),
7686+ type, boolAttr);
76847687
7685- rewriter.replaceOp (cmpiOp, splat);
7686- return success ();
7687- }
7688+ rewriter.replaceOp (cmpiOp, splat);
7689+ return success ();
76887690 }
76897691
76907692 return failure ();
0 commit comments