Skip to content

Commit 0e5c32b

Browse files
authored
[MLIR][Vector] Add unrolling pattern for vector StepOp (#157752)
This PR adds unrolling pattern for vector.step op to VectorUnroll transform.
1 parent 4ab8dab commit 0e5c32b

File tree

4 files changed

+103
-1
lines changed

4 files changed

+103
-1
lines changed

mlir/include/mlir/Dialect/Vector/IR/VectorOps.td

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3019,6 +3019,7 @@ def Vector_ScanOp :
30193019

30203020
def Vector_StepOp : Vector_Op<"step", [
30213021
Pure,
3022+
DeclareOpInterfaceMethods<VectorUnrollOpInterface>,
30223023
DeclareOpInterfaceMethods<InferIntRangeInterface, ["inferResultRanges"]>
30233024
]> {
30243025
let summary = "A linear sequence of values from 0 to N";

mlir/lib/Dialect/Vector/Transforms/VectorUnroll.cpp

Lines changed: 76 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -809,6 +809,81 @@ struct UnrollBroadcastPattern : public OpRewritePattern<vector::BroadcastOp> {
809809
vector::UnrollVectorOptions options;
810810
};
811811

812+
/// This pattern unrolls `vector.step` operations according to the provided
813+
/// target unroll shape. It decomposes a large step vector into smaller step
814+
/// vectors (segments) and assembles the result by inserting each computed
815+
/// segment into the appropriate offset of the original vector.
816+
///
817+
/// The pattern does not support scalable vectors and will fail to match them.
818+
///
819+
/// For each segment, it adds the base step vector and the segment's offset,
820+
/// then inserts the result into the output vector at the corresponding
821+
/// position.
822+
///
823+
/// Example:
824+
/// Given a step operation:
825+
/// %0 = vector.step : vector<8xindex>
826+
///
827+
/// and a target unroll shape of <4>, the pattern produces:
828+
///
829+
/// %base = vector.step : vector<4xindex>
830+
/// %zero = arith.constant dense<0> : vector<8xindex>
831+
/// %result0 = vector.insert_strided_slice %base, %zero
832+
/// {offsets = [0], strides = [1]} : vector<4xindex> into vector<8xindex>
833+
/// %offset = arith.constant dense<4> : vector<4xindex>
834+
/// %segment1 = arith.addi %base, %offset : vector<4xindex>
835+
/// %result1 = vector.insert_strided_slice %segment1, %result0
836+
/// {offsets = [4], strides = [1]} : vector<4xindex> into vector<8xindex>
837+
///
838+
struct UnrollStepPattern : public OpRewritePattern<vector::StepOp> {
839+
UnrollStepPattern(MLIRContext *context,
840+
const vector::UnrollVectorOptions &options,
841+
PatternBenefit benefit = 1)
842+
: OpRewritePattern<vector::StepOp>(context, benefit), options(options) {}
843+
844+
LogicalResult matchAndRewrite(vector::StepOp stepOp,
845+
PatternRewriter &rewriter) const override {
846+
std::optional<SmallVector<int64_t>> targetShape =
847+
getTargetShape(options, stepOp);
848+
if (!targetShape)
849+
return failure();
850+
851+
VectorType vecType = stepOp.getType();
852+
if (vecType.isScalable()) {
853+
// Scalable vectors are not supported by this pattern.
854+
return failure();
855+
}
856+
int64_t originalSize = vecType.getShape()[0];
857+
Location loc = stepOp.getLoc();
858+
SmallVector<int64_t> strides(1, 1);
859+
860+
Value result = arith::ConstantOp::create(rewriter, loc, vecType,
861+
rewriter.getZeroAttr(vecType));
862+
863+
auto targetVecType =
864+
VectorType::get(*targetShape, vecType.getElementType());
865+
Value baseStep = vector::StepOp::create(rewriter, loc, targetVecType);
866+
for (const SmallVector<int64_t> &offsets :
867+
StaticTileOffsetRange({originalSize}, *targetShape)) {
868+
Value bcastOffset = arith::ConstantOp::create(
869+
rewriter, loc, targetVecType,
870+
DenseElementsAttr::get(
871+
targetVecType,
872+
IntegerAttr::get(targetVecType.getElementType(), offsets[0])));
873+
Value tileStep =
874+
arith::AddIOp::create(rewriter, loc, baseStep, bcastOffset);
875+
876+
result = rewriter.createOrFold<vector::InsertStridedSliceOp>(
877+
loc, tileStep, result, offsets, strides);
878+
}
879+
rewriter.replaceOp(stepOp, result);
880+
return success();
881+
}
882+
883+
private:
884+
vector::UnrollVectorOptions options;
885+
};
886+
812887
} // namespace
813888

814889
void mlir::vector::populateVectorUnrollPatterns(
@@ -818,6 +893,6 @@ void mlir::vector::populateVectorUnrollPatterns(
818893
UnrollContractionPattern, UnrollElementwisePattern,
819894
UnrollReductionPattern, UnrollMultiReductionPattern,
820895
UnrollTransposePattern, UnrollGatherPattern, UnrollLoadPattern,
821-
UnrollStorePattern, UnrollBroadcastPattern>(
896+
UnrollStorePattern, UnrollBroadcastPattern, UnrollStepPattern>(
822897
patterns.getContext(), options, benefit);
823898
}

mlir/test/Dialect/Vector/vector-unroll-options.mlir

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -420,3 +420,23 @@ func.func @vector_store_2D(%mem: memref<4x4xf16>, %v: vector<4x4xf16>) {
420420
// CHECK: vector.store %[[V2]], %[[ARG0]][%[[C2]], %[[C0]]] : memref<4x4xf16>, vector<2x2xf16>
421421
// CHECK: %[[V3:.*]] = vector.extract_strided_slice %[[ARG1]] {offsets = [2, 2], sizes = [2, 2], strides = [1, 1]} : vector<4x4xf16> to vector<2x2xf16>
422422
// CHECK: vector.store %[[V3]], %[[ARG0]][%[[C2]], %[[C2]]] : memref<4x4xf16>, vector<2x2xf16>
423+
424+
425+
func.func @vector_step() -> vector<32xindex> {
426+
%0 = vector.step : vector<32xindex>
427+
return %0 : vector<32xindex>
428+
}
429+
// CHECK-LABEL: func @vector_step
430+
// CHECK: %[[CST:.*]] = arith.constant dense<24> : vector<8xindex>
431+
// CHECK: %[[CST0:.*]] = arith.constant dense<16> : vector<8xindex>
432+
// CHECK: %[[CST1:.*]] = arith.constant dense<8> : vector<8xindex>
433+
// CHECK: %[[CST2:.*]] = arith.constant dense<0> : vector<32xindex>
434+
// CHECK: %[[STEP:.*]] = vector.step : vector<8xindex>
435+
// CHECK: %[[INS0:.*]] = vector.insert_strided_slice %[[STEP]], %[[CST2]] {offsets = [0], strides = [1]} : vector<8xindex> into vector<32xindex>
436+
// CHECK: %[[ADD1:.*]] = arith.addi %[[STEP]], %[[CST1]] : vector<8xindex>
437+
// CHECK: %[[INS1:.*]] = vector.insert_strided_slice %[[ADD1]], %[[INS0]] {offsets = [8], strides = [1]} : vector<8xindex> into vector<32xindex>
438+
// CHECK: %[[ADD2:.*]] = arith.addi %[[STEP]], %[[CST0]] : vector<8xindex>
439+
// CHECK: %[[INS2:.*]] = vector.insert_strided_slice %[[ADD2]], %[[INS1]] {offsets = [16], strides = [1]} : vector<8xindex> into vector<32xindex>
440+
// CHECK: %[[ADD3:.*]] = arith.addi %[[STEP]], %[[CST]] : vector<8xindex>
441+
// CHECK: %[[INS3:.*]] = vector.insert_strided_slice %[[ADD3]], %[[INS2]] {offsets = [24], strides = [1]} : vector<8xindex> into vector<32xindex>
442+
// CHECK: return %[[INS3]] : vector<32xindex>

mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -172,6 +172,12 @@ struct TestVectorUnrollingPatterns
172172
.setFilterConstraint([](Operation *op) {
173173
return success(isa<vector::ReductionOp>(op));
174174
}));
175+
populateVectorUnrollPatterns(patterns,
176+
UnrollVectorOptions()
177+
.setNativeShape(ArrayRef<int64_t>{8})
178+
.setFilterConstraint([](Operation *op) {
179+
return success(isa<vector::StepOp>(op));
180+
}));
175181
populateVectorUnrollPatterns(
176182
patterns, UnrollVectorOptions()
177183
.setNativeShape(ArrayRef<int64_t>{1, 3, 4, 2})

0 commit comments

Comments
 (0)