Skip to content

Conversation

@ita9naiwa
Copy link
Contributor

@ita9naiwa ita9naiwa commented Apr 22, 2025

Motivation / Rationale

Why:
tensor.{expand_shape,collapse_shape} sitting between a producer and tensor.pad blocks pad–producer fusion and other canonicalizations. This leaves extra ops (dim, alloc, reshapes) and complicates bufferization/codegen.

What this patch enables:
• Move or remove those reshapes so tensor.pad directly sees the final shape.
• Result: simpler IR and downstream passes (fusion, folding, hoisting) apply cleanly.

Changes

  • Add FoldReshapeWithProducerPadOpByCollapsing pattern for collapse_reshape bubbling
  • Add FoldReshapeWithProducerPadOpByExpansion pattern for expand_reshape bubbling.
  • Add appropriate tests

Before applying this pass: {collapse,expand}_reshape then pad

func.func @fold_tensor_pad_with_expand(%arg0: tensor<512x256x256xf32>) -> tensor<32x16x258x258xf32> {
  %c0 = arith.constant 0.0 : f32
  %producer = linalg.fill ins(%c0 : f32) outs(%arg0 : tensor<512x256x256xf32>) -> tensor<512x256x256xf32>

  %pad = tensor.pad %producer low[0, 1, 1] high[0, 1, 1] {
    ^bb0(%i: index, %j: index, %k: index):
      tensor.yield %c0 : f32
  } : tensor<512x256x256xf32> to tensor<512x258x258xf32>
  %reshape = tensor.expand_shape %pad [[0, 1], [2], [3]]
      output_shape [32, 16, 258, 258] : tensor<512x258x258xf32> into tensor<32x16x258x258xf32>

  return %reshape : tensor<32x16x258x258xf32>
}

After applying this pass. pad then {collapse,expand}_reshape

func.func @fold_tensor_pad_with_expand(%arg0: tensor<512x256x256xf32>) -> tensor<32x16x258x258xf32> {
  %c0 = arith.constant 0.0 : f32
  %producer = linalg.fill ins(%c0 : f32) outs(%arg0 : tensor<512x256x256xf32>) -> tensor<512x256x256xf32>

  %reshape = tensor.expand_shape %producer  [[0, 1], [2], [3]]
      output_shape [32, 16, 258, 258] : tensor<512x256x256xf32> into tensor<32x16x256x256xf32>
  %pad = tensor.pad %reshape low[0, 0, 1, 1] high[0, 0, 1, 1] {
    ^bb0(%i0: index, %i1: index, %i2: index, %i3: index):
      tensor.yield %c0 : f32
  } : tensor<32x16x256x256xf32> to tensor<32x16x258x258xf32>

  return %pad : tensor<32x16x258x258xf32>
}

CC @Max191 for awareness—would love any pointers on the collapse‑side implementation or dynamic‑shape handling!

@github-actions
Copy link

github-actions bot commented Apr 22, 2025

✅ With the latest revision this PR passed the C/C++ code formatter.

@ita9naiwa ita9naiwa changed the title [MLIR] Add reshape propagation through tensor.pad [Draft][MLIR] Add reshape propagation through tensor.pad Apr 22, 2025
Copy link
Contributor

@Max191 Max191 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I left some comments to help with supporting dynamic cases. Lmk if you have more questions!

@ita9naiwa
Copy link
Contributor Author

Apologies for the delay — I’ve been recovering from a medical issue. I’ll resume this soon.
@Max191

@ita9naiwa
Copy link
Contributor Author

Hi @Max191 , I updated this PR to support dynamic cases too, following your review. Sorry It took a while for me to get back from hiatus.

I think it would be better collapse_shape support is done in a separate PR, since this PR is already quite big and they're a distinct functionality.

how do you think?

@ita9naiwa ita9naiwa marked this pull request as ready for review July 12, 2025 09:31
@llvmbot
Copy link
Member

llvmbot commented Jul 12, 2025

@llvm/pr-subscribers-mlir-linalg

Author: Hyunsung Lee (ita9naiwa)

Changes

iree-org/iree#17492 (comment)

I’ve implemented fusion for tensor.expand_shape → tensor.pad, but two gaps remain:

  1. Missing collapse‑side pattern.
    I haven’t yet added the mirror case for tensor.collapse_shape → tensor.pad.
  2. Static‑only support
    The current pattern only handles fully static shapes and padding.

Before (expand then pad):

func.func @<!-- -->fold_tensor_pad_with_expand(%arg0: tensor&lt;512x256x256xf32&gt;) -&gt; tensor&lt;32x16x258x258xf32&gt; {
  %c0 = arith.constant 0.0 : f32
  %producer = linalg.fill ins(%c0 : f32) outs(%arg0 : tensor&lt;512x256x256xf32&gt;) -&gt; tensor&lt;512x256x256xf32&gt;

  %pad = tensor.pad %producer low[0, 1, 1] high[0, 1, 1] {
    ^bb0(%i: index, %j: index, %k: index):
      tensor.yield %c0 : f32
  } : tensor&lt;512x256x256xf32&gt; to tensor&lt;512x258x258xf32&gt;
  %reshape = tensor.expand_shape %pad [[0, 1], [2], [3]]
      output_shape [32, 16, 258, 258] : tensor&lt;512x258x258xf32&gt; into tensor&lt;32x16x258x258xf32&gt;

  return %reshape : tensor&lt;32x16x258x258xf32&gt;
}

After (reshape then pad):

func.func @<!-- -->fold_tensor_pad_with_expand(%arg0: tensor&lt;512x256x256xf32&gt;) -&gt; tensor&lt;32x16x258x258xf32&gt; {
  %c0 = arith.constant 0.0 : f32
  %producer = linalg.fill ins(%c0 : f32) outs(%arg0 : tensor&lt;512x256x256xf32&gt;) -&gt; tensor&lt;512x256x256xf32&gt;

  %reshape = tensor.expand_shape %producer  [[0, 1], [2], [3]]
      output_shape [32, 16, 258, 258] : tensor&lt;512x256x256xf32&gt; into tensor&lt;32x16x256x256xf32&gt;
  %pad = tensor.pad %reshape low[0, 0, 1, 1] high[0, 0, 1, 1] {
    ^bb0(%i0: index, %i1: index, %i2: index, %i3: index):
      tensor.yield %c0 : f32
  } : tensor&lt;32x16x256x256xf32&gt; to tensor&lt;32x16x258x258xf32&gt;

  return %pad : tensor&lt;32x16x258x258xf32&gt;
}

Next steps
• Add a CollapseShapeOp→PadOp pattern to cover the missing collapse‑side fusion.
• Lift the “static‑only” guard so both patterns handle dynamic shapes and pads.

CC @Max191 for awareness—would love any pointers on the collapse‑side implementation or dynamic‑shape handling!


Full diff: https://github.com/llvm/llvm-project/pull/136681.diff

2 Files Affected:

  • (modified) mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp (+142)
  • (modified) mlir/test/Dialect/Linalg/reshape_fusion.mlir (+49-2)
diff --git a/mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp b/mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp
index 9c0f6e5d6469e..39eed6dd4cba4 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp
@@ -1100,6 +1100,146 @@ class FoldPadWithProducerReshapeOpByExpansion
   ControlFusionFn controlFoldingReshapes;
 };
 
+/// Pattern to fold a tensor.expand_shape op with its producer tensor.pad op
+/// by bubbling the expand_shape before the pad.
+struct FoldReshapeWithProducerPadOpByExpansion
+    : public OpRewritePattern<tensor::ExpandShapeOp> {
+
+  FoldReshapeWithProducerPadOpByExpansion(MLIRContext *context,
+                                          ControlFusionFn foldReshapes,
+                                          PatternBenefit benefit = 1)
+      : OpRewritePattern<tensor::ExpandShapeOp>(context, benefit),
+        controlFoldingReshapes(std::move(foldReshapes)) {}
+
+  LogicalResult matchAndRewrite(tensor::ExpandShapeOp expandOp,
+                                PatternRewriter &rewriter) const override {
+    tensor::PadOp padOp = expandOp.getSrc().getDefiningOp<tensor::PadOp>();
+    if (!padOp)
+      return failure();
+
+    if (!padOp->hasOneUse())
+      return failure();
+
+    if (!controlFoldingReshapes(&expandOp.getSrcMutable())) {
+      return rewriter.notifyMatchFailure(expandOp,
+                                         "fusion blocked by control function");
+    }
+
+    SmallVector<ReassociationIndices> reassociations =
+        expandOp.getReassociationIndices();
+    SmallVector<OpFoldResult> low = padOp.getMixedLowPad();
+    SmallVector<OpFoldResult> high = padOp.getMixedHighPad();
+
+    auto isZeroPadding = [](OpFoldResult padValue) -> bool {
+      if (auto attr = dyn_cast<Attribute>(padValue)) {
+        if (auto intAttr = dyn_cast<IntegerAttr>(attr))
+          return intAttr.getInt() == 0;
+      }
+
+      if (auto val = dyn_cast<Value>(padValue)) {
+        if (auto constOp = val.getDefiningOp<arith::ConstantOp>()) {
+          if (auto attr = dyn_cast<IntegerAttr>(constOp.getValue()))
+            return attr.getInt() == 0;
+        }
+      }
+
+      // when padding is dynamic and not constant, we don't know if it's zero or
+      // not. so we return false here.
+      return false;
+    };
+
+    for (auto [idx, reInd] : llvm::enumerate(reassociations)) {
+      OpFoldResult l = low[idx];
+      OpFoldResult h = high[idx];
+      if (reInd.size() != 1 && (!isZeroPadding(l) || !isZeroPadding(h)))
+        return failure();
+    }
+
+    SmallVector<OpFoldResult> newLow, newHigh;
+    for (auto [idx, reInd] : llvm::enumerate(reassociations)) {
+      for (size_t i = 0; i < reInd.size(); ++i) {
+        newLow.push_back(padOp.getMixedLowPad()[idx]);
+        newHigh.push_back(padOp.getMixedHighPad()[idx]);
+      }
+    }
+
+    Location loc = expandOp.getLoc();
+    auto finalType = cast<RankedTensorType>(expandOp.getType());
+    ArrayRef<int64_t> finalShape = finalType.getShape();
+
+    SmallVector<OpFoldResult> expandedShape;
+    for (int64_t dimSize : finalShape) {
+      if (dimSize == ShapedType::kDynamic) {
+        expandedShape.push_back(OpFoldResult{});
+      } else {
+        expandedShape.push_back(rewriter.getI64IntegerAttr(dimSize));
+      }
+    }
+
+    for (auto [inDimIdx, outGroup] : llvm::enumerate(reassociations)) {
+      OpFoldResult l = low[inDimIdx];
+      OpFoldResult h = high[inDimIdx];
+
+      if (!isZeroPadding(l) || !isZeroPadding(h)) {
+        auto srcType = cast<RankedTensorType>(padOp.getSource().getType());
+        int64_t originalSize = srcType.getDimSize(inDimIdx);
+
+        OpFoldResult originalSizeOFR;
+        if (originalSize == ShapedType::kDynamic) {
+          Value orgSizeVal =
+              rewriter.create<tensor::DimOp>(loc, padOp.getSource(), inDimIdx);
+          originalSizeOFR = orgSizeVal;
+        } else {
+          originalSizeOFR = rewriter.getI64IntegerAttr(originalSize);
+        }
+
+        for (auto outDimIdx : outGroup) {
+          expandedShape[outDimIdx] = originalSizeOFR;
+        }
+      }
+    }
+
+    for (auto [outDimIdx, dimSize] : llvm::enumerate(finalShape)) {
+      if (dimSize == ShapedType::kDynamic &&
+          !isa<Value>(expandedShape[outDimIdx]) &&
+          !isa<Attribute>(expandedShape[outDimIdx])) {
+        Value actualSize =
+            rewriter.create<tensor::DimOp>(loc, expandOp.getSrc(), outDimIdx);
+        expandedShape[outDimIdx] = actualSize;
+      }
+    }
+
+    SmallVector<int64_t> staticExpandedShape;
+    for (OpFoldResult dim : expandedShape) {
+      if (auto attr = dyn_cast<Attribute>(dim)) {
+        if (auto intAttr = dyn_cast<IntegerAttr>(attr)) {
+          staticExpandedShape.push_back(intAttr.getInt());
+        } else {
+          staticExpandedShape.push_back(ShapedType::kDynamic);
+        }
+      } else {
+        staticExpandedShape.push_back(ShapedType::kDynamic);
+      }
+    }
+
+    auto newExpandOp = rewriter.create<tensor::ExpandShapeOp>(
+        loc,
+        RankedTensorType::get(staticExpandedShape,
+                              padOp.getSource().getType().getElementType()),
+        padOp.getSource(), reassociations);
+
+    auto newPadOp = rewriter.create<tensor::PadOp>(
+        loc, expandOp.getType(), newExpandOp.getResult(), newLow, newHigh,
+        padOp.getConstantPaddingValue(), padOp.getNofold());
+
+    rewriter.replaceOp(expandOp, newPadOp.getResult());
+    return success();
+  }
+
+private:
+  ControlFusionFn controlFoldingReshapes;
+};
+
 /// Pattern to fold a tensor.expand_shape op with its producer generic op
 /// by expanding the dimensionality of the loop in the producer op.
 struct FoldReshapeWithGenericOpByExpansion
@@ -2235,6 +2375,8 @@ void mlir::linalg::populateFoldReshapeOpsByExpansionPatterns(
                                                     controlFoldingReshapes);
   patterns.add<FoldPadWithProducerReshapeOpByExpansion>(patterns.getContext(),
                                                         controlFoldingReshapes);
+  patterns.add<FoldReshapeWithProducerPadOpByExpansion>(patterns.getContext(),
+                                                        controlFoldingReshapes);
   patterns.add<FoldWithProducerReshapeOpByExpansion>(patterns.getContext(),
                                                      controlFoldingReshapes);
 }
diff --git a/mlir/test/Dialect/Linalg/reshape_fusion.mlir b/mlir/test/Dialect/Linalg/reshape_fusion.mlir
index 67b4f2b32bad5..3ea0babfa3b9d 100644
--- a/mlir/test/Dialect/Linalg/reshape_fusion.mlir
+++ b/mlir/test/Dialect/Linalg/reshape_fusion.mlir
@@ -247,7 +247,7 @@ func.func @indexed_consumer_reshape_producer_fusion(%arg0 : tensor<?x?x4x?xi32>,
 
 #map0 = affine_map<(d0, d1) -> (d0, d1)>
 func.func @indexed_producer_reshape_consumer_fusion(%arg0 : tensor<?x?xi32>,
-                                         %arg1 : tensor<?x?xi32>, 
+                                         %arg1 : tensor<?x?xi32>,
                                          %sz0: index, %sz1: index) ->
                                          tensor<?x?x4x5xi32>
 {
@@ -515,7 +515,7 @@ func.func @fuse_dynamic_dims(%arg0: tensor<?x?xf32>) -> tensor<?xf32> {
 // -----
 
 func.func @reshape_as_consumer_permutation_with_multiple_results
-  (%a : tensor<?x?x?xf32>, %b : tensor<?x?xf32>, %sz0: index, 
+  (%a : tensor<?x?x?xf32>, %b : tensor<?x?xf32>, %sz0: index,
    %sz1: index, %sz2: index, %sz3: index, %sz4: index)
     -> (tensor<?x2x?x3x4x?xf32>, tensor<?x?x2x3x4x?xf32>) {
   %c:2 = linalg.generic {
@@ -893,3 +893,50 @@ func.func @move_operand_deps(%arg0 : tensor<?x128xf16>,
 //      CHECK:   %[[GENERIC:.+]] = linalg.generic
 // CHECK-SAME:       ins(%[[EXPANDED]] :
 //      CHECK:   return %[[GENERIC]]
+
+// -----
+
+func.func @fold_tensor_pad_with_expand(%arg0: tensor<512x256x256xf32>) -> tensor<32x16x258x258xf32> {
+  %cst = arith.constant 0.000000e+00 : f32
+  %0   = linalg.fill ins(%cst : f32) outs(%arg0 : tensor<512x256x256xf32>) -> tensor<512x256x256xf32>
+  %padded = tensor.pad %0 low[0, 1, 1] high[0, 1, 1] {
+    ^bb0(%i: index, %j: index, %k: index):
+      tensor.yield %cst : f32
+  } : tensor<512x256x256xf32> to tensor<512x258x258xf32>
+  %expanded = tensor.expand_shape %padded [[0, 1], [2], [3]] output_shape [32, 16, 258, 258] : tensor<512x258x258xf32> into tensor<32x16x258x258xf32>
+  return %expanded : tensor<32x16x258x258xf32>
+}
+//      CHECK: func @fold_tensor_pad_with_expand(
+// CHECK-SAME:     %[[ARG0:[^:]+]]: tensor<512x256x256xf32>
+//  CHECK-DAG:   %[[CST:.+]] = arith.constant 0.000000e+00 : f32
+//  CHECK-DAG:   %[[EXPANDED:.+]] = tensor.expand_shape %[[ARG0]]
+//      CHECK:   %[[FILLED:.*]] = linalg.fill ins(%[[CST]] : f32) outs(%[[EXPANDED]] : tensor<32x16x256x256xf32>)
+//      CHECK:   %[[PADDED:.*]] = tensor.pad %[[FILLED]] low[0, 0, 1, 1] high[0, 0, 1, 1]
+//      CHECK:   ^bb0(%[[ARG1:.*]]: index, %[[ARG2:.*]]: index, %[[ARG3:.*]]: index, %[[ARG4:.*]]: index):
+//      CHECK:     tensor.yield %[[CST]] : f32
+//      CHECK:   } : tensor<32x16x256x256xf32> to tensor<32x16x258x258xf32>
+//      CHECK:   return %[[PADDED]] : tensor<32x16x258x258xf32>
+
+// -----
+
+func.func @fold_tensor_pad_with_expand_dynamic_pad_zero(%arg0: tensor<512x256x256xf32>) -> tensor<32x16x258x258xf32> {
+  %cst = arith.constant 0.000000e+00 : f32
+  %c0 = arith.constant 0 : index
+  %c1 = arith.constant 1 : index
+  %0   = linalg.fill ins(%cst : f32) outs(%arg0 : tensor<512x256x256xf32>) -> tensor<512x256x256xf32>
+  %padded = tensor.pad %0 low[%c0, %c1, %c1] high[%c0, %c1, %c1] {
+    ^bb0(%i: index, %j: index, %k: index):
+      tensor.yield %cst : f32
+  } : tensor<512x256x256xf32> to tensor<512x258x258xf32>
+  %expanded = tensor.expand_shape %padded [[0, 1], [2], [3]] output_shape [32, 16, 258, 258] : tensor<512x258x258xf32> into tensor<32x16x258x258xf32>
+  return %expanded : tensor<32x16x258x258xf32>
+}
+//      CHECK: func @fold_tensor_pad_with_expand_dynamic_pad_zero(
+// CHECK-SAME:     %[[ARG0:[^:]+]]: tensor<512x256x256xf32>
+//      CHECK:   %[[CST:.+]] = arith.constant 0.000000e+00 : f32
+//      CHECK:   %[[EXPANDED:.+]] = tensor.expand_shape %[[ARG0]]
+//      CHECK:   %[[FILLED:.*]] = linalg.fill ins(%[[CST]] : f32) outs(%[[EXPANDED]]
+//      CHECK:   %[[PADDED:.*]] = tensor.pad %[[FILLED]] low[0, 0, 1, 1] high[0, 0, 1, 1]
+//      CHECK:   ^bb0(
+//      CHECK:     tensor.yield %[[CST]] : f32
+//      CHECK:   return %[[PADDED]]

@llvmbot
Copy link
Member

llvmbot commented Jul 12, 2025

@llvm/pr-subscribers-mlir

Author: Hyunsung Lee (ita9naiwa)

Changes

iree-org/iree#17492 (comment)

I’ve implemented fusion for tensor.expand_shape → tensor.pad, but two gaps remain:

  1. Missing collapse‑side pattern.
    I haven’t yet added the mirror case for tensor.collapse_shape → tensor.pad.
  2. Static‑only support
    The current pattern only handles fully static shapes and padding.

Before (expand then pad):

func.func @<!-- -->fold_tensor_pad_with_expand(%arg0: tensor&lt;512x256x256xf32&gt;) -&gt; tensor&lt;32x16x258x258xf32&gt; {
  %c0 = arith.constant 0.0 : f32
  %producer = linalg.fill ins(%c0 : f32) outs(%arg0 : tensor&lt;512x256x256xf32&gt;) -&gt; tensor&lt;512x256x256xf32&gt;

  %pad = tensor.pad %producer low[0, 1, 1] high[0, 1, 1] {
    ^bb0(%i: index, %j: index, %k: index):
      tensor.yield %c0 : f32
  } : tensor&lt;512x256x256xf32&gt; to tensor&lt;512x258x258xf32&gt;
  %reshape = tensor.expand_shape %pad [[0, 1], [2], [3]]
      output_shape [32, 16, 258, 258] : tensor&lt;512x258x258xf32&gt; into tensor&lt;32x16x258x258xf32&gt;

  return %reshape : tensor&lt;32x16x258x258xf32&gt;
}

After (reshape then pad):

func.func @<!-- -->fold_tensor_pad_with_expand(%arg0: tensor&lt;512x256x256xf32&gt;) -&gt; tensor&lt;32x16x258x258xf32&gt; {
  %c0 = arith.constant 0.0 : f32
  %producer = linalg.fill ins(%c0 : f32) outs(%arg0 : tensor&lt;512x256x256xf32&gt;) -&gt; tensor&lt;512x256x256xf32&gt;

  %reshape = tensor.expand_shape %producer  [[0, 1], [2], [3]]
      output_shape [32, 16, 258, 258] : tensor&lt;512x256x256xf32&gt; into tensor&lt;32x16x256x256xf32&gt;
  %pad = tensor.pad %reshape low[0, 0, 1, 1] high[0, 0, 1, 1] {
    ^bb0(%i0: index, %i1: index, %i2: index, %i3: index):
      tensor.yield %c0 : f32
  } : tensor&lt;32x16x256x256xf32&gt; to tensor&lt;32x16x258x258xf32&gt;

  return %pad : tensor&lt;32x16x258x258xf32&gt;
}

Next steps
• Add a CollapseShapeOp→PadOp pattern to cover the missing collapse‑side fusion.
• Lift the “static‑only” guard so both patterns handle dynamic shapes and pads.

CC @Max191 for awareness—would love any pointers on the collapse‑side implementation or dynamic‑shape handling!


Full diff: https://github.com/llvm/llvm-project/pull/136681.diff

2 Files Affected:

  • (modified) mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp (+142)
  • (modified) mlir/test/Dialect/Linalg/reshape_fusion.mlir (+49-2)
diff --git a/mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp b/mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp
index 9c0f6e5d6469e..39eed6dd4cba4 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp
@@ -1100,6 +1100,146 @@ class FoldPadWithProducerReshapeOpByExpansion
   ControlFusionFn controlFoldingReshapes;
 };
 
+/// Pattern to fold a tensor.expand_shape op with its producer tensor.pad op
+/// by bubbling the expand_shape before the pad.
+struct FoldReshapeWithProducerPadOpByExpansion
+    : public OpRewritePattern<tensor::ExpandShapeOp> {
+
+  FoldReshapeWithProducerPadOpByExpansion(MLIRContext *context,
+                                          ControlFusionFn foldReshapes,
+                                          PatternBenefit benefit = 1)
+      : OpRewritePattern<tensor::ExpandShapeOp>(context, benefit),
+        controlFoldingReshapes(std::move(foldReshapes)) {}
+
+  LogicalResult matchAndRewrite(tensor::ExpandShapeOp expandOp,
+                                PatternRewriter &rewriter) const override {
+    tensor::PadOp padOp = expandOp.getSrc().getDefiningOp<tensor::PadOp>();
+    if (!padOp)
+      return failure();
+
+    if (!padOp->hasOneUse())
+      return failure();
+
+    if (!controlFoldingReshapes(&expandOp.getSrcMutable())) {
+      return rewriter.notifyMatchFailure(expandOp,
+                                         "fusion blocked by control function");
+    }
+
+    SmallVector<ReassociationIndices> reassociations =
+        expandOp.getReassociationIndices();
+    SmallVector<OpFoldResult> low = padOp.getMixedLowPad();
+    SmallVector<OpFoldResult> high = padOp.getMixedHighPad();
+
+    auto isZeroPadding = [](OpFoldResult padValue) -> bool {
+      if (auto attr = dyn_cast<Attribute>(padValue)) {
+        if (auto intAttr = dyn_cast<IntegerAttr>(attr))
+          return intAttr.getInt() == 0;
+      }
+
+      if (auto val = dyn_cast<Value>(padValue)) {
+        if (auto constOp = val.getDefiningOp<arith::ConstantOp>()) {
+          if (auto attr = dyn_cast<IntegerAttr>(constOp.getValue()))
+            return attr.getInt() == 0;
+        }
+      }
+
+      // when padding is dynamic and not constant, we don't know if it's zero or
+      // not. so we return false here.
+      return false;
+    };
+
+    for (auto [idx, reInd] : llvm::enumerate(reassociations)) {
+      OpFoldResult l = low[idx];
+      OpFoldResult h = high[idx];
+      if (reInd.size() != 1 && (!isZeroPadding(l) || !isZeroPadding(h)))
+        return failure();
+    }
+
+    SmallVector<OpFoldResult> newLow, newHigh;
+    for (auto [idx, reInd] : llvm::enumerate(reassociations)) {
+      for (size_t i = 0; i < reInd.size(); ++i) {
+        newLow.push_back(padOp.getMixedLowPad()[idx]);
+        newHigh.push_back(padOp.getMixedHighPad()[idx]);
+      }
+    }
+
+    Location loc = expandOp.getLoc();
+    auto finalType = cast<RankedTensorType>(expandOp.getType());
+    ArrayRef<int64_t> finalShape = finalType.getShape();
+
+    SmallVector<OpFoldResult> expandedShape;
+    for (int64_t dimSize : finalShape) {
+      if (dimSize == ShapedType::kDynamic) {
+        expandedShape.push_back(OpFoldResult{});
+      } else {
+        expandedShape.push_back(rewriter.getI64IntegerAttr(dimSize));
+      }
+    }
+
+    for (auto [inDimIdx, outGroup] : llvm::enumerate(reassociations)) {
+      OpFoldResult l = low[inDimIdx];
+      OpFoldResult h = high[inDimIdx];
+
+      if (!isZeroPadding(l) || !isZeroPadding(h)) {
+        auto srcType = cast<RankedTensorType>(padOp.getSource().getType());
+        int64_t originalSize = srcType.getDimSize(inDimIdx);
+
+        OpFoldResult originalSizeOFR;
+        if (originalSize == ShapedType::kDynamic) {
+          Value orgSizeVal =
+              rewriter.create<tensor::DimOp>(loc, padOp.getSource(), inDimIdx);
+          originalSizeOFR = orgSizeVal;
+        } else {
+          originalSizeOFR = rewriter.getI64IntegerAttr(originalSize);
+        }
+
+        for (auto outDimIdx : outGroup) {
+          expandedShape[outDimIdx] = originalSizeOFR;
+        }
+      }
+    }
+
+    for (auto [outDimIdx, dimSize] : llvm::enumerate(finalShape)) {
+      if (dimSize == ShapedType::kDynamic &&
+          !isa<Value>(expandedShape[outDimIdx]) &&
+          !isa<Attribute>(expandedShape[outDimIdx])) {
+        Value actualSize =
+            rewriter.create<tensor::DimOp>(loc, expandOp.getSrc(), outDimIdx);
+        expandedShape[outDimIdx] = actualSize;
+      }
+    }
+
+    SmallVector<int64_t> staticExpandedShape;
+    for (OpFoldResult dim : expandedShape) {
+      if (auto attr = dyn_cast<Attribute>(dim)) {
+        if (auto intAttr = dyn_cast<IntegerAttr>(attr)) {
+          staticExpandedShape.push_back(intAttr.getInt());
+        } else {
+          staticExpandedShape.push_back(ShapedType::kDynamic);
+        }
+      } else {
+        staticExpandedShape.push_back(ShapedType::kDynamic);
+      }
+    }
+
+    auto newExpandOp = rewriter.create<tensor::ExpandShapeOp>(
+        loc,
+        RankedTensorType::get(staticExpandedShape,
+                              padOp.getSource().getType().getElementType()),
+        padOp.getSource(), reassociations);
+
+    auto newPadOp = rewriter.create<tensor::PadOp>(
+        loc, expandOp.getType(), newExpandOp.getResult(), newLow, newHigh,
+        padOp.getConstantPaddingValue(), padOp.getNofold());
+
+    rewriter.replaceOp(expandOp, newPadOp.getResult());
+    return success();
+  }
+
+private:
+  ControlFusionFn controlFoldingReshapes;
+};
+
 /// Pattern to fold a tensor.expand_shape op with its producer generic op
 /// by expanding the dimensionality of the loop in the producer op.
 struct FoldReshapeWithGenericOpByExpansion
@@ -2235,6 +2375,8 @@ void mlir::linalg::populateFoldReshapeOpsByExpansionPatterns(
                                                     controlFoldingReshapes);
   patterns.add<FoldPadWithProducerReshapeOpByExpansion>(patterns.getContext(),
                                                         controlFoldingReshapes);
+  patterns.add<FoldReshapeWithProducerPadOpByExpansion>(patterns.getContext(),
+                                                        controlFoldingReshapes);
   patterns.add<FoldWithProducerReshapeOpByExpansion>(patterns.getContext(),
                                                      controlFoldingReshapes);
 }
diff --git a/mlir/test/Dialect/Linalg/reshape_fusion.mlir b/mlir/test/Dialect/Linalg/reshape_fusion.mlir
index 67b4f2b32bad5..3ea0babfa3b9d 100644
--- a/mlir/test/Dialect/Linalg/reshape_fusion.mlir
+++ b/mlir/test/Dialect/Linalg/reshape_fusion.mlir
@@ -247,7 +247,7 @@ func.func @indexed_consumer_reshape_producer_fusion(%arg0 : tensor<?x?x4x?xi32>,
 
 #map0 = affine_map<(d0, d1) -> (d0, d1)>
 func.func @indexed_producer_reshape_consumer_fusion(%arg0 : tensor<?x?xi32>,
-                                         %arg1 : tensor<?x?xi32>, 
+                                         %arg1 : tensor<?x?xi32>,
                                          %sz0: index, %sz1: index) ->
                                          tensor<?x?x4x5xi32>
 {
@@ -515,7 +515,7 @@ func.func @fuse_dynamic_dims(%arg0: tensor<?x?xf32>) -> tensor<?xf32> {
 // -----
 
 func.func @reshape_as_consumer_permutation_with_multiple_results
-  (%a : tensor<?x?x?xf32>, %b : tensor<?x?xf32>, %sz0: index, 
+  (%a : tensor<?x?x?xf32>, %b : tensor<?x?xf32>, %sz0: index,
    %sz1: index, %sz2: index, %sz3: index, %sz4: index)
     -> (tensor<?x2x?x3x4x?xf32>, tensor<?x?x2x3x4x?xf32>) {
   %c:2 = linalg.generic {
@@ -893,3 +893,50 @@ func.func @move_operand_deps(%arg0 : tensor<?x128xf16>,
 //      CHECK:   %[[GENERIC:.+]] = linalg.generic
 // CHECK-SAME:       ins(%[[EXPANDED]] :
 //      CHECK:   return %[[GENERIC]]
+
+// -----
+
+func.func @fold_tensor_pad_with_expand(%arg0: tensor<512x256x256xf32>) -> tensor<32x16x258x258xf32> {
+  %cst = arith.constant 0.000000e+00 : f32
+  %0   = linalg.fill ins(%cst : f32) outs(%arg0 : tensor<512x256x256xf32>) -> tensor<512x256x256xf32>
+  %padded = tensor.pad %0 low[0, 1, 1] high[0, 1, 1] {
+    ^bb0(%i: index, %j: index, %k: index):
+      tensor.yield %cst : f32
+  } : tensor<512x256x256xf32> to tensor<512x258x258xf32>
+  %expanded = tensor.expand_shape %padded [[0, 1], [2], [3]] output_shape [32, 16, 258, 258] : tensor<512x258x258xf32> into tensor<32x16x258x258xf32>
+  return %expanded : tensor<32x16x258x258xf32>
+}
+//      CHECK: func @fold_tensor_pad_with_expand(
+// CHECK-SAME:     %[[ARG0:[^:]+]]: tensor<512x256x256xf32>
+//  CHECK-DAG:   %[[CST:.+]] = arith.constant 0.000000e+00 : f32
+//  CHECK-DAG:   %[[EXPANDED:.+]] = tensor.expand_shape %[[ARG0]]
+//      CHECK:   %[[FILLED:.*]] = linalg.fill ins(%[[CST]] : f32) outs(%[[EXPANDED]] : tensor<32x16x256x256xf32>)
+//      CHECK:   %[[PADDED:.*]] = tensor.pad %[[FILLED]] low[0, 0, 1, 1] high[0, 0, 1, 1]
+//      CHECK:   ^bb0(%[[ARG1:.*]]: index, %[[ARG2:.*]]: index, %[[ARG3:.*]]: index, %[[ARG4:.*]]: index):
+//      CHECK:     tensor.yield %[[CST]] : f32
+//      CHECK:   } : tensor<32x16x256x256xf32> to tensor<32x16x258x258xf32>
+//      CHECK:   return %[[PADDED]] : tensor<32x16x258x258xf32>
+
+// -----
+
+func.func @fold_tensor_pad_with_expand_dynamic_pad_zero(%arg0: tensor<512x256x256xf32>) -> tensor<32x16x258x258xf32> {
+  %cst = arith.constant 0.000000e+00 : f32
+  %c0 = arith.constant 0 : index
+  %c1 = arith.constant 1 : index
+  %0   = linalg.fill ins(%cst : f32) outs(%arg0 : tensor<512x256x256xf32>) -> tensor<512x256x256xf32>
+  %padded = tensor.pad %0 low[%c0, %c1, %c1] high[%c0, %c1, %c1] {
+    ^bb0(%i: index, %j: index, %k: index):
+      tensor.yield %cst : f32
+  } : tensor<512x256x256xf32> to tensor<512x258x258xf32>
+  %expanded = tensor.expand_shape %padded [[0, 1], [2], [3]] output_shape [32, 16, 258, 258] : tensor<512x258x258xf32> into tensor<32x16x258x258xf32>
+  return %expanded : tensor<32x16x258x258xf32>
+}
+//      CHECK: func @fold_tensor_pad_with_expand_dynamic_pad_zero(
+// CHECK-SAME:     %[[ARG0:[^:]+]]: tensor<512x256x256xf32>
+//      CHECK:   %[[CST:.+]] = arith.constant 0.000000e+00 : f32
+//      CHECK:   %[[EXPANDED:.+]] = tensor.expand_shape %[[ARG0]]
+//      CHECK:   %[[FILLED:.*]] = linalg.fill ins(%[[CST]] : f32) outs(%[[EXPANDED]]
+//      CHECK:   %[[PADDED:.*]] = tensor.pad %[[FILLED]] low[0, 0, 1, 1] high[0, 0, 1, 1]
+//      CHECK:   ^bb0(
+//      CHECK:     tensor.yield %[[CST]] : f32
+//      CHECK:   return %[[PADDED]]

@ita9naiwa ita9naiwa changed the title [Draft][MLIR] Add reshape propagation through tensor.pad [MLIR] Add reshape_expand propagation through tensor.pad Jul 12, 2025
@ita9naiwa ita9naiwa requested a review from Max191 July 12, 2025 09:54
@ita9naiwa ita9naiwa changed the title [MLIR] Add reshape_expand propagation through tensor.pad [MLIR] Add expand_shape propagation through tensor.pad Jul 12, 2025
@ita9naiwa
Copy link
Contributor Author

collapse_shape added. they works for both static and dynamic case

two limitations are

  • padding value should be constant.
  • padding dimension should not participate in the reshape.
    but I believe that these cases are quite rare cases so we can at least at initial stage skip.

@ita9naiwa ita9naiwa changed the title [MLIR] Add expand_shape propagation through tensor.pad [MLIR] Add shape propagation through tensor.pad Jul 14, 2025
Copy link
Contributor

@Max191 Max191 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Reviewing the FoldReshapeWithProducerPadOpByExpansion for the first round of comments. I think a lot of the cleanups can apply to both patterns, though. Nice work so far!

Comment on lines 29 to 30
#include "llvm/ADT/STLExtras.h"
#include "llvm/Support/LogicalResult.h"
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: I don't think these includes are needed

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Seems like these are still here?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sorry, done.

Comment on lines 1105 to 1117
bool isZero(OpFoldResult value) {
if (auto attr = dyn_cast<Attribute>(value)) {
if (auto intAttr = dyn_cast<IntegerAttr>(attr))
return intAttr.getInt() == 0;
}
if (auto val = dyn_cast<Value>(value)) {
if (auto constOp = val.getDefiningOp<arith::ConstantOp>()) {
if (auto attr = dyn_cast<IntegerAttr>(constOp.getValue()))
return attr.getInt() == 0;
}
}
return false;
}
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You can use isConstantIntValue(value, 0) for this:

bool isConstantIntValue(OpFoldResult ofr, int64_t value) {
return getConstantIntValue(ofr) == value;
}

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done

Comment on lines 1155 to 1157
for (auto [idx, reInd] : llvm::enumerate(reassociations)) {
OpFoldResult l = low[idx];
OpFoldResult h = high[idx];
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: you can use llvm::zip_equal

Suggested change
for (auto [idx, reInd] : llvm::enumerate(reassociations)) {
OpFoldResult l = low[idx];
OpFoldResult h = high[idx];
for (auto [reInd, l, h] : llvm::zip_equal(reassociations, low, high)) {

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done

Comment on lines 1164 to 1167
for (size_t i = 0; i < reInd.size(); ++i) {
newLow.push_back(low[idx]);
newHigh.push_back(high[idx]);
}
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: you can use append

Suggested change
for (size_t i = 0; i < reInd.size(); ++i) {
newLow.push_back(low[idx]);
newHigh.push_back(high[idx]);
}
newLow.append(reInd.size(), low[idx]);
newHigh.append(reInd.size(), high[idx]);

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done

OpFoldResult l = low[idx];
OpFoldResult h = high[idx];
if (reInd.size() > 1 && (!isZero(l) || !isZero(h)))
return failure();
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: Use rewriter.notifyMatchFailure() like above?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done

Comment on lines 1200 to 1202
for (auto outDimIdx : reInd) {
expandedShape[outDimIdx] = originalSizeOFR;
}
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I know that the reInd should have a size of 1 from the previous matching, but I think the logic is more clear if you add an assert here that reInd.size() == 1, and then just do expandedShape[reInd[0]] = originalSizeOFR;

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done

Comment on lines 1206 to 1214
for (auto [outDimIdx, dimSize] : llvm::enumerate(finalShape)) {
if (dimSize == ShapedType::kDynamic &&
!isa<Value>(expandedShape[outDimIdx]) &&
!isa<Attribute>(expandedShape[outDimIdx])) {
Value actualSize =
rewriter.create<tensor::DimOp>(loc, expandOp.getSrc(), outDimIdx);
expandedShape[outDimIdx] = actualSize;
}
}
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this was necessary because some of the expandedShape were null right? I'm pretty sure this shouldn't be necessary if you use getMixedOutputShape as per my above comment.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

did I understand your comment correctly?

this can be reduced

for (auto [outDimIdx, dimSize] : llvm::enumerate(finalShape)) {
      if (dimSize == ShapedType::kDynamic &&
          !isa<Value>(expandedShape[outDimIdx]) &&
          !isa<Attribute>(expandedShape[outDimIdx])) {
        expandedShape[outDimIdx] =
            tensor::getMixedSize(rewriter, loc, expandOp.getSrc(), outDimIdx);
      }
    }

Comment on lines 1216 to 1227
SmallVector<int64_t> staticExpandedShape;
for (OpFoldResult dim : expandedShape) {
if (auto attr = dyn_cast<Attribute>(dim)) {
if (auto intAttr = dyn_cast<IntegerAttr>(attr)) {
staticExpandedShape.push_back(intAttr.getInt());
} else {
staticExpandedShape.push_back(ShapedType::kDynamic);
}
} else {
staticExpandedShape.push_back(ShapedType::kDynamic);
}
}
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You can use decomposeMixedValues here:

Suggested change
SmallVector<int64_t> staticExpandedShape;
for (OpFoldResult dim : expandedShape) {
if (auto attr = dyn_cast<Attribute>(dim)) {
if (auto intAttr = dyn_cast<IntegerAttr>(attr)) {
staticExpandedShape.push_back(intAttr.getInt());
} else {
staticExpandedShape.push_back(ShapedType::kDynamic);
}
} else {
staticExpandedShape.push_back(ShapedType::kDynamic);
}
}
SmallVector<int64_t> staticExpandedShape;
std::tie(staticExpandedShape, std::ignore) = decomposeMixedValues(expandedShape);

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done

Comment on lines 1229 to 1233
auto newExpandOp = rewriter.create<tensor::ExpandShapeOp>(
loc,
RankedTensorType::get(staticExpandedShape,
padOp.getSource().getType().getElementType()),
padOp.getSource(), reassociations);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think you also want to pass the mixed output shape here to use the correct builder:

Suggested change
auto newExpandOp = rewriter.create<tensor::ExpandShapeOp>(
loc,
RankedTensorType::get(staticExpandedShape,
padOp.getSource().getType().getElementType()),
padOp.getSource(), reassociations);
auto newExpandOp = rewriter.create<tensor::ExpandShapeOp>(
loc,
RankedTensorType::get(staticExpandedShape,
padOp.getSource().getType().getElementType()),
padOp.getSource(), reassociations, expandedShape);

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done

Comment on lines 1235 to 1239
auto newPadOp = rewriter.create<tensor::PadOp>(
loc, expandOp.getType(), newExpandOp.getResult(), newLow, newHigh,
padOp.getConstantPaddingValue(), padOp.getNofold());

rewriter.replaceOp(expandOp, newPadOp.getResult());
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: use rewriter.replaceOpWithNewOp<tensor::PadOp>?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done

@ita9naiwa ita9naiwa requested a review from Max191 July 18, 2025 22:37
@banach-space
Copy link
Contributor

Thanks for contributing! I have two high-level asks.

iree-org/iree#17492 (comment)

ASK 1: Please use the summary to describe and to justify your change (i.e. provide rationale). The summary should be self-contained and shouldn't refer to any external projects (within reason, but in this case I don't find the reference helpful). You could, for example, explain why "bubbling up" tensor.extract_shape from your example is desirable.

I’ve implemented fusion for tensor.{collapse_shape,expand_shape} → tensor.pad.

Before (expand then pad):

(...)

After (reshape then pad):

ASK 2: "Before" and "after" what? Could you clarify?

Thanks!

@ita9naiwa
Copy link
Contributor Author

ita9naiwa commented Jul 22, 2025

Thanks @banach-space for thoughtful comment!
I updated my PR description more self-contained.

ASK 2: "Before" and "after" what? Could you clarify?

Before, After - before, after applying the pass introduced with this PR

Comment on lines 1139 to 1150
for (auto [reInd, l, h] : llvm::zip_equal(reassociations, low, high)) {
if (reInd.size() > 1 &&
(!isConstantIntValue(l, 0) || !isConstantIntValue(h, 0)))
return rewriter.notifyMatchFailure(
expandOp, "fusion blocked by non-zero padding");
}

SmallVector<OpFoldResult> newLow, newHigh;
for (auto [idx, reInd] : llvm::enumerate(reassociations)) {
newLow.append(reInd.size(), low[idx]);
newHigh.append(reInd.size(), high[idx]);
}
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: combine these 2 loops.

Comment on lines 1168 to 1175
for (auto [outDimIdx, dimSize] : llvm::enumerate(finalShape)) {
if (dimSize == ShapedType::kDynamic &&
!isa<Value>(expandedShape[outDimIdx]) &&
!isa<Attribute>(expandedShape[outDimIdx])) {
expandedShape[outDimIdx] =
tensor::getMixedSize(rewriter, loc, expandOp.getSrc(), outDimIdx);
}
}
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You can delete this loop because expandOp.getMixedOutputShape() will already populate the expandedShape with the right dynamic values.

Comment on lines 1256 to 1263
SmallVector<OpFoldResult> collapsedShape;
for (int64_t dimSize : finalShape) {
if (dimSize == ShapedType::kDynamic) {
collapsedShape.push_back(OpFoldResult{});
} else {
collapsedShape.push_back(rewriter.getI64IntegerAttr(dimSize));
}
}
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You shouldn't need the mixed sizes for the collapsed shape. It is only used for getting the new type, so you can just collect static sizes instead (SmallVector<int64_t>).

@ita9naiwa
Copy link
Contributor Author

Thanks @Max191 for thoughtful review!
I think I addressed all of your comments, sorry I missed some points.

@ita9naiwa ita9naiwa requested a review from Max191 July 26, 2025 05:25
Copy link
Contributor

@Max191 Max191 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks pretty good now, just a few more small comments!

Comment on lines 1232 to 1237
SmallVector<OpFoldResult> newLow, newHigh;
for (auto [idx, reInd] : llvm::enumerate(reassociations)) {
newLow.push_back(low[reInd[0]]);
newHigh.push_back(high[reInd[0]]);
}

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: combine this loop with the loop above.

@ita9naiwa
Copy link
Contributor Author

@Max191 Thanks, all addressed!
now code looks more clean and neat. I learned a lot from your review!

@ita9naiwa ita9naiwa requested a review from Max191 July 29, 2025 09:47
Copy link
Contributor

@Max191 Max191 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM, thanks for addressing all my comments, nice work!

@ita9naiwa
Copy link
Contributor Author

Hi @banach-space, could you please review this PR?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

To me, "folding" would mean that tensor.expand_shape disappears (i.e. is folded away), but that's not what is happening here, is it? This is merely "bubbling up".

Please update the description accordingly and add some example IR before and after. As an example: https://github.com/banach-space/llvm-project/blob/7d35eb58959c0ab398a9739f38bfb9754c5ba5e5/mlir/lib/Dialect/Tensor/Transforms/ReshapePatterns.cpp#L305-L317

Copy link
Contributor

@banach-space banach-space left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for updating the summary!

tensor.{expand_shape,collapse_shape} sitting between a producer and tensor.pad blocks pad–producer fusion and other canonicalizations.

OK, then there should be a test demonstrating that this new transformation is unblocking fusion. Otherwise, it feels a bit ad-hoc. Do you have an example that we could turn into a test?

Thanks!

@ita9naiwa
Copy link
Contributor Author

@banach-space

// RUN: mlir-opt %s -tensor-reshape-propagation -linalg-elementwise-fusion -canonicalize | FileCheck %s

func @pad_fuse(%arg0: tensor<2x3xf32>) -> tensor<2x4xf32> {
  %generic = linalg.generic
      { indexing_maps = [affine_map<(d0, d1)->(d0, d1)>],
        iterator_types = ["parallel", "parallel"] }
      ins(%arg0 : tensor<2x3xf32>) outs(%arg0 : tensor<2x3xf32>) {
    ^bb0(%in: f32, %out: f32):
      linalg.yield %in : f32
  } -> tensor<2x3xf32>

  %collapsed = tensor.collapse_shape %generic [[0, 1]]
      : tensor<2x3xf32> into tensor<6xf32>

  %cst = arith.constant 1.0 : f32
  %padded = tensor.pad %collapsed low[1] high[1] {
    ^bb0(%idx: index):
      tensor.yield %cst : f32
  } : tensor<6xf32> to tensor<8xf32>

  %expanded = tensor.expand_shape %padded [2, 4]
      : tensor<8xf32> into tensor<2x4xf32>

  return %expanded : tensor<2x4xf32>
}

// CHECK-LABEL: func @pad_fuse
//   Function result type must be tensor<2x4xf32> (reshape + pad fused)
// CHECK-SAME: (%{{.*}}: tensor<2x3xf32>) -> tensor<2x4xf32>

// No standalone reshapes or pad should survive
// CHECK-NOT: tensor.collapse_shape
// CHECK-NOT: tensor.expand_shape
// CHECK-NOT: tensor.pad

// A single linalg.generic should produce the final tensor<2x4xf32>
// CHECK: linalg.generic
// CHECK-SAME: outs(%{{.*}} : tensor<2x4xf32>)

This will be a example for this purpose, but I am not sure where to put this. could you recommend appropriate place to put?

@banach-space
Copy link
Contributor

This will be a example for this purpose, but I am not sure where to put this.

The flags that you used in your example don't exist:

$ bin/mlir-opt -tensor-reshape-propagation -linalg-elementwise-fusion -canonicalize  eample.mlir
mlir-opt: Unknown command line argument '-tensor-reshape-propagation'.  Try: 'bin/mlir-opt --help'
mlir-opt: Did you mean '--sharding-propagation'?
mlir-opt: Unknown command line argument '-linalg-elementwise-fusion'.  Try: 'bin/mlir-opt --help'
mlir-opt: Did you mean '--linalg-fuse-elementwise-ops'?

Where did you take them from? Also, the MLIR file is "broken":

$ bin/mlir-opt bad.mlir
bad.mlir:1:1: error: custom op 'func' is unknown (tried 'builtin.func' as well)
func @pad_fuse(%arg0: tensor<2x3xf32>) -> tensor<2x4xf32> {
^

There's more issues then just this one.

could you recommend appropriate place to put?

Please provide a working example first. Sharing broken examples is bad use of reviewers time.

@ita9naiwa
Copy link
Contributor Author

I sincerely apologize — while cleaning up the MLIR code snippet, I relied on an AI assistant and inadvertently broke it.
I take full responsibility for not double-checking the changes. I’ll provide a clearly working example shortly.

@ita9naiwa
Copy link
Contributor Author

This transformation is common in ML workloads such as CNNs, where input tensors are padded (e.g., to match convolution kernel sizes) and then packed into smaller tiles for efficient Tensor Core or SIMD execution.
Layout changes like tensor.expand_shape may appear between padding and packing, blocking optimizations such as linalg-fold-padding.

Bubbling tensor.pad directly before linalg.pack restores adjacency, enabling the padding to be folded into the packing step and unlocking better tiling and vectorization.

Example: Bubbling tensor.pad Before linalg.pack

Before Bubbling

func.func @fold_tensor_pad_with_expand_and_pack(%arg0: tensor<512x256x256xf32>)
    -> tensor<32x16x129x129x2x2xf32> {
  %c0 = arith.constant 0.0 : f32
  %producer = linalg.fill ins(%c0 : f32)
      outs(%arg0 : tensor<512x256x256xf32>)
      -> tensor<512x256x256xf32>

  %pad = tensor.pad %producer low[0, 1, 1] high[0, 1, 1] {
    ^bb0(%i: index, %j: index, %k: index):
      tensor.yield %c0 : f32
  } : tensor<512x256x256xf32> to tensor<512x258x258xf32>

  %reshape = tensor.expand_shape %pad [[0, 1], [2], [3]]
      output_shape [32, 16, 258, 258]
      : tensor<512x258x258xf32> into tensor<32x16x258x258xf32>

  // The pack is blocked: `pad` is not directly adjacent due to reshape
  %dest = tensor.empty() : tensor<32x16x129x129x2x2xf32>
  %packed = linalg.pack %reshape
      inner_dims_pos = [2, 3]
      inner_tiles = [2, 2]
      into %dest
      : tensor<32x16x258x258xf32> -> tensor<32x16x129x129x2x2xf32>

  return %packed : tensor<32x16x129x129x2x2xf32>
}

Here, tensor.pad is separated from linalg.pack by tensor.expand_shape,
so folding patterns such as linalg-fold-padding cannot match.

After Bubbling

command: mlir-opt input.mlir -test-linalg-elementwise-fusion-patterns=fuse-with-reshape-by-expansion

module {
  func.func @fold_tensor_pad_with_expand_and_pack(%arg0: tensor<512x256x256xf32>)
      -> tensor<32x16x129x129x2x2xf32> {
    %cst = arith.constant 0.000000e+00 : f32
    %expanded = tensor.expand_shape %arg0 [[0, 1], [2], [3]]
        output_shape [32, 16, 256, 256]
        : tensor<512x256x256xf32> into tensor<32x16x256x256xf32>
    %0 = linalg.fill ins(%cst : f32)
        outs(%expanded : tensor<32x16x256x256xf32>)
        -> tensor<32x16x256x256xf32>
    %padded = tensor.pad %0 low[0, 0, 1, 1] high[0, 0, 1, 1] {
      ^bb0(%arg1: index, %arg2: index, %arg3: index, %arg4: index):
        tensor.yield %cst : f32
    } : tensor<32x16x256x256xf32> to tensor<32x16x258x258xf32>
    %1 = tensor.empty() : tensor<32x16x129x129x2x2xf32>
    %pack = linalg.pack %padded
        inner_dims_pos = [2, 3]
        inner_tiles = [2, 2]
        into %1
        : tensor<32x16x258x258xf32> -> tensor<32x16x129x129x2x2xf32>
    return %pack : tensor<32x16x129x129x2x2xf32>
  }
}

Now tensor.pad is immediately before linalg.pack.
This enables fusion patterns (linalg-fold-padding) and other tiling/vectorization passes.

Benefits of Bubbling in This Case

  1. Pad–Consumer Adjacency
  • Required for linalg-fold-padding and related optimizations.
  • Without bubbling, the reshape blocks the match.
  1. Improved Downstream Codegen
  • linalg.pack can incorporate padding directly into the blocked layout.
  • Avoids materializing a fully padded intermediate tensor.

Summary:

Bubbling tensor.expand_shape above tensor.pad restores pad–consumer adjacency,
enabling pad folding into linalg.pack and improving tiling and vectorization opportunities.

@ita9naiwa
Copy link
Contributor Author

I really apologize that I made a comment without testing comprehensively. @banach-space

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants