@@ -809,6 +809,54 @@ struct UnrollBroadcastPattern : public OpRewritePattern<vector::BroadcastOp> {
809809 vector::UnrollVectorOptions options;
810810};
811811
812+ struct UnrollStepPattern : public OpRewritePattern <vector::StepOp> {
813+ UnrollStepPattern (MLIRContext *context,
814+ const vector::UnrollVectorOptions &options,
815+ PatternBenefit benefit = 1 )
816+ : OpRewritePattern<vector::StepOp>(context, benefit), options(options) {}
817+
818+ LogicalResult matchAndRewrite (vector::StepOp stepOp,
819+ PatternRewriter &rewriter) const override {
820+ auto targetShape = getTargetShape (options, stepOp);
821+ if (!targetShape)
822+ return failure ();
823+
824+ VectorType vecType = stepOp.getType ();
825+ if (vecType.isScalable ()) {
826+ // Scalable vectors are not supported by this pattern.
827+ return failure ();
828+ }
829+ int64_t originalSize = vecType.getShape ()[0 ];
830+ Location loc = stepOp.getLoc ();
831+ SmallVector<int64_t > strides (1 , 1 );
832+
833+ Value result = arith::ConstantOp::create (rewriter, loc, vecType,
834+ rewriter.getZeroAttr (vecType));
835+
836+ for (SmallVector<int64_t > offsets :
837+ StaticTileOffsetRange ({originalSize}, *targetShape)) {
838+ int64_t tileOffset = offsets[0 ];
839+ auto targetVecType =
840+ VectorType::get (*targetShape, vecType.getElementType ());
841+ Value baseStep = rewriter.create <vector::StepOp>(loc, targetVecType);
842+ Value offsetVal =
843+ rewriter.create <arith::ConstantIndexOp>(loc, tileOffset);
844+ Value bcastOffset =
845+ rewriter.create <vector::BroadcastOp>(loc, targetVecType, offsetVal);
846+ Value tileStep =
847+ rewriter.create <arith::AddIOp>(loc, baseStep, bcastOffset);
848+
849+ result = rewriter.createOrFold <vector::InsertStridedSliceOp>(
850+ loc, tileStep, result, offsets, strides);
851+ }
852+ rewriter.replaceOp (stepOp, result);
853+ return success ();
854+ }
855+
856+ private:
857+ vector::UnrollVectorOptions options;
858+ };
859+
812860} // namespace
813861
814862void mlir::vector::populateVectorUnrollPatterns (
@@ -818,6 +866,6 @@ void mlir::vector::populateVectorUnrollPatterns(
818866 UnrollContractionPattern, UnrollElementwisePattern,
819867 UnrollReductionPattern, UnrollMultiReductionPattern,
820868 UnrollTransposePattern, UnrollGatherPattern, UnrollLoadPattern,
821- UnrollStorePattern, UnrollBroadcastPattern>(
869+ UnrollStorePattern, UnrollBroadcastPattern, UnrollStepPattern >(
822870 patterns.getContext (), options, benefit);
823871}
0 commit comments