-
Notifications
You must be signed in to change notification settings - Fork 14.9k
[mlir][Vector] Fold vector.step compared to constant #161615
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
@llvm/pr-subscribers-mlir-vector @llvm/pr-subscribers-mlir Author: James Newling (newling) ChangesThis PR adds a canonicalizer to vector.step that folds vector.step iff the result of the fold is a splat value. An alternative would be to always constant fold it, but that might result in some very large/cumbersome constants. I do wonder if vector.step might be better represented as some sort of attribute in the arith dialect, like %step = arith.constant iota<32> : vector<32xindex>. Full diff: https://github.com/llvm/llvm-project/pull/161615.diff 3 Files Affected:
diff --git a/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td b/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td
index 252c0b72456df..dbb5d0f659159 100644
--- a/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td
+++ b/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td
@@ -3045,6 +3045,7 @@ def Vector_StepOp : Vector_Op<"step", [
}];
let results = (outs VectorOfRankAndType<[1], [Index]>:$result);
let assemblyFormat = "attr-dict `:` type($result)";
+ let hasCanonicalizer = 1;
}
def Vector_YieldOp : Vector_Op<"yield", [
diff --git a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
index eb4686997c1b9..306be186308b0 100644
--- a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
+++ b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
@@ -7524,6 +7524,101 @@ void StepOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,
setResultRanges(getResult(), result);
}
+namespace {
+
+/// Fold `vector.step -> arith.cmpi` when the step value is compared to a
+/// constant large enough such that the result is the same at all indices.
+///
+/// For example, rewrite the 'greater than' comparison below,
+///
+/// %cst = arith.constant dense<7> : vector<3xindex>
+/// %stp = vector.step : vector<3xindex>
+/// %out = arith.cmpi ugt, %stp, %cst : vector<3xindex>
+///
+/// as,
+///
+/// %out = arith.constant dense<false> : vector<3xi1>.
+///
+/// Above [0, 1, 2] > [7, 7, 7] => [false, false, false]. Because the result is
+/// false at ALL indices we fold. If the constant was 1, then
+/// [0, 1, 2] > [1, 1, 1] => [false, false, true] and we do fold, conservatively
+/// preferring the 'compact' vector.step representation.
+struct StepCompareFolder : public OpRewritePattern<StepOp> {
+ using OpRewritePattern::OpRewritePattern;
+
+ LogicalResult matchAndRewrite(StepOp stepOp,
+ PatternRewriter &rewriter) const override {
+ const int64_t stepSize = stepOp.getResult().getType().getNumElements();
+
+ for (auto &use : stepOp.getResult().getUses()) {
+ if (auto cmpiOp = dyn_cast<arith::CmpIOp>(use.getOwner())) {
+ const unsigned stepOperandNumber = use.getOperandNumber();
+
+ // arith.cmpi canonicalizer makes constants final operands.
+ if (stepOperandNumber != 0)
+ continue;
+
+ // Check that operand 1 is a constant.
+ unsigned constOperandNumber = 1;
+ Value otherOperand = cmpiOp.getOperand(constOperandNumber);
+ auto maybeConstValue = getConstantIntValue(otherOperand);
+ if (!maybeConstValue.has_value())
+ continue;
+
+ int64_t constValue = maybeConstValue.value();
+ arith::CmpIPredicate pred = cmpiOp.getPredicate();
+
+ auto maybeSplat = [&]() -> std::optional<bool> {
+ // Handle ult (unsigned less than) and uge (unsigned greater equal).
+ if ((pred == arith::CmpIPredicate::ult ||
+ pred == arith::CmpIPredicate::uge) &&
+ stepSize <= constValue)
+ return pred == arith::CmpIPredicate::ult;
+
+ // Handle ule and ugt.
+ if ((pred == arith::CmpIPredicate::ule ||
+ pred == arith::CmpIPredicate::ugt) &&
+ stepSize <= constValue + 1)
+ return pred == arith::CmpIPredicate::ule;
+
+ // Handle eq and ne.
+ if ((pred == arith::CmpIPredicate::eq ||
+ pred == arith::CmpIPredicate::ne) &&
+ stepSize <= constValue)
+ return pred == arith::CmpIPredicate::ne;
+
+ return std::optional<bool>();
+ }();
+
+ if (!maybeSplat.has_value())
+ continue;
+
+ rewriter.setInsertionPointAfter(cmpiOp);
+
+ auto type = dyn_cast<VectorType>(cmpiOp.getResult().getType());
+ if (!type)
+ continue;
+
+ DenseElementsAttr boolAttr =
+ DenseElementsAttr::get(type, maybeSplat.value());
+ Value splat = mlir::arith::ConstantOp::create(rewriter, cmpiOp.getLoc(),
+ type, boolAttr);
+
+ rewriter.replaceOp(cmpiOp, splat);
+ return success();
+ }
+ }
+
+ return failure();
+ }
+};
+} // namespace
+
+void StepOp::getCanonicalizationPatterns(RewritePatternSet &results,
+ MLIRContext *context) {
+ results.add<StepCompareFolder>(context);
+}
+
//===----------------------------------------------------------------------===//
// Vector Masking Utilities
//===----------------------------------------------------------------------===//
diff --git a/mlir/test/Dialect/Vector/canonicalize/vector-step.mlir b/mlir/test/Dialect/Vector/canonicalize/vector-step.mlir
new file mode 100644
index 0000000000000..effeb3d9c093a
--- /dev/null
+++ b/mlir/test/Dialect/Vector/canonicalize/vector-step.mlir
@@ -0,0 +1,379 @@
+// RUN: mlir-opt %s -canonicalize="test-convergence" -split-input-file -allow-unregistered-dialect | FileCheck %s
+
+///===----------------------------------------------===//
+/// Tests of `StepCompareFolder`
+///===----------------------------------------------===//
+
+
+///===------------------------------------===//
+/// Tests of `ugt` (unsigned greater than)
+///===------------------------------------===//
+
+// CHECK-LABEL: @check_ugt_constant_3_lhs
+// CHECK: %[[CST:.*]] = arith.constant dense<true> : vector<3xi1>
+// CHECK: return %[[CST]] : vector<3xi1>
+func.func @check_ugt_constant_3_lhs() -> vector<3xi1> {
+ %cst = arith.constant dense<3> : vector<3xindex>
+ %0 = vector.step : vector<3xindex>
+ // 3 > [0, 1, 2] => true
+ %1 = arith.cmpi ugt, %cst, %0 : vector<3xindex>
+ return %1 : vector<3xi1>
+}
+
+// -----
+
+// CHECK-LABEL: @check_ugt_constant_2_lhs
+// CHECK: %[[CMP:.*]] = arith.cmpi
+// CHECK: return %[[CMP]]
+func.func @check_ugt_constant_2_lhs() -> vector<3xi1> {
+ %cst = arith.constant dense<2> : vector<3xindex>
+ %0 = vector.step : vector<3xindex>
+ // 2 > [0, 1, 2] => not constant
+ %1 = arith.cmpi ugt, %cst, %0 : vector<3xindex>
+ return %1 : vector<3xi1>
+}
+
+// -----
+
+// CHECK-LABEL: @check_ugt_constant_1_lhs
+// CHECK: %[[CMP:.*]] = arith.cmpi
+// CHECK: return %[[CMP]]
+func.func @check_ugt_constant_1_lhs() -> vector<3xi1> {
+ %cst = arith.constant dense<1> : vector<3xindex>
+ %0 = vector.step : vector<3xindex>
+ // 1 > [0, 1, 2] => not constant
+ %1 = arith.cmpi ugt, %cst, %0 : vector<3xindex>
+ return %1 : vector<3xi1>
+}
+
+// -----
+
+// CHECK-LABEL: @check_ugt_constant_3_rhs
+// CHECK: %[[CST:.*]] = arith.constant dense<false> : vector<3xi1>
+// CHECK: return %[[CST]] : vector<3xi1>
+func.func @check_ugt_constant_3_rhs() -> vector<3xi1> {
+ %cst = arith.constant dense<3> : vector<3xindex>
+ %0 = vector.step : vector<3xindex>
+ // [0, 1, 2] > 3 => false
+ %1 = arith.cmpi ugt, %0, %cst : vector<3xindex>
+ return %1 : vector<3xi1>
+}
+
+// -----
+
+// CHECK-LABEL: @check_ugt_constant_2_rhs
+// CHECK: %[[CST:.*]] = arith.constant dense<false> : vector<3xi1>
+// CHECK: return %[[CST]] : vector<3xi1>
+func.func @check_ugt_constant_2_rhs() -> vector<3xi1> {
+ %cst = arith.constant dense<2> : vector<3xindex>
+ %0 = vector.step : vector<3xindex>
+ // [0, 1, 2] > 2 => false
+ %1 = arith.cmpi ugt, %0, %cst : vector<3xindex>
+ return %1 : vector<3xi1>
+}
+
+// -----
+
+// CHECK-LABEL: @check_ugt_constant_1_rhs
+// CHECK: %[[CMP:.*]] = arith.cmpi
+// CHECK: return %[[CMP]]
+func.func @check_ugt_constant_1_rhs() -> vector<3xi1> {
+ %cst = arith.constant dense<1> : vector<3xindex>
+ %0 = vector.step : vector<3xindex>
+ // [0, 1, 2] > 1 => not constant
+ %1 = arith.cmpi ugt, %0, %cst: vector<3xindex>
+ return %1 : vector<3xi1>
+}
+
+// -----
+
+///===------------------------------------===//
+/// Tests of `uge` (unsigned greater than or equal)
+///===------------------------------------===//
+
+// CHECK-LABEL: @check_uge_constant_3_lhs
+// CHECK: %[[CST:.*]] = arith.constant dense<true> : vector<3xi1>
+// CHECK: return %[[CST]] : vector<3xi1>
+func.func @check_uge_constant_3_lhs() -> vector<3xi1> {
+ %cst = arith.constant dense<3> : vector<3xindex>
+ %0 = vector.step : vector<3xindex>
+ // 3 >= [0, 1, 2] => true
+ %1 = arith.cmpi uge, %cst, %0 : vector<3xindex>
+ return %1 : vector<3xi1>
+}
+
+// -----
+
+// CHECK-LABEL: @check_uge_constant_2_lhs
+// CHECK: %[[CST:.*]] = arith.constant dense<true> : vector<3xi1>
+// CHECK: return %[[CST]] : vector<3xi1>
+func.func @check_uge_constant_2_lhs() -> vector<3xi1> {
+ %cst = arith.constant dense<2> : vector<3xindex>
+ %0 = vector.step : vector<3xindex>
+ // 2 >= [0, 1, 2] => true
+ %1 = arith.cmpi uge, %cst, %0 : vector<3xindex>
+ return %1 : vector<3xi1>
+}
+
+// -----
+
+// CHECK-LABEL: @check_uge_constant_1_lhs
+// CHECK: %[[CMP:.*]] = arith.cmpi
+// CHECK: return %[[CMP]]
+func.func @check_uge_constant_1_lhs() -> vector<3xi1> {
+ %cst = arith.constant dense<1> : vector<3xindex>
+ %0 = vector.step : vector<3xindex>
+ // 1 >= [0, 1, 2] => not constant
+ %1 = arith.cmpi uge, %cst, %0 : vector<3xindex>
+ return %1 : vector<3xi1>
+}
+
+// -----
+
+// CHECK-LABEL: @check_uge_constant_3_rhs
+// CHECK: %[[CST:.*]] = arith.constant dense<false> : vector<3xi1>
+// CHECK: return %[[CST]] : vector<3xi1>
+func.func @check_uge_constant_3_rhs() -> vector<3xi1> {
+ %cst = arith.constant dense<3> : vector<3xindex>
+ %0 = vector.step : vector<3xindex>
+ // [0, 1, 2] >= 3 => false
+ %1 = arith.cmpi uge, %0, %cst : vector<3xindex>
+ return %1 : vector<3xi1>
+}
+
+// -----
+
+// CHECK-LABEL: @check_uge_constant_2_rhs
+// CHECK: %[[CMP:.*]] = arith.cmpi
+// CHECK: return %[[CMP]]
+func.func @check_uge_constant_2_rhs() -> vector<3xi1> {
+ %cst = arith.constant dense<2> : vector<3xindex>
+ %0 = vector.step : vector<3xindex>
+ // [0, 1, 2] >= 2 => not constant
+ %1 = arith.cmpi uge, %0, %cst : vector<3xindex>
+ return %1 : vector<3xi1>
+}
+
+// -----
+
+// CHECK-LABEL: @check_uge_constant_1_rhs
+// CHECK: %[[CMP:.*]] = arith.cmpi
+// CHECK: return %[[CMP]]
+func.func @check_uge_constant_1_rhs() -> vector<3xi1> {
+ %cst = arith.constant dense<1> : vector<3xindex>
+ %0 = vector.step : vector<3xindex>
+ // [0, 1, 2] >= 1 => not constant
+ %1 = arith.cmpi uge, %0, %cst: vector<3xindex>
+ return %1 : vector<3xi1>
+}
+
+// -----
+
+
+
+///===------------------------------------===//
+/// Tests of `ult` (unsigned less than)
+///===------------------------------------===//
+
+// CHECK-LABEL: @check_ult_constant_3_lhs
+// CHECK: %[[CST:.*]] = arith.constant dense<false> : vector<3xi1>
+// CHECK: return %[[CST]] : vector<3xi1>
+func.func @check_ult_constant_3_lhs() -> vector<3xi1> {
+ %cst = arith.constant dense<3> : vector<3xindex>
+ %0 = vector.step : vector<3xindex>
+ %1 = arith.cmpi ult, %cst, %0 : vector<3xindex>
+ return %1 : vector<3xi1>
+}
+
+// -----
+
+// CHECK-LABEL: @check_ult_constant_2_lhs
+// CHECK: %[[CST:.*]] = arith.constant dense<false> : vector<3xi1>
+// CHECK: return %[[CST]] : vector<3xi1>
+func.func @check_ult_constant_2_lhs() -> vector<3xi1> {
+ %cst = arith.constant dense<2> : vector<3xindex>
+ %0 = vector.step : vector<3xindex>
+ %1 = arith.cmpi ult, %cst, %0 : vector<3xindex>
+ return %1 : vector<3xi1>
+}
+
+// -----
+
+// CHECK-LABEL: @check_ult_constant_1_lhs
+// CHECK: %[[CMP:.*]] = arith.cmpi
+// CHECK: return %[[CMP]]
+func.func @check_ult_constant_1_lhs() -> vector<3xi1> {
+ %cst = arith.constant dense<1> : vector<3xindex>
+ %0 = vector.step : vector<3xindex>
+ %1 = arith.cmpi ult, %cst, %0 : vector<3xindex>
+ return %1 : vector<3xi1>
+}
+
+// -----
+
+// CHECK-LABEL: @check_ult_constant_3_rhs
+// CHECK: %[[CST:.*]] = arith.constant dense<true> : vector<3xi1>
+// CHECK: return %[[CST]] : vector<3xi1>
+func.func @check_ult_constant_3_rhs() -> vector<3xi1> {
+ %cst = arith.constant dense<3> : vector<3xindex>
+ %0 = vector.step : vector<3xindex>
+ %1 = arith.cmpi ult, %0, %cst : vector<3xindex>
+ return %1 : vector<3xi1>
+}
+
+// -----
+
+// CHECK-LABEL: @check_ult_constant_2_rhs
+// CHECK: %[[CMP:.*]] = arith.cmpi
+// CHECK: return %[[CMP]]
+func.func @check_ult_constant_2_rhs() -> vector<3xi1> {
+ %cst = arith.constant dense<2> : vector<3xindex>
+ %0 = vector.step : vector<3xindex>
+ %1 = arith.cmpi ult, %0, %cst : vector<3xindex>
+ return %1 : vector<3xi1>
+}
+
+// -----
+
+// CHECK-LABEL: @check_ult_constant_1_rhs
+// CHECK: %[[CMP:.*]] = arith.cmpi
+// CHECK: return %[[CMP]]
+func.func @check_ult_constant_1_rhs() -> vector<3xi1> {
+ %cst = arith.constant dense<1> : vector<3xindex>
+ %0 = vector.step : vector<3xindex>
+ %1 = arith.cmpi ult, %0, %cst: vector<3xindex>
+ return %1 : vector<3xi1>
+}
+
+// -----
+
+///===------------------------------------===//
+/// Tests of `ule` (unsigned less than or equal)
+///===------------------------------------===//
+
+// CHECK-LABEL: @check_ule_constant_3_lhs
+// CHECK: %[[CST:.*]] = arith.constant dense<false> : vector<3xi1>
+// CHECK: return %[[CST]] : vector<3xi1>
+func.func @check_ule_constant_3_lhs() -> vector<3xi1> {
+ %cst = arith.constant dense<3> : vector<3xindex>
+ %0 = vector.step : vector<3xindex>
+ %1 = arith.cmpi ule, %cst, %0 : vector<3xindex>
+ return %1 : vector<3xi1>
+}
+
+// -----
+
+// CHECK-LABEL: @check_ule_constant_2_lhs
+// CHECK: %[[CMP:.*]] = arith.cmpi
+// CHECK: return %[[CMP]]
+func.func @check_ule_constant_2_lhs() -> vector<3xi1> {
+ %cst = arith.constant dense<2> : vector<3xindex>
+ %0 = vector.step : vector<3xindex>
+ %1 = arith.cmpi ule, %cst, %0 : vector<3xindex>
+ return %1 : vector<3xi1>
+}
+
+// -----
+
+// CHECK-LABEL: @check_ule_constant_1_lhs
+// CHECK: %[[CMP:.*]] = arith.cmpi
+// CHECK: return %[[CMP]]
+func.func @check_ule_constant_1_lhs() -> vector<3xi1> {
+ %cst = arith.constant dense<1> : vector<3xindex>
+ %0 = vector.step : vector<3xindex>
+ %1 = arith.cmpi ule, %cst, %0 : vector<3xindex>
+ return %1 : vector<3xi1>
+}
+
+// -----
+
+// CHECK-LABEL: @check_ule_constant_3_rhs
+// CHECK: %[[CST:.*]] = arith.constant dense<true> : vector<3xi1>
+// CHECK: return %[[CST]] : vector<3xi1>
+func.func @check_ule_constant_3_rhs() -> vector<3xi1> {
+ %cst = arith.constant dense<3> : vector<3xindex>
+ %0 = vector.step : vector<3xindex>
+ %1 = arith.cmpi ule, %0, %cst : vector<3xindex>
+ return %1 : vector<3xi1>
+}
+
+// -----
+
+// CHECK-LABEL: @check_ule_constant_2_rhs
+// CHECK: %[[CST:.*]] = arith.constant dense<true> : vector<3xi1>
+// CHECK: return %[[CST]] : vector<3xi1>
+func.func @check_ule_constant_2_rhs() -> vector<3xi1> {
+ %cst = arith.constant dense<2> : vector<3xindex>
+ %0 = vector.step : vector<3xindex>
+ %1 = arith.cmpi ule, %0, %cst : vector<3xindex>
+ return %1 : vector<3xi1>
+}
+
+// -----
+
+// CHECK-LABEL: @check_ule_constant_1_rhs
+// CHECK: %[[CMP:.*]] = arith.cmpi
+// CHECK: return %[[CMP]]
+func.func @check_ule_constant_1_rhs() -> vector<3xi1> {
+ %cst = arith.constant dense<1> : vector<3xindex>
+ %0 = vector.step : vector<3xindex>
+ %1 = arith.cmpi ule, %0, %cst: vector<3xindex>
+ return %1 : vector<3xi1>
+}
+
+// -----
+
+///===------------------------------------===//
+/// Tests of `eq` (equal)
+///===------------------------------------===//
+
+// CHECK-LABEL: @check_eq_constant_3
+// CHECK: %[[CST:.*]] = arith.constant dense<false> : vector<3xi1>
+// CHECK: return %[[CST]] : vector<3xi1>
+func.func @check_eq_constant_3() -> vector<3xi1> {
+ %cst = arith.constant dense<3> : vector<3xindex>
+ %0 = vector.step : vector<3xindex>
+ %1 = arith.cmpi eq, %0, %cst: vector<3xindex>
+ return %1 : vector<3xi1>
+}
+
+// -----
+
+// CHECK-LABEL: @check_eq_constant_2
+// CHECK: %[[CMP:.*]] = arith.cmpi
+// CHECK: return %[[CMP]]
+func.func @check_eq_constant_2() -> vector<3xi1> {
+ %cst = arith.constant dense<2> : vector<3xindex>
+ %0 = vector.step : vector<3xindex>
+ %1 = arith.cmpi eq, %0, %cst: vector<3xindex>
+ return %1 : vector<3xi1>
+}
+
+// -----
+
+///===------------------------------------===//
+/// Tests of `ne` (not equal)
+///===------------------------------------===//
+
+// CHECK-LABEL: @check_ne_constant_3
+// CHECK: %[[CST:.*]] = arith.constant dense<true> : vector<3xi1>
+// CHECK: return %[[CST]] : vector<3xi1>
+func.func @check_ne_constant_3() -> vector<3xi1> {
+ %cst = arith.constant dense<3> : vector<3xindex>
+ %0 = vector.step : vector<3xindex>
+ %1 = arith.cmpi ne, %0, %cst: vector<3xindex>
+ return %1 : vector<3xi1>
+}
+
+// -----
+
+// CHECK-LABEL: @check_ne_constant_2
+// CHECK: %[[CMP:.*]] = arith.cmpi
+// CHECK: return %[[CMP]]
+func.func @check_ne_constant_2() -> vector<3xi1> {
+ %cst = arith.constant dense<2> : vector<3xindex>
+ %0 = vector.step : vector<3xindex>
+ %1 = arith.cmpi ne, %0, %cst: vector<3xindex>
+ return %1 : vector<3xi1>
+}
+
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
An alternative would be to always constant fold it, but that might result in some very large/cumbersome constants.
FYI, we have some max vector sizes defined for insert/extract folders for this exact reason. I think it would work here too.
PatternRewriter &rewriter) const override { | ||
const int64_t stepSize = stepOp.getResult().getType().getNumElements(); | ||
|
||
for (auto &use : stepOp.getResult().getUses()) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: spell out the type since it's not immediately obvious based on the RHS
// Check that operand 1 is a constant. | ||
unsigned constOperandNumber = 1; | ||
Value otherOperand = cmpiOp.getOperand(constOperandNumber); | ||
auto maybeConstValue = getConstantIntValue(otherOperand); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: spell out the type
This is a 'direct' approach alternative to the upstream PR llvm/llvm-project#161615. It directly folds the case where the linalg.index is known to always be less than the unpadded dimension size. This happens when `partial_reduction` perfectly divides the reduction dimension. The upstream approach (see link above) folds at a later stage, where, after vectorization, there is a vector.step compared to a constant. --------- Signed-off-by: James Newling <[email protected]>
934606f
to
c48b49f
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM % remaining nits
/// Above [0, 1, 2] > [7, 7, 7] => [false, false, false]. Because the result is | ||
/// false at ALL indices we fold. If the constant was 1, then | ||
/// [0, 1, 2] > [1, 1, 1] => [false, false, true] and we do fold, conservatively |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
/// Above [0, 1, 2] > [7, 7, 7] => [false, false, false]. Because the result is | |
/// false at ALL indices we fold. If the constant was 1, then | |
/// [0, 1, 2] > [1, 1, 1] => [false, false, true] and we do fold, conservatively | |
/// Above `[0, 1, 2] > [7, 7, 7]` => `[false, false, false]`. Because the result is | |
/// false at ALL indices we fold. If the constant was 1, then | |
/// `[0, 1, 2]` > [1, 1, 1]` => `[false, false, true]` and we do fold, conservatively |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM, thanks!
Personally I would reduce the number of tests by checking fewer _lhs
vs _rhs
variants. There's also one other [nit] inline 😅
Signed-off-by: James Newling <[email protected]>
Signed-off-by: James Newling <[email protected]>
Signed-off-by: James Newling <[email protected]>
Signed-off-by: James Newling <[email protected]>
914d75e
to
b851e07
Compare
Signed-off-by: James Newling <[email protected]>
I've reduced the number of tests by about 1/3x, by only testing 2 cases for each variant instead of 3. |
Nice contribution! Couldn't we make |
Interesting, do you mean make it something like : |
Yes, there is a |
This PR adds a canonicalizer to vector.step that folds vector.step iff the result of the fold is a splat value. An alternative would be to always constant fold it, but that might result in some very large/cumbersome constants.
I do wonder if vector.step might be better represented as some sort of attribute in the arith dialect, like %step = arith.constant iota<32> : vector<32xindex>.