Skip to content

Conversation

@IanWood1
Copy link
Contributor

@IanWood1 IanWood1 commented Nov 26, 2024

If expand(collapse) has a dimension that gets collapsed and then expanded to the same shape, the pattern would fail to canonicalize this to a single collapse shape. Line 341 was changed because the expand(collapse) could be a reinterpret-cast like sequence where the shapes differ but the rank is the same. This cannot be represented by a single collapse_shape op.

@IanWood1 IanWood1 requested review from Max191, pashu123 and qedawkins and removed request for pashu123 and qedawkins November 26, 2024 19:02
@llvmbot
Copy link
Member

llvmbot commented Nov 26, 2024

@llvm/pr-subscribers-mlir-tensor

Author: Ian Wood (IanWood1)

Changes

If expand(collapse) has a dimension that gets collapsed and then expanded to the same shape, the pattern would fail to canonicalize this to a single collapse shape. Line 341 was changed because the expand(collapse) could be a reinterpret-cast like sequence where the shapes differ but the rank is the same. This cannot be represented by a single collapse_shape op and should be converted to a cast.


Full diff: https://github.com/llvm/llvm-project/pull/117768.diff

2 Files Affected:

  • (modified) mlir/include/mlir/Dialect/Utils/ReshapeOpsUtils.h (+14-10)
  • (modified) mlir/test/Dialect/Tensor/canonicalize.mlir (+28)
diff --git a/mlir/include/mlir/Dialect/Utils/ReshapeOpsUtils.h b/mlir/include/mlir/Dialect/Utils/ReshapeOpsUtils.h
index 89bc57f09ec8ba..0357e34a2e0963 100644
--- a/mlir/include/mlir/Dialect/Utils/ReshapeOpsUtils.h
+++ b/mlir/include/mlir/Dialect/Utils/ReshapeOpsUtils.h
@@ -338,7 +338,7 @@ struct ComposeExpandOfCollapseOp : public OpRewritePattern<ExpandOpTy> {
 
     int64_t srcRank = srcType.getRank();
     int64_t resultRank = resultType.getRank();
-    if (srcType == resultType)
+    if (srcRank == resultRank)
       return failure();
 
     auto srcReassociation = collapseOp.getReassociationIndices();
@@ -388,12 +388,16 @@ struct ComposeExpandOfCollapseOp : public OpRewritePattern<ExpandOpTy> {
           resultShape.slice(resultIndices.front(), resultIndices.size());
 
       if (srcSubShape.size() == resultSubShape.size()) {
-        if (srcSubShape == resultSubShape &&
-            llvm::count_if(srcSubShape, ShapedType::isDynamic) < 2) {
-          composedReassociation.push_back(srcIndices);
-        } else {
+        if (srcSubShape != resultSubShape ||
+            llvm::count_if(srcSubShape, ShapedType::isDynamic) >= 2) {
           return std::nullopt;
         }
+        for (auto dim : llvm::seq<int64_t>(0, srcSubShape.size())) {
+          ReassociationIndices reassoc;
+          reassoc.push_back(srcIndices.front() + dim);
+          composedReassociation.push_back(reassoc);
+        }
+        continue;
       }
 
       // Find reassociation to collapse `srcSubShape` into `resultSubShape`.
@@ -403,11 +407,11 @@ struct ComposeExpandOfCollapseOp : public OpRewritePattern<ExpandOpTy> {
         return std::nullopt;
 
       // Remap the subshape indices back to the original srcShape.
-      for (auto &subshape_indices : *subShapeReassociation) {
-        ReassociationIndices shape_indices;
-        for (int64_t index : subshape_indices)
-          shape_indices.push_back(srcIndices.front() + index);
-        composedReassociation.push_back(shape_indices);
+      for (auto &subshapeIndices : *subShapeReassociation) {
+        ReassociationIndices shapeIndices;
+        for (int64_t index : subshapeIndices)
+          shapeIndices.push_back(srcIndices.front() + index);
+        composedReassociation.push_back(shapeIndices);
       }
     }
     return {std::move(composedReassociation)};
diff --git a/mlir/test/Dialect/Tensor/canonicalize.mlir b/mlir/test/Dialect/Tensor/canonicalize.mlir
index 0b54c207dea84e..613ec066337294 100644
--- a/mlir/test/Dialect/Tensor/canonicalize.mlir
+++ b/mlir/test/Dialect/Tensor/canonicalize.mlir
@@ -1382,6 +1382,34 @@ func.func @compose_expand_of_collapse_0_rank_to_collapse(%arg0 : tensor<1x1x1x1x
 
 // -----
 
+func.func @compose_expand_of_collapse_static(%arg0 : tensor<4x32x10x64x2xf16>) -> tensor<4x32x10x128xf16> {
+  %collapsed = tensor.collapse_shape %arg0 [[0, 1], [2], [3, 4]] : tensor<4x32x10x64x2xf16> into tensor<128x10x128xf16>
+  %expanded = tensor.expand_shape %collapsed [[0, 1], [2], [3]] output_shape [4, 32, 10, 128] : tensor<128x10x128xf16> into tensor<4x32x10x128xf16>
+  return %expanded : tensor<4x32x10x128xf16>
+}
+
+// CHECK-LABEL: func @compose_expand_of_collapse_static
+// CHECK-SAME:   %[[ARG0:.+]]: tensor<4x32x10x64x2xf16>
+//      CHECK:   %[[RESULT:.+]] = tensor.collapse_shape %[[ARG0]]
+// CHECK-SAME:     [0], [1], [2], [3, 4]
+//      CHECK:   return %[[RESULT]]
+
+// -----
+
+func.func @compose_expand_of_collapse_dynamic(%arg0 : tensor<4x?x10x64x2xf16>, %arg1 : index) -> tensor<4x?x10x128xf16> {
+  %collapsed = tensor.collapse_shape %arg0 [[0, 1], [2], [3, 4]] : tensor<4x?x10x64x2xf16> into tensor<?x10x128xf16>
+  %expanded = tensor.expand_shape %collapsed [[0, 1], [2], [3]] output_shape [4, %arg1,  10, 128] : tensor<?x10x128xf16> into tensor<4x?x10x128xf16>
+  return %expanded : tensor<4x?x10x128xf16>
+}
+
+// CHECK-LABEL: func @compose_expand_of_collapse_dynamic
+// CHECK-SAME:   %[[ARG0:.+]]: tensor<4x?x10x64x2xf16>
+//      CHECK:   %[[RESULT:.+]] = tensor.collapse_shape %[[ARG0]]
+// CHECK-SAME:     [0], [1], [2], [3, 4]
+//      CHECK:   return %[[RESULT]]
+
+// -----
+
 // CHECK-LABEL: func @zero_rank_reshape_multi
 func.func @zero_rank_reshape_multi(%arg0: tensor<f32>) -> tensor<f32> {
   // CHECK: return %arg0

@llvmbot
Copy link
Member

llvmbot commented Nov 26, 2024

@llvm/pr-subscribers-mlir

Author: Ian Wood (IanWood1)

Changes

If expand(collapse) has a dimension that gets collapsed and then expanded to the same shape, the pattern would fail to canonicalize this to a single collapse shape. Line 341 was changed because the expand(collapse) could be a reinterpret-cast like sequence where the shapes differ but the rank is the same. This cannot be represented by a single collapse_shape op and should be converted to a cast.


Full diff: https://github.com/llvm/llvm-project/pull/117768.diff

2 Files Affected:

  • (modified) mlir/include/mlir/Dialect/Utils/ReshapeOpsUtils.h (+14-10)
  • (modified) mlir/test/Dialect/Tensor/canonicalize.mlir (+28)
diff --git a/mlir/include/mlir/Dialect/Utils/ReshapeOpsUtils.h b/mlir/include/mlir/Dialect/Utils/ReshapeOpsUtils.h
index 89bc57f09ec8ba..0357e34a2e0963 100644
--- a/mlir/include/mlir/Dialect/Utils/ReshapeOpsUtils.h
+++ b/mlir/include/mlir/Dialect/Utils/ReshapeOpsUtils.h
@@ -338,7 +338,7 @@ struct ComposeExpandOfCollapseOp : public OpRewritePattern<ExpandOpTy> {
 
     int64_t srcRank = srcType.getRank();
     int64_t resultRank = resultType.getRank();
-    if (srcType == resultType)
+    if (srcRank == resultRank)
       return failure();
 
     auto srcReassociation = collapseOp.getReassociationIndices();
@@ -388,12 +388,16 @@ struct ComposeExpandOfCollapseOp : public OpRewritePattern<ExpandOpTy> {
           resultShape.slice(resultIndices.front(), resultIndices.size());
 
       if (srcSubShape.size() == resultSubShape.size()) {
-        if (srcSubShape == resultSubShape &&
-            llvm::count_if(srcSubShape, ShapedType::isDynamic) < 2) {
-          composedReassociation.push_back(srcIndices);
-        } else {
+        if (srcSubShape != resultSubShape ||
+            llvm::count_if(srcSubShape, ShapedType::isDynamic) >= 2) {
           return std::nullopt;
         }
+        for (auto dim : llvm::seq<int64_t>(0, srcSubShape.size())) {
+          ReassociationIndices reassoc;
+          reassoc.push_back(srcIndices.front() + dim);
+          composedReassociation.push_back(reassoc);
+        }
+        continue;
       }
 
       // Find reassociation to collapse `srcSubShape` into `resultSubShape`.
@@ -403,11 +407,11 @@ struct ComposeExpandOfCollapseOp : public OpRewritePattern<ExpandOpTy> {
         return std::nullopt;
 
       // Remap the subshape indices back to the original srcShape.
-      for (auto &subshape_indices : *subShapeReassociation) {
-        ReassociationIndices shape_indices;
-        for (int64_t index : subshape_indices)
-          shape_indices.push_back(srcIndices.front() + index);
-        composedReassociation.push_back(shape_indices);
+      for (auto &subshapeIndices : *subShapeReassociation) {
+        ReassociationIndices shapeIndices;
+        for (int64_t index : subshapeIndices)
+          shapeIndices.push_back(srcIndices.front() + index);
+        composedReassociation.push_back(shapeIndices);
       }
     }
     return {std::move(composedReassociation)};
diff --git a/mlir/test/Dialect/Tensor/canonicalize.mlir b/mlir/test/Dialect/Tensor/canonicalize.mlir
index 0b54c207dea84e..613ec066337294 100644
--- a/mlir/test/Dialect/Tensor/canonicalize.mlir
+++ b/mlir/test/Dialect/Tensor/canonicalize.mlir
@@ -1382,6 +1382,34 @@ func.func @compose_expand_of_collapse_0_rank_to_collapse(%arg0 : tensor<1x1x1x1x
 
 // -----
 
+func.func @compose_expand_of_collapse_static(%arg0 : tensor<4x32x10x64x2xf16>) -> tensor<4x32x10x128xf16> {
+  %collapsed = tensor.collapse_shape %arg0 [[0, 1], [2], [3, 4]] : tensor<4x32x10x64x2xf16> into tensor<128x10x128xf16>
+  %expanded = tensor.expand_shape %collapsed [[0, 1], [2], [3]] output_shape [4, 32, 10, 128] : tensor<128x10x128xf16> into tensor<4x32x10x128xf16>
+  return %expanded : tensor<4x32x10x128xf16>
+}
+
+// CHECK-LABEL: func @compose_expand_of_collapse_static
+// CHECK-SAME:   %[[ARG0:.+]]: tensor<4x32x10x64x2xf16>
+//      CHECK:   %[[RESULT:.+]] = tensor.collapse_shape %[[ARG0]]
+// CHECK-SAME:     [0], [1], [2], [3, 4]
+//      CHECK:   return %[[RESULT]]
+
+// -----
+
+func.func @compose_expand_of_collapse_dynamic(%arg0 : tensor<4x?x10x64x2xf16>, %arg1 : index) -> tensor<4x?x10x128xf16> {
+  %collapsed = tensor.collapse_shape %arg0 [[0, 1], [2], [3, 4]] : tensor<4x?x10x64x2xf16> into tensor<?x10x128xf16>
+  %expanded = tensor.expand_shape %collapsed [[0, 1], [2], [3]] output_shape [4, %arg1,  10, 128] : tensor<?x10x128xf16> into tensor<4x?x10x128xf16>
+  return %expanded : tensor<4x?x10x128xf16>
+}
+
+// CHECK-LABEL: func @compose_expand_of_collapse_dynamic
+// CHECK-SAME:   %[[ARG0:.+]]: tensor<4x?x10x64x2xf16>
+//      CHECK:   %[[RESULT:.+]] = tensor.collapse_shape %[[ARG0]]
+// CHECK-SAME:     [0], [1], [2], [3, 4]
+//      CHECK:   return %[[RESULT]]
+
+// -----
+
 // CHECK-LABEL: func @zero_rank_reshape_multi
 func.func @zero_rank_reshape_multi(%arg0: tensor<f32>) -> tensor<f32> {
   // CHECK: return %arg0

If expand(collapse) has a dimension that gets collapsed and then
expanded to the same shape, the pattern would fail to canonicalize this
to a single collapse shape.

Signed-off-by: Ian Wood <[email protected]>
@IanWood1 IanWood1 merged commit fcfdabf into llvm:main Dec 2, 2024
8 checks passed
@IanWood1 IanWood1 deleted the improve_reshape_canon branch December 2, 2024 16:34
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants