diff --git a/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td b/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td index e68a3c77881fb..5d45508af5c06 100644 --- a/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td +++ b/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td @@ -2877,7 +2877,10 @@ def Vector_ScanOp : // VectorStepOp //===----------------------------------------------------------------------===// -def Vector_StepOp : Vector_Op<"step", [Pure]> { +def Vector_StepOp : Vector_Op<"step", [ + Pure, + DeclareOpInterfaceMethods + ]> { let summary = "A linear sequence of values from 0 to N"; let description = [{ A `step` operation produces an index vector, i.e. a 1-D vector of values of diff --git a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp index 25ce292f16e45..86fbb76790312 100644 --- a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp +++ b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp @@ -7202,6 +7202,23 @@ Value mlir::vector::makeArithReduction(OpBuilder &b, Location loc, return selectPassthru(b, mask, result, acc); } +//===----------------------------------------------------------------------===// +// StepOp +//===----------------------------------------------------------------------===// + +void StepOp::inferResultRanges(ArrayRef argRanges, + SetIntRangeFn setResultRanges) { + auto resultType = cast(getType()); + if (resultType.isScalable()) { + return; + } + unsigned bitwidth = ConstantIntRanges::getStorageBitwidth(resultType); + APInt zero(bitwidth, 0); + APInt high(bitwidth, resultType.getDimSize(0) - 1); + ConstantIntRanges result = {zero, high, zero, high}; + setResultRanges(getResult(), result); +} + //===----------------------------------------------------------------------===// // Vector Masking Utilities //===----------------------------------------------------------------------===// diff --git a/mlir/test/Dialect/Vector/int-range-interface.mlir b/mlir/test/Dialect/Vector/int-range-interface.mlir index f89d307590e0b..b2f16bb3dac9c 100644 --- a/mlir/test/Dialect/Vector/int-range-interface.mlir +++ b/mlir/test/Dialect/Vector/int-range-interface.mlir @@ -108,3 +108,11 @@ func.func @test_vector_extsi() -> vector<2xi32> { %2 = test.reflect_bounds %1 : vector<2xi32> func.return %2 : vector<2xi32> } + +// CHECK-LABEL: func @vector_step +// CHECK: test.reflect_bounds {smax = 7 : index, smin = 0 : index, umax = 7 : index, umin = 0 : index} +func.func @vector_step() -> vector<8xindex> { + %0 = vector.step : vector<8xindex> + %1 = test.reflect_bounds %0 : vector<8xindex> + func.return %1 : vector<8xindex> +}