@@ -809,6 +809,32 @@ struct UnrollBroadcastPattern : public OpRewritePattern<vector::BroadcastOp> {
809809 vector::UnrollVectorOptions options;
810810};
811811
812+ // / This pattern unrolls `vector.step` operations according to the provided
813+ // / target unroll shape. It decomposes a large step vector into smaller step
814+ // / vectors (segments) and assembles the result by inserting each computed
815+ // / segment into the appropriate offset of the original vector.
816+ // /
817+ // / The pattern does not support scalable vectors and will fail to match them.
818+ // /
819+ // / For each segment, it adds the base step vector and the segment's offset,
820+ // / then inserts the result into the output vector at the corresponding
821+ // / position.
822+ // /
823+ // / Example:
824+ // / Given a step operation:
825+ // / %0 = vector.step : vector<8xindex>
826+ // /
827+ // / and a target unroll shape of <4>, the pattern produces:
828+ // /
829+ // / %base = vector.step : vector<4xindex>
830+ // / %zero = arith.constant dense<0> : vector<4xindex>
831+ // / %result0 = vector.insert_strided_slice %base, %zero
832+ // / {offsets = [0], strides = [1]} : vector<4xindex> into vector<8xindex>
833+ // / %offset = arith.constant dense<4> : vector<4xindex>
834+ // / %segment1 = arith.addi %base, %offset : vector<4xindex>
835+ // / %result1 = vector.insert_strided_slice %segment1, %result0
836+ // / {offsets = [4], strides = [1]} : vector<4xindex> into vector<8xindex>
837+ // /
812838struct UnrollStepPattern : public OpRewritePattern <vector::StepOp> {
813839 UnrollStepPattern (MLIRContext *context,
814840 const vector::UnrollVectorOptions &options,
@@ -817,7 +843,8 @@ struct UnrollStepPattern : public OpRewritePattern<vector::StepOp> {
817843
818844 LogicalResult matchAndRewrite (vector::StepOp stepOp,
819845 PatternRewriter &rewriter) const override {
820- auto targetShape = getTargetShape (options, stepOp);
846+ std::optional<SmallVector<int64_t >> targetShape =
847+ getTargetShape (options, stepOp);
821848 if (!targetShape)
822849 return failure ();
823850
@@ -833,18 +860,18 @@ struct UnrollStepPattern : public OpRewritePattern<vector::StepOp> {
833860 Value result = arith::ConstantOp::create (rewriter, loc, vecType,
834861 rewriter.getZeroAttr (vecType));
835862
863+ VectorType targetVecType =
864+ VectorType::get (*targetShape, vecType.getElementType ());
865+ Value baseStep = vector::StepOp::create (rewriter, loc, targetVecType);
836866 for (SmallVector<int64_t > offsets :
837867 StaticTileOffsetRange ({originalSize}, *targetShape)) {
838- int64_t tileOffset = offsets[0 ];
839- auto targetVecType =
840- VectorType::get (*targetShape, vecType.getElementType ());
841- Value baseStep = rewriter.create <vector::StepOp>(loc, targetVecType);
842- Value offsetVal =
843- rewriter.create <arith::ConstantIndexOp>(loc, tileOffset);
844- Value bcastOffset =
845- rewriter.create <vector::BroadcastOp>(loc, targetVecType, offsetVal);
868+ Value bcastOffset = arith::ConstantOp::create (
869+ rewriter, loc, targetVecType,
870+ DenseElementsAttr::get (
871+ targetVecType,
872+ IntegerAttr::get (targetVecType.getElementType (), offsets[0 ])));
846873 Value tileStep =
847- rewriter. create < arith::AddIOp>( loc, baseStep, bcastOffset);
874+ arith::AddIOp::create (rewriter, loc, baseStep, bcastOffset);
848875
849876 result = rewriter.createOrFold <vector::InsertStridedSliceOp>(
850877 loc, tileStep, result, offsets, strides);
0 commit comments