@@ -809,6 +809,81 @@ struct UnrollBroadcastPattern : public OpRewritePattern<vector::BroadcastOp> {
809809 vector::UnrollVectorOptions options;
810810};
811811
812+ // / This pattern unrolls `vector.step` operations according to the provided
813+ // / target unroll shape. It decomposes a large step vector into smaller step
814+ // / vectors (segments) and assembles the result by inserting each computed
815+ // / segment into the appropriate offset of the original vector.
816+ // /
817+ // / The pattern does not support scalable vectors and will fail to match them.
818+ // /
819+ // / For each segment, it adds the base step vector and the segment's offset,
820+ // / then inserts the result into the output vector at the corresponding
821+ // / position.
822+ // /
823+ // / Example:
824+ // / Given a step operation:
825+ // / %0 = vector.step : vector<8xindex>
826+ // /
827+ // / and a target unroll shape of <4>, the pattern produces:
828+ // /
829+ // / %base = vector.step : vector<4xindex>
830+ // / %zero = arith.constant dense<0> : vector<8xindex>
831+ // / %result0 = vector.insert_strided_slice %base, %zero
832+ // / {offsets = [0], strides = [1]} : vector<4xindex> into vector<8xindex>
833+ // / %offset = arith.constant dense<4> : vector<4xindex>
834+ // / %segment1 = arith.addi %base, %offset : vector<4xindex>
835+ // / %result1 = vector.insert_strided_slice %segment1, %result0
836+ // / {offsets = [4], strides = [1]} : vector<4xindex> into vector<8xindex>
837+ // /
838+ struct UnrollStepPattern : public OpRewritePattern <vector::StepOp> {
839+ UnrollStepPattern (MLIRContext *context,
840+ const vector::UnrollVectorOptions &options,
841+ PatternBenefit benefit = 1 )
842+ : OpRewritePattern<vector::StepOp>(context, benefit), options(options) {}
843+
844+ LogicalResult matchAndRewrite (vector::StepOp stepOp,
845+ PatternRewriter &rewriter) const override {
846+ std::optional<SmallVector<int64_t >> targetShape =
847+ getTargetShape (options, stepOp);
848+ if (!targetShape)
849+ return failure ();
850+
851+ VectorType vecType = stepOp.getType ();
852+ if (vecType.isScalable ()) {
853+ // Scalable vectors are not supported by this pattern.
854+ return failure ();
855+ }
856+ int64_t originalSize = vecType.getShape ()[0 ];
857+ Location loc = stepOp.getLoc ();
858+ SmallVector<int64_t > strides (1 , 1 );
859+
860+ Value result = arith::ConstantOp::create (rewriter, loc, vecType,
861+ rewriter.getZeroAttr (vecType));
862+
863+ auto targetVecType =
864+ VectorType::get (*targetShape, vecType.getElementType ());
865+ Value baseStep = vector::StepOp::create (rewriter, loc, targetVecType);
866+ for (const SmallVector<int64_t > &offsets :
867+ StaticTileOffsetRange ({originalSize}, *targetShape)) {
868+ Value bcastOffset = arith::ConstantOp::create (
869+ rewriter, loc, targetVecType,
870+ DenseElementsAttr::get (
871+ targetVecType,
872+ IntegerAttr::get (targetVecType.getElementType (), offsets[0 ])));
873+ Value tileStep =
874+ arith::AddIOp::create (rewriter, loc, baseStep, bcastOffset);
875+
876+ result = rewriter.createOrFold <vector::InsertStridedSliceOp>(
877+ loc, tileStep, result, offsets, strides);
878+ }
879+ rewriter.replaceOp (stepOp, result);
880+ return success ();
881+ }
882+
883+ private:
884+ vector::UnrollVectorOptions options;
885+ };
886+
812887} // namespace
813888
814889void mlir::vector::populateVectorUnrollPatterns (
@@ -818,6 +893,6 @@ void mlir::vector::populateVectorUnrollPatterns(
818893 UnrollContractionPattern, UnrollElementwisePattern,
819894 UnrollReductionPattern, UnrollMultiReductionPattern,
820895 UnrollTransposePattern, UnrollGatherPattern, UnrollLoadPattern,
821- UnrollStorePattern, UnrollBroadcastPattern>(
896+ UnrollStorePattern, UnrollBroadcastPattern, UnrollStepPattern >(
822897 patterns.getContext (), options, benefit);
823898}
0 commit comments