@@ -7603,89 +7603,66 @@ void StepOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,
76037603
76047604namespace {
76057605
7606- // / Constant fold vector.step when it is compared to constant with arith.cmpi
7607- // / and the result is the same at all indices. For example, rewrite:
7606+ // / Fold `vector.step -> arith.cmpi` when the step value is compared to a
7607+ // / constant large enough such that the result is the same at all indices.
7608+ // /
7609+ // / For example, rewrite the 'greater than' comparison below,
76087610// /
76097611// / %cst = arith.constant dense<7> : vector<3xindex>
7610- // / %0 = vector.step : vector<3xindex>
7611- // / %1 = arith.cmpi ugt, %0 , %cst : vector<3xindex>
7612+ // / %stp = vector.step : vector<3xindex>
7613+ // / %out = arith.cmpi ugt, %stp , %cst : vector<3xindex>
76127614// /
7613- // / as
7615+ // / as,
76147616// /
7615- // / %out = arith.constant dense<false> : vector<3xi1>
7617+ // / %out = arith.constant dense<false> : vector<3xi1>.
76167618// /
76177619// / Above [0, 1, 2] > [7, 7, 7] => [false, false, false]. Because the result is
7618- // / false at ALL indices we fold to the constant. false. If the constant was 1,
7619- // / then [0, 1, 2] > [1, 1, 1] => [false, false, true] and we do not constant
7620- // / fold, preferring the more 'compact' vector.step representation.
7620+ // / false at ALL indices we fold. If the constant was 1, then
7621+ // / [0, 1, 2] > [1, 1, 1] => [false, false, true] and we do fold, conservatively
7622+ // / preferring the 'compact' vector.step representation.
76217623struct StepCompareFolder : public OpRewritePattern <StepOp> {
76227624 using OpRewritePattern::OpRewritePattern;
76237625
76247626 LogicalResult matchAndRewrite (StepOp stepOp,
76257627 PatternRewriter &rewriter) const override {
7626-
7627- int64_t stepSize = stepOp.getResult ().getType ().getNumElements ();
7628+ const int64_t stepSize = stepOp.getResult ().getType ().getNumElements ();
76287629
76297630 for (auto &use : stepOp.getResult ().getUses ()) {
76307631 if (auto cmpiOp = dyn_cast<arith::CmpIOp>(use.getOwner ())) {
7631- unsigned stepOperandNumber = use.getOperandNumber ();
7632+ const unsigned stepOperandNumber = use.getOperandNumber ();
76327633
7633- // arith.cmpi has a canonicalizer to put constants on operand 1. Let it
7634- // run first.
7635- if (stepOperandNumber != 0 ) {
7634+ // arith.cmpi canonicalizer makes constants final operands.
7635+ if (stepOperandNumber != 0 )
76367636 continue ;
7637- }
76387637
76397638 // Check that operand 1 is a constant.
7640- unsigned otherOperandNumber = 1 ;
7641- Value otherOperand = cmpiOp.getOperand (otherOperandNumber );
7639+ unsigned constOperandNumber = 1 ;
7640+ Value otherOperand = cmpiOp.getOperand (constOperandNumber );
76427641 auto maybeConstValue = getConstantIntValue (otherOperand);
76437642 if (!maybeConstValue.has_value ())
76447643 continue ;
7645- int64_t constValue = maybeConstValue.value ();
76467644
7645+ int64_t constValue = maybeConstValue.value ();
76477646 arith::CmpIPredicate pred = cmpiOp.getPredicate ();
76487647
76497648 auto maybeSplat = [&]() -> std::optional<bool > {
76507649 // Handle ult (unsigned less than) and uge (unsigned greater equal).
7651- // Examples where stepSize = constValue = 3, for the 4
7652- // cases of [ult, uge] x [stepOperandNumber = 0, 1]:
7653- //
7654- // pred stepOperandNumber
7655- // ==== =================
7656- // ult 0 [0, 1, 2] < 3 ==> true.
7657- // ult 1 3 < [0, 1, 2] ==> false.
7658- // uge 0 [0, 1, 2] >= 3 ==> true.
7659- // uge 1 3 >= [0, 1, 2] ==> false.
7660- //
7661- // If constValue is any smaller, the comparison is not constant.
7662- if (pred == arith::CmpIPredicate::ult ||
7663- pred == arith::CmpIPredicate::uge) {
7664- if (stepSize <= constValue) {
7665- return pred == arith::CmpIPredicate::ult;
7666- }
7667- }
7650+ if ((pred == arith::CmpIPredicate::ult ||
7651+ pred == arith::CmpIPredicate::uge) &&
7652+ stepSize <= constValue)
7653+ return pred == arith::CmpIPredicate::ult;
76687654
76697655 // Handle ule and ugt.
7670- //
7671- // pred stepOperandNumber
7672- // ==== =================
7673- // ule 0 [0, 1, 2] <= 2 ==> true
7674- // (stepSize = 3, constValue = 2).
7675- if (pred == arith::CmpIPredicate::ule ||
7676- pred == arith::CmpIPredicate::ugt) {
7677- if (stepSize <= constValue + 1 ) {
7678- return pred == arith::CmpIPredicate::ule;
7679- }
7680- }
7656+ if ((pred == arith::CmpIPredicate::ule ||
7657+ pred == arith::CmpIPredicate::ugt) &&
7658+ stepSize <= constValue + 1 )
7659+ return pred == arith::CmpIPredicate::ule;
76817660
7682- // Handle eq and ne
7683- if (pred == arith::CmpIPredicate::eq ||
7684- pred == arith::CmpIPredicate::ne) {
7685- if (stepSize <= constValue) {
7686- return pred == arith::CmpIPredicate::ne;
7687- }
7688- }
7661+ // Handle eq and ne.
7662+ if ((pred == arith::CmpIPredicate::eq ||
7663+ pred == arith::CmpIPredicate::ne) &&
7664+ stepSize <= constValue)
7665+ return pred == arith::CmpIPredicate::ne;
76897666
76907667 return std::optional<bool >();
76917668 }();
@@ -7694,13 +7671,17 @@ struct StepCompareFolder : public OpRewritePattern<StepOp> {
76947671 continue ;
76957672
76967673 rewriter.setInsertionPointAfter (cmpiOp);
7697- auto boolConst = mlir::arith::ConstantOp::create (
7698- rewriter, cmpiOp.getLoc (),
7699- rewriter.getBoolAttr (maybeSplat.value ()));
7700- auto splat = vector::BroadcastOp::create (
7701- rewriter, cmpiOp.getLoc (), cmpiOp.getResult ().getType (), boolConst);
77027674
7703- rewriter.replaceOp (cmpiOp, splat.getResult ());
7675+ auto type = dyn_cast<VectorType>(cmpiOp.getResult ().getType ());
7676+ if (!type)
7677+ continue ;
7678+
7679+ DenseElementsAttr boolAttr =
7680+ DenseElementsAttr::get (type, maybeSplat.value ());
7681+ Value splat = mlir::arith::ConstantOp::create (rewriter, cmpiOp.getLoc (),
7682+ type, boolAttr);
7683+
7684+ rewriter.replaceOp (cmpiOp, splat);
77047685 return success ();
77057686 }
77067687 }
0 commit comments