Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
28 changes: 28 additions & 0 deletions mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
Original file line number Diff line number Diff line change
Expand Up @@ -1893,6 +1893,34 @@ void populateDecomposeWinogradOpsPatterns(RewritePatternSet &patterns);
/// convert to a `linalg.dot`.
void populateContractionOpRankReducingPatterns(RewritePatternSet &patterns);

/// Add patterns to fuse a linalg fill operation with a linalg operation.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In general this is not the prefered way of fusing a fill with a reduction. The preferred way is to use tile + fuse approach to fuse at a tile granularity (since fusing fill with its consumer reduction operations results in an inherently imperfectly nested loop computation). THe main issue here is this adds a conditional to the innermost loop computation which isnt what is generally performant.
But this seems still valid. Could you add some comments explaining the alternatives.

/// Add patterns to fold linalg.fill into linalg.reduce by creating a fused
/// linalg.generic operation.
/// The fill operation is expected to happen only on the first index
/// of the reduction dimension. Currently only one reduction dimension is
/// supported. Given the pattern:
/// %empty = tensor.empty() : tensor<i8>
/// %filled = linalg.fill ins(%c0 : i8) outs(%empty : tensor<i8>) ->
/// tensor<i8> %reduced = linalg.reduce ins(%0 : tensor<147456xi8>)
/// outs(%filled : tensor<i8>) dimensions = [0]
/// (%in: i8, %init: i8) {
/// %3 = arith.addi %in, %init : i8
/// linalg.yield %3 : i8
/// }
/// The pattern is rewritten into:
/// %empty = tensor.empty() : tensor<i8>
/// %reduced = linalg.generic ins(%0 : tensor<147456xi8>) outs(%empty :
/// tensor<i8>) {
/// ^bb0(%in: i8, %init: i8):
/// %cst = arith.constant 0 : index
/// %index = linalg.index %c0 : index
/// %cmp = arith.cmpi eq, %cst, %index : i1
Comment on lines +1916 to +1917
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The iteration order of reduction and parallel iterators in a linalg operation is undefined. I don't think you can assume that the first iteration of the iterator is 0.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

To be clear, I don't think you can do this fusion while adding linalg.index to the body, because that would mean you are assuming the first iteration index to be something.

/// %sum = arith.select %cmp, %c0, %init : i8
/// %res = arith.addi %in, %sum : i8
/// linalg.yield %res : i8
/// }
void populateFuseFillOpWithReduceOpPatterns(RewritePatternSet &patterns);

} // namespace linalg
} // namespace mlir

Expand Down
1 change: 1 addition & 0 deletions mlir/lib/Dialect/Linalg/Transforms/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ add_mlir_dialect_library(MLIRLinalgTransforms
EraseUnusedOperandsAndResults.cpp
FoldAddIntoDest.cpp
FusePadOpWithLinalgProducer.cpp
FuseFillOpWithReduceOp.cpp
Fusion.cpp
Generalization.cpp
Hoisting.cpp
Expand Down
107 changes: 107 additions & 0 deletions mlir/lib/Dialect/Linalg/Transforms/FuseFillOpWithReduceOp.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,107 @@
//===- FuseFillOpWithReduceOp.cpp - Fuse linalg fill with reduce producer -===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
//
// This file implements patterns that fuses a linalg.generic -> tensor.pad op
// chain into a tensor.extract_slice -> linalg.generic -> tensor.insert_slice
// op chain.
//
//===----------------------------------------------------------------------===//

#include "mlir/Dialect/Linalg/Transforms/Transforms.h"

#include "mlir/Dialect/Linalg/IR/Linalg.h"
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"

using namespace mlir;

namespace {

/// Fold linalg.fill into linalg.reduce by creating a fused linalg.generic
/// operation. The fill operation is expected to happen only on the first index
/// of the reduction dimension. Currently only one reduction dimension is
/// supported. Given the pattern:
/// %empty = tensor.empty() : tensor<i8>
/// %filled = linalg.fill ins(%c0 : i8) outs(%empty : tensor<i8>) ->
/// tensor<i8> %reduced = linalg.reduce ins(%0 : tensor<147456xi8>)
/// outs(%filled : tensor<i8>) dimensions = [0]
/// (%in: i8, %init: i8) {
/// %3 = arith.addi %in, %init : i8
/// linalg.yield %3 : i8
/// }
/// The pattern is rewritten into:
/// %empty = tensor.empty() : tensor<i8>
/// %reduced = linalg.generic ins(%0 : tensor<147456xi8>) outs(%empty :
/// tensor<i8>) {
/// ^bb0(%in: i8, %init: i8):
/// %cst = arith.constant 0 : index
/// %index = linalg.index %c0 : index
/// %cmp = arith.cmpi eq, %cst, %index : i1
/// %sum = arith.select %cmp, %c0, %init : i8
/// %res = arith.addi %in, %sum : i8
/// linalg.yield %res : i8
/// }
struct FoldFillWithReduceOp : public OpRewritePattern<linalg::ReduceOp> {
using OpRewritePattern<linalg::ReduceOp>::OpRewritePattern;

LogicalResult matchAndRewrite(linalg::ReduceOp reduceOp,
PatternRewriter &rewriter) const override {
if (!reduceOp.hasPureTensorSemantics())
return rewriter.notifyMatchFailure(
reduceOp, "skip reduce op with non-pure tensor semantics");
if (reduceOp.getDimensions().size() != 1)
return rewriter.notifyMatchFailure(
reduceOp, "skip reduce op with non-single dimension");
if (reduceOp.getNumDpsInputs() != 1 || reduceOp.getNumDpsInits() != 1)
return rewriter.notifyMatchFailure(
reduceOp, "skip reduce op with multiple number of inputs/results");
auto fillOp = reduceOp.getInits()[0].getDefiningOp<linalg::FillOp>();
if (!fillOp)
return rewriter.notifyMatchFailure(
reduceOp,
"skip reduce op with inits not directly based on fill operation");

long dim = reduceOp.getDimensions()[0];
// Note: on success, the `reduceOp` is replaced with a genericOp and no
// longer valid.
auto failureOrGenericOp = linalg::generalizeNamedOp(rewriter, reduceOp);
if (failed(failureOrGenericOp))
return rewriter.notifyMatchFailure(reduceOp,
"failed to generalize reduce op");

linalg::GenericOp genericReduceOp = *failureOrGenericOp;
auto operandIdx = -1;
for (auto &use : genericReduceOp->getOpOperands()) {
if (use.get().getDefiningOp() == fillOp)
operandIdx = use.getOperandNumber();
}
assert(operandIdx != -1 && "fill op not found in reduce op uses");

Location loc = genericReduceOp.getLoc();
auto blockArg = genericReduceOp.getMatchingBlockArgument(
&genericReduceOp->getOpOperand(operandIdx));
rewriter.setInsertionPointToStart(genericReduceOp.getBody());
auto constZeroIndexOp = rewriter.create<arith::ConstantIndexOp>(loc, 0);
auto linalgIndexOp = rewriter.create<linalg::IndexOp>(loc, dim);
auto cmpIOp = rewriter.create<arith::CmpIOp>(loc, arith::CmpIPredicate::eq,
constZeroIndexOp.getResult(),
linalgIndexOp.getResult());
auto selectOp =
rewriter.create<arith::SelectOp>(loc, cmpIOp, fillOp.value(), blockArg);
rewriter.replaceAllUsesExcept(blockArg, selectOp.getResult(), selectOp);
genericReduceOp->setOperand(operandIdx, fillOp.getDpsInitOperand(0)->get());

return success();
}
};

} // namespace

void mlir::linalg::populateFuseFillOpWithReduceOpPatterns(
RewritePatternSet &patterns) {
patterns.add<FoldFillWithReduceOp>(patterns.getContext());
}
88 changes: 88 additions & 0 deletions mlir/test/Dialect/Linalg/fuse_fill_op_with_reduce_op.mlir
Original file line number Diff line number Diff line change
@@ -0,0 +1,88 @@
// RUN: mlir-opt -test-linalg-fuse-fill-op-with-reduce-op -split-input-file %s | FileCheck %s

// CHECK-LABEL: func.func private @test_reduce_sum_kernel(
// CHECK-SAME: %[[VAL_0:.*]]: tensor<147456xi8>) -> tensor<i8> {
// CHECK: %[[VAL_1:.*]] = arith.constant 0 : index
// CHECK: %[[VAL_2:.*]] = arith.constant 0 : i8
// CHECK: %[[VAL_3:.*]] = tensor.empty() : tensor<i8>
// CHECK: %[[VAL_4:.*]] = linalg.generic {indexing_maps = [#map, #map1], iterator_types = ["reduction"]} ins(%[[VAL_0]] : tensor<147456xi8>) outs(%[[VAL_3]] : tensor<i8>) {
// CHECK: ^bb0(%[[VAL_5:.*]]: i8, %[[VAL_6:.*]]: i8):
// CHECK: %[[VAL_7:.*]] = linalg.index 0 : index
// CHECK: %[[VAL_8:.*]] = arith.cmpi eq, %[[VAL_7]], %[[VAL_1]] : index
// CHECK: %[[VAL_9:.*]] = arith.select %[[VAL_8]], %[[VAL_2]], %[[VAL_6]] : i8
// CHECK: %[[VAL_10:.*]] = arith.addi %[[VAL_5]], %[[VAL_9]] : i8
// CHECK: linalg.yield %[[VAL_10]] : i8
// CHECK: } -> tensor<i8>
// CHECK: return %[[VAL_11:.*]] : tensor<i8>
// CHECK: }

func.func private @test_reduce_sum_kernel(%arg0: tensor<147456xi8>) -> (tensor<i8>) {
%1 = tensor.empty() : tensor<i8>
%c0_i8 = arith.constant 0 : i8
%2 = linalg.fill ins(%c0_i8 : i8) outs(%1 : tensor<i8>) -> tensor<i8>
%reduced = linalg.reduce ins(%arg0 : tensor<147456xi8>) outs(%2 : tensor<i8>) dimensions = [0]
(%in: i8, %init: i8) {
%3 = arith.addi %in, %init : i8
linalg.yield %3 : i8
}
return %reduced : tensor<i8>
}

// -----

func.func private @test_missing_fill(%arg0: tensor<147456xi8>) -> (tensor<i8>) {
%1 = tensor.empty() : tensor<i8>
// CHECK: linalg.reduce
%reduced = linalg.reduce ins(%arg0 : tensor<147456xi8>) outs(%1 : tensor<i8>) dimensions = [0]
(%in: i8, %init: i8) {
%3 = arith.addi %in, %init : i8
linalg.yield %3 : i8
}
return %reduced : tensor<i8>
}

// -----

// CHECK-LABEL: func.func private @test_reduce_multiply_kernel(
// CHECK-SAME: %[[VAL_0:.*]]: tensor<147456xi8>) -> tensor<i8> {
// CHECK: %[[VAL_1:.*]] = arith.constant 0 : index
// CHECK: %[[VAL_2:.*]] = arith.constant 1 : i8
// CHECK: %[[VAL_3:.*]] = tensor.empty() : tensor<i8>
// CHECK: %[[VAL_4:.*]] = linalg.generic {indexing_maps = [#map, #map1], iterator_types = ["reduction"]} ins(%[[VAL_0]] : tensor<147456xi8>) outs(%[[VAL_3]] : tensor<i8>) {
// CHECK: ^bb0(%[[VAL_5:.*]]: i8, %[[VAL_6:.*]]: i8):
// CHECK: %[[VAL_7:.*]] = linalg.index 0 : index
// CHECK: %[[VAL_8:.*]] = arith.cmpi eq, %[[VAL_7]], %[[VAL_1]] : index
// CHECK: %[[VAL_9:.*]] = arith.select %[[VAL_8]], %[[VAL_2]], %[[VAL_6]] : i8
// CHECK: %[[VAL_10:.*]] = arith.muli %[[VAL_5]], %[[VAL_9]] : i8
// CHECK: linalg.yield %[[VAL_10]] : i8
// CHECK: } -> tensor<i8>
// CHECK: return %[[VAL_11:.*]] : tensor<i8>
// CHECK: }

func.func private @test_reduce_multiply_kernel(%arg0: tensor<147456xi8>) -> (tensor<i8>) {
%1 = tensor.empty() : tensor<i8>
%c1_i8 = arith.constant 1 : i8
%2 = linalg.fill ins(%c1_i8 : i8) outs(%1 : tensor<i8>) -> tensor<i8>
%reduced = linalg.reduce ins(%arg0 : tensor<147456xi8>) outs(%2 : tensor<i8>) dimensions = [0]
(%in: i8, %init: i8) {
%3 = arith.muli %in, %init : i8
linalg.yield %3 : i8
}
return %reduced : tensor<i8>
}

// -----

func.func private @test_reduce_sum_on_multiple_dims(%arg0: tensor<2x147456xi8>) -> (tensor<i8>) {
%1 = tensor.empty() : tensor<i8>
%c0_i8 = arith.constant 0 : i8
// CHECK: linalg.fill
%2 = linalg.fill ins(%c0_i8 : i8) outs(%1 : tensor<i8>) -> tensor<i8>
// CHECK: linalg.reduce
%reduced = linalg.reduce ins(%arg0 : tensor<2x147456xi8>) outs(%2 : tensor<i8>) dimensions = [0, 1]
(%in: i8, %init: i8) {
%3 = arith.addi %in, %init : i8
linalg.yield %3 : i8
}
return %reduced : tensor<i8>
}
1 change: 1 addition & 0 deletions mlir/test/lib/Dialect/Linalg/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ add_mlir_library(MLIRLinalgTestPasses
TestLinalgElementwiseFusion.cpp
TestLinalgFusionTransforms.cpp
TestLinalgRankReduceContractionOps.cpp
TestLinalgFuseFillOpWithReduceOp.cpp
TestLinalgTransforms.cpp
TestPadFusion.cpp

Expand Down
63 changes: 63 additions & 0 deletions mlir/test/lib/Dialect/Linalg/TestLinalgFuseFillOpWithReduceOp.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,63 @@
//===- TestLinalgFuseFillOpWithReduceOp.cpp -----------------------------===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
//
// This file implements a pass for testing fuse linalg fill with linalg reduce
// into a new linalg generic operation.
//
//===----------------------------------------------------------------------===//

#include "mlir/Dialect/Func/IR/FuncOps.h"
#include "mlir/Dialect/Linalg/Transforms/Transforms.h"
#include "mlir/Pass/Pass.h"
#include "mlir/Pass/PassManager.h"
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"

using namespace mlir;

namespace {

struct TestLinalgFuseFillOpWithReduceOp
: public PassWrapper<TestLinalgFuseFillOpWithReduceOp,
OperationPass<func::FuncOp>> {
MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestLinalgFuseFillOpWithReduceOp)

TestLinalgFuseFillOpWithReduceOp() = default;
TestLinalgFuseFillOpWithReduceOp(
const TestLinalgFuseFillOpWithReduceOp &pass) = default;
void getDependentDialects(DialectRegistry &registry) const override {
registry.insert<arith::ArithDialect, linalg::LinalgDialect,
tensor::TensorDialect>();
}
StringRef getArgument() const final {
return "test-linalg-fuse-fill-op-with-reduce-op";
}
StringRef getDescription() const final {
return "Test fuse linalg fill with linalg reduce into a new linalg generic "
"operation";
}

void runOnOperation() override {
MLIRContext *context = &this->getContext();
func::FuncOp funcOp = this->getOperation();

RewritePatternSet patterns(context);
linalg::populateFuseFillOpWithReduceOpPatterns(patterns);
if (failed(applyPatternsGreedily(funcOp.getBody(), std::move(patterns))))
return signalPassFailure();
}
};

} // namespace

namespace mlir {
namespace test {
void registerTestLinalgFuseFillOpWithReduceOp() {
PassRegistration<TestLinalgFuseFillOpWithReduceOp>();
}
} // namespace test
} // namespace mlir
2 changes: 2 additions & 0 deletions mlir/tools/mlir-opt/mlir-opt.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -111,6 +111,7 @@ void registerTestLinalgDropUnitDims();
void registerTestLinalgElementwiseFusion();
void registerTestLinalgGreedyFusion();
void registerTestLinalgRankReduceContractionOps();
void registerTestLinalgFuseFillOpWithReduceOp();
void registerTestLinalgTransforms();
void registerTestLivenessAnalysisPass();
void registerTestLivenessPass();
Expand Down Expand Up @@ -251,6 +252,7 @@ void registerTestPasses() {
mlir::test::registerTestLinalgElementwiseFusion();
mlir::test::registerTestLinalgGreedyFusion();
mlir::test::registerTestLinalgRankReduceContractionOps();
mlir::test::registerTestLinalgFuseFillOpWithReduceOp();
mlir::test::registerTestLinalgTransforms();
mlir::test::registerTestLivenessAnalysisPass();
mlir::test::registerTestLivenessPass();
Expand Down