Skip to content

Commit b7f146c

Browse files
committed
add folder
Signed-off-by: James Newling <[email protected]>
1 parent 0aef9eb commit b7f146c

File tree

3 files changed

+494
-0
lines changed

3 files changed

+494
-0
lines changed

mlir/include/mlir/Dialect/Vector/IR/VectorOps.td

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2999,6 +2999,7 @@ def Vector_StepOp : Vector_Op<"step", [
29992999
}];
30003000
let results = (outs VectorOfRankAndType<[1], [Index]>:$result);
30013001
let assemblyFormat = "attr-dict `:` type($result)";
3002+
let hasCanonicalizer = 1;
30023003
}
30033004

30043005
def Vector_YieldOp : Vector_Op<"yield", [

mlir/lib/Dialect/Vector/IR/VectorOps.cpp

Lines changed: 114 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7602,6 +7602,120 @@ void StepOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,
76027602
setResultRanges(getResult(), result);
76037603
}
76047604

7605+
namespace {
7606+
7607+
/// Constant fold vector.step when it is compared to constant with arith.cmpi
7608+
/// and the result is the same at all indices. For example, rewrite:
7609+
///
7610+
/// %cst = arith.constant dense<7> : vector<3xindex>
7611+
/// %0 = vector.step : vector<3xindex>
7612+
/// %1 = arith.cmpi ugt, %0, %cst : vector<3xindex>
7613+
///
7614+
/// as
7615+
///
7616+
/// %out = arith.constant dense<false> : vector<3xi1>
7617+
///
7618+
/// Above [0, 1, 2] > [7, 7, 7] => [false, false, false]. Because the result is
7619+
/// false at ALL indices we fold to the constant. false. If the constant was 1,
7620+
/// then [0, 1, 2] > [1, 1, 1] => [false, false, true] and we do not constant
7621+
/// fold, preferring the more 'compact' vector.step representation.
7622+
struct StepCompareFolder : public OpRewritePattern<StepOp> {
7623+
using OpRewritePattern::OpRewritePattern;
7624+
7625+
LogicalResult matchAndRewrite(StepOp stepOp,
7626+
PatternRewriter &rewriter) const override {
7627+
7628+
int64_t stepSize = stepOp.getResult().getType().getNumElements();
7629+
7630+
for (auto &use : stepOp.getResult().getUses()) {
7631+
if (auto cmpiOp = dyn_cast<arith::CmpIOp>(use.getOwner())) {
7632+
unsigned stepOperandNumber = use.getOperandNumber();
7633+
7634+
// arith.cmpi has a canonicalizer to put constants on operand 1. Let it
7635+
// run first.
7636+
if (stepOperandNumber != 0) {
7637+
continue;
7638+
}
7639+
7640+
// Check that operand 1 is a constant.
7641+
unsigned otherOperandNumber = 1;
7642+
Value otherOperand = cmpiOp.getOperand(otherOperandNumber);
7643+
auto maybeConstValue = getConstantIntValue(otherOperand);
7644+
if (!maybeConstValue.has_value())
7645+
continue;
7646+
int64_t constValue = maybeConstValue.value();
7647+
7648+
arith::CmpIPredicate pred = cmpiOp.getPredicate();
7649+
7650+
auto maybeSplat = [&]() -> std::optional<bool> {
7651+
// Handle ult (unsigned less than) and uge (unsigned greater equal).
7652+
// Examples where stepSize = constValue = 3, for the 4
7653+
// cases of [ult, uge] x [stepOperandNumber = 0, 1]:
7654+
//
7655+
// pred stepOperandNumber
7656+
// ==== =================
7657+
// ult 0 [0, 1, 2] < 3 ==> true.
7658+
// ult 1 3 < [0, 1, 2] ==> false.
7659+
// uge 0 [0, 1, 2] >= 3 ==> true.
7660+
// uge 1 3 >= [0, 1, 2] ==> false.
7661+
//
7662+
// If constValue is any smaller, the comparison is not constant.
7663+
if (pred == arith::CmpIPredicate::ult ||
7664+
pred == arith::CmpIPredicate::uge) {
7665+
if (stepSize <= constValue) {
7666+
return pred == arith::CmpIPredicate::ult;
7667+
}
7668+
}
7669+
7670+
// Handle ule and ugt.
7671+
//
7672+
// pred stepOperandNumber
7673+
// ==== =================
7674+
// ule 0 [0, 1, 2] <= 2 ==> true
7675+
// (stepSize = 3, constValue = 2).
7676+
if (pred == arith::CmpIPredicate::ule ||
7677+
pred == arith::CmpIPredicate::ugt) {
7678+
if (stepSize <= constValue + 1) {
7679+
return pred == arith::CmpIPredicate::ule;
7680+
}
7681+
}
7682+
7683+
// Handle eq and ne
7684+
if (pred == arith::CmpIPredicate::eq ||
7685+
pred == arith::CmpIPredicate::ne) {
7686+
if (stepSize <= constValue) {
7687+
return pred == arith::CmpIPredicate::ne;
7688+
}
7689+
}
7690+
7691+
return std::optional<bool>();
7692+
}();
7693+
7694+
if (!maybeSplat.has_value())
7695+
continue;
7696+
7697+
rewriter.setInsertionPointAfter(cmpiOp);
7698+
auto boolConst = mlir::arith::ConstantOp::create(
7699+
rewriter, cmpiOp.getLoc(),
7700+
rewriter.getBoolAttr(maybeSplat.value()));
7701+
auto splat = vector::BroadcastOp::create(
7702+
rewriter, cmpiOp.getLoc(), cmpiOp.getResult().getType(), boolConst);
7703+
7704+
rewriter.replaceOp(cmpiOp, splat.getResult());
7705+
return success();
7706+
}
7707+
}
7708+
7709+
return failure();
7710+
}
7711+
};
7712+
} // namespace
7713+
7714+
void StepOp::getCanonicalizationPatterns(RewritePatternSet &results,
7715+
MLIRContext *context) {
7716+
results.add<StepCompareFolder>(context);
7717+
}
7718+
76057719
//===----------------------------------------------------------------------===//
76067720
// Vector Masking Utilities
76077721
//===----------------------------------------------------------------------===//

0 commit comments

Comments
 (0)