Skip to content

Commit 91e0055

Browse files
authored
[mlir] Implement inferResultRanges for vector.step op (#151536)
Implements the `inferResultRanges` method from the `InferIntRangeInterface` interface for `vector.step`. The implementation is similar to that of arith.constant, since the exact result values are statically known. Signed-off-by: Max Dawkins <[email protected]>
1 parent f23c10f commit 91e0055

File tree

3 files changed

+29
-1
lines changed

3 files changed

+29
-1
lines changed

mlir/include/mlir/Dialect/Vector/IR/VectorOps.td

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2877,7 +2877,10 @@ def Vector_ScanOp :
28772877
// VectorStepOp
28782878
//===----------------------------------------------------------------------===//
28792879

2880-
def Vector_StepOp : Vector_Op<"step", [Pure]> {
2880+
def Vector_StepOp : Vector_Op<"step", [
2881+
Pure,
2882+
DeclareOpInterfaceMethods<InferIntRangeInterface, ["inferResultRanges"]>
2883+
]> {
28812884
let summary = "A linear sequence of values from 0 to N";
28822885
let description = [{
28832886
A `step` operation produces an index vector, i.e. a 1-D vector of values of

mlir/lib/Dialect/Vector/IR/VectorOps.cpp

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7202,6 +7202,23 @@ Value mlir::vector::makeArithReduction(OpBuilder &b, Location loc,
72027202
return selectPassthru(b, mask, result, acc);
72037203
}
72047204

7205+
//===----------------------------------------------------------------------===//
7206+
// StepOp
7207+
//===----------------------------------------------------------------------===//
7208+
7209+
void StepOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,
7210+
SetIntRangeFn setResultRanges) {
7211+
auto resultType = cast<VectorType>(getType());
7212+
if (resultType.isScalable()) {
7213+
return;
7214+
}
7215+
unsigned bitwidth = ConstantIntRanges::getStorageBitwidth(resultType);
7216+
APInt zero(bitwidth, 0);
7217+
APInt high(bitwidth, resultType.getDimSize(0) - 1);
7218+
ConstantIntRanges result = {zero, high, zero, high};
7219+
setResultRanges(getResult(), result);
7220+
}
7221+
72057222
//===----------------------------------------------------------------------===//
72067223
// Vector Masking Utilities
72077224
//===----------------------------------------------------------------------===//

mlir/test/Dialect/Vector/int-range-interface.mlir

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -108,3 +108,11 @@ func.func @test_vector_extsi() -> vector<2xi32> {
108108
%2 = test.reflect_bounds %1 : vector<2xi32>
109109
func.return %2 : vector<2xi32>
110110
}
111+
112+
// CHECK-LABEL: func @vector_step
113+
// CHECK: test.reflect_bounds {smax = 7 : index, smin = 0 : index, umax = 7 : index, umin = 0 : index}
114+
func.func @vector_step() -> vector<8xindex> {
115+
%0 = vector.step : vector<8xindex>
116+
%1 = test.reflect_bounds %0 : vector<8xindex>
117+
func.return %1 : vector<8xindex>
118+
}

0 commit comments

Comments
 (0)