@@ -809,6 +809,81 @@ struct UnrollBroadcastPattern : public OpRewritePattern<vector::BroadcastOp> {
809
809
vector::UnrollVectorOptions options;
810
810
};
811
811
812
+ // / This pattern unrolls `vector.step` operations according to the provided
813
+ // / target unroll shape. It decomposes a large step vector into smaller step
814
+ // / vectors (segments) and assembles the result by inserting each computed
815
+ // / segment into the appropriate offset of the original vector.
816
+ // /
817
+ // / The pattern does not support scalable vectors and will fail to match them.
818
+ // /
819
+ // / For each segment, it adds the base step vector and the segment's offset,
820
+ // / then inserts the result into the output vector at the corresponding
821
+ // / position.
822
+ // /
823
+ // / Example:
824
+ // / Given a step operation:
825
+ // / %0 = vector.step : vector<8xindex>
826
+ // /
827
+ // / and a target unroll shape of <4>, the pattern produces:
828
+ // /
829
+ // / %base = vector.step : vector<4xindex>
830
+ // / %zero = arith.constant dense<0> : vector<8xindex>
831
+ // / %result0 = vector.insert_strided_slice %base, %zero
832
+ // / {offsets = [0], strides = [1]} : vector<4xindex> into vector<8xindex>
833
+ // / %offset = arith.constant dense<4> : vector<4xindex>
834
+ // / %segment1 = arith.addi %base, %offset : vector<4xindex>
835
+ // / %result1 = vector.insert_strided_slice %segment1, %result0
836
+ // / {offsets = [4], strides = [1]} : vector<4xindex> into vector<8xindex>
837
+ // /
838
+ struct UnrollStepPattern : public OpRewritePattern <vector::StepOp> {
839
+ UnrollStepPattern (MLIRContext *context,
840
+ const vector::UnrollVectorOptions &options,
841
+ PatternBenefit benefit = 1 )
842
+ : OpRewritePattern<vector::StepOp>(context, benefit), options(options) {}
843
+
844
+ LogicalResult matchAndRewrite (vector::StepOp stepOp,
845
+ PatternRewriter &rewriter) const override {
846
+ std::optional<SmallVector<int64_t >> targetShape =
847
+ getTargetShape (options, stepOp);
848
+ if (!targetShape)
849
+ return failure ();
850
+
851
+ VectorType vecType = stepOp.getType ();
852
+ if (vecType.isScalable ()) {
853
+ // Scalable vectors are not supported by this pattern.
854
+ return failure ();
855
+ }
856
+ int64_t originalSize = vecType.getShape ()[0 ];
857
+ Location loc = stepOp.getLoc ();
858
+ SmallVector<int64_t > strides (1 , 1 );
859
+
860
+ Value result = arith::ConstantOp::create (rewriter, loc, vecType,
861
+ rewriter.getZeroAttr (vecType));
862
+
863
+ auto targetVecType =
864
+ VectorType::get (*targetShape, vecType.getElementType ());
865
+ Value baseStep = vector::StepOp::create (rewriter, loc, targetVecType);
866
+ for (const SmallVector<int64_t > &offsets :
867
+ StaticTileOffsetRange ({originalSize}, *targetShape)) {
868
+ Value bcastOffset = arith::ConstantOp::create (
869
+ rewriter, loc, targetVecType,
870
+ DenseElementsAttr::get (
871
+ targetVecType,
872
+ IntegerAttr::get (targetVecType.getElementType (), offsets[0 ])));
873
+ Value tileStep =
874
+ arith::AddIOp::create (rewriter, loc, baseStep, bcastOffset);
875
+
876
+ result = rewriter.createOrFold <vector::InsertStridedSliceOp>(
877
+ loc, tileStep, result, offsets, strides);
878
+ }
879
+ rewriter.replaceOp (stepOp, result);
880
+ return success ();
881
+ }
882
+
883
+ private:
884
+ vector::UnrollVectorOptions options;
885
+ };
886
+
812
887
} // namespace
813
888
814
889
void mlir::vector::populateVectorUnrollPatterns (
@@ -818,6 +893,6 @@ void mlir::vector::populateVectorUnrollPatterns(
818
893
UnrollContractionPattern, UnrollElementwisePattern,
819
894
UnrollReductionPattern, UnrollMultiReductionPattern,
820
895
UnrollTransposePattern, UnrollGatherPattern, UnrollLoadPattern,
821
- UnrollStorePattern, UnrollBroadcastPattern>(
896
+ UnrollStorePattern, UnrollBroadcastPattern, UnrollStepPattern >(
822
897
patterns.getContext (), options, benefit);
823
898
}
0 commit comments