@@ -7602,6 +7602,120 @@ void StepOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,
76027602 setResultRanges (getResult (), result);
76037603}
76047604
7605+ namespace {
7606+
7607+ // / Constant fold vector.step when it is compared to constant with arith.cmpi
7608+ // / and the result is the same at all indices. For example, rewrite:
7609+ // /
7610+ // / %cst = arith.constant dense<7> : vector<3xindex>
7611+ // / %0 = vector.step : vector<3xindex>
7612+ // / %1 = arith.cmpi ugt, %0, %cst : vector<3xindex>
7613+ // /
7614+ // / as
7615+ // /
7616+ // / %out = arith.constant dense<false> : vector<3xi1>
7617+ // /
7618+ // / Above [0, 1, 2] > [7, 7, 7] => [false, false, false]. Because the result is
7619+ // / false at ALL indices we fold to the constant. false. If the constant was 1,
7620+ // / then [0, 1, 2] > [1, 1, 1] => [false, false, true] and we do not constant
7621+ // / fold, preferring the more 'compact' vector.step representation.
7622+ struct StepCompareFolder : public OpRewritePattern <StepOp> {
7623+ using OpRewritePattern::OpRewritePattern;
7624+
7625+ LogicalResult matchAndRewrite (StepOp stepOp,
7626+ PatternRewriter &rewriter) const override {
7627+
7628+ int64_t stepSize = stepOp.getResult ().getType ().getNumElements ();
7629+
7630+ for (auto &use : stepOp.getResult ().getUses ()) {
7631+ if (auto cmpiOp = dyn_cast<arith::CmpIOp>(use.getOwner ())) {
7632+ unsigned stepOperandNumber = use.getOperandNumber ();
7633+
7634+ // arith.cmpi has a canonicalizer to put constants on operand 1. Let it
7635+ // run first.
7636+ if (stepOperandNumber != 0 ) {
7637+ continue ;
7638+ }
7639+
7640+ // Check that operand 1 is a constant.
7641+ unsigned otherOperandNumber = 1 ;
7642+ Value otherOperand = cmpiOp.getOperand (otherOperandNumber);
7643+ auto maybeConstValue = getConstantIntValue (otherOperand);
7644+ if (!maybeConstValue.has_value ())
7645+ continue ;
7646+ int64_t constValue = maybeConstValue.value ();
7647+
7648+ arith::CmpIPredicate pred = cmpiOp.getPredicate ();
7649+
7650+ auto maybeSplat = [&]() -> std::optional<bool > {
7651+ // Handle ult (unsigned less than) and uge (unsigned greater equal).
7652+ // Examples where stepSize = constValue = 3, for the 4
7653+ // cases of [ult, uge] x [stepOperandNumber = 0, 1]:
7654+ //
7655+ // pred stepOperandNumber
7656+ // ==== =================
7657+ // ult 0 [0, 1, 2] < 3 ==> true.
7658+ // ult 1 3 < [0, 1, 2] ==> false.
7659+ // uge 0 [0, 1, 2] >= 3 ==> true.
7660+ // uge 1 3 >= [0, 1, 2] ==> false.
7661+ //
7662+ // If constValue is any smaller, the comparison is not constant.
7663+ if (pred == arith::CmpIPredicate::ult ||
7664+ pred == arith::CmpIPredicate::uge) {
7665+ if (stepSize <= constValue) {
7666+ return pred == arith::CmpIPredicate::ult;
7667+ }
7668+ }
7669+
7670+ // Handle ule and ugt.
7671+ //
7672+ // pred stepOperandNumber
7673+ // ==== =================
7674+ // ule 0 [0, 1, 2] <= 2 ==> true
7675+ // (stepSize = 3, constValue = 2).
7676+ if (pred == arith::CmpIPredicate::ule ||
7677+ pred == arith::CmpIPredicate::ugt) {
7678+ if (stepSize <= constValue + 1 ) {
7679+ return pred == arith::CmpIPredicate::ule;
7680+ }
7681+ }
7682+
7683+ // Handle eq and ne
7684+ if (pred == arith::CmpIPredicate::eq ||
7685+ pred == arith::CmpIPredicate::ne) {
7686+ if (stepSize <= constValue) {
7687+ return pred == arith::CmpIPredicate::ne;
7688+ }
7689+ }
7690+
7691+ return std::optional<bool >();
7692+ }();
7693+
7694+ if (!maybeSplat.has_value ())
7695+ continue ;
7696+
7697+ rewriter.setInsertionPointAfter (cmpiOp);
7698+ auto boolConst = mlir::arith::ConstantOp::create (
7699+ rewriter, cmpiOp.getLoc (),
7700+ rewriter.getBoolAttr (maybeSplat.value ()));
7701+ auto splat = vector::BroadcastOp::create (
7702+ rewriter, cmpiOp.getLoc (), cmpiOp.getResult ().getType (), boolConst);
7703+
7704+ rewriter.replaceOp (cmpiOp, splat.getResult ());
7705+ return success ();
7706+ }
7707+ }
7708+
7709+ return failure ();
7710+ }
7711+ };
7712+ } // namespace
7713+
7714+ void StepOp::getCanonicalizationPatterns (RewritePatternSet &results,
7715+ MLIRContext *context) {
7716+ results.add <StepCompareFolder>(context);
7717+ }
7718+
76057719// ===----------------------------------------------------------------------===//
76067720// Vector Masking Utilities
76077721// ===----------------------------------------------------------------------===//
0 commit comments