From d4a83dd7c6e8e9b7f93b918404f8b3526b406cc3 Mon Sep 17 00:00:00 2001 From: Max Dawkins Date: Thu, 20 Nov 2025 14:53:00 +0000 Subject: [PATCH 1/4] [mlir] Add missing pad reshape propagation patterns Signed-off-by: Max Dawkins --- .../Linalg/Transforms/ElementwiseOpFusion.cpp | 283 +++++++++++++++--- .../fuse-with-reshape-by-collapsing.mlir | 39 +++ mlir/test/Dialect/Linalg/reshape_fusion.mlir | 41 +++ 3 files changed, 314 insertions(+), 49 deletions(-) diff --git a/mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp b/mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp index 05fc7cbbb90af..8c5a0c1474408 100644 --- a/mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp @@ -1038,6 +1038,54 @@ class FoldWithProducerReshapeOpByExpansion ControlFusionFn controlFoldingReshapes; }; +/// Carries information about a padded dimension. +struct PadDimInfo { + // The resulting shape after padding each dimension. + SmallVector paddedShape; + + // Low and high padding amounts for each dimension. + SmallVector lowPad; + SmallVector highPad; +}; + +/// Computes the expanded padding information for the given pad operation based +/// on the provided expanded shape and reassociation indices. Returns a list of +/// PaddedDimInfo containing the low and high padding amounts and the padded +/// size for each dimension, or failure if the expansion is not possible. +static FailureOr +computeExpandedPadding(tensor::PadOp padOp, ArrayRef expandedShape, + ArrayRef reassociations, + PatternRewriter &rewriter) { + ArrayRef low = padOp.getStaticLow(); + ArrayRef high = padOp.getStaticHigh(); + + // Expanded dimensions cannot have padding because the resulting padding may + // not be representable by a tensor.pad op. There are some special cases where + // it is possible (like expanding unit dims), but supporting these cases is + // NYI, so disallow it for now. + for (auto [reInd, l, h] : llvm::zip_equal(reassociations, low, high)) { + if (reInd.size() != 1 && (l != 0 || h != 0)) + return failure(); + } + + SmallVector mixedLowPad(padOp.getMixedLowPad()); + SmallVector mixedHighPad(padOp.getMixedHighPad()); + ArrayRef paddedShape = padOp.getResultType().getShape(); + PadDimInfo padDimInfo; + padDimInfo.paddedShape.assign(expandedShape); + padDimInfo.lowPad.assign(expandedShape.size(), rewriter.getIndexAttr(0)); + padDimInfo.highPad.assign(expandedShape.size(), rewriter.getIndexAttr(0)); + for (auto [idx, reInd] : llvm::enumerate(reassociations)) { + if (reInd.size() == 1) { + padDimInfo.paddedShape[reInd[0]] = paddedShape[idx]; + padDimInfo.lowPad[reInd[0]] = mixedLowPad[idx]; + padDimInfo.highPad[reInd[0]] = mixedHighPad[idx]; + } + } + + return padDimInfo; +} + class FoldPadWithProducerReshapeOpByExpansion : public OpRewritePattern { public: @@ -1061,38 +1109,92 @@ class FoldPadWithProducerReshapeOpByExpansion "fusion blocked by control function"); } - ArrayRef low = padOp.getStaticLow(); - ArrayRef high = padOp.getStaticHigh(); + RankedTensorType expandedType = reshapeOp.getSrcType(); SmallVector reassociations = reshapeOp.getReassociationIndices(); + FailureOr maybeExpandedPadding = computeExpandedPadding( + padOp, expandedType.getShape(), reassociations, rewriter); + if (failed(maybeExpandedPadding)) + return failure(); + PadDimInfo expandedPadding = maybeExpandedPadding.value(); - for (auto [reInd, l, h] : llvm::zip_equal(reassociations, low, high)) { - if (reInd.size() != 1 && (l != 0 || h != 0)) - return failure(); + Location loc = padOp->getLoc(); + RankedTensorType expandedPaddedType = + padOp.getResultType().clone(expandedPadding.paddedShape); + + auto newPadOp = tensor::PadOp::create( + rewriter, loc, expandedPaddedType, reshapeOp.getSrc(), + expandedPadding.lowPad, expandedPadding.highPad, + padOp.getConstantPaddingValue(), padOp.getNofold()); + + rewriter.replaceOpWithNewOp( + padOp, padOp.getResultType(), newPadOp.getResult(), reassociations); + + return success(); + } + +private: + ControlFusionFn controlFoldingReshapes; +}; + +class FoldExpandShapeWithProducerPadOp + : public OpRewritePattern { +public: + FoldExpandShapeWithProducerPadOp(MLIRContext *context, + ControlFusionFn foldReshapes, + PatternBenefit benefit = 1) + : OpRewritePattern(context, benefit), + controlFoldingReshapes(std::move(foldReshapes)) {} + + LogicalResult matchAndRewrite(tensor::ExpandShapeOp expandOp, + PatternRewriter &rewriter) const override { + tensor::PadOp padOp = expandOp.getSrc().getDefiningOp(); + if (!padOp) + return failure(); + if (!padOp->hasOneUse()) + return failure(); + + if (!controlFoldingReshapes(&expandOp.getSrcMutable())) { + return rewriter.notifyMatchFailure(expandOp, + "fusion blocked by control function"); } - SmallVector newLow, newHigh; - RankedTensorType expandedType = reshapeOp.getSrcType(); - RankedTensorType paddedType = padOp.getResultType(); - SmallVector expandedPaddedShape(expandedType.getShape()); + RankedTensorType expandedType = expandOp.getResultType(); + SmallVector reassociations = + expandOp.getReassociationIndices(); + FailureOr maybeExpandedPadding = computeExpandedPadding( + padOp, expandedType.getShape(), reassociations, rewriter); + if (failed(maybeExpandedPadding)) + return failure(); + PadDimInfo expandedPadding = maybeExpandedPadding.value(); + + Location loc = expandOp->getLoc(); + SmallVector newExpandedSizes = expandOp.getMixedOutputShape(); + SmallVector newExpandedShape(expandedType.getShape()); + rewriter.setInsertionPointAfterValue(padOp.getSource()); + SmallVector padSrcSizes = + tensor::getMixedSizes(rewriter, loc, padOp.getSource()); for (auto [idx, reInd] : llvm::enumerate(reassociations)) { + // We know that any reassociation with multiple dims is not padded because + // of the requirements of computeExpandedPadding. if (reInd.size() == 1) { - expandedPaddedShape[reInd[0]] = paddedType.getShape()[idx]; - } - for (size_t i = 0; i < reInd.size(); ++i) { - newLow.push_back(padOp.getMixedLowPad()[idx]); - newHigh.push_back(padOp.getMixedHighPad()[idx]); + newExpandedShape[reInd[0]] = padOp.getSourceType().getDimSize(idx); + newExpandedSizes[reInd[0]] = padSrcSizes[idx]; } } - - Location loc = padOp->getLoc(); - RankedTensorType expandedPaddedType = paddedType.clone(expandedPaddedShape); + RankedTensorType newExpandedType = expandedType.clone(newExpandedShape); + auto newExpandOp = tensor::ExpandShapeOp::create( + rewriter, loc, newExpandedType, padOp.getSource(), reassociations, + newExpandedSizes); + RankedTensorType expandedPaddedType = + padOp.getResultType().clone(expandedPadding.paddedShape); + rewriter.setInsertionPoint(expandOp); auto newPadOp = tensor::PadOp::create( - rewriter, loc, expandedPaddedType, reshapeOp.getSrc(), newLow, newHigh, + rewriter, loc, expandedPaddedType, newExpandOp.getResult(), + expandedPadding.lowPad, expandedPadding.highPad, padOp.getConstantPaddingValue(), padOp.getNofold()); - rewriter.replaceOpWithNewOp( - padOp, padOp.getResultType(), newPadOp.getResult(), reassociations); + rewriter.replaceOp(expandOp, newPadOp.getResult()); return success(); } @@ -1921,6 +2023,52 @@ struct FoldReshapeWithGenericOpByCollapsing ControlFusionFn controlFoldingReshapes; }; +/// Computes the collapsed padding information for the given pad operation based +/// on the provided collapsed shape and reassociation indices. Returns a +/// PadDimInfo containing the low and high padding amounts and the collapsed +/// shape for each dimension, or failure if the collapse is not possible. +static FailureOr +computeCollapsedPadding(tensor::PadOp padOp, + ArrayRef reassociations, + PatternRewriter &rewriter) { + ArrayRef low = padOp.getStaticLow(); + ArrayRef high = padOp.getStaticHigh(); + + // Collapsed dimensions cannot have padding because this can produce strided + // padding that isn't representable by a tensor.pad op. There are some special + // cases where it it possible (like collapsing unit dims), but supporting + // these cases is NYI, so disallow it for now. + for (auto [idx, reInd] : llvm::enumerate(reassociations)) { + for (int64_t dim : reInd) { + if ((low[dim] != 0 || high[dim] != 0) && reInd.size() != 1) + return failure(); + } + } + + // Initialize padding values for collapsed tensors with zeros + ArrayRef expandedPaddedShape = padOp.getType().getShape(); + PadDimInfo padDimInfo; + padDimInfo.lowPad.assign(reassociations.size(), rewriter.getIndexAttr(0)); + padDimInfo.highPad.assign(reassociations.size(), rewriter.getIndexAttr(0)); + + // Update padding for dimensions that are not being collapsed, and compute + // the collapsed padded shape. + for (auto [idx, reInd] : llvm::enumerate(reassociations)) { + if (reInd.size() == 1) { + padDimInfo.lowPad[idx] = padOp.getMixedLowPad()[reInd[0]]; + padDimInfo.highPad[idx] = padOp.getMixedHighPad()[reInd[0]]; + } + SaturatedInteger collapsedSize = SaturatedInteger::wrap(1); + for (int64_t dim : reInd) { + collapsedSize = + collapsedSize * SaturatedInteger::wrap(expandedPaddedShape[dim]); + } + padDimInfo.paddedShape.push_back(collapsedSize.asInteger()); + } + + return padDimInfo; +} + class FoldPadWithProducerReshapeOpByCollapsing : public OpRewritePattern { public: @@ -1944,49 +2092,34 @@ class FoldPadWithProducerReshapeOpByCollapsing "fusion blocked by control function"); } - ArrayRef low = padOp.getStaticLow(); - ArrayRef high = padOp.getStaticHigh(); SmallVector reassociations = reshapeOp.getReassociationIndices(); + FailureOr maybeCollapsedPadding = + computeCollapsedPadding(padOp, reassociations, rewriter); + if (failed(maybeCollapsedPadding)) + return failure(); + PadDimInfo collapsedPadding = maybeCollapsedPadding.value(); - for (auto reInd : reassociations) { - if (reInd.size() == 1) - continue; - if (llvm::any_of(reInd, [&](int64_t ind) { - return low[ind] != 0 || high[ind] != 0; - })) { - return failure(); - } - } - - SmallVector newLow, newHigh; - RankedTensorType collapsedType = reshapeOp.getSrcType(); - RankedTensorType paddedType = padOp.getResultType(); - SmallVector collapsedPaddedShape(collapsedType.getShape()); - SmallVector expandedPaddedSizes( - getMixedValues(reshapeOp.getStaticOutputShape(), - reshapeOp.getOutputShape(), rewriter)); + SmallVector expandedPaddedSizes = + reshapeOp.getMixedOutputShape(); AffineExpr d0, d1, d2; bindDims(rewriter.getContext(), d0, d1, d2); auto addMap = AffineMap::get(3, 0, {d0 + d1 + d2}); Location loc = reshapeOp->getLoc(); - for (auto [idx, reInd] : llvm::enumerate(reassociations)) { - OpFoldResult l = padOp.getMixedLowPad()[reInd[0]]; - OpFoldResult h = padOp.getMixedHighPad()[reInd[0]]; + for (auto [reInd, l, h] : + llvm::zip_equal(reassociations, collapsedPadding.lowPad, + collapsedPadding.highPad)) { if (reInd.size() == 1) { - collapsedPaddedShape[idx] = paddedType.getShape()[reInd[0]]; - OpFoldResult paddedSize = affine::makeComposedFoldedAffineApply( + expandedPaddedSizes[reInd[0]] = affine::makeComposedFoldedAffineApply( rewriter, loc, addMap, {l, h, expandedPaddedSizes[reInd[0]]}); - expandedPaddedSizes[reInd[0]] = paddedSize; } - newLow.push_back(l); - newHigh.push_back(h); } RankedTensorType collapsedPaddedType = - paddedType.clone(collapsedPaddedShape); + padOp.getType().clone(collapsedPadding.paddedShape); auto newPadOp = tensor::PadOp::create( - rewriter, loc, collapsedPaddedType, reshapeOp.getSrc(), newLow, newHigh, + rewriter, loc, collapsedPaddedType, reshapeOp.getSrc(), + collapsedPadding.lowPad, collapsedPadding.highPad, padOp.getConstantPaddingValue(), padOp.getNofold()); rewriter.replaceOpWithNewOp( @@ -2000,6 +2133,54 @@ class FoldPadWithProducerReshapeOpByCollapsing ControlFusionFn controlFoldingReshapes; }; +class FoldReshapeWithProducerPadOpByCollapsing + : public OpRewritePattern { +public: + FoldReshapeWithProducerPadOpByCollapsing(MLIRContext *context, + ControlFusionFn foldReshapes, + PatternBenefit benefit = 1) + : OpRewritePattern(context, benefit), + controlFoldingReshapes(std::move(foldReshapes)) {} + + LogicalResult matchAndRewrite(tensor::CollapseShapeOp reshapeOp, + PatternRewriter &rewriter) const override { + tensor::PadOp padOp = reshapeOp.getSrc().getDefiningOp(); + if (!padOp) + return failure(); + if (!padOp->hasOneUse()) + return failure(); + + if (!controlFoldingReshapes(&reshapeOp.getSrcMutable())) { + return rewriter.notifyMatchFailure(padOp, + "fusion blocked by control function"); + } + + SmallVector reassociations = + reshapeOp.getReassociationIndices(); + RankedTensorType collapsedPaddedType = reshapeOp.getResultType(); + FailureOr maybeCollapsedPadding = + computeCollapsedPadding(padOp, reassociations, rewriter); + if (failed(maybeCollapsedPadding)) + return failure(); + PadDimInfo collapsedPadding = maybeCollapsedPadding.value(); + + Location loc = reshapeOp->getLoc(); + auto newCollapseOp = tensor::CollapseShapeOp::create( + rewriter, loc, padOp.getSource(), reassociations); + + auto newPadOp = tensor::PadOp::create( + rewriter, loc, collapsedPaddedType, newCollapseOp.getResult(), + collapsedPadding.lowPad, collapsedPadding.highPad, + padOp.getConstantPaddingValue(), padOp.getNofold()); + + rewriter.replaceOp(reshapeOp, newPadOp.getResult()); + return success(); + } + +private: + ControlFusionFn controlFoldingReshapes; +}; + /// Pattern to collapse dimensions. template class CollapseLinalgDimensions : public OpRewritePattern { @@ -2239,6 +2420,8 @@ void mlir::linalg::populateFoldReshapeOpsByExpansionPatterns( controlFoldingReshapes); patterns.add(patterns.getContext(), controlFoldingReshapes); + patterns.add(patterns.getContext(), + controlFoldingReshapes); patterns.add(patterns.getContext(), controlFoldingReshapes); } @@ -2250,6 +2433,8 @@ void mlir::linalg::populateFoldReshapeOpsByCollapsingPatterns( controlFoldingReshapes); patterns.add( patterns.getContext(), controlFoldingReshapes); + patterns.add( + patterns.getContext(), controlFoldingReshapes); patterns.add(patterns.getContext(), controlFoldingReshapes); } diff --git a/mlir/test/Dialect/Linalg/fuse-with-reshape-by-collapsing.mlir b/mlir/test/Dialect/Linalg/fuse-with-reshape-by-collapsing.mlir index 2bf3d21c35526..923bb2ca9c28a 100644 --- a/mlir/test/Dialect/Linalg/fuse-with-reshape-by-collapsing.mlir +++ b/mlir/test/Dialect/Linalg/fuse-with-reshape-by-collapsing.mlir @@ -639,6 +639,45 @@ func.func @fuse_by_collapsing_dynamic_pad(%arg0 : tensor, // CHECK-SAME: output_shape [%[[PAD_SIZE0]], %[[S1]], %[[S2]], %[[PAD_SIZE1]], %[[S4]], %[[S5]]] : tensor into tensor // CHECK: return %[[EXPAND]] +// ----- + +func.func @collapse_shape_with_producer_pad(%arg0: tensor<2x3x4x5x6x7x8x9xi32>) -> tensor<8x12x17x336x14xi32> { + %cst = arith.constant 0 : i32 + %padded = tensor.pad %arg0 low[1, 0, 0, 8, 0, 0, 0, 3] high[5, 0, 0, 4, 0, 0, 0, 2] { + ^bb0(%arg1: index, %arg2: index, %arg3: index, %arg4: index, + %arg5: index, %arg6: index, %arg7: index, %arg8: index): + tensor.yield %cst : i32 + } : tensor<2x3x4x5x6x7x8x9xi32> to tensor<8x3x4x17x6x7x8x14xi32> + %collapsed = tensor.collapse_shape %padded [[0], [1, 2], [3], [4, 5, 6], [7]] + : tensor<8x3x4x17x6x7x8x14xi32> into tensor<8x12x17x336x14xi32> + return %collapsed : tensor<8x12x17x336x14xi32> +} +// CHECK: func @collapse_shape_with_producer_pad +// CHECK-SAME: %[[ARG0:.+]]: tensor<2x3x4x5x6x7x8x9xi32> +// CHECK: %[[COLLAPSE:.+]] = tensor.collapse_shape %[[ARG0]] {{\[}}[0], [1, 2], [3], [4, 5, 6], [7]] +// CHECK: %[[PAD:.+]] = tensor.pad %[[COLLAPSE]] low[1, 0, 8, 0, 3] high[5, 0, 4, 0, 2] +// CHECK: return %[[PAD]] + +// ----- + +func.func @collapse_shape_with_producer_pad_dynamic(%arg0: tensor, + %l0 : index, %l1 : index, %h0 : index, %h1 : index) -> tensor { + %cst = arith.constant 0.0 : f32 + %padded = tensor.pad %arg0 low[%l0, 0, 0, %l1, 0, 0] high[%h0, 0, 0, %h1, 0, 0] { + ^bb0(%arg1: index, %arg2: index, %arg3: index, %arg4: index, %arg5: index, %arg6: index): + tensor.yield %cst : f32 + } : tensor to tensor + %collapsed = tensor.collapse_shape %padded [[0], [1, 2], [3], [4, 5]] + : tensor into tensor + return %collapsed : tensor +} +// CHECK: func @collapse_shape_with_producer_pad_dynamic +// CHECK-SAME: %[[ARG0:.+]]: tensor +// CHECK-SAME: %[[L0:.+]]: index, %[[L1:.+]]: index, %[[H0:.+]]: index, %[[H1:.+]]: index +// CHECK: %[[COLLAPSE:.+]] = tensor.collapse_shape %[[ARG0]] {{\[}}[0], [1, 2], [3], [4, 5]] +// CHECK: %[[PAD:.+]] = tensor.pad %[[COLLAPSE]] low[%[[L0]], 0, %[[L1]], 0] high[%[[H0]], 0, %[[H1]], 0] +// CHECK: return %[[PAD]] + // ----- // Static problem sizes. Checks all aspects of fusion by collapsing with bubbling up collapse shapes. #map0 = affine_map<(d0, d1, d2, d3, d4, d5, d6, d7) -> (d0, d1, d2, d3, d4, d5, d6, d7)> diff --git a/mlir/test/Dialect/Linalg/reshape_fusion.mlir b/mlir/test/Dialect/Linalg/reshape_fusion.mlir index 67b4f2b32bad5..f6572674d10e2 100644 --- a/mlir/test/Dialect/Linalg/reshape_fusion.mlir +++ b/mlir/test/Dialect/Linalg/reshape_fusion.mlir @@ -863,6 +863,47 @@ func.func @fuse_by_expanding_dynamic_pad(%arg0 : tensor, %l0: i // ----- +func.func @expand_shape_with_producer_pad(%arg0: tensor<2x12x5x336x9xi32>) -> tensor<8x3x4x17x6x7x8x14xi32> { + %cst = arith.constant 0 : i32 + %padded = tensor.pad %arg0 low[1, 0, 8, 0, 3] high[5, 0, 4, 0, 2] { + ^bb0(%arg1: index, %arg2: index, %arg3: index, %arg4: index, %arg5: index): + tensor.yield %cst : i32 + } : tensor<2x12x5x336x9xi32> to tensor<8x12x17x336x14xi32> + %expanded = tensor.expand_shape %padded [[0], [1, 2], [3], [4, 5, 6], [7]] output_shape [8, 3, 4, 17, 6, 7, 8, 14] + : tensor<8x12x17x336x14xi32> into tensor<8x3x4x17x6x7x8x14xi32> + return %expanded : tensor<8x3x4x17x6x7x8x14xi32> +} +// CHECK: func @expand_shape_with_producer_pad +// CHECK-SAME: %[[ARG0:.+]]: tensor<2x12x5x336x9xi32> +// CHECK: %[[EXPAND:.+]] = tensor.expand_shape %[[ARG0]] {{\[}}[0], [1, 2], [3], [4, 5, 6], [7]] output_shape [2, 3, 4, 5, 6, 7, 8, 9] +// CHECK: %[[PAD:.+]] = tensor.pad %[[EXPAND]] low[1, 0, 0, 8, 0, 0, 0, 3] high[5, 0, 0, 4, 0, 0, 0, 2] +// CHECK: return %[[PAD]] + +// ----- + +func.func @expand_shape_with_producer_pad_dynamic(%arg0: tensor, + %s0: index, %s1: index, %s2: index, %s3: index, %s4: index, %s5: index, + %l0: index, %l1: index, %h0: index, %h1: index) -> tensor { + %cst = arith.constant 0.0 : f32 + %padded = tensor.pad %arg0 low[%l0, 0, %l1, 0] high[%h0, 0, %h1, 0] { + ^bb0(%arg1: index, %arg2: index, %arg3: index, %arg4: index): + tensor.yield %cst : f32 + } : tensor to tensor + %expanded = tensor.expand_shape %padded [[0], [1, 2], [3], [4, 5]] output_shape [%s0, %s1, %s2, %s3, %s4, %s5] + : tensor into tensor + return %expanded : tensor +} +// CHECK: func @expand_shape_with_producer_pad_dynamic +// CHECK-SAME: %[[ARG0:.+]]: tensor +// CHECK-SAME: %[[S0:.+]]: index, %[[S1:.+]]: index, %[[S2:.+]]: index, %[[S3:.+]]: index, %[[S4:.+]]: index, %[[S5:.+]]: index, %[[L0:.+]]: index, %[[L1:.+]]: index, %[[H0:.+]]: index, %[[H1:.+]]: index +// CHECK: %[[DIM0:.+]] = tensor.dim %[[ARG0]], %[[C0:.+]] : tensor +// CHECK: %[[DIM2:.+]] = tensor.dim %[[ARG0]], %[[C2:.+]] : tensor +// CHECK: %[[EXPAND:.+]] = tensor.expand_shape %[[ARG0]] {{\[}}[0], [1, 2], [3], [4, 5]] output_shape [%[[DIM0]], %[[S1]], %[[S2]], %[[DIM2]], %[[S4]], %[[S5]]] +// CHECK: %[[PAD:.+]] = tensor.pad %[[EXPAND]] low[%[[L0]], 0, 0, %[[L1]], 0, 0] high[%[[H0]], 0, 0, %[[H1]], 0, 0] +// CHECK: return %[[PAD]] + +// ----- + func.func @move_operand_deps(%arg0 : tensor, %arg1 : tensor<4x?x32x128xf16>, %empty : tensor<4x?x32x128xf16>) -> tensor<4x?x32x8x16xf16> { %c0 = arith.constant 0 : index From edf7e3b6d4ef0412f3b1bc83fe01fe212f6ca8e7 Mon Sep 17 00:00:00 2001 From: Max Dawkins Date: Thu, 20 Nov 2025 21:51:08 +0000 Subject: [PATCH 2/4] address comments Signed-off-by: Max Dawkins --- .../Linalg/Transforms/ElementwiseOpFusion.cpp | 56 ++++++++++--------- 1 file changed, 30 insertions(+), 26 deletions(-) diff --git a/mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp b/mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp index 8c5a0c1474408..1e110d1c6b113 100644 --- a/mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp @@ -1050,19 +1050,24 @@ struct PadDimInfo { /// Computes the expanded padding information for the given pad operation based /// on the provided expanded shape and reassociation indices. Returns a list of -/// PaddedDimInfo containing the low and high padding amounts and the padded +/// PadDimInfo containing the low and high padding amounts and the padded /// size for each dimension, or failure if the expansion is not possible. static FailureOr computeExpandedPadding(tensor::PadOp padOp, ArrayRef expandedShape, ArrayRef reassociations, PatternRewriter &rewriter) { - ArrayRef low = padOp.getStaticLow(); - ArrayRef high = padOp.getStaticHigh(); + // If the padding value depends on the index values of the pad operation, + // then it may not be valid to expand the dimensions, since it will change + // the index values on which the padding value depends. + if (!padOp.getConstantPaddingValue()) + return failure(); // Expanded dimensions cannot have padding because the resulting padding may // not be representable by a tensor.pad op. There are some special cases where // it is possible (like expanding unit dims), but supporting these cases is // NYI, so disallow it for now. + ArrayRef low = padOp.getStaticLow(); + ArrayRef high = padOp.getStaticHigh(); for (auto [reInd, l, h] : llvm::zip_equal(reassociations, low, high)) { if (reInd.size() != 1 && (l != 0 || h != 0)) return failure(); @@ -1101,8 +1106,6 @@ class FoldPadWithProducerReshapeOpByExpansion padOp.getSource().getDefiningOp(); if (!reshapeOp) return failure(); - if (!reshapeOp->hasOneUse()) - return failure(); if (!controlFoldingReshapes(&padOp.getSourceMutable())) { return rewriter.notifyMatchFailure(padOp, @@ -1116,7 +1119,7 @@ class FoldPadWithProducerReshapeOpByExpansion padOp, expandedType.getShape(), reassociations, rewriter); if (failed(maybeExpandedPadding)) return failure(); - PadDimInfo expandedPadding = maybeExpandedPadding.value(); + PadDimInfo &expandedPadding = maybeExpandedPadding.value(); Location loc = padOp->getLoc(); RankedTensorType expandedPaddedType = @@ -1137,12 +1140,12 @@ class FoldPadWithProducerReshapeOpByExpansion ControlFusionFn controlFoldingReshapes; }; -class FoldExpandShapeWithProducerPadOp +class FoldReshapeWithProducerPadOpByExpansion : public OpRewritePattern { public: - FoldExpandShapeWithProducerPadOp(MLIRContext *context, - ControlFusionFn foldReshapes, - PatternBenefit benefit = 1) + FoldReshapeWithProducerPadOpByExpansion(MLIRContext *context, + ControlFusionFn foldReshapes, + PatternBenefit benefit = 1) : OpRewritePattern(context, benefit), controlFoldingReshapes(std::move(foldReshapes)) {} @@ -1151,8 +1154,6 @@ class FoldExpandShapeWithProducerPadOp tensor::PadOp padOp = expandOp.getSrc().getDefiningOp(); if (!padOp) return failure(); - if (!padOp->hasOneUse()) - return failure(); if (!controlFoldingReshapes(&expandOp.getSrcMutable())) { return rewriter.notifyMatchFailure(expandOp, @@ -1166,7 +1167,7 @@ class FoldExpandShapeWithProducerPadOp padOp, expandedType.getShape(), reassociations, rewriter); if (failed(maybeExpandedPadding)) return failure(); - PadDimInfo expandedPadding = maybeExpandedPadding.value(); + PadDimInfo &expandedPadding = maybeExpandedPadding.value(); Location loc = expandOp->getLoc(); SmallVector newExpandedSizes = expandOp.getMixedOutputShape(); @@ -2031,13 +2032,18 @@ static FailureOr computeCollapsedPadding(tensor::PadOp padOp, ArrayRef reassociations, PatternRewriter &rewriter) { - ArrayRef low = padOp.getStaticLow(); - ArrayRef high = padOp.getStaticHigh(); + // If the padding value depends on the index values of the pad operation, + // then it may not be valid to collapse the dimensions, since it will change + // the index values on which the padding value depends. + if (!padOp.getConstantPaddingValue()) + return failure(); // Collapsed dimensions cannot have padding because this can produce strided // padding that isn't representable by a tensor.pad op. There are some special - // cases where it it possible (like collapsing unit dims), but supporting + // cases where it is possible (like collapsing unit dims), but supporting // these cases is NYI, so disallow it for now. + ArrayRef low = padOp.getStaticLow(); + ArrayRef high = padOp.getStaticHigh(); for (auto [idx, reInd] : llvm::enumerate(reassociations)) { for (int64_t dim : reInd) { if ((low[dim] != 0 || high[dim] != 0) && reInd.size() != 1) @@ -2053,10 +2059,12 @@ computeCollapsedPadding(tensor::PadOp padOp, // Update padding for dimensions that are not being collapsed, and compute // the collapsed padded shape. + SmallVector mixedLowPad(padOp.getMixedLowPad()); + SmallVector mixedHighPad(padOp.getMixedHighPad()); for (auto [idx, reInd] : llvm::enumerate(reassociations)) { if (reInd.size() == 1) { - padDimInfo.lowPad[idx] = padOp.getMixedLowPad()[reInd[0]]; - padDimInfo.highPad[idx] = padOp.getMixedHighPad()[reInd[0]]; + padDimInfo.lowPad[idx] = mixedLowPad[reInd[0]]; + padDimInfo.highPad[idx] = mixedHighPad[reInd[0]]; } SaturatedInteger collapsedSize = SaturatedInteger::wrap(1); for (int64_t dim : reInd) { @@ -2084,8 +2092,6 @@ class FoldPadWithProducerReshapeOpByCollapsing padOp.getSource().getDefiningOp(); if (!reshapeOp) return failure(); - if (!reshapeOp->hasOneUse()) - return failure(); if (!controlFoldingReshapes(&padOp.getSourceMutable())) { return rewriter.notifyMatchFailure(padOp, @@ -2098,7 +2104,7 @@ class FoldPadWithProducerReshapeOpByCollapsing computeCollapsedPadding(padOp, reassociations, rewriter); if (failed(maybeCollapsedPadding)) return failure(); - PadDimInfo collapsedPadding = maybeCollapsedPadding.value(); + PadDimInfo &collapsedPadding = maybeCollapsedPadding.value(); SmallVector expandedPaddedSizes = reshapeOp.getMixedOutputShape(); @@ -2147,8 +2153,6 @@ class FoldReshapeWithProducerPadOpByCollapsing tensor::PadOp padOp = reshapeOp.getSrc().getDefiningOp(); if (!padOp) return failure(); - if (!padOp->hasOneUse()) - return failure(); if (!controlFoldingReshapes(&reshapeOp.getSrcMutable())) { return rewriter.notifyMatchFailure(padOp, @@ -2162,7 +2166,7 @@ class FoldReshapeWithProducerPadOpByCollapsing computeCollapsedPadding(padOp, reassociations, rewriter); if (failed(maybeCollapsedPadding)) return failure(); - PadDimInfo collapsedPadding = maybeCollapsedPadding.value(); + PadDimInfo &collapsedPadding = maybeCollapsedPadding.value(); Location loc = reshapeOp->getLoc(); auto newCollapseOp = tensor::CollapseShapeOp::create( @@ -2420,8 +2424,8 @@ void mlir::linalg::populateFoldReshapeOpsByExpansionPatterns( controlFoldingReshapes); patterns.add(patterns.getContext(), controlFoldingReshapes); - patterns.add(patterns.getContext(), - controlFoldingReshapes); + patterns.add(patterns.getContext(), + controlFoldingReshapes); patterns.add(patterns.getContext(), controlFoldingReshapes); } From 9f3df77784b2c58f8a88ae274b490163e0b0d85c Mon Sep 17 00:00:00 2001 From: Max Dawkins Date: Thu, 20 Nov 2025 22:23:53 +0000 Subject: [PATCH 3/4] add tests for non const pad val Signed-off-by: Max Dawkins --- .../fuse-with-reshape-by-collapsing.mlir | 36 +++++++++++++++++++ mlir/test/Dialect/Linalg/reshape_fusion.mlir | 34 ++++++++++++++++++ 2 files changed, 70 insertions(+) diff --git a/mlir/test/Dialect/Linalg/fuse-with-reshape-by-collapsing.mlir b/mlir/test/Dialect/Linalg/fuse-with-reshape-by-collapsing.mlir index 923bb2ca9c28a..77c7d7d69a77d 100644 --- a/mlir/test/Dialect/Linalg/fuse-with-reshape-by-collapsing.mlir +++ b/mlir/test/Dialect/Linalg/fuse-with-reshape-by-collapsing.mlir @@ -594,6 +594,24 @@ func.func @fuse_by_collapsing_pad(%arg0 : tensor<2x12x5x336x9xi32>) -> tensor<8x // ----- +func.func @no_fuse_by_collapsing_pad_non_constant_padding(%arg0 : tensor<2x12xi32>) -> tensor<8x3x4xi32> { + %expand = tensor.expand_shape %arg0 [[0], [1, 2]] output_shape [2, 3, 4] : tensor<2x12xi32> into tensor<2x3x4xi32> + %cst = arith.constant 0 : i32 + %padded_0 = tensor.pad %expand low[1, 0, 0] high[5, 0, 0] { + ^bb0(%arg1: index, %arg2: index, %arg3: index): + %pad_val = arith.index_cast %arg1 : index to i32 + tensor.yield %pad_val : i32 + } : tensor<2x3x4xi32> to tensor<8x3x4xi32> + return %padded_0 : tensor<8x3x4xi32> +} +// CHECK: func @no_fuse_by_collapsing_pad_non_constant_padding( +// CHECK-SAME: %[[ARG0:.+]]: tensor<2x12xi32>) +// CHECK: %[[EXPAND:.+]] = tensor.expand_shape %[[ARG0]] +// CHECK: %[[PAD:.+]] = tensor.pad %[[EXPAND]] +// CHECK: return %[[PAD]] + +// ----- + func.func @no_fuse_by_collapsing_pad(%arg0 : tensor<2x12x5x336x9xi32>) -> tensor<8x5x4x17x6x7x8x14xi32> { %expand = tensor.expand_shape %arg0 [[0], [1, 2], [3], [4, 5, 6], [7]] output_shape [2, 3, 4, 5, 6, 7, 8, 9] : tensor<2x12x5x336x9xi32> into tensor<2x3x4x5x6x7x8x9xi32> %cst = arith.constant 0 : i32 @@ -678,6 +696,24 @@ func.func @collapse_shape_with_producer_pad_dynamic(%arg0: tensor) -> tensor<8x12xi32> { + %cst = arith.constant 0 : i32 + %padded_0 = tensor.pad %arg0 low[1, 0, 0] high[5, 0, 0] { + ^bb0(%arg1: index, %arg2: index, %arg3: index): + %pad_val = arith.index_cast %arg1 : index to i32 + tensor.yield %pad_val : i32 + } : tensor<2x3x4xi32> to tensor<8x3x4xi32> + %collapsed = tensor.collapse_shape %padded_0 [[0], [1, 2]] : tensor<8x3x4xi32> into tensor<8x12xi32> + return %collapsed : tensor<8x12xi32> +} +// CHECK: func @collapse_shape_with_producer_pad_non_constant_padding( +// CHECK-SAME: %[[ARG0:.+]]: tensor<2x3x4xi32>) +// CHECK: %[[PAD:.+]] = tensor.pad %[[ARG0]] +// CHECK: %[[COLLAPSED:.+]] = tensor.collapse_shape %[[PAD]] +// CHECK: return %[[COLLAPSED]] + // ----- // Static problem sizes. Checks all aspects of fusion by collapsing with bubbling up collapse shapes. #map0 = affine_map<(d0, d1, d2, d3, d4, d5, d6, d7) -> (d0, d1, d2, d3, d4, d5, d6, d7)> diff --git a/mlir/test/Dialect/Linalg/reshape_fusion.mlir b/mlir/test/Dialect/Linalg/reshape_fusion.mlir index f6572674d10e2..3fb7225069983 100644 --- a/mlir/test/Dialect/Linalg/reshape_fusion.mlir +++ b/mlir/test/Dialect/Linalg/reshape_fusion.mlir @@ -822,6 +822,23 @@ func.func @fuse_by_expanding_pad(%arg0 : tensor<2x3x4x5x6x7x8x9xi32>) -> tensor< // ----- +func.func @no_fuse_by_expanding_pad_non_constant_padding(%arg0 : tensor<2x3x4xi32>) -> tensor<8x12xi32> { + %collapse = tensor.collapse_shape %arg0 [[0], [1, 2]] : tensor<2x3x4xi32> into tensor<2x12xi32> + %padded_0 = tensor.pad %collapse low[1, 0] high[5, 0] { + ^bb0(%arg1: index, %arg2: index): + %pad_val = arith.index_cast %arg1 : index to i32 + tensor.yield %pad_val : i32 + } : tensor<2x12xi32> to tensor<8x12xi32> + return %padded_0 : tensor<8x12xi32> +} +// CHECK: func @no_fuse_by_expanding_pad_non_constant_padding( +// CHECK-SAME: %[[ARG0:.+]]: tensor<2x3x4xi32>) +// CHECK: %[[COLLAPSE:.+]] = tensor.collapse_shape %[[ARG0]] +// CHECK: %[[PAD:.+]] = tensor.pad %[[COLLAPSE]] +// CHECK: return %[[PAD]] + +// ----- + func.func @no_fuse_by_expanding_pad(%arg0 : tensor<2x3x4x5x6x7x8x9xi32>) -> tensor<8x12x17x339x14xi32> { %collapse = tensor.collapse_shape %arg0 [[0], [1, 2], [3], [4, 5, 6], [7]] : tensor<2x3x4x5x6x7x8x9xi32> into tensor<2x12x5x336x9xi32> %cst = arith.constant 0 : i32 @@ -904,6 +921,23 @@ func.func @expand_shape_with_producer_pad_dynamic(%arg0: tensor, // ----- +func.func @expand_shape_with_producer_pad_non_constant_padding(%arg0 : tensor<2x12xi32>) -> tensor<8x3x4xi32> { + %padded_0 = tensor.pad %arg0 low[1, 0] high[5, 0] { + ^bb0(%arg1: index, %arg2: index): + %pad_val = arith.index_cast %arg1 : index to i32 + tensor.yield %pad_val : i32 + } : tensor<2x12xi32> to tensor<8x12xi32> + %expand = tensor.expand_shape %padded_0 [[0], [1, 2]] output_shape [8, 3, 4] : tensor<8x12xi32> into tensor<8x3x4xi32> + return %expand : tensor<8x3x4xi32> +} +// CHECK: func @expand_shape_with_producer_pad_non_constant_padding( +// CHECK-SAME: %[[ARG0:.+]]: tensor<2x12xi32>) +// CHECK: %[[PAD:.+]] = tensor.pad %[[ARG0]] +// CHECK: %[[EXPAND:.+]] = tensor.expand_shape %[[PAD]] +// CHECK: return %[[EXPAND]] + +// ----- + func.func @move_operand_deps(%arg0 : tensor, %arg1 : tensor<4x?x32x128xf16>, %empty : tensor<4x?x32x128xf16>) -> tensor<4x?x32x8x16xf16> { %c0 = arith.constant 0 : index From 3594ff06e7f1ad4c7ef1f5b0afbf49bc773523de Mon Sep 17 00:00:00 2001 From: Max Dawkins Date: Mon, 1 Dec 2025 17:11:26 +0000 Subject: [PATCH 4/4] clarify comments Signed-off-by: Max Dawkins --- .../Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp | 10 ++++++++-- 1 file changed, 8 insertions(+), 2 deletions(-) diff --git a/mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp b/mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp index 1e110d1c6b113..421ab5e3760a7 100644 --- a/mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp @@ -1058,7 +1058,10 @@ computeExpandedPadding(tensor::PadOp padOp, ArrayRef expandedShape, PatternRewriter &rewriter) { // If the padding value depends on the index values of the pad operation, // then it may not be valid to expand the dimensions, since it will change - // the index values on which the padding value depends. + // the index values on which the padding value depends. This is not currently + // supported by the pad expansion patterns, but it could be implemented + // similarly to the expansion of linalg.generic ops with linalg.index ops in + // the body, as is done in `updateExpandedGenericOpRegion`. if (!padOp.getConstantPaddingValue()) return failure(); @@ -2034,7 +2037,10 @@ computeCollapsedPadding(tensor::PadOp padOp, PatternRewriter &rewriter) { // If the padding value depends on the index values of the pad operation, // then it may not be valid to collapse the dimensions, since it will change - // the index values on which the padding value depends. + // the index values on which the padding value depends. This is not currently + // supported by the pad collapsing patterns, but it could be implemented + // similarly to the collapsing of linalg.generic ops with linalg.index ops in + // the body, as is done in `generateCollapsedIndexingRegion`. if (!padOp.getConstantPaddingValue()) return failure();