From ff1cac21afcf4e981baeac7306c3296849070994 Mon Sep 17 00:00:00 2001 From: Max Dawkins Date: Thu, 31 Jul 2025 15:11:49 +0000 Subject: [PATCH] [mlir] Implement inferResultRanges for vector.step op Signed-off-by: Max Dawkins --- .../include/mlir/Dialect/Vector/IR/VectorOps.td | 5 ++++- mlir/lib/Dialect/Vector/IR/VectorOps.cpp | 17 +++++++++++++++++ .../Dialect/Vector/int-range-interface.mlir | 8 ++++++++ 3 files changed, 29 insertions(+), 1 deletion(-) diff --git a/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td b/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td index 3885439e11f89..1d9a9d3f699ac 100644 --- a/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td +++ b/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td @@ -2876,7 +2876,10 @@ def Vector_ScanOp : // VectorStepOp //===----------------------------------------------------------------------===// -def Vector_StepOp : Vector_Op<"step", [Pure]> { +def Vector_StepOp : Vector_Op<"step", [ + Pure, + DeclareOpInterfaceMethods + ]> { let summary = "A linear sequence of values from 0 to N"; let description = [{ A `step` operation produces an index vector, i.e. a 1-D vector of values of diff --git a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp index 8789f55707267..83f3d8e65b785 100644 --- a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp +++ b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp @@ -7197,6 +7197,23 @@ Value mlir::vector::makeArithReduction(OpBuilder &b, Location loc, return selectPassthru(b, mask, result, acc); } +//===----------------------------------------------------------------------===// +// StepOp +//===----------------------------------------------------------------------===// + +void StepOp::inferResultRanges(ArrayRef argRanges, + SetIntRangeFn setResultRanges) { + auto resultType = cast(getType()); + if (resultType.isScalable()) { + return; + } + unsigned bitwidth = ConstantIntRanges::getStorageBitwidth(resultType); + APInt zero(bitwidth, 0); + APInt high(bitwidth, resultType.getDimSize(0) - 1); + ConstantIntRanges result = {zero, high, zero, high}; + setResultRanges(getResult(), result); +} + //===----------------------------------------------------------------------===// // Vector Masking Utilities //===----------------------------------------------------------------------===// diff --git a/mlir/test/Dialect/Vector/int-range-interface.mlir b/mlir/test/Dialect/Vector/int-range-interface.mlir index 2563b48cdd506..c60c21fadb668 100644 --- a/mlir/test/Dialect/Vector/int-range-interface.mlir +++ b/mlir/test/Dialect/Vector/int-range-interface.mlir @@ -99,3 +99,11 @@ func.func @test_vector_extsi() -> vector<2xi32> { %2 = test.reflect_bounds %1 : vector<2xi32> func.return %2 : vector<2xi32> } + +// CHECK-LABEL: func @vector_step +// CHECK: test.reflect_bounds {smax = 7 : index, smin = 0 : index, umax = 7 : index, umin = 0 : index} +func.func @vector_step() -> vector<8xindex> { + %0 = vector.step : vector<8xindex> + %1 = test.reflect_bounds %0 : vector<8xindex> + func.return %1 : vector<8xindex> +}