@@ -5594,6 +5594,29 @@ LogicalResult ShapeCastOp::verify() {
55945594 return success ();
55955595}
55965596
5597+ namespace {
5598+
5599+ // / Return true if `transpose` does not permute a pair of dimensions that are
5600+ // / both not of size 1. By `order preserving` we mean that the flattened
5601+ // / versions of the input and output vectors are (numerically) identical.
5602+ // / In other words `transpose` is effectively a shape cast.
5603+ bool isOrderPreserving (TransposeOp transpose) {
5604+ ArrayRef<int64_t > permutation = transpose.getPermutation ();
5605+ ArrayRef<int64_t > inShape = transpose.getSourceVectorType ().getShape ();
5606+ int64_t current = 0 ;
5607+ for (auto p : permutation) {
5608+ if (inShape[p] != 1 ) {
5609+ if (p < current) {
5610+ return false ;
5611+ }
5612+ current = p;
5613+ }
5614+ }
5615+ return true ;
5616+ }
5617+
5618+ } // namespace
5619+
55975620OpFoldResult ShapeCastOp::fold (FoldAdaptor adaptor) {
55985621
55995622 // No-op shape cast.
@@ -5602,13 +5625,15 @@ OpFoldResult ShapeCastOp::fold(FoldAdaptor adaptor) {
56025625
56035626 VectorType resultType = getType ();
56045627
5605- // Canceling shape casts.
5606- if (auto otherOp = getSource ().getDefiningOp <ShapeCastOp>()) {
5607-
5608- // Only allows valid transitive folding (expand/collapse dimensions).
5609- VectorType srcType = otherOp.getSource ().getType ();
5628+ // shape_cast(something(x)) -> x, or
5629+ // -> shape_cast(x).
5630+ //
5631+ // Confirms that a new shape_cast will have valid semantics (expands OR
5632+ // collapses dimensions).
5633+ auto maybeFold = [&](TypedValue<VectorType> source) -> OpFoldResult {
5634+ VectorType srcType = source.getType ();
56105635 if (resultType == srcType)
5611- return otherOp. getSource () ;
5636+ return source ;
56125637 if (srcType.getRank () < resultType.getRank ()) {
56135638 if (!isValidShapeCast (srcType.getShape (), resultType.getShape ()))
56145639 return {};
@@ -5618,8 +5643,25 @@ OpFoldResult ShapeCastOp::fold(FoldAdaptor adaptor) {
56185643 } else {
56195644 return {};
56205645 }
5621- setOperand (otherOp. getSource () );
5646+ setOperand (source );
56225647 return getResult ();
5648+ };
5649+
5650+ // Canceling shape casts.
5651+ if (auto otherOp = getSource ().getDefiningOp <ShapeCastOp>()) {
5652+ TypedValue<VectorType> source = otherOp.getSource ();
5653+ return maybeFold (source);
5654+ }
5655+
5656+ // shape_cast(transpose(x)) -> shape_cast(x)
5657+ if (auto transpose = getSource ().getDefiningOp <TransposeOp>()) {
5658+ if (transpose.getType ().isScalable ())
5659+ return {};
5660+ if (isOrderPreserving (transpose)) {
5661+ TypedValue<VectorType> source = transpose.getVector ();
5662+ return maybeFold (source);
5663+ }
5664+ return {};
56235665 }
56245666
56255667 // Cancelling broadcast and shape cast ops.
@@ -5646,7 +5688,7 @@ namespace {
56465688// / Helper function that computes a new vector type based on the input vector
56475689// / type by removing the trailing one dims:
56485690// /
5649- // / vector<4x1x1xi1> --> vector<4x1 >
5691+ // / vector<4x1x1xi1> --> vector<4x1xi1 >
56505692// /
56515693static VectorType trimTrailingOneDims (VectorType oldType) {
56525694 ArrayRef<int64_t > oldShape = oldType.getShape ();
@@ -6113,6 +6155,34 @@ class FoldTransposeCreateMask final : public OpRewritePattern<TransposeOp> {
61136155 }
61146156};
61156157
6158+ // / Folds transpose(shape_cast) into a new shape_cast.
6159+ class FoldTransposeShapeCast final : public OpRewritePattern<TransposeOp> {
6160+ public:
6161+ using OpRewritePattern::OpRewritePattern;
6162+
6163+ LogicalResult matchAndRewrite (TransposeOp transposeOp,
6164+ PatternRewriter &rewriter) const override {
6165+ auto shapeCastOp =
6166+ transposeOp.getVector ().getDefiningOp <vector::ShapeCastOp>();
6167+ if (!shapeCastOp)
6168+ return failure ();
6169+ if (!isOrderPreserving (transposeOp))
6170+ return failure ();
6171+ if (transposeOp.getType ().isScalable ())
6172+ return failure ();
6173+
6174+ VectorType resultType = transposeOp.getType ();
6175+
6176+ // We don't need to check isValidShapeCast at this point, because it is
6177+ // guaranteed that merging the transpose into the the shape_cast is a valid
6178+ // shape_cast, because the transpose just inserts/removes ones.
6179+
6180+ rewriter.replaceOpWithNewOp <vector::ShapeCastOp>(transposeOp, resultType,
6181+ shapeCastOp.getSource ());
6182+ return success ();
6183+ }
6184+ };
6185+
61166186// / Folds transpose(broadcast(x)) to broadcast(x) if the transpose is
61176187// / 'order preserving', where 'order preserving' means the flattened
61186188// / inputs and outputs of the transpose have identical (numerical) values.
@@ -6211,8 +6281,8 @@ class FoldTransposeBroadcast : public OpRewritePattern<vector::TransposeOp> {
62116281
62126282void vector::TransposeOp::getCanonicalizationPatterns (
62136283 RewritePatternSet &results, MLIRContext *context) {
6214- results.add <FoldTransposeCreateMask, TransposeFolder, FoldTransposeSplat ,
6215- FoldTransposeBroadcast>(context);
6284+ results.add <FoldTransposeCreateMask, FoldTransposeShapeCast, TransposeFolder ,
6285+ FoldTransposeSplat, FoldTransposeBroadcast>(context);
62166286}
62176287
62186288// ===----------------------------------------------------------------------===//
0 commit comments