Skip to content
Open
Show file tree
Hide file tree
Changes from 9 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
175 changes: 175 additions & 0 deletions mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1100,6 +1100,174 @@ class FoldPadWithProducerReshapeOpByExpansion
ControlFusionFn controlFoldingReshapes;
};

/// Pattern to fold a tensor.expand_shape op with its producer tensor.pad op
/// by bubbling the expand_shape before the pad.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

To me, "folding" would mean that tensor.expand_shape disappears (i.e. is folded away), but that's not what is happening here, is it? This is merely "bubbling up".

Please update the description accordingly and add some example IR before and after. As an example: https://github.com/banach-space/llvm-project/blob/7d35eb58959c0ab398a9739f38bfb9754c5ba5e5/mlir/lib/Dialect/Tensor/Transforms/ReshapePatterns.cpp#L305-L317

struct FoldReshapeWithProducerPadOpByExpansion
: public OpRewritePattern<tensor::ExpandShapeOp> {

FoldReshapeWithProducerPadOpByExpansion(MLIRContext *context,
ControlFusionFn foldReshapes,
PatternBenefit benefit = 1)
: OpRewritePattern<tensor::ExpandShapeOp>(context, benefit),
controlFoldingReshapes(std::move(foldReshapes)) {}

LogicalResult matchAndRewrite(tensor::ExpandShapeOp expandOp,
PatternRewriter &rewriter) const override {
tensor::PadOp padOp = expandOp.getSrc().getDefiningOp<tensor::PadOp>();
if (!padOp)
return failure();

if (!padOp->hasOneUse())
return failure();

if (!controlFoldingReshapes(&expandOp.getSrcMutable())) {
return rewriter.notifyMatchFailure(expandOp,
"fusion blocked by control function");
}

Value constantPaddingValue = padOp.getConstantPaddingValue();
if (!constantPaddingValue) {
return rewriter.notifyMatchFailure(
expandOp, "cannot fold with non-constant padding value");
}

SmallVector<ReassociationIndices> reassociations =
expandOp.getReassociationIndices();
SmallVector<OpFoldResult> low = padOp.getMixedLowPad();
SmallVector<OpFoldResult> high = padOp.getMixedHighPad();

SmallVector<OpFoldResult> newLow, newHigh;
for (auto [idx, reInd] : llvm::enumerate(reassociations)) {
if (reInd.size() > 1 && (!isConstantIntValue(low[idx], 0) ||
!isConstantIntValue(high[idx], 0)))
return rewriter.notifyMatchFailure(
expandOp, "fusion blocked by non-zero padding");

newLow.append(reInd.size(), low[idx]);
newHigh.append(reInd.size(), high[idx]);
}

Location loc = expandOp.getLoc();
SmallVector<OpFoldResult> expandedShape = expandOp.getMixedOutputShape();
for (auto [inDimIdx, reInd] : llvm::enumerate(reassociations)) {
OpFoldResult l = low[inDimIdx];
OpFoldResult h = high[inDimIdx];

if (!isConstantIntValue(l, 0) || !isConstantIntValue(h, 0)) {
assert(reInd.size() == 1 && "expected single dimension");
expandedShape[reInd[0]] =
tensor::getMixedSize(rewriter, loc, padOp.getSource(), inDimIdx);
;
}
}

SmallVector<int64_t> staticExpandedShape;
std::tie(staticExpandedShape, std::ignore) =
decomposeMixedValues(expandedShape);

auto newExpandOp = rewriter.create<tensor::ExpandShapeOp>(
loc,
RankedTensorType::get(staticExpandedShape,
padOp.getSource().getType().getElementType()),
padOp.getSource(), reassociations, expandedShape);

rewriter.replaceOpWithNewOp<tensor::PadOp>(
expandOp, expandOp.getType(), newExpandOp.getResult(), newLow, newHigh,
padOp.getConstantPaddingValue(), padOp.getNofold());
return success();
}

private:
ControlFusionFn controlFoldingReshapes;
};

/// Pattern to fold a tensor.collapse_shape op with its producer tensor.pad op
/// by bubbling the collapse_shape before the pad.
struct FoldReshapeWithProducerPadOpByCollapsing
: public OpRewritePattern<tensor::CollapseShapeOp> {

FoldReshapeWithProducerPadOpByCollapsing(MLIRContext *context,
ControlFusionFn foldReshapes,
PatternBenefit benefit = 1)
: OpRewritePattern<tensor::CollapseShapeOp>(context, benefit),
controlFoldingReshapes(std::move(foldReshapes)) {}

LogicalResult matchAndRewrite(tensor::CollapseShapeOp collapseOp,
PatternRewriter &rewriter) const override {
tensor::PadOp padOp = collapseOp.getSrc().getDefiningOp<tensor::PadOp>();

if (!padOp)
return failure();

if (!padOp->hasOneUse())
return failure();

if (!controlFoldingReshapes(&collapseOp.getSrcMutable())) {
return rewriter.notifyMatchFailure(collapseOp,
"fusion blocked by control function");
}

Value constantPaddingValue = padOp.getConstantPaddingValue();
if (!constantPaddingValue) {
return rewriter.notifyMatchFailure(
collapseOp, "cannot fold with non-constant padding value");
}

SmallVector<ReassociationIndices> reassociations =
collapseOp.getReassociationIndices();
SmallVector<OpFoldResult> low = padOp.getMixedLowPad();
SmallVector<OpFoldResult> high = padOp.getMixedHighPad();

for (auto [idx, reInd] : llvm::enumerate(reassociations)) {
if (reInd.size() > 1) {
for (auto dimIdx : reInd) {
if (!isConstantIntValue(low[dimIdx], 0) ||
!isConstantIntValue(high[dimIdx], 0)) {
return failure();
}
}
}
}

SmallVector<OpFoldResult> newLow, newHigh;
for (auto [idx, reInd] : llvm::enumerate(reassociations)) {
newLow.push_back(low[reInd[0]]);
newHigh.push_back(high[reInd[0]]);
}

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: combine this loop with the loop above.

Location loc = collapseOp.getLoc();
auto resultType = collapseOp.getResultType();

ArrayRef<int64_t> finalShape = collapseOp.getResultType().getShape();
SmallVector<int64_t> collapsedShape(finalShape.begin(), finalShape.end());
for (auto [inDimIdx, reInd] : llvm::enumerate(reassociations)) {
OpFoldResult l = low[reInd[0]];
OpFoldResult h = high[reInd[0]];
if (!isConstantIntValue(l, 0) || !isConstantIntValue(h, 0)) {
auto mixedSize =
tensor::getMixedSize(rewriter, loc, padOp.getSource(), reInd[0]);
auto dimSize = getConstantIntValue(mixedSize);
assert(dimSize.has_value() && "Expected static dimension");
collapsedShape[inDimIdx] = *dimSize;
}
}

auto newCollapseType = RankedTensorType::get(
collapsedShape, padOp.getSource().getType().getElementType());
auto newCollapseOp = rewriter.create<tensor::CollapseShapeOp>(
loc, newCollapseType, padOp.getSource(), reassociations);

rewriter.replaceOpWithNewOp<tensor::PadOp>(
collapseOp, resultType, newCollapseOp.getResult(), newLow, newHigh,
padOp.getConstantPaddingValue(), padOp.getNofold());

return success();
}

private:
ControlFusionFn controlFoldingReshapes;
};

/// Pattern to fold a tensor.expand_shape op with its producer generic op
/// by expanding the dimensionality of the loop in the producer op.
struct FoldReshapeWithGenericOpByExpansion
Expand Down Expand Up @@ -2235,6 +2403,8 @@ void mlir::linalg::populateFoldReshapeOpsByExpansionPatterns(
controlFoldingReshapes);
patterns.add<FoldPadWithProducerReshapeOpByExpansion>(patterns.getContext(),
controlFoldingReshapes);
patterns.add<FoldReshapeWithProducerPadOpByExpansion>(patterns.getContext(),
controlFoldingReshapes);
patterns.add<FoldWithProducerReshapeOpByExpansion>(patterns.getContext(),
controlFoldingReshapes);
}
Expand All @@ -2246,6 +2416,11 @@ void mlir::linalg::populateFoldReshapeOpsByCollapsingPatterns(
controlFoldingReshapes);
patterns.add<FoldPadWithProducerReshapeOpByCollapsing>(
patterns.getContext(), controlFoldingReshapes);
patterns.add<FoldReshapeWithProducerPadOpByCollapsing>(
patterns.getContext(), controlFoldingReshapes);

patterns.add<FoldReshapeWithProducerPadOpByCollapsing>(
patterns.getContext(), controlFoldingReshapes);
patterns.add<FoldReshapeWithGenericOpByCollapsing>(patterns.getContext(),
controlFoldingReshapes);
}
Expand Down
53 changes: 52 additions & 1 deletion mlir/test/Dialect/Linalg/fuse-with-reshape-by-collapsing.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -232,7 +232,7 @@ func.func @fuse_by_collapsing_dynamic_2(%arg0 : tensor<?xf32>, %sz0: index, %sz1
%1 = linalg.generic {
indexing_maps = [#map0, #map0],
iterator_types = ["parallel", "parallel"]}
ins(%0 : tensor<?x?xf32>)
ins(%0 : tensor<?x?xf32>)
outs(%init : tensor<?x?xf32>) {
^bb0(%b0 : f32, %b1 : f32):
%out = arith.negf %b0 : f32
Expand Down Expand Up @@ -858,3 +858,54 @@ func.func @partial_fuse_by_collapsing(%arg0: tensor<4x?x32x128x192xf16>, %arg1:
// CHECK: %[[COLLAPSED:.+]] = tensor.collapse_shape %[[EXPANDED]]
// CHECK-SAME: tensor<4x128x192x?x32xf32> into tensor<512x192x?xf32>
// CHECK: return %[[COLLAPSED]] : tensor<512x192x?xf32>

// -----

func.func @fold_tensor_pad_with_collapse(%arg0: tensor<32x16x256x256xf32>) -> tensor<512x258x258xf32> {
%cst = arith.constant 0.000000e+00 : f32
%0 = linalg.fill ins(%cst : f32) outs(%arg0 : tensor<32x16x256x256xf32>) -> tensor<32x16x256x256xf32>
%padded = tensor.pad %0 low[0, 0, 1, 1] high[0, 0, 1, 1] {
^bb0(%i0: index, %i1: index, %i2: index, %i3: index):
tensor.yield %cst : f32
} : tensor<32x16x256x256xf32> to tensor<32x16x258x258xf32>
%collapsed = tensor.collapse_shape %padded [[0, 1], [2], [3]]
: tensor<32x16x258x258xf32> into tensor<512x258x258xf32>
return %collapsed : tensor<512x258x258xf32>
}
// CHECK: func @fold_tensor_pad_with_collapse(
// CHECK-SAME: %[[ARG0:[^:]+]]: tensor<32x16x256x256xf32>
// CHECK-DAG: %[[CST:.+]] = arith.constant 0.000000e+00 : f32
// CHECK: %[[FILLED:.*]] = linalg.fill ins(%[[CST]] : f32) outs(%[[ARG0]] : tensor<32x16x256x256xf32>)
// CHECK: %[[COLLAPSED:.+]] = tensor.collapse_shape %[[FILLED]] {{\[}}[0, 1], [2], [3]{{\]}}
// CHECK-SAME: : tensor<32x16x256x256xf32> into tensor<512x256x256xf32>
// CHECK: %[[PADDED:.*]] = tensor.pad %[[COLLAPSED]] low[0, 1, 1] high[0, 1, 1]
// CHECK: ^bb0(%[[ARG1:.*]]: index, %[[ARG2:.*]]: index, %[[ARG3:.*]]: index):
// CHECK: tensor.yield %[[CST]] : f32
// CHECK: } : tensor<512x256x256xf32> to tensor<512x258x258xf32>
// CHECK: return %[[PADDED]] : tensor<512x258x258xf32>

// -----

func.func @fold_tensor_pad_with_collapse_dynamic_pad_zero(%arg0: tensor<32x16x256x256xf32>) -> tensor<512x258x258xf32> {
%cst = arith.constant 0.000000e+00 : f32
%c0 = arith.constant 0 : index
%c1 = arith.constant 1 : index
%0 = linalg.fill ins(%cst : f32) outs(%arg0 : tensor<32x16x256x256xf32>) -> tensor<32x16x256x256xf32>
%padded = tensor.pad %0 low[%c0, %c0, %c1, %c1] high[%c0, %c0, %c1, %c1] {
^bb0(%i0: index, %i1: index, %i2: index, %i3: index):
tensor.yield %cst : f32
} : tensor<32x16x256x256xf32> to tensor<32x16x258x258xf32>
%collapsed = tensor.collapse_shape %padded [[0, 1], [2], [3]]
: tensor<32x16x258x258xf32> into tensor<512x258x258xf32>
return %collapsed : tensor<512x258x258xf32>
}
// CHECK: func @fold_tensor_pad_with_collapse_dynamic_pad_zero(
// CHECK-SAME: %[[ARG0:[^:]+]]: tensor<32x16x256x256xf32>
// CHECK: %[[CST:.+]] = arith.constant 0.000000e+00 : f32
// CHECK: %[[FILLED:.*]] = linalg.fill ins(%[[CST]] : f32) outs(%[[ARG0]] : tensor<32x16x256x256xf32>)
// CHECK: %[[COLLAPSED:.+]] = tensor.collapse_shape %[[FILLED]] {{\[}}[0, 1], [2], [3]{{\]}}
// CHECK-SAME: : tensor<32x16x256x256xf32> into tensor<512x256x256xf32>
// CHECK: %[[PADDED:.*]] = tensor.pad %[[COLLAPSED]] low[0, 1, 1] high[0, 1, 1]
// CHECK: ^bb0(
// CHECK: tensor.yield %[[CST]] : f32
// CHECK: return %[[PADDED]]
51 changes: 49 additions & 2 deletions mlir/test/Dialect/Linalg/reshape_fusion.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -247,7 +247,7 @@ func.func @indexed_consumer_reshape_producer_fusion(%arg0 : tensor<?x?x4x?xi32>,

#map0 = affine_map<(d0, d1) -> (d0, d1)>
func.func @indexed_producer_reshape_consumer_fusion(%arg0 : tensor<?x?xi32>,
%arg1 : tensor<?x?xi32>,
%arg1 : tensor<?x?xi32>,
%sz0: index, %sz1: index) ->
tensor<?x?x4x5xi32>
{
Expand Down Expand Up @@ -515,7 +515,7 @@ func.func @fuse_dynamic_dims(%arg0: tensor<?x?xf32>) -> tensor<?xf32> {
// -----

func.func @reshape_as_consumer_permutation_with_multiple_results
(%a : tensor<?x?x?xf32>, %b : tensor<?x?xf32>, %sz0: index,
(%a : tensor<?x?x?xf32>, %b : tensor<?x?xf32>, %sz0: index,
%sz1: index, %sz2: index, %sz3: index, %sz4: index)
-> (tensor<?x2x?x3x4x?xf32>, tensor<?x?x2x3x4x?xf32>) {
%c:2 = linalg.generic {
Expand Down Expand Up @@ -893,3 +893,50 @@ func.func @move_operand_deps(%arg0 : tensor<?x128xf16>,
// CHECK: %[[GENERIC:.+]] = linalg.generic
// CHECK-SAME: ins(%[[EXPANDED]] :
// CHECK: return %[[GENERIC]]

// -----

func.func @fold_tensor_pad_with_expand(%arg0: tensor<512x256x256xf32>) -> tensor<32x16x258x258xf32> {
%cst = arith.constant 0.000000e+00 : f32
%0 = linalg.fill ins(%cst : f32) outs(%arg0 : tensor<512x256x256xf32>) -> tensor<512x256x256xf32>
%padded = tensor.pad %0 low[0, 1, 1] high[0, 1, 1] {
^bb0(%i: index, %j: index, %k: index):
tensor.yield %cst : f32
} : tensor<512x256x256xf32> to tensor<512x258x258xf32>
%expanded = tensor.expand_shape %padded [[0, 1], [2], [3]] output_shape [32, 16, 258, 258] : tensor<512x258x258xf32> into tensor<32x16x258x258xf32>
return %expanded : tensor<32x16x258x258xf32>
}
// CHECK: func @fold_tensor_pad_with_expand(
// CHECK-SAME: %[[ARG0:[^:]+]]: tensor<512x256x256xf32>
// CHECK-DAG: %[[CST:.+]] = arith.constant 0.000000e+00 : f32
// CHECK-DAG: %[[EXPANDED:.+]] = tensor.expand_shape %[[ARG0]]
// CHECK: %[[FILLED:.*]] = linalg.fill ins(%[[CST]] : f32) outs(%[[EXPANDED]] : tensor<32x16x256x256xf32>)
// CHECK: %[[PADDED:.*]] = tensor.pad %[[FILLED]] low[0, 0, 1, 1] high[0, 0, 1, 1]
// CHECK: ^bb0(%[[ARG1:.*]]: index, %[[ARG2:.*]]: index, %[[ARG3:.*]]: index, %[[ARG4:.*]]: index):
// CHECK: tensor.yield %[[CST]] : f32
// CHECK: } : tensor<32x16x256x256xf32> to tensor<32x16x258x258xf32>
// CHECK: return %[[PADDED]] : tensor<32x16x258x258xf32>

// -----

func.func @fold_tensor_pad_with_expand_dynamic_pad_zero(%arg0: tensor<512x256x256xf32>) -> tensor<32x16x258x258xf32> {
%cst = arith.constant 0.000000e+00 : f32
%c0 = arith.constant 0 : index
%c1 = arith.constant 1 : index
%0 = linalg.fill ins(%cst : f32) outs(%arg0 : tensor<512x256x256xf32>) -> tensor<512x256x256xf32>
%padded = tensor.pad %0 low[%c0, %c1, %c1] high[%c0, %c1, %c1] {
^bb0(%i: index, %j: index, %k: index):
tensor.yield %cst : f32
} : tensor<512x256x256xf32> to tensor<512x258x258xf32>
%expanded = tensor.expand_shape %padded [[0, 1], [2], [3]] output_shape [32, 16, 258, 258] : tensor<512x258x258xf32> into tensor<32x16x258x258xf32>
return %expanded : tensor<32x16x258x258xf32>
}
// CHECK: func @fold_tensor_pad_with_expand_dynamic_pad_zero(
// CHECK-SAME: %[[ARG0:[^:]+]]: tensor<512x256x256xf32>
// CHECK: %[[CST:.+]] = arith.constant 0.000000e+00 : f32
// CHECK: %[[EXPANDED:.+]] = tensor.expand_shape %[[ARG0]]
// CHECK: %[[FILLED:.*]] = linalg.fill ins(%[[CST]] : f32) outs(%[[EXPANDED]]
// CHECK: %[[PADDED:.*]] = tensor.pad %[[FILLED]] low[0, 0, 1, 1] high[0, 0, 1, 1]
// CHECK: ^bb0(
// CHECK: tensor.yield %[[CST]] : f32
// CHECK: return %[[PADDED]]