Skip to content

Conversation

@dan-garvey
Copy link
Contributor

@dan-garvey dan-garvey commented Jul 25, 2025

Previously when dropping unit dim from a pad with mixed dynamic/static input/output shapes, the resulting shape would take on the Type of the input, resulting in invalid IR.

Also did some minor cleanup to the formatting of the drop_unit_dim_corresponding_to_dynamic_dim test to make it match the rest of the file.

@llvmbot
Copy link
Member

llvmbot commented Jul 25, 2025

@llvm/pr-subscribers-mlir-linalg

Author: Daniel Garvey (dan-garvey)

Changes

Previously when dropping unit dim from a pad with mixed dynamic/static input/output shapes, the resulting shape would take on the Type of the input, resulting in invalid IR


Full diff: https://github.com/llvm/llvm-project/pull/150706.diff

2 Files Affected:

  • (modified) mlir/lib/Dialect/Linalg/Transforms/DropUnitDims.cpp (+10-4)
  • (modified) mlir/test/Dialect/Linalg/drop-unit-extent-dims.mlir (+38-17)
diff --git a/mlir/lib/Dialect/Linalg/Transforms/DropUnitDims.cpp b/mlir/lib/Dialect/Linalg/Transforms/DropUnitDims.cpp
index e0062d15e61ca..7194662145bcc 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/DropUnitDims.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/DropUnitDims.cpp
@@ -603,6 +603,7 @@ struct DropPadUnitDims : public OpRewritePattern<tensor::PadOp> {
     }
 
     ArrayRef<int64_t> sourceShape = padOp.getSourceType().getShape();
+    ArrayRef<int64_t> resultShape = padOp.getResultType().getShape();
     int64_t padRank = sourceShape.size();
 
     auto isStaticZero = [](OpFoldResult f) {
@@ -613,16 +614,18 @@ struct DropPadUnitDims : public OpRewritePattern<tensor::PadOp> {
                                                  allowedUnitDims.end());
     llvm::SmallDenseSet<unsigned> unitDims;
     SmallVector<int64_t> newShape;
+    SmallVector<int64_t> newResultShape;
     SmallVector<OpFoldResult> newLowPad;
     SmallVector<OpFoldResult> newHighPad;
-    for (const auto [dim, size, low, high] :
-         zip_equal(llvm::seq(static_cast<int64_t>(0), padRank), sourceShape,
-                   padOp.getMixedLowPad(), padOp.getMixedHighPad())) {
+    for (const auto [dim, size, outSize, low, high] : zip_equal(
+             llvm::seq(static_cast<int64_t>(0), padRank), sourceShape,
+             resultShape, padOp.getMixedLowPad(), padOp.getMixedHighPad())) {
       if (unitDimsFilter.contains(dim) && size == 1 && isStaticZero(low) &&
           isStaticZero(high)) {
         unitDims.insert(dim);
       } else {
         newShape.push_back(size);
+        newResultShape.push_back(outSize);
         newLowPad.push_back(low);
         newHighPad.push_back(high);
       }
@@ -652,8 +655,11 @@ struct DropPadUnitDims : public OpRewritePattern<tensor::PadOp> {
         collapseValue(rewriter, padOp.getLoc(), padOp.getSource(), newShape,
                       reassociationMap, options.rankReductionStrategy);
 
+    auto resultType = RankedTensorType::get(
+        newResultShape, padOp.getResultType().getElementType()); //,
+    //        padOp.getResultType().getEncoding());
     auto newPadOp = rewriter.create<tensor::PadOp>(
-        padOp.getLoc(), /*result=*/Type(), collapsedSource, newLowPad,
+        padOp.getLoc(), /*result=*/resultType, collapsedSource, newLowPad,
         newHighPad, paddingVal, padOp.getNofold());
 
     Value dest = padOp.getResult();
diff --git a/mlir/test/Dialect/Linalg/drop-unit-extent-dims.mlir b/mlir/test/Dialect/Linalg/drop-unit-extent-dims.mlir
index a00c798197e5a..5df113b038469 100644
--- a/mlir/test/Dialect/Linalg/drop-unit-extent-dims.mlir
+++ b/mlir/test/Dialect/Linalg/drop-unit-extent-dims.mlir
@@ -1076,6 +1076,44 @@ func.func @drop_known_unit_constant_low_high(%arg0: tensor<1x383x128xf32>) -> te
 
 // -----
 
+func.func @drop_unit_dim_mixed_static_dynamic(%arg0: tensor<1x?xf32>) -> tensor<1x16xf32> {
+  %c0 = arith.constant 0 : index
+  %c1 = arith.constant 1 : index
+  %cst = arith.constant 0.000000e+00 : f32
+  %padded = tensor.pad %arg0 low[%c0, %c1] high[%c0, %c0] {
+  ^bb0(%arg1: index, %arg2: index):
+    tensor.yield %cst : f32
+  } : tensor<1x?xf32> to tensor<1x16xf32>
+  return %padded : tensor<1x16xf32>
+}
+// CHECK-LABEL: func @drop_unit_dim_mixed_static_dynamic
+//       CHECK:   %[[CST:.*]] = arith.constant 0.000000e+00 : f32
+//       CHECK:   %[[COLLAPSE:.+]] = tensor.collapse_shape 
+//       CHECK:   %[[PADDED:.*]] = tensor.pad %[[COLLAPSE]] low[1] high[0] {
+//       CHECK:   ^bb0(%[[IDX:.*]]: index):
+//       CHECK:     tensor.yield %[[CST]] : f32
+//       CHECK:   } : tensor<?xf32> to tensor<16xf32>
+//       CHECK:   %[[EXPANDED:.*]] = tensor.expand_shape %[[PADDED]]
+//       CHECK:   return %[[EXPANDED]] : tensor<1x16xf32>
+
+// -----
+
+#map = affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (d0, d1 + d4, d2 + d5, d6)>
+#map1 = affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (d4, d5, d6, d3)>
+#map2 = affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (d0, d1, d2, d3)>
+module {
+  func.func @drop_unit_dim_corresponding_to_dynamic_dim(%arg0: tensor<1x?x?x1xf32>, %arg1: index) -> tensor<?x1x61x1xf32> {
+    %cst = arith.constant dense<1.000000e+00> : tensor<1x1x1x1xf32>
+    %0 = tensor.empty(%arg1) : tensor<?x1x61x1xf32>
+    %1 = linalg.generic {indexing_maps = [#map, #map1, #map2], iterator_types = ["parallel", "parallel", "parallel", "parallel", "reduction", "reduction", "reduction"]} ins(%arg0, %cst : tensor<1x?x?x1xf32>, tensor<1x1x1x1xf32>) outs(%0 : tensor<?x1x61x1xf32>) {
+    ^bb0(%in: f32, %in_0: f32, %out: f32):
+      %2 = arith.mulf %in, %in_0 : f32
+      %3 = arith.addf %out, %2 : f32
+      linalg.yield %3 : f32
+    } -> tensor<?x1x61x1xf32>
+    return %1 : tensor<?x1x61x1xf32>
+  }
+}
 // CHECK: #[[$MAP1:.+]] = affine_map<(d0) -> (0, d0)>
 // CHECK: #[[$MAP2:.+]] = affine_map<(d0) -> ()>
 
@@ -1097,23 +1135,6 @@ func.func @drop_known_unit_constant_low_high(%arg0: tensor<1x383x128xf32>) -> te
 // CHECK:           return %[[VAL_14]] : tensor<?x1x61x1xf32>
 // CHECK:         }
 
-#map = affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (d0, d1 + d4, d2 + d5, d6)>
-#map1 = affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (d4, d5, d6, d3)>
-#map2 = affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (d0, d1, d2, d3)>
-module {
-  func.func @drop_unit_dim_corresponding_to_dynamic_dim(%arg0: tensor<1x?x?x1xf32>, %arg1: index) -> tensor<?x1x61x1xf32> {
-    %cst = arith.constant dense<1.000000e+00> : tensor<1x1x1x1xf32>
-    %0 = tensor.empty(%arg1) : tensor<?x1x61x1xf32>
-    %1 = linalg.generic {indexing_maps = [#map, #map1, #map2], iterator_types = ["parallel", "parallel", "parallel", "parallel", "reduction", "reduction", "reduction"]} ins(%arg0, %cst : tensor<1x?x?x1xf32>, tensor<1x1x1x1xf32>) outs(%0 : tensor<?x1x61x1xf32>) {
-    ^bb0(%in: f32, %in_0: f32, %out: f32):
-      %2 = arith.mulf %in, %in_0 : f32
-      %3 = arith.addf %out, %2 : f32
-      linalg.yield %3 : f32
-    } -> tensor<?x1x61x1xf32>
-    return %1 : tensor<?x1x61x1xf32>
-  }
-}
-
 // -----
 
 func.func @no_fold_empty_tensor_dim_out_of_bounds(%arg0: tensor<1x?x10xf32>) -> tensor<1x?xf32> {

@llvmbot
Copy link
Member

llvmbot commented Jul 25, 2025

@llvm/pr-subscribers-mlir

Author: Daniel Garvey (dan-garvey)

Changes

Previously when dropping unit dim from a pad with mixed dynamic/static input/output shapes, the resulting shape would take on the Type of the input, resulting in invalid IR


Full diff: https://github.com/llvm/llvm-project/pull/150706.diff

2 Files Affected:

  • (modified) mlir/lib/Dialect/Linalg/Transforms/DropUnitDims.cpp (+10-4)
  • (modified) mlir/test/Dialect/Linalg/drop-unit-extent-dims.mlir (+38-17)
diff --git a/mlir/lib/Dialect/Linalg/Transforms/DropUnitDims.cpp b/mlir/lib/Dialect/Linalg/Transforms/DropUnitDims.cpp
index e0062d15e61ca..7194662145bcc 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/DropUnitDims.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/DropUnitDims.cpp
@@ -603,6 +603,7 @@ struct DropPadUnitDims : public OpRewritePattern<tensor::PadOp> {
     }
 
     ArrayRef<int64_t> sourceShape = padOp.getSourceType().getShape();
+    ArrayRef<int64_t> resultShape = padOp.getResultType().getShape();
     int64_t padRank = sourceShape.size();
 
     auto isStaticZero = [](OpFoldResult f) {
@@ -613,16 +614,18 @@ struct DropPadUnitDims : public OpRewritePattern<tensor::PadOp> {
                                                  allowedUnitDims.end());
     llvm::SmallDenseSet<unsigned> unitDims;
     SmallVector<int64_t> newShape;
+    SmallVector<int64_t> newResultShape;
     SmallVector<OpFoldResult> newLowPad;
     SmallVector<OpFoldResult> newHighPad;
-    for (const auto [dim, size, low, high] :
-         zip_equal(llvm::seq(static_cast<int64_t>(0), padRank), sourceShape,
-                   padOp.getMixedLowPad(), padOp.getMixedHighPad())) {
+    for (const auto [dim, size, outSize, low, high] : zip_equal(
+             llvm::seq(static_cast<int64_t>(0), padRank), sourceShape,
+             resultShape, padOp.getMixedLowPad(), padOp.getMixedHighPad())) {
       if (unitDimsFilter.contains(dim) && size == 1 && isStaticZero(low) &&
           isStaticZero(high)) {
         unitDims.insert(dim);
       } else {
         newShape.push_back(size);
+        newResultShape.push_back(outSize);
         newLowPad.push_back(low);
         newHighPad.push_back(high);
       }
@@ -652,8 +655,11 @@ struct DropPadUnitDims : public OpRewritePattern<tensor::PadOp> {
         collapseValue(rewriter, padOp.getLoc(), padOp.getSource(), newShape,
                       reassociationMap, options.rankReductionStrategy);
 
+    auto resultType = RankedTensorType::get(
+        newResultShape, padOp.getResultType().getElementType()); //,
+    //        padOp.getResultType().getEncoding());
     auto newPadOp = rewriter.create<tensor::PadOp>(
-        padOp.getLoc(), /*result=*/Type(), collapsedSource, newLowPad,
+        padOp.getLoc(), /*result=*/resultType, collapsedSource, newLowPad,
         newHighPad, paddingVal, padOp.getNofold());
 
     Value dest = padOp.getResult();
diff --git a/mlir/test/Dialect/Linalg/drop-unit-extent-dims.mlir b/mlir/test/Dialect/Linalg/drop-unit-extent-dims.mlir
index a00c798197e5a..5df113b038469 100644
--- a/mlir/test/Dialect/Linalg/drop-unit-extent-dims.mlir
+++ b/mlir/test/Dialect/Linalg/drop-unit-extent-dims.mlir
@@ -1076,6 +1076,44 @@ func.func @drop_known_unit_constant_low_high(%arg0: tensor<1x383x128xf32>) -> te
 
 // -----
 
+func.func @drop_unit_dim_mixed_static_dynamic(%arg0: tensor<1x?xf32>) -> tensor<1x16xf32> {
+  %c0 = arith.constant 0 : index
+  %c1 = arith.constant 1 : index
+  %cst = arith.constant 0.000000e+00 : f32
+  %padded = tensor.pad %arg0 low[%c0, %c1] high[%c0, %c0] {
+  ^bb0(%arg1: index, %arg2: index):
+    tensor.yield %cst : f32
+  } : tensor<1x?xf32> to tensor<1x16xf32>
+  return %padded : tensor<1x16xf32>
+}
+// CHECK-LABEL: func @drop_unit_dim_mixed_static_dynamic
+//       CHECK:   %[[CST:.*]] = arith.constant 0.000000e+00 : f32
+//       CHECK:   %[[COLLAPSE:.+]] = tensor.collapse_shape 
+//       CHECK:   %[[PADDED:.*]] = tensor.pad %[[COLLAPSE]] low[1] high[0] {
+//       CHECK:   ^bb0(%[[IDX:.*]]: index):
+//       CHECK:     tensor.yield %[[CST]] : f32
+//       CHECK:   } : tensor<?xf32> to tensor<16xf32>
+//       CHECK:   %[[EXPANDED:.*]] = tensor.expand_shape %[[PADDED]]
+//       CHECK:   return %[[EXPANDED]] : tensor<1x16xf32>
+
+// -----
+
+#map = affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (d0, d1 + d4, d2 + d5, d6)>
+#map1 = affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (d4, d5, d6, d3)>
+#map2 = affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (d0, d1, d2, d3)>
+module {
+  func.func @drop_unit_dim_corresponding_to_dynamic_dim(%arg0: tensor<1x?x?x1xf32>, %arg1: index) -> tensor<?x1x61x1xf32> {
+    %cst = arith.constant dense<1.000000e+00> : tensor<1x1x1x1xf32>
+    %0 = tensor.empty(%arg1) : tensor<?x1x61x1xf32>
+    %1 = linalg.generic {indexing_maps = [#map, #map1, #map2], iterator_types = ["parallel", "parallel", "parallel", "parallel", "reduction", "reduction", "reduction"]} ins(%arg0, %cst : tensor<1x?x?x1xf32>, tensor<1x1x1x1xf32>) outs(%0 : tensor<?x1x61x1xf32>) {
+    ^bb0(%in: f32, %in_0: f32, %out: f32):
+      %2 = arith.mulf %in, %in_0 : f32
+      %3 = arith.addf %out, %2 : f32
+      linalg.yield %3 : f32
+    } -> tensor<?x1x61x1xf32>
+    return %1 : tensor<?x1x61x1xf32>
+  }
+}
 // CHECK: #[[$MAP1:.+]] = affine_map<(d0) -> (0, d0)>
 // CHECK: #[[$MAP2:.+]] = affine_map<(d0) -> ()>
 
@@ -1097,23 +1135,6 @@ func.func @drop_known_unit_constant_low_high(%arg0: tensor<1x383x128xf32>) -> te
 // CHECK:           return %[[VAL_14]] : tensor<?x1x61x1xf32>
 // CHECK:         }
 
-#map = affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (d0, d1 + d4, d2 + d5, d6)>
-#map1 = affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (d4, d5, d6, d3)>
-#map2 = affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (d0, d1, d2, d3)>
-module {
-  func.func @drop_unit_dim_corresponding_to_dynamic_dim(%arg0: tensor<1x?x?x1xf32>, %arg1: index) -> tensor<?x1x61x1xf32> {
-    %cst = arith.constant dense<1.000000e+00> : tensor<1x1x1x1xf32>
-    %0 = tensor.empty(%arg1) : tensor<?x1x61x1xf32>
-    %1 = linalg.generic {indexing_maps = [#map, #map1, #map2], iterator_types = ["parallel", "parallel", "parallel", "parallel", "reduction", "reduction", "reduction"]} ins(%arg0, %cst : tensor<1x?x?x1xf32>, tensor<1x1x1x1xf32>) outs(%0 : tensor<?x1x61x1xf32>) {
-    ^bb0(%in: f32, %in_0: f32, %out: f32):
-      %2 = arith.mulf %in, %in_0 : f32
-      %3 = arith.addf %out, %2 : f32
-      linalg.yield %3 : f32
-    } -> tensor<?x1x61x1xf32>
-    return %1 : tensor<?x1x61x1xf32>
-  }
-}
-
 // -----
 
 func.func @no_fold_empty_tensor_dim_out_of_bounds(%arg0: tensor<1x?x10xf32>) -> tensor<1x?xf32> {

@dan-garvey
Copy link
Contributor Author

(The change to the unrelated test was just to make the formatting match the other tests in the file)

@dan-garvey dan-garvey changed the title [MLIR] Specify outputType in pad op unitDim drop [MLIR] Specify new padOp's output type in DropPadUnitDims Jul 25, 2025
Copy link
Contributor

@banach-space banach-space left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hey Dan, long time no see!

Thanks for the contribution, LGTM % some nits.

Previously when dropping unit dim from a pad with mixed
dynamic/static input/output shapes, the resulting shape
would take on the Type of the input, resulting in invalid IR

Additionally, fixed the formatting on the
 drop_unit_dim_corresponding_to_dynamic_dim test in order
to match the rest of the tests in that file.

Signed-off-by: dan <[email protected]>
@dan-garvey dan-garvey force-pushed the handle_pad_mixed_dynamic_static branch from a7b2f0d to df03118 Compare July 29, 2025 22:06
@dan-garvey
Copy link
Contributor Author

Hey Dan, long time no see!

Thanks for the contribution, LGTM % some nits.

Thanks for the review! I hope we can have lunch again this year.

If you feel I've adequately addressed your concerns I'd appreciate it if you also hit the merge button for me, as I do not have access yet. = )

@banach-space
Copy link
Contributor

If you feel I've adequately addressed your concerns I'd appreciate it if you also hit the merge button for me, as I do not have access yet. = )

Sure, though I am getting:

This commit will be authored by [email protected].

So, don't forget to follow https://llvm.org/docs/GitHub.html#before-your-first-pr :)

@dan-garvey
Copy link
Contributor Author

If you feel I've adequately addressed your concerns I'd appreciate it if you also hit the merge button for me, as I do not have access yet. = )

Sure, though I am getting:

This commit will be authored by [email protected].

So, don't forget to follow https://llvm.org/docs/GitHub.html#before-your-first-pr :)

That is a comprehensive doc. I un-privated my email, but I'm not certain if I need to recommit to get it to update. Does it still show the same authored by?

@banach-space banach-space merged commit 1e504be into llvm:main Jul 31, 2025
9 checks passed
@banach-space
Copy link
Contributor

I'm not certain if I need to recommit to get it to update. Does it still show the same authored by?

No, no need for that :)

@dan-garvey dan-garvey deleted the handle_pad_mixed_dynamic_static branch July 31, 2025 17:27
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants