@@ -28,6 +28,27 @@ static int64_t getNumGtOneDims(ArrayRef<int64_t> shape) {
2828 shape, [](int64_t v) { return ShapedType::isDynamic (v) || v > 1 ; });
2929}
3030
31+ // / Returns success() if there is only 1 dimension size in non-packed domain
32+ // / being greater than 1 and packing only happens on the dimension.
33+ // / Note: this method should only be used by pack/unpack to reshape conversion.
34+ // / It assumes that non-unit inner tile size must be used by the non-unit
35+ // / dimension.
36+ static LogicalResult isPackOn1D (RewriterBase &rewriter, Operation *op,
37+ ArrayRef<int64_t > srcShape,
38+ ArrayRef<int64_t > innerPackTileSize) {
39+ if (getNumGtOneDims (srcShape) > 1 ) {
40+ return rewriter.notifyMatchFailure (
41+ op, " expects non-packed domain to have at most one non-unit dims" );
42+ }
43+ // Non-unit inner tile size must be used by the non-unit dimension. If not, it
44+ // will faill on getting reassociation maps.
45+ if (getNumGtOneDims (innerPackTileSize) > 1 ) {
46+ return rewriter.notifyMatchFailure (
47+ op, " expects at most one non-unit inner tiles" );
48+ }
49+ return success ();
50+ }
51+
3152// / Packing one-dimensional tensor can be expressed as an expand shape op.
3253struct SimplifyPackToExpandShape : public OpRewritePattern <PackOp> {
3354 using OpRewritePattern<PackOp>::OpRewritePattern;
@@ -59,40 +80,18 @@ struct SimplifyPackToExpandShape : public OpRewritePattern<PackOp> {
5980 return success ();
6081 }
6182
62- // / Returns success() if there is only 1 dimension size in source being
63- // / greater than 1 and packing only happens on the dimension. It assumes that
64- // / the pack op does not have padding value.
65- LogicalResult isPack1DSrc (RewriterBase &rewriter, PackOp packOp) const {
66- assert (!packOp.getPaddingValue () &&
67- " expect the op does not have padding value." );
68- ArrayRef<int64_t > srcShape = packOp.getSourceType ().getShape ();
69- if (getNumGtOneDims (srcShape) > 1 ) {
70- return rewriter.notifyMatchFailure (
71- packOp, " expects source to have at most one non-unit dims" );
72- }
73-
74- // The pack op does not have padding value. Non-unit inner tile size must be
75- // be used by the non-unit dimension.
76- SmallVector<int64_t > innerTiles = packOp.getStaticTiles ();
77- if (getNumGtOneDims (innerTiles) > 1 ) {
78- return rewriter.notifyMatchFailure (
79- packOp, " expects at most one non-unit inner tiles" );
80- }
81-
82- return success ();
83- }
84-
8583 LogicalResult matchAndRewrite (PackOp packOp,
8684 PatternRewriter &rewriter) const override {
8785 if (packOp.getPaddingValue ())
8886 return rewriter.notifyMatchFailure (packOp, " expects no padding value" );
8987
88+ RankedTensorType sourceType = packOp.getSourceType ();
9089 if (failed (isPackOnInnerMostDim (rewriter, packOp)) &&
91- failed (isPack1DSrc (rewriter, packOp))) {
90+ failed (isPackOn1D (rewriter, packOp, sourceType.getShape (),
91+ packOp.getStaticTiles ()))) {
9292 return failure ();
9393 }
9494
95- RankedTensorType sourceType = packOp.getSourceType ();
9695 RankedTensorType destType = packOp.getDestType ();
9796 auto reassociation =
9897 getReassociationIndicesForReshape (sourceType, destType);
@@ -117,8 +116,9 @@ struct SimplifyUnPackToCollapseShape : public OpRewritePattern<UnPackOp> {
117116 operand, reassociation);
118117 }
119118
120- LogicalResult matchAndRewrite (UnPackOp unpackOp,
121- PatternRewriter &rewriter) const override {
119+ // / Returns success() if it is unpacking on the innermost dimension.
120+ LogicalResult isUnpackOnInnerMostDim (RewriterBase &rewriter,
121+ UnPackOp unpackOp) const {
122122 auto outerDimsPerm = unpackOp.getOuterDimsPerm ();
123123 if (!outerDimsPerm.empty () && !isIdentityPermutation (outerDimsPerm)) {
124124 return rewriter.notifyMatchFailure (
@@ -134,9 +134,22 @@ struct SimplifyUnPackToCollapseShape : public OpRewritePattern<UnPackOp> {
134134 ArrayRef<int64_t > dimsPos = unpackOp.getInnerDimsPos ();
135135 if (dimsPos.size () != 1 || (dimsPos[0 ] + 1 != destType.getRank ())) {
136136 return rewriter.notifyMatchFailure (
137- unpackOp, " expects unpacking at the innermost dimension" );
137+ unpackOp, " expects unpacking on the innermost dimension" );
138138 }
139139
140+ return success ();
141+ }
142+
143+ LogicalResult matchAndRewrite (UnPackOp unpackOp,
144+ PatternRewriter &rewriter) const override {
145+ RankedTensorType destType = unpackOp.getDestType ();
146+ if (failed (isUnpackOnInnerMostDim (rewriter, unpackOp)) &&
147+ failed (isPackOn1D (rewriter, unpackOp, destType.getShape (),
148+ unpackOp.getStaticTiles ()))) {
149+ return failure ();
150+ }
151+
152+ RankedTensorType sourceType = unpackOp.getSourceType ();
140153 auto reassociation =
141154 getReassociationIndicesForReshape (sourceType, destType);
142155 if (!reassociation)
0 commit comments