Skip to content

Commit ff1cac2

Browse files
committed
[mlir] Implement inferResultRanges for vector.step op
Signed-off-by: Max Dawkins <[email protected]>
1 parent 9fdd1d3 commit ff1cac2

File tree

3 files changed

+29
-1
lines changed

3 files changed

+29
-1
lines changed

mlir/include/mlir/Dialect/Vector/IR/VectorOps.td

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2876,7 +2876,10 @@ def Vector_ScanOp :
28762876
// VectorStepOp
28772877
//===----------------------------------------------------------------------===//
28782878

2879-
def Vector_StepOp : Vector_Op<"step", [Pure]> {
2879+
def Vector_StepOp : Vector_Op<"step", [
2880+
Pure,
2881+
DeclareOpInterfaceMethods<InferIntRangeInterface, ["inferResultRanges"]>
2882+
]> {
28802883
let summary = "A linear sequence of values from 0 to N";
28812884
let description = [{
28822885
A `step` operation produces an index vector, i.e. a 1-D vector of values of

mlir/lib/Dialect/Vector/IR/VectorOps.cpp

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7197,6 +7197,23 @@ Value mlir::vector::makeArithReduction(OpBuilder &b, Location loc,
71977197
return selectPassthru(b, mask, result, acc);
71987198
}
71997199

7200+
//===----------------------------------------------------------------------===//
7201+
// StepOp
7202+
//===----------------------------------------------------------------------===//
7203+
7204+
void StepOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,
7205+
SetIntRangeFn setResultRanges) {
7206+
auto resultType = cast<VectorType>(getType());
7207+
if (resultType.isScalable()) {
7208+
return;
7209+
}
7210+
unsigned bitwidth = ConstantIntRanges::getStorageBitwidth(resultType);
7211+
APInt zero(bitwidth, 0);
7212+
APInt high(bitwidth, resultType.getDimSize(0) - 1);
7213+
ConstantIntRanges result = {zero, high, zero, high};
7214+
setResultRanges(getResult(), result);
7215+
}
7216+
72007217
//===----------------------------------------------------------------------===//
72017218
// Vector Masking Utilities
72027219
//===----------------------------------------------------------------------===//

mlir/test/Dialect/Vector/int-range-interface.mlir

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -99,3 +99,11 @@ func.func @test_vector_extsi() -> vector<2xi32> {
9999
%2 = test.reflect_bounds %1 : vector<2xi32>
100100
func.return %2 : vector<2xi32>
101101
}
102+
103+
// CHECK-LABEL: func @vector_step
104+
// CHECK: test.reflect_bounds {smax = 7 : index, smin = 0 : index, umax = 7 : index, umin = 0 : index}
105+
func.func @vector_step() -> vector<8xindex> {
106+
%0 = vector.step : vector<8xindex>
107+
%1 = test.reflect_bounds %0 : vector<8xindex>
108+
func.return %1 : vector<8xindex>
109+
}

0 commit comments

Comments
 (0)