Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
191 changes: 191 additions & 0 deletions mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1100,6 +1100,193 @@ class FoldPadWithProducerReshapeOpByExpansion
ControlFusionFn controlFoldingReshapes;
};

/// Pattern to move a tensor.expand_shape op with its producer tensor.pad op
/// by bubbling the expand_shape before the pad.
///
/// ```
/// BEFORE:
/// %padded = tensor.pad %input low[0, 1, 1] high[0, 1, 1]
/// tensor<512x256x256xf32> to tensor<512x258x258xf32>
/// %expanded = tensor.expand_shape %padded [[0, 1], [2], [3]]
/// tensor<512x258x258xf32> to tensor<32x16x258x258xf32>
///
/// AFTER:
/// %expanded = tensor.expand_shape %input [[0, 1], [2], [3]]
/// tensor<512x256x256xf32> to tensor<32x16x256x256xf32>
/// %padded = tensor.pad %expanded low[0, 0, 1, 1] high[0, 0, 1, 1]
/// tensor<32x16x256x256xf32> to tensor<32x16x258x258xf32>
/// ```
struct MoveReshapeWithProducerPadOpByExpansion
: public OpRewritePattern<tensor::ExpandShapeOp> {

MoveReshapeWithProducerPadOpByExpansion(MLIRContext *context,
ControlFusionFn foldReshapes,
PatternBenefit benefit = 1)
: OpRewritePattern<tensor::ExpandShapeOp>(context, benefit),
controlFoldingReshapes(std::move(foldReshapes)) {}

LogicalResult matchAndRewrite(tensor::ExpandShapeOp expandOp,
PatternRewriter &rewriter) const override {
tensor::PadOp padOp = expandOp.getSrc().getDefiningOp<tensor::PadOp>();
if (!padOp)
return failure();

if (!padOp->hasOneUse())
return failure();

if (!controlFoldingReshapes(&expandOp.getSrcMutable())) {
return rewriter.notifyMatchFailure(expandOp,
"fusion blocked by control function");
}

Value constantPaddingValue = padOp.getConstantPaddingValue();
if (!constantPaddingValue) {
return rewriter.notifyMatchFailure(
expandOp, "cannot fold with non-constant padding value");
}

SmallVector<ReassociationIndices> reassociations =
expandOp.getReassociationIndices();
SmallVector<OpFoldResult> low = padOp.getMixedLowPad();
SmallVector<OpFoldResult> high = padOp.getMixedHighPad();

SmallVector<OpFoldResult> newLow, newHigh;
for (auto [idx, reInd] : llvm::enumerate(reassociations)) {
if (reInd.size() > 1 && (!isConstantIntValue(low[idx], 0) ||
!isConstantIntValue(high[idx], 0)))
return rewriter.notifyMatchFailure(
expandOp, "fusion blocked by non-zero padding");

newLow.append(reInd.size(), low[idx]);
newHigh.append(reInd.size(), high[idx]);
}

Location loc = expandOp.getLoc();
SmallVector<OpFoldResult> expandedShape = expandOp.getMixedOutputShape();
for (auto [inDimIdx, reInd] : llvm::enumerate(reassociations)) {
OpFoldResult l = low[inDimIdx];
OpFoldResult h = high[inDimIdx];

if (!isConstantIntValue(l, 0) || !isConstantIntValue(h, 0)) {
assert(reInd.size() == 1 && "expected single dimension");
expandedShape[reInd[0]] =
tensor::getMixedSize(rewriter, loc, padOp.getSource(), inDimIdx);
;
}
}

SmallVector<int64_t> staticExpandedShape;
std::tie(staticExpandedShape, std::ignore) =
decomposeMixedValues(expandedShape);

auto newExpandOp = rewriter.create<tensor::ExpandShapeOp>(
loc,
RankedTensorType::get(staticExpandedShape,
padOp.getSource().getType().getElementType()),
padOp.getSource(), reassociations, expandedShape);

rewriter.replaceOpWithNewOp<tensor::PadOp>(
expandOp, expandOp.getType(), newExpandOp.getResult(), newLow, newHigh,
padOp.getConstantPaddingValue(), padOp.getNofold());
return success();
}

private:
ControlFusionFn controlFoldingReshapes;
};

/// Pattern to move a tensor.collapse_shape op with its producer tensor.pad op
/// by bubbling the collapse_shape before the pad.
///
/// ```
/// BEFORE:
/// %padded = tensor.pad %input low[1, 1, 0] high[1, 1, 0]
/// tensor<32x16x256xf32> to tensor<34x18x256xf32>
/// %collapsed = tensor.collapse_shape %padded [[0, 1], [2]]
/// tensor<34x18x256xf32> to tensor<612x256xf32>
///
/// AFTER:
/// %collapsed = tensor.collapse_shape %input [[0, 1], [2]]
/// tensor<32x16x256xf32> to tensor<512x256xf32>
/// %padded = tensor.pad %collapsed low[1, 0] high[1, 0]
/// tensor<512x256xf32> to tensor<514x256xf32>
/// ```
struct MoveReshapeWithProducerPadOpByCollapsing
: public OpRewritePattern<tensor::CollapseShapeOp> {

MoveReshapeWithProducerPadOpByCollapsing(MLIRContext *context,
ControlFusionFn foldReshapes,
PatternBenefit benefit = 1)
: OpRewritePattern<tensor::CollapseShapeOp>(context, benefit),
controlFoldingReshapes(std::move(foldReshapes)) {}

LogicalResult matchAndRewrite(tensor::CollapseShapeOp collapseOp,
PatternRewriter &rewriter) const override {
tensor::PadOp padOp = collapseOp.getSrc().getDefiningOp<tensor::PadOp>();

if (!padOp)
return failure();

if (!padOp->hasOneUse())
return failure();

if (!controlFoldingReshapes(&collapseOp.getSrcMutable())) {
return rewriter.notifyMatchFailure(collapseOp,
"fusion blocked by control function");
}

Value constantPaddingValue = padOp.getConstantPaddingValue();
if (!constantPaddingValue) {
return rewriter.notifyMatchFailure(
collapseOp, "cannot fold with non-constant padding value");
}

SmallVector<ReassociationIndices> reassociations =
collapseOp.getReassociationIndices();
SmallVector<OpFoldResult> low = padOp.getMixedLowPad();
SmallVector<OpFoldResult> high = padOp.getMixedHighPad();
SmallVector<OpFoldResult> newLow, newHigh;
for (auto [idx, reInd] : llvm::enumerate(reassociations)) {
if (reInd.size() > 1 && llvm::any_of(reInd, [&](int64_t dimIdx) {
return !isConstantIntValue(low[dimIdx], 0) ||
!isConstantIntValue(high[dimIdx], 0);
})) {
return failure();
}

newLow.push_back(low[reInd[0]]);
newHigh.push_back(high[reInd[0]]);
}

Location loc = collapseOp.getLoc();
auto resultType = collapseOp.getResultType();

ArrayRef<int64_t> finalShape = collapseOp.getResultType().getShape();
SmallVector<int64_t> collapsedShape(finalShape.begin(), finalShape.end());
for (auto [inDimIdx, reInd] : llvm::enumerate(reassociations)) {
OpFoldResult l = low[reInd[0]];
OpFoldResult h = high[reInd[0]];
if (!isConstantIntValue(l, 0) || !isConstantIntValue(h, 0)) {
collapsedShape[inDimIdx] = padOp.getSourceType().getShape()[reInd[0]];
}
}

auto newCollapseType = RankedTensorType::get(
collapsedShape, padOp.getSourceType().getElementType());
auto newCollapseOp = rewriter.create<tensor::CollapseShapeOp>(
loc, newCollapseType, padOp.getSource(), reassociations);

rewriter.replaceOpWithNewOp<tensor::PadOp>(
collapseOp, resultType, newCollapseOp.getResult(), newLow, newHigh,
padOp.getConstantPaddingValue(), padOp.getNofold());

return success();
}

private:
ControlFusionFn controlFoldingReshapes;
};

/// Pattern to fold a tensor.expand_shape op with its producer generic op
/// by expanding the dimensionality of the loop in the producer op.
struct FoldReshapeWithGenericOpByExpansion
Expand Down Expand Up @@ -2235,6 +2422,8 @@ void mlir::linalg::populateFoldReshapeOpsByExpansionPatterns(
controlFoldingReshapes);
patterns.add<FoldPadWithProducerReshapeOpByExpansion>(patterns.getContext(),
controlFoldingReshapes);
patterns.add<MoveReshapeWithProducerPadOpByExpansion>(patterns.getContext(),
controlFoldingReshapes);
patterns.add<FoldWithProducerReshapeOpByExpansion>(patterns.getContext(),
controlFoldingReshapes);
}
Expand All @@ -2246,6 +2435,8 @@ void mlir::linalg::populateFoldReshapeOpsByCollapsingPatterns(
controlFoldingReshapes);
patterns.add<FoldPadWithProducerReshapeOpByCollapsing>(
patterns.getContext(), controlFoldingReshapes);
patterns.add<MoveReshapeWithProducerPadOpByCollapsing>(
patterns.getContext(), controlFoldingReshapes);
patterns.add<FoldReshapeWithGenericOpByCollapsing>(patterns.getContext(),
controlFoldingReshapes);
}
Expand Down
53 changes: 52 additions & 1 deletion mlir/test/Dialect/Linalg/fuse-with-reshape-by-collapsing.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -232,7 +232,7 @@ func.func @fuse_by_collapsing_dynamic_2(%arg0 : tensor<?xf32>, %sz0: index, %sz1
%1 = linalg.generic {
indexing_maps = [#map0, #map0],
iterator_types = ["parallel", "parallel"]}
ins(%0 : tensor<?x?xf32>)
ins(%0 : tensor<?x?xf32>)
outs(%init : tensor<?x?xf32>) {
^bb0(%b0 : f32, %b1 : f32):
%out = arith.negf %b0 : f32
Expand Down Expand Up @@ -858,3 +858,54 @@ func.func @partial_fuse_by_collapsing(%arg0: tensor<4x?x32x128x192xf16>, %arg1:
// CHECK: %[[COLLAPSED:.+]] = tensor.collapse_shape %[[EXPANDED]]
// CHECK-SAME: tensor<4x128x192x?x32xf32> into tensor<512x192x?xf32>
// CHECK: return %[[COLLAPSED]] : tensor<512x192x?xf32>

// -----

func.func @fold_tensor_pad_with_collapse(%arg0: tensor<32x16x256x256xf32>) -> tensor<512x258x258xf32> {
%cst = arith.constant 0.000000e+00 : f32
%0 = linalg.fill ins(%cst : f32) outs(%arg0 : tensor<32x16x256x256xf32>) -> tensor<32x16x256x256xf32>
%padded = tensor.pad %0 low[0, 0, 1, 1] high[0, 0, 1, 1] {
^bb0(%i0: index, %i1: index, %i2: index, %i3: index):
tensor.yield %cst : f32
} : tensor<32x16x256x256xf32> to tensor<32x16x258x258xf32>
%collapsed = tensor.collapse_shape %padded [[0, 1], [2], [3]]
: tensor<32x16x258x258xf32> into tensor<512x258x258xf32>
return %collapsed : tensor<512x258x258xf32>
}
// CHECK: func @fold_tensor_pad_with_collapse(
// CHECK-SAME: %[[ARG0:[^:]+]]: tensor<32x16x256x256xf32>
// CHECK-DAG: %[[CST:.+]] = arith.constant 0.000000e+00 : f32
// CHECK: %[[FILLED:.*]] = linalg.fill ins(%[[CST]] : f32) outs(%[[ARG0]] : tensor<32x16x256x256xf32>)
// CHECK: %[[COLLAPSED:.+]] = tensor.collapse_shape %[[FILLED]] {{\[}}[0, 1], [2], [3]{{\]}}
// CHECK-SAME: : tensor<32x16x256x256xf32> into tensor<512x256x256xf32>
// CHECK: %[[PADDED:.*]] = tensor.pad %[[COLLAPSED]] low[0, 1, 1] high[0, 1, 1]
// CHECK: ^bb0(%[[ARG1:.*]]: index, %[[ARG2:.*]]: index, %[[ARG3:.*]]: index):
// CHECK: tensor.yield %[[CST]] : f32
// CHECK: } : tensor<512x256x256xf32> to tensor<512x258x258xf32>
// CHECK: return %[[PADDED]] : tensor<512x258x258xf32>

// -----

func.func @fold_tensor_pad_with_collapse_dynamic_pad_zero(%arg0: tensor<32x16x256x256xf32>) -> tensor<512x258x258xf32> {
%cst = arith.constant 0.000000e+00 : f32
%c0 = arith.constant 0 : index
%c1 = arith.constant 1 : index
%0 = linalg.fill ins(%cst : f32) outs(%arg0 : tensor<32x16x256x256xf32>) -> tensor<32x16x256x256xf32>
%padded = tensor.pad %0 low[%c0, %c0, %c1, %c1] high[%c0, %c0, %c1, %c1] {
^bb0(%i0: index, %i1: index, %i2: index, %i3: index):
tensor.yield %cst : f32
} : tensor<32x16x256x256xf32> to tensor<32x16x258x258xf32>
%collapsed = tensor.collapse_shape %padded [[0, 1], [2], [3]]
: tensor<32x16x258x258xf32> into tensor<512x258x258xf32>
return %collapsed : tensor<512x258x258xf32>
}
// CHECK: func @fold_tensor_pad_with_collapse_dynamic_pad_zero(
// CHECK-SAME: %[[ARG0:[^:]+]]: tensor<32x16x256x256xf32>
// CHECK: %[[CST:.+]] = arith.constant 0.000000e+00 : f32
// CHECK: %[[FILLED:.*]] = linalg.fill ins(%[[CST]] : f32) outs(%[[ARG0]] : tensor<32x16x256x256xf32>)
// CHECK: %[[COLLAPSED:.+]] = tensor.collapse_shape %[[FILLED]] {{\[}}[0, 1], [2], [3]{{\]}}
// CHECK-SAME: : tensor<32x16x256x256xf32> into tensor<512x256x256xf32>
// CHECK: %[[PADDED:.*]] = tensor.pad %[[COLLAPSED]] low[0, 1, 1] high[0, 1, 1]
// CHECK: ^bb0(
// CHECK: tensor.yield %[[CST]] : f32
// CHECK: return %[[PADDED]]
51 changes: 49 additions & 2 deletions mlir/test/Dialect/Linalg/reshape_fusion.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -247,7 +247,7 @@ func.func @indexed_consumer_reshape_producer_fusion(%arg0 : tensor<?x?x4x?xi32>,

#map0 = affine_map<(d0, d1) -> (d0, d1)>
func.func @indexed_producer_reshape_consumer_fusion(%arg0 : tensor<?x?xi32>,
%arg1 : tensor<?x?xi32>,
%arg1 : tensor<?x?xi32>,
%sz0: index, %sz1: index) ->
tensor<?x?x4x5xi32>
{
Expand Down Expand Up @@ -515,7 +515,7 @@ func.func @fuse_dynamic_dims(%arg0: tensor<?x?xf32>) -> tensor<?xf32> {
// -----

func.func @reshape_as_consumer_permutation_with_multiple_results
(%a : tensor<?x?x?xf32>, %b : tensor<?x?xf32>, %sz0: index,
(%a : tensor<?x?x?xf32>, %b : tensor<?x?xf32>, %sz0: index,
%sz1: index, %sz2: index, %sz3: index, %sz4: index)
-> (tensor<?x2x?x3x4x?xf32>, tensor<?x?x2x3x4x?xf32>) {
%c:2 = linalg.generic {
Expand Down Expand Up @@ -893,3 +893,50 @@ func.func @move_operand_deps(%arg0 : tensor<?x128xf16>,
// CHECK: %[[GENERIC:.+]] = linalg.generic
// CHECK-SAME: ins(%[[EXPANDED]] :
// CHECK: return %[[GENERIC]]

// -----

func.func @fold_tensor_pad_with_expand(%arg0: tensor<512x256x256xf32>) -> tensor<32x16x258x258xf32> {
%cst = arith.constant 0.000000e+00 : f32
%0 = linalg.fill ins(%cst : f32) outs(%arg0 : tensor<512x256x256xf32>) -> tensor<512x256x256xf32>
%padded = tensor.pad %0 low[0, 1, 1] high[0, 1, 1] {
^bb0(%i: index, %j: index, %k: index):
tensor.yield %cst : f32
} : tensor<512x256x256xf32> to tensor<512x258x258xf32>
%expanded = tensor.expand_shape %padded [[0, 1], [2], [3]] output_shape [32, 16, 258, 258] : tensor<512x258x258xf32> into tensor<32x16x258x258xf32>
return %expanded : tensor<32x16x258x258xf32>
}
// CHECK: func @fold_tensor_pad_with_expand(
// CHECK-SAME: %[[ARG0:[^:]+]]: tensor<512x256x256xf32>
// CHECK-DAG: %[[CST:.+]] = arith.constant 0.000000e+00 : f32
// CHECK-DAG: %[[EXPANDED:.+]] = tensor.expand_shape %[[ARG0]]
// CHECK: %[[FILLED:.*]] = linalg.fill ins(%[[CST]] : f32) outs(%[[EXPANDED]] : tensor<32x16x256x256xf32>)
// CHECK: %[[PADDED:.*]] = tensor.pad %[[FILLED]] low[0, 0, 1, 1] high[0, 0, 1, 1]
// CHECK: ^bb0(%[[ARG1:.*]]: index, %[[ARG2:.*]]: index, %[[ARG3:.*]]: index, %[[ARG4:.*]]: index):
// CHECK: tensor.yield %[[CST]] : f32
// CHECK: } : tensor<32x16x256x256xf32> to tensor<32x16x258x258xf32>
// CHECK: return %[[PADDED]] : tensor<32x16x258x258xf32>

// -----

func.func @fold_tensor_pad_with_expand_dynamic_pad_zero(%arg0: tensor<512x256x256xf32>) -> tensor<32x16x258x258xf32> {
%cst = arith.constant 0.000000e+00 : f32
%c0 = arith.constant 0 : index
%c1 = arith.constant 1 : index
%0 = linalg.fill ins(%cst : f32) outs(%arg0 : tensor<512x256x256xf32>) -> tensor<512x256x256xf32>
%padded = tensor.pad %0 low[%c0, %c1, %c1] high[%c0, %c1, %c1] {
^bb0(%i: index, %j: index, %k: index):
tensor.yield %cst : f32
} : tensor<512x256x256xf32> to tensor<512x258x258xf32>
%expanded = tensor.expand_shape %padded [[0, 1], [2], [3]] output_shape [32, 16, 258, 258] : tensor<512x258x258xf32> into tensor<32x16x258x258xf32>
return %expanded : tensor<32x16x258x258xf32>
}
// CHECK: func @fold_tensor_pad_with_expand_dynamic_pad_zero(
// CHECK-SAME: %[[ARG0:[^:]+]]: tensor<512x256x256xf32>
// CHECK: %[[CST:.+]] = arith.constant 0.000000e+00 : f32
// CHECK: %[[EXPANDED:.+]] = tensor.expand_shape %[[ARG0]]
// CHECK: %[[FILLED:.*]] = linalg.fill ins(%[[CST]] : f32) outs(%[[EXPANDED]]
// CHECK: %[[PADDED:.*]] = tensor.pad %[[FILLED]] low[0, 0, 1, 1] high[0, 0, 1, 1]
// CHECK: ^bb0(
// CHECK: tensor.yield %[[CST]] : f32
// CHECK: return %[[PADDED]]