From 0d8c636d4fd4d5c9636cfd3599c804e4a89e81e6 Mon Sep 17 00:00:00 2001 From: Hyunsung Lee Date: Tue, 22 Apr 2025 19:04:02 +0900 Subject: [PATCH 01/10] Add FoldReshapeWithProducerPadOpByExpansion --- .../Linalg/Transforms/ElementwiseOpFusion.cpp | 142 ++++++++++++++++++ mlir/test/Dialect/Linalg/reshape_fusion.mlir | 51 ++++++- 2 files changed, 191 insertions(+), 2 deletions(-) diff --git a/mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp b/mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp index bf70597d5ddfe..dd4ac89e98090 100644 --- a/mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp @@ -1101,6 +1101,146 @@ class FoldPadWithProducerReshapeOpByExpansion ControlFusionFn controlFoldingReshapes; }; +/// Pattern to fold a tensor.expand_shape op with its producer tensor.pad op +/// by bubbling the expand_shape before the pad. +struct FoldReshapeWithProducerPadOpByExpansion + : public OpRewritePattern { + + FoldReshapeWithProducerPadOpByExpansion(MLIRContext *context, + ControlFusionFn foldReshapes, + PatternBenefit benefit = 1) + : OpRewritePattern(context, benefit), + controlFoldingReshapes(std::move(foldReshapes)) {} + + LogicalResult matchAndRewrite(tensor::ExpandShapeOp expandOp, + PatternRewriter &rewriter) const override { + tensor::PadOp padOp = expandOp.getSrc().getDefiningOp(); + if (!padOp) + return failure(); + + if (!padOp->hasOneUse()) + return failure(); + + if (!controlFoldingReshapes(&expandOp.getSrcMutable())) { + return rewriter.notifyMatchFailure(expandOp, + "fusion blocked by control function"); + } + + SmallVector reassociations = + expandOp.getReassociationIndices(); + SmallVector low = padOp.getMixedLowPad(); + SmallVector high = padOp.getMixedHighPad(); + + auto isZeroPadding = [](OpFoldResult padValue) -> bool { + if (auto attr = dyn_cast(padValue)) { + if (auto intAttr = dyn_cast(attr)) + return intAttr.getInt() == 0; + } + + if (auto val = dyn_cast(padValue)) { + if (auto constOp = val.getDefiningOp()) { + if (auto attr = dyn_cast(constOp.getValue())) + return attr.getInt() == 0; + } + } + + // when padding is dynamic and not constant, we don't know if it's zero or + // not. so we return false here. + return false; + }; + + for (auto [idx, reInd] : llvm::enumerate(reassociations)) { + OpFoldResult l = low[idx]; + OpFoldResult h = high[idx]; + if (reInd.size() != 1 && (!isZeroPadding(l) || !isZeroPadding(h))) + return failure(); + } + + SmallVector newLow, newHigh; + for (auto [idx, reInd] : llvm::enumerate(reassociations)) { + for (size_t i = 0; i < reInd.size(); ++i) { + newLow.push_back(padOp.getMixedLowPad()[idx]); + newHigh.push_back(padOp.getMixedHighPad()[idx]); + } + } + + Location loc = expandOp.getLoc(); + auto finalType = cast(expandOp.getType()); + ArrayRef finalShape = finalType.getShape(); + + SmallVector expandedShape; + for (int64_t dimSize : finalShape) { + if (dimSize == ShapedType::kDynamic) { + expandedShape.push_back(OpFoldResult{}); + } else { + expandedShape.push_back(rewriter.getI64IntegerAttr(dimSize)); + } + } + + for (auto [inDimIdx, outGroup] : llvm::enumerate(reassociations)) { + OpFoldResult l = low[inDimIdx]; + OpFoldResult h = high[inDimIdx]; + + if (!isZeroPadding(l) || !isZeroPadding(h)) { + auto srcType = cast(padOp.getSource().getType()); + int64_t originalSize = srcType.getDimSize(inDimIdx); + + OpFoldResult originalSizeOFR; + if (originalSize == ShapedType::kDynamic) { + Value orgSizeVal = + rewriter.create(loc, padOp.getSource(), inDimIdx); + originalSizeOFR = orgSizeVal; + } else { + originalSizeOFR = rewriter.getI64IntegerAttr(originalSize); + } + + for (auto outDimIdx : outGroup) { + expandedShape[outDimIdx] = originalSizeOFR; + } + } + } + + for (auto [outDimIdx, dimSize] : llvm::enumerate(finalShape)) { + if (dimSize == ShapedType::kDynamic && + !isa(expandedShape[outDimIdx]) && + !isa(expandedShape[outDimIdx])) { + Value actualSize = + rewriter.create(loc, expandOp.getSrc(), outDimIdx); + expandedShape[outDimIdx] = actualSize; + } + } + + SmallVector staticExpandedShape; + for (OpFoldResult dim : expandedShape) { + if (auto attr = dyn_cast(dim)) { + if (auto intAttr = dyn_cast(attr)) { + staticExpandedShape.push_back(intAttr.getInt()); + } else { + staticExpandedShape.push_back(ShapedType::kDynamic); + } + } else { + staticExpandedShape.push_back(ShapedType::kDynamic); + } + } + + auto newExpandOp = rewriter.create( + loc, + RankedTensorType::get(staticExpandedShape, + padOp.getSource().getType().getElementType()), + padOp.getSource(), reassociations); + + auto newPadOp = rewriter.create( + loc, expandOp.getType(), newExpandOp.getResult(), newLow, newHigh, + padOp.getConstantPaddingValue(), padOp.getNofold()); + + rewriter.replaceOp(expandOp, newPadOp.getResult()); + return success(); + } + +private: + ControlFusionFn controlFoldingReshapes; +}; + /// Pattern to fold a tensor.expand_shape op with its producer generic op /// by expanding the dimensionality of the loop in the producer op. struct FoldReshapeWithGenericOpByExpansion @@ -2249,6 +2389,8 @@ void mlir::linalg::populateFoldReshapeOpsByExpansionPatterns( controlFoldingReshapes); patterns.add(patterns.getContext(), controlFoldingReshapes); + patterns.add(patterns.getContext(), + controlFoldingReshapes); patterns.add(patterns.getContext(), controlFoldingReshapes); } diff --git a/mlir/test/Dialect/Linalg/reshape_fusion.mlir b/mlir/test/Dialect/Linalg/reshape_fusion.mlir index 67b4f2b32bad5..3ea0babfa3b9d 100644 --- a/mlir/test/Dialect/Linalg/reshape_fusion.mlir +++ b/mlir/test/Dialect/Linalg/reshape_fusion.mlir @@ -247,7 +247,7 @@ func.func @indexed_consumer_reshape_producer_fusion(%arg0 : tensor, #map0 = affine_map<(d0, d1) -> (d0, d1)> func.func @indexed_producer_reshape_consumer_fusion(%arg0 : tensor, - %arg1 : tensor, + %arg1 : tensor, %sz0: index, %sz1: index) -> tensor { @@ -515,7 +515,7 @@ func.func @fuse_dynamic_dims(%arg0: tensor) -> tensor { // ----- func.func @reshape_as_consumer_permutation_with_multiple_results - (%a : tensor, %b : tensor, %sz0: index, + (%a : tensor, %b : tensor, %sz0: index, %sz1: index, %sz2: index, %sz3: index, %sz4: index) -> (tensor, tensor) { %c:2 = linalg.generic { @@ -893,3 +893,50 @@ func.func @move_operand_deps(%arg0 : tensor, // CHECK: %[[GENERIC:.+]] = linalg.generic // CHECK-SAME: ins(%[[EXPANDED]] : // CHECK: return %[[GENERIC]] + +// ----- + +func.func @fold_tensor_pad_with_expand(%arg0: tensor<512x256x256xf32>) -> tensor<32x16x258x258xf32> { + %cst = arith.constant 0.000000e+00 : f32 + %0 = linalg.fill ins(%cst : f32) outs(%arg0 : tensor<512x256x256xf32>) -> tensor<512x256x256xf32> + %padded = tensor.pad %0 low[0, 1, 1] high[0, 1, 1] { + ^bb0(%i: index, %j: index, %k: index): + tensor.yield %cst : f32 + } : tensor<512x256x256xf32> to tensor<512x258x258xf32> + %expanded = tensor.expand_shape %padded [[0, 1], [2], [3]] output_shape [32, 16, 258, 258] : tensor<512x258x258xf32> into tensor<32x16x258x258xf32> + return %expanded : tensor<32x16x258x258xf32> +} +// CHECK: func @fold_tensor_pad_with_expand( +// CHECK-SAME: %[[ARG0:[^:]+]]: tensor<512x256x256xf32> +// CHECK-DAG: %[[CST:.+]] = arith.constant 0.000000e+00 : f32 +// CHECK-DAG: %[[EXPANDED:.+]] = tensor.expand_shape %[[ARG0]] +// CHECK: %[[FILLED:.*]] = linalg.fill ins(%[[CST]] : f32) outs(%[[EXPANDED]] : tensor<32x16x256x256xf32>) +// CHECK: %[[PADDED:.*]] = tensor.pad %[[FILLED]] low[0, 0, 1, 1] high[0, 0, 1, 1] +// CHECK: ^bb0(%[[ARG1:.*]]: index, %[[ARG2:.*]]: index, %[[ARG3:.*]]: index, %[[ARG4:.*]]: index): +// CHECK: tensor.yield %[[CST]] : f32 +// CHECK: } : tensor<32x16x256x256xf32> to tensor<32x16x258x258xf32> +// CHECK: return %[[PADDED]] : tensor<32x16x258x258xf32> + +// ----- + +func.func @fold_tensor_pad_with_expand_dynamic_pad_zero(%arg0: tensor<512x256x256xf32>) -> tensor<32x16x258x258xf32> { + %cst = arith.constant 0.000000e+00 : f32 + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %0 = linalg.fill ins(%cst : f32) outs(%arg0 : tensor<512x256x256xf32>) -> tensor<512x256x256xf32> + %padded = tensor.pad %0 low[%c0, %c1, %c1] high[%c0, %c1, %c1] { + ^bb0(%i: index, %j: index, %k: index): + tensor.yield %cst : f32 + } : tensor<512x256x256xf32> to tensor<512x258x258xf32> + %expanded = tensor.expand_shape %padded [[0, 1], [2], [3]] output_shape [32, 16, 258, 258] : tensor<512x258x258xf32> into tensor<32x16x258x258xf32> + return %expanded : tensor<32x16x258x258xf32> +} +// CHECK: func @fold_tensor_pad_with_expand_dynamic_pad_zero( +// CHECK-SAME: %[[ARG0:[^:]+]]: tensor<512x256x256xf32> +// CHECK: %[[CST:.+]] = arith.constant 0.000000e+00 : f32 +// CHECK: %[[EXPANDED:.+]] = tensor.expand_shape %[[ARG0]] +// CHECK: %[[FILLED:.*]] = linalg.fill ins(%[[CST]] : f32) outs(%[[EXPANDED]] +// CHECK: %[[PADDED:.*]] = tensor.pad %[[FILLED]] low[0, 0, 1, 1] high[0, 0, 1, 1] +// CHECK: ^bb0( +// CHECK: tensor.yield %[[CST]] : f32 +// CHECK: return %[[PADDED]] From 57ec65705339764b1a472f32b382c015909b25e8 Mon Sep 17 00:00:00 2001 From: Hyunsung Lee Date: Mon, 14 Jul 2025 09:59:33 +0900 Subject: [PATCH 02/10] add collapse_shape --- .../Linalg/Transforms/ElementwiseOpFusion.cpp | 176 +++++++++++++++--- .../fuse-with-reshape-by-collapsing.mlir | 53 +++++- 2 files changed, 204 insertions(+), 25 deletions(-) diff --git a/mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp b/mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp index 39eed6dd4cba4..e65228ae0e3eb 100644 --- a/mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp @@ -26,6 +26,8 @@ #include "mlir/Support/LLVM.h" #include "mlir/Transforms/GreedyPatternRewriteDriver.h" #include "mlir/Transforms/RegionUtils.h" +#include "llvm/ADT/STLExtras.h" +#include "llvm/Support/LogicalResult.h" #include #include @@ -1100,6 +1102,20 @@ class FoldPadWithProducerReshapeOpByExpansion ControlFusionFn controlFoldingReshapes; }; +bool isZero(OpFoldResult value) { + if (auto attr = dyn_cast(value)) { + if (auto intAttr = dyn_cast(attr)) + return intAttr.getInt() == 0; + } + if (auto val = dyn_cast(value)) { + if (auto constOp = val.getDefiningOp()) { + if (auto attr = dyn_cast(constOp.getValue())) + return attr.getInt() == 0; + } + } + return false; +} + /// Pattern to fold a tensor.expand_shape op with its producer tensor.pad op /// by bubbling the expand_shape before the pad. struct FoldReshapeWithProducerPadOpByExpansion @@ -1125,41 +1141,29 @@ struct FoldReshapeWithProducerPadOpByExpansion "fusion blocked by control function"); } + Value constantPaddingValue = padOp.getConstantPaddingValue(); + if (!constantPaddingValue) { + return rewriter.notifyMatchFailure( + expandOp, "cannot fold with non-constant padding value"); + } + SmallVector reassociations = expandOp.getReassociationIndices(); SmallVector low = padOp.getMixedLowPad(); SmallVector high = padOp.getMixedHighPad(); - auto isZeroPadding = [](OpFoldResult padValue) -> bool { - if (auto attr = dyn_cast(padValue)) { - if (auto intAttr = dyn_cast(attr)) - return intAttr.getInt() == 0; - } - - if (auto val = dyn_cast(padValue)) { - if (auto constOp = val.getDefiningOp()) { - if (auto attr = dyn_cast(constOp.getValue())) - return attr.getInt() == 0; - } - } - - // when padding is dynamic and not constant, we don't know if it's zero or - // not. so we return false here. - return false; - }; - for (auto [idx, reInd] : llvm::enumerate(reassociations)) { OpFoldResult l = low[idx]; OpFoldResult h = high[idx]; - if (reInd.size() != 1 && (!isZeroPadding(l) || !isZeroPadding(h))) + if (reInd.size() > 1 && (!isZero(l) || !isZero(h))) return failure(); } SmallVector newLow, newHigh; for (auto [idx, reInd] : llvm::enumerate(reassociations)) { for (size_t i = 0; i < reInd.size(); ++i) { - newLow.push_back(padOp.getMixedLowPad()[idx]); - newHigh.push_back(padOp.getMixedHighPad()[idx]); + newLow.push_back(low[idx]); + newHigh.push_back(high[idx]); } } @@ -1176,11 +1180,11 @@ struct FoldReshapeWithProducerPadOpByExpansion } } - for (auto [inDimIdx, outGroup] : llvm::enumerate(reassociations)) { + for (auto [inDimIdx, reInd] : llvm::enumerate(reassociations)) { OpFoldResult l = low[inDimIdx]; OpFoldResult h = high[inDimIdx]; - if (!isZeroPadding(l) || !isZeroPadding(h)) { + if (!isZero(l) || !isZero(h)) { auto srcType = cast(padOp.getSource().getType()); int64_t originalSize = srcType.getDimSize(inDimIdx); @@ -1193,7 +1197,7 @@ struct FoldReshapeWithProducerPadOpByExpansion originalSizeOFR = rewriter.getI64IntegerAttr(originalSize); } - for (auto outDimIdx : outGroup) { + for (auto outDimIdx : reInd) { expandedShape[outDimIdx] = originalSizeOFR; } } @@ -1240,6 +1244,125 @@ struct FoldReshapeWithProducerPadOpByExpansion ControlFusionFn controlFoldingReshapes; }; +/// Pattern to fold a tensor.collapse_shape op with its producer tensor.pad op +/// by bubbling the collapse_shape before the pad. +struct FoldReshapeWithProducerPadOpByCollapsing + : public OpRewritePattern { + + FoldReshapeWithProducerPadOpByCollapsing(MLIRContext *context, + ControlFusionFn foldReshapes, + PatternBenefit benefit = 1) + : OpRewritePattern(context, benefit), + controlFoldingReshapes(std::move(foldReshapes)) {} + + LogicalResult matchAndRewrite(tensor::CollapseShapeOp collapseOp, + PatternRewriter &rewriter) const override { + tensor::PadOp padOp = collapseOp.getSrc().getDefiningOp(); + + if (!padOp) + return failure(); + + if (!padOp->hasOneUse()) + return failure(); + + if (!controlFoldingReshapes(&collapseOp.getSrcMutable())) { + return rewriter.notifyMatchFailure(collapseOp, + "fusion blocked by control function"); + } + + Value constantPaddingValue = padOp.getConstantPaddingValue(); + if (!constantPaddingValue) { + return rewriter.notifyMatchFailure( + collapseOp, "cannot fold with non-constant padding value"); + } + + SmallVector reassociations = + collapseOp.getReassociationIndices(); + SmallVector low = padOp.getMixedLowPad(); + SmallVector high = padOp.getMixedHighPad(); + + for (auto [idx, reInd] : llvm::enumerate(reassociations)) { + if (reInd.size() > 1) { + for (auto dimIdx : reInd) { + if (!isZero(low[dimIdx]) || !isZero(high[dimIdx])) { + return failure(); + } + } + } + } + + SmallVector newLow, newHigh; + for (auto [idx, reInd] : llvm::enumerate(reassociations)) { + newLow.push_back(low[reInd[0]]); + newHigh.push_back(high[reInd[0]]); + } + + Location loc = collapseOp.getLoc(); + auto resultType = collapseOp.getResultType(); + + auto finalType = cast(collapseOp.getType()); + ArrayRef finalShape = finalType.getShape(); + + SmallVector collapsedShape; + for (int64_t dimSize : finalShape) { + if (dimSize == ShapedType::kDynamic) { + collapsedShape.push_back(OpFoldResult{}); + } else { + collapsedShape.push_back(rewriter.getI64IntegerAttr(dimSize)); + } + } + + for (auto [inDimIdx, reInd] : llvm::enumerate(reassociations)) { + OpFoldResult l = low[reInd[0]]; + OpFoldResult h = high[reInd[0]]; + + if (!isZero(l) || !isZero(h)) { + auto srcType = cast(padOp.getSource().getType()); + int64_t originalSize = srcType.getDimSize(reInd[0]); + + OpFoldResult originalSizeOFR; + if (originalSize == ShapedType::kDynamic) { + Value orgSizeVal = + rewriter.create(loc, padOp.getSource(), reInd[0]); + originalSizeOFR = orgSizeVal; + } else { + originalSizeOFR = rewriter.getI64IntegerAttr(originalSize); + } + collapsedShape[inDimIdx] = originalSizeOFR; + } + } + + SmallVector staticCollapsedShape; + for (OpFoldResult dim : collapsedShape) { + if (auto attr = dyn_cast(dim)) { + if (auto intAttr = dyn_cast(attr)) { + staticCollapsedShape.push_back(intAttr.getInt()); + } else { + staticCollapsedShape.push_back(ShapedType::kDynamic); + } + } else { + staticCollapsedShape.push_back(ShapedType::kDynamic); + } + } + + auto newCollapseType = RankedTensorType::get( + staticCollapsedShape, padOp.getSource().getType().getElementType()); + auto newCollapseOp = rewriter.create( + loc, newCollapseType, padOp.getSource(), reassociations); + + auto newPadOp = rewriter.create( + loc, resultType, newCollapseOp.getResult(), newLow, newHigh, + padOp.getConstantPaddingValue(), padOp.getNofold()); + + rewriter.replaceOp(collapseOp, newPadOp.getResult()); + + return success(); + } + +private: + ControlFusionFn controlFoldingReshapes; +}; + /// Pattern to fold a tensor.expand_shape op with its producer generic op /// by expanding the dimensionality of the loop in the producer op. struct FoldReshapeWithGenericOpByExpansion @@ -2388,6 +2511,11 @@ void mlir::linalg::populateFoldReshapeOpsByCollapsingPatterns( controlFoldingReshapes); patterns.add( patterns.getContext(), controlFoldingReshapes); + patterns.add( + patterns.getContext(), controlFoldingReshapes); + + patterns.add( + patterns.getContext(), controlFoldingReshapes); patterns.add(patterns.getContext(), controlFoldingReshapes); } diff --git a/mlir/test/Dialect/Linalg/fuse-with-reshape-by-collapsing.mlir b/mlir/test/Dialect/Linalg/fuse-with-reshape-by-collapsing.mlir index 2bf3d21c35526..0ac1686361bf7 100644 --- a/mlir/test/Dialect/Linalg/fuse-with-reshape-by-collapsing.mlir +++ b/mlir/test/Dialect/Linalg/fuse-with-reshape-by-collapsing.mlir @@ -232,7 +232,7 @@ func.func @fuse_by_collapsing_dynamic_2(%arg0 : tensor, %sz0: index, %sz1 %1 = linalg.generic { indexing_maps = [#map0, #map0], iterator_types = ["parallel", "parallel"]} - ins(%0 : tensor) + ins(%0 : tensor) outs(%init : tensor) { ^bb0(%b0 : f32, %b1 : f32): %out = arith.negf %b0 : f32 @@ -858,3 +858,54 @@ func.func @partial_fuse_by_collapsing(%arg0: tensor<4x?x32x128x192xf16>, %arg1: // CHECK: %[[COLLAPSED:.+]] = tensor.collapse_shape %[[EXPANDED]] // CHECK-SAME: tensor<4x128x192x?x32xf32> into tensor<512x192x?xf32> // CHECK: return %[[COLLAPSED]] : tensor<512x192x?xf32> + +// ----- + +func.func @fold_tensor_pad_with_collapse(%arg0: tensor<32x16x256x256xf32>) -> tensor<512x258x258xf32> { + %cst = arith.constant 0.000000e+00 : f32 + %0 = linalg.fill ins(%cst : f32) outs(%arg0 : tensor<32x16x256x256xf32>) -> tensor<32x16x256x256xf32> + %padded = tensor.pad %0 low[0, 0, 1, 1] high[0, 0, 1, 1] { + ^bb0(%i0: index, %i1: index, %i2: index, %i3: index): + tensor.yield %cst : f32 + } : tensor<32x16x256x256xf32> to tensor<32x16x258x258xf32> + %collapsed = tensor.collapse_shape %padded [[0, 1], [2], [3]] + : tensor<32x16x258x258xf32> into tensor<512x258x258xf32> + return %collapsed : tensor<512x258x258xf32> +} +// CHECK: func @fold_tensor_pad_with_collapse( +// CHECK-SAME: %[[ARG0:[^:]+]]: tensor<32x16x256x256xf32> +// CHECK-DAG: %[[CST:.+]] = arith.constant 0.000000e+00 : f32 +// CHECK: %[[FILLED:.*]] = linalg.fill ins(%[[CST]] : f32) outs(%[[ARG0]] : tensor<32x16x256x256xf32>) +// CHECK: %[[COLLAPSED:.+]] = tensor.collapse_shape %[[FILLED]] {{\[}}[0, 1], [2], [3]{{\]}} +// CHECK-SAME: : tensor<32x16x256x256xf32> into tensor<512x256x256xf32> +// CHECK: %[[PADDED:.*]] = tensor.pad %[[COLLAPSED]] low[0, 1, 1] high[0, 1, 1] +// CHECK: ^bb0(%[[ARG1:.*]]: index, %[[ARG2:.*]]: index, %[[ARG3:.*]]: index): +// CHECK: tensor.yield %[[CST]] : f32 +// CHECK: } : tensor<512x256x256xf32> to tensor<512x258x258xf32> +// CHECK: return %[[PADDED]] : tensor<512x258x258xf32> + +// ----- + +func.func @fold_tensor_pad_with_collapse_dynamic_pad_zero(%arg0: tensor<32x16x256x256xf32>) -> tensor<512x258x258xf32> { + %cst = arith.constant 0.000000e+00 : f32 + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %0 = linalg.fill ins(%cst : f32) outs(%arg0 : tensor<32x16x256x256xf32>) -> tensor<32x16x256x256xf32> + %padded = tensor.pad %0 low[%c0, %c0, %c1, %c1] high[%c0, %c0, %c1, %c1] { + ^bb0(%i0: index, %i1: index, %i2: index, %i3: index): + tensor.yield %cst : f32 + } : tensor<32x16x256x256xf32> to tensor<32x16x258x258xf32> + %collapsed = tensor.collapse_shape %padded [[0, 1], [2], [3]] + : tensor<32x16x258x258xf32> into tensor<512x258x258xf32> + return %collapsed : tensor<512x258x258xf32> +} +// CHECK: func @fold_tensor_pad_with_collapse_dynamic_pad_zero( +// CHECK-SAME: %[[ARG0:[^:]+]]: tensor<32x16x256x256xf32> +// CHECK: %[[CST:.+]] = arith.constant 0.000000e+00 : f32 +// CHECK: %[[FILLED:.*]] = linalg.fill ins(%[[CST]] : f32) outs(%[[ARG0]] : tensor<32x16x256x256xf32>) +// CHECK: %[[COLLAPSED:.+]] = tensor.collapse_shape %[[FILLED]] {{\[}}[0, 1], [2], [3]{{\]}} +// CHECK-SAME: : tensor<32x16x256x256xf32> into tensor<512x256x256xf32> +// CHECK: %[[PADDED:.*]] = tensor.pad %[[COLLAPSED]] low[0, 1, 1] high[0, 1, 1] +// CHECK: ^bb0( +// CHECK: tensor.yield %[[CST]] : f32 +// CHECK: return %[[PADDED]] From 737d4a4c776030cf9154aeb10039c870a6a63211 Mon Sep 17 00:00:00 2001 From: Hyunsung Lee Date: Sat, 19 Jul 2025 07:37:30 +0900 Subject: [PATCH 03/10] fix upon review --- .../Linalg/Transforms/ElementwiseOpFusion.cpp | 75 ++++++------------- 1 file changed, 22 insertions(+), 53 deletions(-) diff --git a/mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp b/mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp index e65228ae0e3eb..05dbb7cd7ba43 100644 --- a/mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp @@ -1102,20 +1102,6 @@ class FoldPadWithProducerReshapeOpByExpansion ControlFusionFn controlFoldingReshapes; }; -bool isZero(OpFoldResult value) { - if (auto attr = dyn_cast(value)) { - if (auto intAttr = dyn_cast(attr)) - return intAttr.getInt() == 0; - } - if (auto val = dyn_cast(value)) { - if (auto constOp = val.getDefiningOp()) { - if (auto attr = dyn_cast(constOp.getValue())) - return attr.getInt() == 0; - } - } - return false; -} - /// Pattern to fold a tensor.expand_shape op with its producer tensor.pad op /// by bubbling the expand_shape before the pad. struct FoldReshapeWithProducerPadOpByExpansion @@ -1152,19 +1138,17 @@ struct FoldReshapeWithProducerPadOpByExpansion SmallVector low = padOp.getMixedLowPad(); SmallVector high = padOp.getMixedHighPad(); - for (auto [idx, reInd] : llvm::enumerate(reassociations)) { - OpFoldResult l = low[idx]; - OpFoldResult h = high[idx]; - if (reInd.size() > 1 && (!isZero(l) || !isZero(h))) - return failure(); + for (auto [reInd, l, h] : llvm::zip_equal(reassociations, low, high)) { + if (reInd.size() > 1 && + (!isConstantIntValue(l, 0) || !isConstantIntValue(h, 0))) + return rewriter.notifyMatchFailure( + expandOp, "fusion blocked by non-zero padding"); } SmallVector newLow, newHigh; for (auto [idx, reInd] : llvm::enumerate(reassociations)) { - for (size_t i = 0; i < reInd.size(); ++i) { - newLow.push_back(low[idx]); - newHigh.push_back(high[idx]); - } + newLow.append(reInd.size(), low[idx]); + newHigh.append(reInd.size(), high[idx]); } Location loc = expandOp.getLoc(); @@ -1184,7 +1168,7 @@ struct FoldReshapeWithProducerPadOpByExpansion OpFoldResult l = low[inDimIdx]; OpFoldResult h = high[inDimIdx]; - if (!isZero(l) || !isZero(h)) { + if (!isConstantIntValue(l, 0) || !isConstantIntValue(h, 0)) { auto srcType = cast(padOp.getSource().getType()); int64_t originalSize = srcType.getDimSize(inDimIdx); @@ -1196,10 +1180,8 @@ struct FoldReshapeWithProducerPadOpByExpansion } else { originalSizeOFR = rewriter.getI64IntegerAttr(originalSize); } - - for (auto outDimIdx : reInd) { - expandedShape[outDimIdx] = originalSizeOFR; - } + assert(reInd.size() == 1 && "expected single dimension"); + expandedShape[reInd[0]] = originalSizeOFR; } } @@ -1207,36 +1189,24 @@ struct FoldReshapeWithProducerPadOpByExpansion if (dimSize == ShapedType::kDynamic && !isa(expandedShape[outDimIdx]) && !isa(expandedShape[outDimIdx])) { - Value actualSize = - rewriter.create(loc, expandOp.getSrc(), outDimIdx); - expandedShape[outDimIdx] = actualSize; + expandedShape[outDimIdx] = + tensor::getMixedSize(rewriter, loc, expandOp.getSrc(), outDimIdx); } } SmallVector staticExpandedShape; - for (OpFoldResult dim : expandedShape) { - if (auto attr = dyn_cast(dim)) { - if (auto intAttr = dyn_cast(attr)) { - staticExpandedShape.push_back(intAttr.getInt()); - } else { - staticExpandedShape.push_back(ShapedType::kDynamic); - } - } else { - staticExpandedShape.push_back(ShapedType::kDynamic); - } - } + std::tie(staticExpandedShape, std::ignore) = + decomposeMixedValues(expandedShape); auto newExpandOp = rewriter.create( loc, RankedTensorType::get(staticExpandedShape, padOp.getSource().getType().getElementType()), - padOp.getSource(), reassociations); + padOp.getSource(), reassociations, expandedShape); - auto newPadOp = rewriter.create( - loc, expandOp.getType(), newExpandOp.getResult(), newLow, newHigh, + rewriter.replaceOpWithNewOp( + expandOp, expandOp.getType(), newExpandOp.getResult(), newLow, newHigh, padOp.getConstantPaddingValue(), padOp.getNofold()); - - rewriter.replaceOp(expandOp, newPadOp.getResult()); return success(); } @@ -1284,7 +1254,8 @@ struct FoldReshapeWithProducerPadOpByCollapsing for (auto [idx, reInd] : llvm::enumerate(reassociations)) { if (reInd.size() > 1) { for (auto dimIdx : reInd) { - if (!isZero(low[dimIdx]) || !isZero(high[dimIdx])) { + if (!isConstantIntValue(low[dimIdx], 0) || + !isConstantIntValue(high[dimIdx], 0)) { return failure(); } } @@ -1316,7 +1287,7 @@ struct FoldReshapeWithProducerPadOpByCollapsing OpFoldResult l = low[reInd[0]]; OpFoldResult h = high[reInd[0]]; - if (!isZero(l) || !isZero(h)) { + if (!isConstantIntValue(l, 0) || !isConstantIntValue(h, 0)) { auto srcType = cast(padOp.getSource().getType()); int64_t originalSize = srcType.getDimSize(reInd[0]); @@ -1350,12 +1321,10 @@ struct FoldReshapeWithProducerPadOpByCollapsing auto newCollapseOp = rewriter.create( loc, newCollapseType, padOp.getSource(), reassociations); - auto newPadOp = rewriter.create( - loc, resultType, newCollapseOp.getResult(), newLow, newHigh, + rewriter.replaceOpWithNewOp( + collapseOp, resultType, newCollapseOp.getResult(), newLow, newHigh, padOp.getConstantPaddingValue(), padOp.getNofold()); - rewriter.replaceOp(collapseOp, newPadOp.getResult()); - return success(); } From d8ca03657a6445a99cbe4558b7b8990a703c3c57 Mon Sep 17 00:00:00 2001 From: Hyunsung Lee Date: Sat, 19 Jul 2025 07:44:43 +0900 Subject: [PATCH 04/10] fix upon review --- .../Linalg/Transforms/ElementwiseOpFusion.cpp | 13 ++----------- 1 file changed, 2 insertions(+), 11 deletions(-) diff --git a/mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp b/mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp index 05dbb7cd7ba43..6499a3387efca 100644 --- a/mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp @@ -1304,17 +1304,8 @@ struct FoldReshapeWithProducerPadOpByCollapsing } SmallVector staticCollapsedShape; - for (OpFoldResult dim : collapsedShape) { - if (auto attr = dyn_cast(dim)) { - if (auto intAttr = dyn_cast(attr)) { - staticCollapsedShape.push_back(intAttr.getInt()); - } else { - staticCollapsedShape.push_back(ShapedType::kDynamic); - } - } else { - staticCollapsedShape.push_back(ShapedType::kDynamic); - } - } + std::tie(staticCollapsedShape, std::ignore) = + decomposeMixedValues(collapsedShape); auto newCollapseType = RankedTensorType::get( staticCollapsedShape, padOp.getSource().getType().getElementType()); From 17a24473a1c881cd6ac16f8a3cf5d7665ee477fe Mon Sep 17 00:00:00 2001 From: Hyunsung Lee Date: Tue, 22 Jul 2025 15:48:27 +0900 Subject: [PATCH 05/10] fix upon review --- .../Linalg/Transforms/ElementwiseOpFusion.cpp | 31 +++---------------- 1 file changed, 5 insertions(+), 26 deletions(-) diff --git a/mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp b/mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp index 6499a3387efca..1ec3bd2ac8f1d 100644 --- a/mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp @@ -26,8 +26,6 @@ #include "mlir/Support/LLVM.h" #include "mlir/Transforms/GreedyPatternRewriteDriver.h" #include "mlir/Transforms/RegionUtils.h" -#include "llvm/ADT/STLExtras.h" -#include "llvm/Support/LogicalResult.h" #include #include @@ -1169,19 +1167,10 @@ struct FoldReshapeWithProducerPadOpByExpansion OpFoldResult h = high[inDimIdx]; if (!isConstantIntValue(l, 0) || !isConstantIntValue(h, 0)) { - auto srcType = cast(padOp.getSource().getType()); - int64_t originalSize = srcType.getDimSize(inDimIdx); - - OpFoldResult originalSizeOFR; - if (originalSize == ShapedType::kDynamic) { - Value orgSizeVal = - rewriter.create(loc, padOp.getSource(), inDimIdx); - originalSizeOFR = orgSizeVal; - } else { - originalSizeOFR = rewriter.getI64IntegerAttr(originalSize); - } assert(reInd.size() == 1 && "expected single dimension"); - expandedShape[reInd[0]] = originalSizeOFR; + expandedShape[reInd[0]] = + tensor::getMixedSize(rewriter, loc, padOp.getSource(), inDimIdx); + ; } } @@ -1288,18 +1277,8 @@ struct FoldReshapeWithProducerPadOpByCollapsing OpFoldResult h = high[reInd[0]]; if (!isConstantIntValue(l, 0) || !isConstantIntValue(h, 0)) { - auto srcType = cast(padOp.getSource().getType()); - int64_t originalSize = srcType.getDimSize(reInd[0]); - - OpFoldResult originalSizeOFR; - if (originalSize == ShapedType::kDynamic) { - Value orgSizeVal = - rewriter.create(loc, padOp.getSource(), reInd[0]); - originalSizeOFR = orgSizeVal; - } else { - originalSizeOFR = rewriter.getI64IntegerAttr(originalSize); - } - collapsedShape[inDimIdx] = originalSizeOFR; + collapsedShape[inDimIdx] = + tensor::getMixedSize(rewriter, loc, padOp.getSource(), reInd[0]); } } From 9ee8e08f9a8e48403df169499d2cb9b765156017 Mon Sep 17 00:00:00 2001 From: Hyunsung Lee Date: Tue, 22 Jul 2025 16:54:23 +0900 Subject: [PATCH 06/10] fix upon review --- .../Linalg/Transforms/ElementwiseOpFusion.cpp | 16 +++------------- 1 file changed, 3 insertions(+), 13 deletions(-) diff --git a/mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp b/mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp index 1ec3bd2ac8f1d..e99de0e78eabe 100644 --- a/mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp @@ -1150,17 +1150,8 @@ struct FoldReshapeWithProducerPadOpByExpansion } Location loc = expandOp.getLoc(); - auto finalType = cast(expandOp.getType()); - ArrayRef finalShape = finalType.getShape(); - - SmallVector expandedShape; - for (int64_t dimSize : finalShape) { - if (dimSize == ShapedType::kDynamic) { - expandedShape.push_back(OpFoldResult{}); - } else { - expandedShape.push_back(rewriter.getI64IntegerAttr(dimSize)); - } - } + ArrayRef finalShape = expandOp.getResultType().getShape(); + SmallVector expandedShape = expandOp.getMixedOutputShape(); for (auto [inDimIdx, reInd] : llvm::enumerate(reassociations)) { OpFoldResult l = low[inDimIdx]; @@ -1260,8 +1251,7 @@ struct FoldReshapeWithProducerPadOpByCollapsing Location loc = collapseOp.getLoc(); auto resultType = collapseOp.getResultType(); - auto finalType = cast(collapseOp.getType()); - ArrayRef finalShape = finalType.getShape(); + ArrayRef finalShape = collapseOp.getResultType().getShape(); SmallVector collapsedShape; for (int64_t dimSize : finalShape) { From 3b916457029e89f871394df0f7a25cdf0b674aff Mon Sep 17 00:00:00 2001 From: Hyunsung Lee Date: Sat, 26 Jul 2025 13:56:22 +0900 Subject: [PATCH 07/10] fix upon review --- .../Linalg/Transforms/ElementwiseOpFusion.cpp | 21 ++++--------------- 1 file changed, 4 insertions(+), 17 deletions(-) diff --git a/mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp b/mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp index e99de0e78eabe..0687502cd1092 100644 --- a/mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp @@ -1136,23 +1136,19 @@ struct FoldReshapeWithProducerPadOpByExpansion SmallVector low = padOp.getMixedLowPad(); SmallVector high = padOp.getMixedHighPad(); - for (auto [reInd, l, h] : llvm::zip_equal(reassociations, low, high)) { - if (reInd.size() > 1 && - (!isConstantIntValue(l, 0) || !isConstantIntValue(h, 0))) + SmallVector newLow, newHigh; + for (auto [idx, reInd] : llvm::enumerate(reassociations)) { + if (reInd.size() > 1 && (!isConstantIntValue(low[idx], 0) || + !isConstantIntValue(high[idx], 0))) return rewriter.notifyMatchFailure( expandOp, "fusion blocked by non-zero padding"); - } - SmallVector newLow, newHigh; - for (auto [idx, reInd] : llvm::enumerate(reassociations)) { newLow.append(reInd.size(), low[idx]); newHigh.append(reInd.size(), high[idx]); } Location loc = expandOp.getLoc(); - ArrayRef finalShape = expandOp.getResultType().getShape(); SmallVector expandedShape = expandOp.getMixedOutputShape(); - for (auto [inDimIdx, reInd] : llvm::enumerate(reassociations)) { OpFoldResult l = low[inDimIdx]; OpFoldResult h = high[inDimIdx]; @@ -1165,15 +1161,6 @@ struct FoldReshapeWithProducerPadOpByExpansion } } - for (auto [outDimIdx, dimSize] : llvm::enumerate(finalShape)) { - if (dimSize == ShapedType::kDynamic && - !isa(expandedShape[outDimIdx]) && - !isa(expandedShape[outDimIdx])) { - expandedShape[outDimIdx] = - tensor::getMixedSize(rewriter, loc, expandOp.getSrc(), outDimIdx); - } - } - SmallVector staticExpandedShape; std::tie(staticExpandedShape, std::ignore) = decomposeMixedValues(expandedShape); From 9c38ad58f93dbf18fc42a2284fd86cf24f34d2cb Mon Sep 17 00:00:00 2001 From: Hyunsung Lee Date: Sat, 26 Jul 2025 14:24:13 +0900 Subject: [PATCH 08/10] fix upon review --- .../Linalg/Transforms/ElementwiseOpFusion.cpp | 23 +++++-------------- 1 file changed, 6 insertions(+), 17 deletions(-) diff --git a/mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp b/mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp index 0687502cd1092..86e287bae6cf5 100644 --- a/mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp @@ -1239,32 +1239,21 @@ struct FoldReshapeWithProducerPadOpByCollapsing auto resultType = collapseOp.getResultType(); ArrayRef finalShape = collapseOp.getResultType().getShape(); - - SmallVector collapsedShape; - for (int64_t dimSize : finalShape) { - if (dimSize == ShapedType::kDynamic) { - collapsedShape.push_back(OpFoldResult{}); - } else { - collapsedShape.push_back(rewriter.getI64IntegerAttr(dimSize)); - } - } - + SmallVector collapsedShape(finalShape.begin(), finalShape.end()); for (auto [inDimIdx, reInd] : llvm::enumerate(reassociations)) { OpFoldResult l = low[reInd[0]]; OpFoldResult h = high[reInd[0]]; - if (!isConstantIntValue(l, 0) || !isConstantIntValue(h, 0)) { - collapsedShape[inDimIdx] = + auto mixedSize = tensor::getMixedSize(rewriter, loc, padOp.getSource(), reInd[0]); + auto dimSize = getConstantIntValue(mixedSize); + assert(dimSize.has_value() && "Expected static dimension"); + collapsedShape[inDimIdx] = *dimSize; } } - SmallVector staticCollapsedShape; - std::tie(staticCollapsedShape, std::ignore) = - decomposeMixedValues(collapsedShape); - auto newCollapseType = RankedTensorType::get( - staticCollapsedShape, padOp.getSource().getType().getElementType()); + collapsedShape, padOp.getSource().getType().getElementType()); auto newCollapseOp = rewriter.create( loc, newCollapseType, padOp.getSource(), reassociations); From 0faf0849388e7180e3606f9403a6b72cb07048e6 Mon Sep 17 00:00:00 2001 From: Hyunsung Lee Date: Tue, 29 Jul 2025 18:46:34 +0900 Subject: [PATCH 09/10] fix upon review --- .../Linalg/Transforms/ElementwiseOpFusion.cpp | 25 ++++++------------- 1 file changed, 8 insertions(+), 17 deletions(-) diff --git a/mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp b/mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp index 86e287bae6cf5..a038d3c95c0d5 100644 --- a/mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp @@ -1217,20 +1217,15 @@ struct FoldReshapeWithProducerPadOpByCollapsing collapseOp.getReassociationIndices(); SmallVector low = padOp.getMixedLowPad(); SmallVector high = padOp.getMixedHighPad(); - + SmallVector newLow, newHigh; for (auto [idx, reInd] : llvm::enumerate(reassociations)) { - if (reInd.size() > 1) { - for (auto dimIdx : reInd) { - if (!isConstantIntValue(low[dimIdx], 0) || - !isConstantIntValue(high[dimIdx], 0)) { - return failure(); - } - } + if (reInd.size() > 1 && llvm::any_of(reInd, [&](int64_t dimIdx) { + return !isConstantIntValue(low[dimIdx], 0) || + !isConstantIntValue(high[dimIdx], 0); + })) { + return failure(); } - } - SmallVector newLow, newHigh; - for (auto [idx, reInd] : llvm::enumerate(reassociations)) { newLow.push_back(low[reInd[0]]); newHigh.push_back(high[reInd[0]]); } @@ -1244,16 +1239,12 @@ struct FoldReshapeWithProducerPadOpByCollapsing OpFoldResult l = low[reInd[0]]; OpFoldResult h = high[reInd[0]]; if (!isConstantIntValue(l, 0) || !isConstantIntValue(h, 0)) { - auto mixedSize = - tensor::getMixedSize(rewriter, loc, padOp.getSource(), reInd[0]); - auto dimSize = getConstantIntValue(mixedSize); - assert(dimSize.has_value() && "Expected static dimension"); - collapsedShape[inDimIdx] = *dimSize; + collapsedShape[inDimIdx] = padOp.getSourceType().getShape()[reInd[0]]; } } auto newCollapseType = RankedTensorType::get( - collapsedShape, padOp.getSource().getType().getElementType()); + collapsedShape, padOp.getSourceType().getElementType()); auto newCollapseOp = rewriter.create( loc, newCollapseType, padOp.getSource(), reassociations); From 9cbd03262a27b6c75e6b229d7e5ca0e1774f3204 Mon Sep 17 00:00:00 2001 From: Hyunsung Lee Date: Thu, 7 Aug 2025 07:20:46 +0900 Subject: [PATCH 10/10] fix comments --- .../Linalg/Transforms/ElementwiseOpFusion.cpp | 47 ++++++++++++++----- 1 file changed, 36 insertions(+), 11 deletions(-) diff --git a/mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp b/mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp index a038d3c95c0d5..3f58bfea23d41 100644 --- a/mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp @@ -1100,12 +1100,26 @@ class FoldPadWithProducerReshapeOpByExpansion ControlFusionFn controlFoldingReshapes; }; -/// Pattern to fold a tensor.expand_shape op with its producer tensor.pad op +/// Pattern to move a tensor.expand_shape op with its producer tensor.pad op /// by bubbling the expand_shape before the pad. -struct FoldReshapeWithProducerPadOpByExpansion +/// +/// ``` +/// BEFORE: +/// %padded = tensor.pad %input low[0, 1, 1] high[0, 1, 1] +/// tensor<512x256x256xf32> to tensor<512x258x258xf32> +/// %expanded = tensor.expand_shape %padded [[0, 1], [2], [3]] +/// tensor<512x258x258xf32> to tensor<32x16x258x258xf32> +/// +/// AFTER: +/// %expanded = tensor.expand_shape %input [[0, 1], [2], [3]] +/// tensor<512x256x256xf32> to tensor<32x16x256x256xf32> +/// %padded = tensor.pad %expanded low[0, 0, 1, 1] high[0, 0, 1, 1] +/// tensor<32x16x256x256xf32> to tensor<32x16x258x258xf32> +/// ``` +struct MoveReshapeWithProducerPadOpByExpansion : public OpRewritePattern { - FoldReshapeWithProducerPadOpByExpansion(MLIRContext *context, + MoveReshapeWithProducerPadOpByExpansion(MLIRContext *context, ControlFusionFn foldReshapes, PatternBenefit benefit = 1) : OpRewritePattern(context, benefit), @@ -1181,12 +1195,26 @@ struct FoldReshapeWithProducerPadOpByExpansion ControlFusionFn controlFoldingReshapes; }; -/// Pattern to fold a tensor.collapse_shape op with its producer tensor.pad op +/// Pattern to move a tensor.collapse_shape op with its producer tensor.pad op /// by bubbling the collapse_shape before the pad. -struct FoldReshapeWithProducerPadOpByCollapsing +/// +/// ``` +/// BEFORE: +/// %padded = tensor.pad %input low[1, 1, 0] high[1, 1, 0] +/// tensor<32x16x256xf32> to tensor<34x18x256xf32> +/// %collapsed = tensor.collapse_shape %padded [[0, 1], [2]] +/// tensor<34x18x256xf32> to tensor<612x256xf32> +/// +/// AFTER: +/// %collapsed = tensor.collapse_shape %input [[0, 1], [2]] +/// tensor<32x16x256xf32> to tensor<512x256xf32> +/// %padded = tensor.pad %collapsed low[1, 0] high[1, 0] +/// tensor<512x256xf32> to tensor<514x256xf32> +/// ``` +struct MoveReshapeWithProducerPadOpByCollapsing : public OpRewritePattern { - FoldReshapeWithProducerPadOpByCollapsing(MLIRContext *context, + MoveReshapeWithProducerPadOpByCollapsing(MLIRContext *context, ControlFusionFn foldReshapes, PatternBenefit benefit = 1) : OpRewritePattern(context, benefit), @@ -2394,7 +2422,7 @@ void mlir::linalg::populateFoldReshapeOpsByExpansionPatterns( controlFoldingReshapes); patterns.add(patterns.getContext(), controlFoldingReshapes); - patterns.add(patterns.getContext(), + patterns.add(patterns.getContext(), controlFoldingReshapes); patterns.add(patterns.getContext(), controlFoldingReshapes); @@ -2407,10 +2435,7 @@ void mlir::linalg::populateFoldReshapeOpsByCollapsingPatterns( controlFoldingReshapes); patterns.add( patterns.getContext(), controlFoldingReshapes); - patterns.add( - patterns.getContext(), controlFoldingReshapes); - - patterns.add( + patterns.add( patterns.getContext(), controlFoldingReshapes); patterns.add(patterns.getContext(), controlFoldingReshapes);