diff --git a/mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp b/mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp index 9c0f6e5d6469e..3f58bfea23d41 100644 --- a/mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp @@ -1100,6 +1100,193 @@ class FoldPadWithProducerReshapeOpByExpansion ControlFusionFn controlFoldingReshapes; }; +/// Pattern to move a tensor.expand_shape op with its producer tensor.pad op +/// by bubbling the expand_shape before the pad. +/// +/// ``` +/// BEFORE: +/// %padded = tensor.pad %input low[0, 1, 1] high[0, 1, 1] +/// tensor<512x256x256xf32> to tensor<512x258x258xf32> +/// %expanded = tensor.expand_shape %padded [[0, 1], [2], [3]] +/// tensor<512x258x258xf32> to tensor<32x16x258x258xf32> +/// +/// AFTER: +/// %expanded = tensor.expand_shape %input [[0, 1], [2], [3]] +/// tensor<512x256x256xf32> to tensor<32x16x256x256xf32> +/// %padded = tensor.pad %expanded low[0, 0, 1, 1] high[0, 0, 1, 1] +/// tensor<32x16x256x256xf32> to tensor<32x16x258x258xf32> +/// ``` +struct MoveReshapeWithProducerPadOpByExpansion + : public OpRewritePattern { + + MoveReshapeWithProducerPadOpByExpansion(MLIRContext *context, + ControlFusionFn foldReshapes, + PatternBenefit benefit = 1) + : OpRewritePattern(context, benefit), + controlFoldingReshapes(std::move(foldReshapes)) {} + + LogicalResult matchAndRewrite(tensor::ExpandShapeOp expandOp, + PatternRewriter &rewriter) const override { + tensor::PadOp padOp = expandOp.getSrc().getDefiningOp(); + if (!padOp) + return failure(); + + if (!padOp->hasOneUse()) + return failure(); + + if (!controlFoldingReshapes(&expandOp.getSrcMutable())) { + return rewriter.notifyMatchFailure(expandOp, + "fusion blocked by control function"); + } + + Value constantPaddingValue = padOp.getConstantPaddingValue(); + if (!constantPaddingValue) { + return rewriter.notifyMatchFailure( + expandOp, "cannot fold with non-constant padding value"); + } + + SmallVector reassociations = + expandOp.getReassociationIndices(); + SmallVector low = padOp.getMixedLowPad(); + SmallVector high = padOp.getMixedHighPad(); + + SmallVector newLow, newHigh; + for (auto [idx, reInd] : llvm::enumerate(reassociations)) { + if (reInd.size() > 1 && (!isConstantIntValue(low[idx], 0) || + !isConstantIntValue(high[idx], 0))) + return rewriter.notifyMatchFailure( + expandOp, "fusion blocked by non-zero padding"); + + newLow.append(reInd.size(), low[idx]); + newHigh.append(reInd.size(), high[idx]); + } + + Location loc = expandOp.getLoc(); + SmallVector expandedShape = expandOp.getMixedOutputShape(); + for (auto [inDimIdx, reInd] : llvm::enumerate(reassociations)) { + OpFoldResult l = low[inDimIdx]; + OpFoldResult h = high[inDimIdx]; + + if (!isConstantIntValue(l, 0) || !isConstantIntValue(h, 0)) { + assert(reInd.size() == 1 && "expected single dimension"); + expandedShape[reInd[0]] = + tensor::getMixedSize(rewriter, loc, padOp.getSource(), inDimIdx); + ; + } + } + + SmallVector staticExpandedShape; + std::tie(staticExpandedShape, std::ignore) = + decomposeMixedValues(expandedShape); + + auto newExpandOp = rewriter.create( + loc, + RankedTensorType::get(staticExpandedShape, + padOp.getSource().getType().getElementType()), + padOp.getSource(), reassociations, expandedShape); + + rewriter.replaceOpWithNewOp( + expandOp, expandOp.getType(), newExpandOp.getResult(), newLow, newHigh, + padOp.getConstantPaddingValue(), padOp.getNofold()); + return success(); + } + +private: + ControlFusionFn controlFoldingReshapes; +}; + +/// Pattern to move a tensor.collapse_shape op with its producer tensor.pad op +/// by bubbling the collapse_shape before the pad. +/// +/// ``` +/// BEFORE: +/// %padded = tensor.pad %input low[1, 1, 0] high[1, 1, 0] +/// tensor<32x16x256xf32> to tensor<34x18x256xf32> +/// %collapsed = tensor.collapse_shape %padded [[0, 1], [2]] +/// tensor<34x18x256xf32> to tensor<612x256xf32> +/// +/// AFTER: +/// %collapsed = tensor.collapse_shape %input [[0, 1], [2]] +/// tensor<32x16x256xf32> to tensor<512x256xf32> +/// %padded = tensor.pad %collapsed low[1, 0] high[1, 0] +/// tensor<512x256xf32> to tensor<514x256xf32> +/// ``` +struct MoveReshapeWithProducerPadOpByCollapsing + : public OpRewritePattern { + + MoveReshapeWithProducerPadOpByCollapsing(MLIRContext *context, + ControlFusionFn foldReshapes, + PatternBenefit benefit = 1) + : OpRewritePattern(context, benefit), + controlFoldingReshapes(std::move(foldReshapes)) {} + + LogicalResult matchAndRewrite(tensor::CollapseShapeOp collapseOp, + PatternRewriter &rewriter) const override { + tensor::PadOp padOp = collapseOp.getSrc().getDefiningOp(); + + if (!padOp) + return failure(); + + if (!padOp->hasOneUse()) + return failure(); + + if (!controlFoldingReshapes(&collapseOp.getSrcMutable())) { + return rewriter.notifyMatchFailure(collapseOp, + "fusion blocked by control function"); + } + + Value constantPaddingValue = padOp.getConstantPaddingValue(); + if (!constantPaddingValue) { + return rewriter.notifyMatchFailure( + collapseOp, "cannot fold with non-constant padding value"); + } + + SmallVector reassociations = + collapseOp.getReassociationIndices(); + SmallVector low = padOp.getMixedLowPad(); + SmallVector high = padOp.getMixedHighPad(); + SmallVector newLow, newHigh; + for (auto [idx, reInd] : llvm::enumerate(reassociations)) { + if (reInd.size() > 1 && llvm::any_of(reInd, [&](int64_t dimIdx) { + return !isConstantIntValue(low[dimIdx], 0) || + !isConstantIntValue(high[dimIdx], 0); + })) { + return failure(); + } + + newLow.push_back(low[reInd[0]]); + newHigh.push_back(high[reInd[0]]); + } + + Location loc = collapseOp.getLoc(); + auto resultType = collapseOp.getResultType(); + + ArrayRef finalShape = collapseOp.getResultType().getShape(); + SmallVector collapsedShape(finalShape.begin(), finalShape.end()); + for (auto [inDimIdx, reInd] : llvm::enumerate(reassociations)) { + OpFoldResult l = low[reInd[0]]; + OpFoldResult h = high[reInd[0]]; + if (!isConstantIntValue(l, 0) || !isConstantIntValue(h, 0)) { + collapsedShape[inDimIdx] = padOp.getSourceType().getShape()[reInd[0]]; + } + } + + auto newCollapseType = RankedTensorType::get( + collapsedShape, padOp.getSourceType().getElementType()); + auto newCollapseOp = rewriter.create( + loc, newCollapseType, padOp.getSource(), reassociations); + + rewriter.replaceOpWithNewOp( + collapseOp, resultType, newCollapseOp.getResult(), newLow, newHigh, + padOp.getConstantPaddingValue(), padOp.getNofold()); + + return success(); + } + +private: + ControlFusionFn controlFoldingReshapes; +}; + /// Pattern to fold a tensor.expand_shape op with its producer generic op /// by expanding the dimensionality of the loop in the producer op. struct FoldReshapeWithGenericOpByExpansion @@ -2235,6 +2422,8 @@ void mlir::linalg::populateFoldReshapeOpsByExpansionPatterns( controlFoldingReshapes); patterns.add(patterns.getContext(), controlFoldingReshapes); + patterns.add(patterns.getContext(), + controlFoldingReshapes); patterns.add(patterns.getContext(), controlFoldingReshapes); } @@ -2246,6 +2435,8 @@ void mlir::linalg::populateFoldReshapeOpsByCollapsingPatterns( controlFoldingReshapes); patterns.add( patterns.getContext(), controlFoldingReshapes); + patterns.add( + patterns.getContext(), controlFoldingReshapes); patterns.add(patterns.getContext(), controlFoldingReshapes); } diff --git a/mlir/test/Dialect/Linalg/fuse-with-reshape-by-collapsing.mlir b/mlir/test/Dialect/Linalg/fuse-with-reshape-by-collapsing.mlir index 2bf3d21c35526..0ac1686361bf7 100644 --- a/mlir/test/Dialect/Linalg/fuse-with-reshape-by-collapsing.mlir +++ b/mlir/test/Dialect/Linalg/fuse-with-reshape-by-collapsing.mlir @@ -232,7 +232,7 @@ func.func @fuse_by_collapsing_dynamic_2(%arg0 : tensor, %sz0: index, %sz1 %1 = linalg.generic { indexing_maps = [#map0, #map0], iterator_types = ["parallel", "parallel"]} - ins(%0 : tensor) + ins(%0 : tensor) outs(%init : tensor) { ^bb0(%b0 : f32, %b1 : f32): %out = arith.negf %b0 : f32 @@ -858,3 +858,54 @@ func.func @partial_fuse_by_collapsing(%arg0: tensor<4x?x32x128x192xf16>, %arg1: // CHECK: %[[COLLAPSED:.+]] = tensor.collapse_shape %[[EXPANDED]] // CHECK-SAME: tensor<4x128x192x?x32xf32> into tensor<512x192x?xf32> // CHECK: return %[[COLLAPSED]] : tensor<512x192x?xf32> + +// ----- + +func.func @fold_tensor_pad_with_collapse(%arg0: tensor<32x16x256x256xf32>) -> tensor<512x258x258xf32> { + %cst = arith.constant 0.000000e+00 : f32 + %0 = linalg.fill ins(%cst : f32) outs(%arg0 : tensor<32x16x256x256xf32>) -> tensor<32x16x256x256xf32> + %padded = tensor.pad %0 low[0, 0, 1, 1] high[0, 0, 1, 1] { + ^bb0(%i0: index, %i1: index, %i2: index, %i3: index): + tensor.yield %cst : f32 + } : tensor<32x16x256x256xf32> to tensor<32x16x258x258xf32> + %collapsed = tensor.collapse_shape %padded [[0, 1], [2], [3]] + : tensor<32x16x258x258xf32> into tensor<512x258x258xf32> + return %collapsed : tensor<512x258x258xf32> +} +// CHECK: func @fold_tensor_pad_with_collapse( +// CHECK-SAME: %[[ARG0:[^:]+]]: tensor<32x16x256x256xf32> +// CHECK-DAG: %[[CST:.+]] = arith.constant 0.000000e+00 : f32 +// CHECK: %[[FILLED:.*]] = linalg.fill ins(%[[CST]] : f32) outs(%[[ARG0]] : tensor<32x16x256x256xf32>) +// CHECK: %[[COLLAPSED:.+]] = tensor.collapse_shape %[[FILLED]] {{\[}}[0, 1], [2], [3]{{\]}} +// CHECK-SAME: : tensor<32x16x256x256xf32> into tensor<512x256x256xf32> +// CHECK: %[[PADDED:.*]] = tensor.pad %[[COLLAPSED]] low[0, 1, 1] high[0, 1, 1] +// CHECK: ^bb0(%[[ARG1:.*]]: index, %[[ARG2:.*]]: index, %[[ARG3:.*]]: index): +// CHECK: tensor.yield %[[CST]] : f32 +// CHECK: } : tensor<512x256x256xf32> to tensor<512x258x258xf32> +// CHECK: return %[[PADDED]] : tensor<512x258x258xf32> + +// ----- + +func.func @fold_tensor_pad_with_collapse_dynamic_pad_zero(%arg0: tensor<32x16x256x256xf32>) -> tensor<512x258x258xf32> { + %cst = arith.constant 0.000000e+00 : f32 + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %0 = linalg.fill ins(%cst : f32) outs(%arg0 : tensor<32x16x256x256xf32>) -> tensor<32x16x256x256xf32> + %padded = tensor.pad %0 low[%c0, %c0, %c1, %c1] high[%c0, %c0, %c1, %c1] { + ^bb0(%i0: index, %i1: index, %i2: index, %i3: index): + tensor.yield %cst : f32 + } : tensor<32x16x256x256xf32> to tensor<32x16x258x258xf32> + %collapsed = tensor.collapse_shape %padded [[0, 1], [2], [3]] + : tensor<32x16x258x258xf32> into tensor<512x258x258xf32> + return %collapsed : tensor<512x258x258xf32> +} +// CHECK: func @fold_tensor_pad_with_collapse_dynamic_pad_zero( +// CHECK-SAME: %[[ARG0:[^:]+]]: tensor<32x16x256x256xf32> +// CHECK: %[[CST:.+]] = arith.constant 0.000000e+00 : f32 +// CHECK: %[[FILLED:.*]] = linalg.fill ins(%[[CST]] : f32) outs(%[[ARG0]] : tensor<32x16x256x256xf32>) +// CHECK: %[[COLLAPSED:.+]] = tensor.collapse_shape %[[FILLED]] {{\[}}[0, 1], [2], [3]{{\]}} +// CHECK-SAME: : tensor<32x16x256x256xf32> into tensor<512x256x256xf32> +// CHECK: %[[PADDED:.*]] = tensor.pad %[[COLLAPSED]] low[0, 1, 1] high[0, 1, 1] +// CHECK: ^bb0( +// CHECK: tensor.yield %[[CST]] : f32 +// CHECK: return %[[PADDED]] diff --git a/mlir/test/Dialect/Linalg/reshape_fusion.mlir b/mlir/test/Dialect/Linalg/reshape_fusion.mlir index 67b4f2b32bad5..3ea0babfa3b9d 100644 --- a/mlir/test/Dialect/Linalg/reshape_fusion.mlir +++ b/mlir/test/Dialect/Linalg/reshape_fusion.mlir @@ -247,7 +247,7 @@ func.func @indexed_consumer_reshape_producer_fusion(%arg0 : tensor, #map0 = affine_map<(d0, d1) -> (d0, d1)> func.func @indexed_producer_reshape_consumer_fusion(%arg0 : tensor, - %arg1 : tensor, + %arg1 : tensor, %sz0: index, %sz1: index) -> tensor { @@ -515,7 +515,7 @@ func.func @fuse_dynamic_dims(%arg0: tensor) -> tensor { // ----- func.func @reshape_as_consumer_permutation_with_multiple_results - (%a : tensor, %b : tensor, %sz0: index, + (%a : tensor, %b : tensor, %sz0: index, %sz1: index, %sz2: index, %sz3: index, %sz4: index) -> (tensor, tensor) { %c:2 = linalg.generic { @@ -893,3 +893,50 @@ func.func @move_operand_deps(%arg0 : tensor, // CHECK: %[[GENERIC:.+]] = linalg.generic // CHECK-SAME: ins(%[[EXPANDED]] : // CHECK: return %[[GENERIC]] + +// ----- + +func.func @fold_tensor_pad_with_expand(%arg0: tensor<512x256x256xf32>) -> tensor<32x16x258x258xf32> { + %cst = arith.constant 0.000000e+00 : f32 + %0 = linalg.fill ins(%cst : f32) outs(%arg0 : tensor<512x256x256xf32>) -> tensor<512x256x256xf32> + %padded = tensor.pad %0 low[0, 1, 1] high[0, 1, 1] { + ^bb0(%i: index, %j: index, %k: index): + tensor.yield %cst : f32 + } : tensor<512x256x256xf32> to tensor<512x258x258xf32> + %expanded = tensor.expand_shape %padded [[0, 1], [2], [3]] output_shape [32, 16, 258, 258] : tensor<512x258x258xf32> into tensor<32x16x258x258xf32> + return %expanded : tensor<32x16x258x258xf32> +} +// CHECK: func @fold_tensor_pad_with_expand( +// CHECK-SAME: %[[ARG0:[^:]+]]: tensor<512x256x256xf32> +// CHECK-DAG: %[[CST:.+]] = arith.constant 0.000000e+00 : f32 +// CHECK-DAG: %[[EXPANDED:.+]] = tensor.expand_shape %[[ARG0]] +// CHECK: %[[FILLED:.*]] = linalg.fill ins(%[[CST]] : f32) outs(%[[EXPANDED]] : tensor<32x16x256x256xf32>) +// CHECK: %[[PADDED:.*]] = tensor.pad %[[FILLED]] low[0, 0, 1, 1] high[0, 0, 1, 1] +// CHECK: ^bb0(%[[ARG1:.*]]: index, %[[ARG2:.*]]: index, %[[ARG3:.*]]: index, %[[ARG4:.*]]: index): +// CHECK: tensor.yield %[[CST]] : f32 +// CHECK: } : tensor<32x16x256x256xf32> to tensor<32x16x258x258xf32> +// CHECK: return %[[PADDED]] : tensor<32x16x258x258xf32> + +// ----- + +func.func @fold_tensor_pad_with_expand_dynamic_pad_zero(%arg0: tensor<512x256x256xf32>) -> tensor<32x16x258x258xf32> { + %cst = arith.constant 0.000000e+00 : f32 + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %0 = linalg.fill ins(%cst : f32) outs(%arg0 : tensor<512x256x256xf32>) -> tensor<512x256x256xf32> + %padded = tensor.pad %0 low[%c0, %c1, %c1] high[%c0, %c1, %c1] { + ^bb0(%i: index, %j: index, %k: index): + tensor.yield %cst : f32 + } : tensor<512x256x256xf32> to tensor<512x258x258xf32> + %expanded = tensor.expand_shape %padded [[0, 1], [2], [3]] output_shape [32, 16, 258, 258] : tensor<512x258x258xf32> into tensor<32x16x258x258xf32> + return %expanded : tensor<32x16x258x258xf32> +} +// CHECK: func @fold_tensor_pad_with_expand_dynamic_pad_zero( +// CHECK-SAME: %[[ARG0:[^:]+]]: tensor<512x256x256xf32> +// CHECK: %[[CST:.+]] = arith.constant 0.000000e+00 : f32 +// CHECK: %[[EXPANDED:.+]] = tensor.expand_shape %[[ARG0]] +// CHECK: %[[FILLED:.*]] = linalg.fill ins(%[[CST]] : f32) outs(%[[EXPANDED]] +// CHECK: %[[PADDED:.*]] = tensor.pad %[[FILLED]] low[0, 0, 1, 1] high[0, 0, 1, 1] +// CHECK: ^bb0( +// CHECK: tensor.yield %[[CST]] : f32 +// CHECK: return %[[PADDED]]