@@ -7604,89 +7604,66 @@ void StepOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,
76047604
76057605namespace {
76067606
7607- // / Constant fold vector.step when it is compared to constant with arith.cmpi
7608- // / and the result is the same at all indices. For example, rewrite:
7607+ // / Fold `vector.step -> arith.cmpi` when the step value is compared to a
7608+ // / constant large enough such that the result is the same at all indices.
7609+ // /
7610+ // / For example, rewrite the 'greater than' comparison below,
76097611// /
76107612// / %cst = arith.constant dense<7> : vector<3xindex>
7611- // / %0 = vector.step : vector<3xindex>
7612- // / %1 = arith.cmpi ugt, %0 , %cst : vector<3xindex>
7613+ // / %stp = vector.step : vector<3xindex>
7614+ // / %out = arith.cmpi ugt, %stp , %cst : vector<3xindex>
76137615// /
7614- // / as
7616+ // / as,
76157617// /
7616- // / %out = arith.constant dense<false> : vector<3xi1>
7618+ // / %out = arith.constant dense<false> : vector<3xi1>.
76177619// /
76187620// / Above [0, 1, 2] > [7, 7, 7] => [false, false, false]. Because the result is
7619- // / false at ALL indices we fold to the constant. false. If the constant was 1,
7620- // / then [0, 1, 2] > [1, 1, 1] => [false, false, true] and we do not constant
7621- // / fold, preferring the more 'compact' vector.step representation.
7621+ // / false at ALL indices we fold. If the constant was 1, then
7622+ // / [0, 1, 2] > [1, 1, 1] => [false, false, true] and we do fold, conservatively
7623+ // / preferring the 'compact' vector.step representation.
76227624struct StepCompareFolder : public OpRewritePattern <StepOp> {
76237625 using OpRewritePattern::OpRewritePattern;
76247626
76257627 LogicalResult matchAndRewrite (StepOp stepOp,
76267628 PatternRewriter &rewriter) const override {
7627-
7628- int64_t stepSize = stepOp.getResult ().getType ().getNumElements ();
7629+ const int64_t stepSize = stepOp.getResult ().getType ().getNumElements ();
76297630
76307631 for (auto &use : stepOp.getResult ().getUses ()) {
76317632 if (auto cmpiOp = dyn_cast<arith::CmpIOp>(use.getOwner ())) {
7632- unsigned stepOperandNumber = use.getOperandNumber ();
7633+ const unsigned stepOperandNumber = use.getOperandNumber ();
76337634
7634- // arith.cmpi has a canonicalizer to put constants on operand 1. Let it
7635- // run first.
7636- if (stepOperandNumber != 0 ) {
7635+ // arith.cmpi canonicalizer makes constants final operands.
7636+ if (stepOperandNumber != 0 )
76377637 continue ;
7638- }
76397638
76407639 // Check that operand 1 is a constant.
7641- unsigned otherOperandNumber = 1 ;
7642- Value otherOperand = cmpiOp.getOperand (otherOperandNumber );
7640+ unsigned constOperandNumber = 1 ;
7641+ Value otherOperand = cmpiOp.getOperand (constOperandNumber );
76437642 auto maybeConstValue = getConstantIntValue (otherOperand);
76447643 if (!maybeConstValue.has_value ())
76457644 continue ;
7646- int64_t constValue = maybeConstValue.value ();
76477645
7646+ int64_t constValue = maybeConstValue.value ();
76487647 arith::CmpIPredicate pred = cmpiOp.getPredicate ();
76497648
76507649 auto maybeSplat = [&]() -> std::optional<bool > {
76517650 // Handle ult (unsigned less than) and uge (unsigned greater equal).
7652- // Examples where stepSize = constValue = 3, for the 4
7653- // cases of [ult, uge] x [stepOperandNumber = 0, 1]:
7654- //
7655- // pred stepOperandNumber
7656- // ==== =================
7657- // ult 0 [0, 1, 2] < 3 ==> true.
7658- // ult 1 3 < [0, 1, 2] ==> false.
7659- // uge 0 [0, 1, 2] >= 3 ==> true.
7660- // uge 1 3 >= [0, 1, 2] ==> false.
7661- //
7662- // If constValue is any smaller, the comparison is not constant.
7663- if (pred == arith::CmpIPredicate::ult ||
7664- pred == arith::CmpIPredicate::uge) {
7665- if (stepSize <= constValue) {
7666- return pred == arith::CmpIPredicate::ult;
7667- }
7668- }
7651+ if ((pred == arith::CmpIPredicate::ult ||
7652+ pred == arith::CmpIPredicate::uge) &&
7653+ stepSize <= constValue)
7654+ return pred == arith::CmpIPredicate::ult;
76697655
76707656 // Handle ule and ugt.
7671- //
7672- // pred stepOperandNumber
7673- // ==== =================
7674- // ule 0 [0, 1, 2] <= 2 ==> true
7675- // (stepSize = 3, constValue = 2).
7676- if (pred == arith::CmpIPredicate::ule ||
7677- pred == arith::CmpIPredicate::ugt) {
7678- if (stepSize <= constValue + 1 ) {
7679- return pred == arith::CmpIPredicate::ule;
7680- }
7681- }
7657+ if ((pred == arith::CmpIPredicate::ule ||
7658+ pred == arith::CmpIPredicate::ugt) &&
7659+ stepSize <= constValue + 1 )
7660+ return pred == arith::CmpIPredicate::ule;
76827661
7683- // Handle eq and ne
7684- if (pred == arith::CmpIPredicate::eq ||
7685- pred == arith::CmpIPredicate::ne) {
7686- if (stepSize <= constValue) {
7687- return pred == arith::CmpIPredicate::ne;
7688- }
7689- }
7662+ // Handle eq and ne.
7663+ if ((pred == arith::CmpIPredicate::eq ||
7664+ pred == arith::CmpIPredicate::ne) &&
7665+ stepSize <= constValue)
7666+ return pred == arith::CmpIPredicate::ne;
76907667
76917668 return std::optional<bool >();
76927669 }();
@@ -7695,13 +7672,17 @@ struct StepCompareFolder : public OpRewritePattern<StepOp> {
76957672 continue ;
76967673
76977674 rewriter.setInsertionPointAfter (cmpiOp);
7698- auto boolConst = mlir::arith::ConstantOp::create (
7699- rewriter, cmpiOp.getLoc (),
7700- rewriter.getBoolAttr (maybeSplat.value ()));
7701- auto splat = vector::BroadcastOp::create (
7702- rewriter, cmpiOp.getLoc (), cmpiOp.getResult ().getType (), boolConst);
77037675
7704- rewriter.replaceOp (cmpiOp, splat.getResult ());
7676+ auto type = dyn_cast<VectorType>(cmpiOp.getResult ().getType ());
7677+ if (!type)
7678+ continue ;
7679+
7680+ DenseElementsAttr boolAttr =
7681+ DenseElementsAttr::get (type, maybeSplat.value ());
7682+ Value splat = mlir::arith::ConstantOp::create (rewriter, cmpiOp.getLoc (),
7683+ type, boolAttr);
7684+
7685+ rewriter.replaceOp (cmpiOp, splat);
77057686 return success ();
77067687 }
77077688 }
0 commit comments