Skip to content
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
30 changes: 28 additions & 2 deletions mlir/include/mlir/Transforms/Inliner.h
Original file line number Diff line number Diff line change
Expand Up @@ -90,18 +90,40 @@ class Inliner {
/// this hook's interface might need to be extended in future.
using ProfitabilityCallbackTy = std::function<bool(const ResolvedCall &)>;

/// Type of the callback that determines if the inliner can inline a function
/// containing multiple blocks into a region that requires a single block. By
/// default, it is not allowed.
/// If this function return true, the static member function doClone()
/// should perform the actual transformation with its support.
using canHandleMultipleBlocksCbTy = std::function<bool()>;

using CloneCallbackTy =
std::function<void(OpBuilder &builder, Region *src, Block *inlineBlock,
Block *postInsertBlock, IRMapping &mapper,
bool shouldCloneInlinedRegion)>;

Inliner(Operation *op, CallGraph &cg, Pass &pass, AnalysisManager am,
RunPipelineHelperTy runPipelineHelper, const InlinerConfig &config,
ProfitabilityCallbackTy isProfitableToInline)
ProfitabilityCallbackTy isProfitableToInline,
canHandleMultipleBlocksCbTy canHandleMultipleBlocks)
: op(op), cg(cg), pass(pass), am(am),
runPipelineHelper(std::move(runPipelineHelper)), config(config),
isProfitableToInline(std::move(isProfitableToInline)) {}
isProfitableToInline(std::move(isProfitableToInline)),
canHandleMultipleBlocks(std::move(canHandleMultipleBlocks)) {}
Inliner(Inliner &) = delete;
void operator=(const Inliner &) = delete;

/// Perform inlining on a OpTrait::SymbolTable operation.
LogicalResult doInlining();

/// This function provides a callback mechanism to clone the instructions and
/// other information from the callee function into the caller function.
static CloneCallbackTy &doClone();

/// Set the clone callback function.
/// The provided function "func" will be invoked by Inliner::doClone().
void setCloneCallback(CloneCallbackTy func) { doClone() = func; }

private:
/// An OpTrait::SymbolTable operation to run the inlining on.
Operation *op;
Expand All @@ -119,10 +141,14 @@ class Inliner {
/// Returns true, if it is profitable to inline the callable operation
/// at the call site.
ProfitabilityCallbackTy isProfitableToInline;
/// Return true, if functions with multiple blocks can be inlined
/// into a region with the SingleBlock trait.
canHandleMultipleBlocksCbTy canHandleMultipleBlocks;

/// Forward declaration of the class providing the actual implementation.
class Impl;
};

} // namespace mlir

#endif // MLIR_TRANSFORMS_INLINER_H
6 changes: 5 additions & 1 deletion mlir/lib/Transforms/InlinerPass.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -142,9 +142,13 @@ void InlinerPass::runOnOperation() {
return isProfitableToInline(call, inliningThreshold);
};

// By default, prevent inlining a function containing multiple blocks into a
// region that requires a single block.
auto canHandleMultipleBlocksCb = [=]() { return false; };

// Get an instance of the inliner.
Inliner inliner(op, cg, *this, getAnalysisManager(), runPipelineHelper,
config, profitabilityCb);
config, profitabilityCb, canHandleMultipleBlocksCb);

// Run the inlining.
if (failed(inliner.doInlining()))
Expand Down
51 changes: 38 additions & 13 deletions mlir/lib/Transforms/Utils/Inliner.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -344,6 +344,28 @@ static void collectCallOps(iterator_range<Region::iterator> blocks,
}
}
}
//===----------------------------------------------------------------------===//
// Inliner
//===----------------------------------------------------------------------===//
// Initialize doClone function with the default implementation
Inliner::CloneCallbackTy &Inliner::doClone() {
static Inliner::CloneCallbackTy doWork =
[](OpBuilder &builder, Region *src, Block *inlineBlock,
Block *postInsertBlock, IRMapping &mapper,
bool shouldCloneInlinedRegion) {
// Check to see if the region is being cloned, or moved inline. In
// either case, move the new blocks after the 'insertBlock' to improve
// IR readability.
Region *insertRegion = inlineBlock->getParent();
if (shouldCloneInlinedRegion)
src->cloneInto(insertRegion, postInsertBlock->getIterator(), mapper);
else
insertRegion->getBlocks().splice(postInsertBlock->getIterator(),
src->getBlocks(), src->begin(),
src->end());
};
return doWork;
}

//===----------------------------------------------------------------------===//
// InlinerInterfaceImpl
Expand Down Expand Up @@ -729,19 +751,22 @@ bool Inliner::Impl::shouldInline(ResolvedCall &resolvedCall) {

// Don't allow inlining if the callee has multiple blocks (unstructured
// control flow) but we cannot be sure that the caller region supports that.
bool calleeHasMultipleBlocks =
llvm::hasNItemsOrMore(*callableRegion, /*N=*/2);
// If both parent ops have the same type, it is safe to inline. Otherwise,
// decide based on whether the op has the SingleBlock trait or not.
// Note: This check does currently not account for SizedRegion/MaxSizedRegion.
auto callerRegionSupportsMultipleBlocks = [&]() {
return callableRegion->getParentOp()->getName() ==
resolvedCall.call->getParentOp()->getName() ||
!resolvedCall.call->getParentOp()
->mightHaveTrait<OpTrait::SingleBlock>();
};
if (calleeHasMultipleBlocks && !callerRegionSupportsMultipleBlocks())
return false;
if (!inliner.canHandleMultipleBlocks()) {
bool calleeHasMultipleBlocks =
llvm::hasNItemsOrMore(*callableRegion, /*N=*/2);
// If both parent ops have the same type, it is safe to inline. Otherwise,
// decide based on whether the op has the SingleBlock trait or not.
// Note: This check does currently not account for
// SizedRegion/MaxSizedRegion.
auto callerRegionSupportsMultipleBlocks = [&]() {
return callableRegion->getParentOp()->getName() ==
resolvedCall.call->getParentOp()->getName() ||
!resolvedCall.call->getParentOp()
->mightHaveTrait<OpTrait::SingleBlock>();
};
if (calleeHasMultipleBlocks && !callerRegionSupportsMultipleBlocks())
return false;
}

if (!inliner.isProfitableToInline(resolvedCall))
return false;
Expand Down
13 changes: 4 additions & 9 deletions mlir/lib/Transforms/Utils/InliningUtils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
//===----------------------------------------------------------------------===//

#include "mlir/Transforms/InliningUtils.h"
#include "mlir/Transforms/Inliner.h"

#include "mlir/IR/Builders.h"
#include "mlir/IR/IRMapping.h"
Expand Down Expand Up @@ -275,16 +276,10 @@ inlineRegionImpl(InlinerInterface &interface, Region *src, Block *inlineBlock,
if (call && callable)
handleArgumentImpl(interface, builder, call, callable, mapper);

// Check to see if the region is being cloned, or moved inline. In either
// case, move the new blocks after the 'insertBlock' to improve IR
// readability.
// Clone the callee's source into the caller.
Block *postInsertBlock = inlineBlock->splitBlock(inlinePoint);
if (shouldCloneInlinedRegion)
src->cloneInto(insertRegion, postInsertBlock->getIterator(), mapper);
else
insertRegion->getBlocks().splice(postInsertBlock->getIterator(),
src->getBlocks(), src->begin(),
src->end());
Inliner::doClone()(builder, src, inlineBlock, postInsertBlock, mapper,
shouldCloneInlinedRegion);

// Get the range of newly inserted blocks.
auto newBlocks = llvm::make_range(std::next(inlineBlock->getIterator()),
Expand Down
24 changes: 24 additions & 0 deletions mlir/test/Transforms/test-inlining-callback.mlir
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
// RUN: mlir-opt -allow-unregistered-dialect %s -test-inline-callback | FileCheck %s

// Test inlining with multiple blocks and scf.execute_region transformation
// CHECK-LABEL: func @test_inline_multiple_blocks
func.func @test_inline_multiple_blocks(%arg0: i32) -> i32 {
// CHECK: %[[RES:.*]] = scf.execute_region -> i32
// CHECK-NEXT: %[[ADD1:.*]] = arith.addi %arg0, %arg0
// CHECK-NEXT: cf.br ^bb1(%[[ADD1]] : i32)
// CHECK: ^bb1(%[[ARG:.*]]: i32):
// CHECK-NEXT: %[[ADD2:.*]] = arith.addi %[[ARG]], %[[ARG]]
// CHECK-NEXT: scf.yield %[[ADD2]]
// CHECK: return %[[RES]]
%fn = "test.functional_region_op"() ({
^bb0(%a : i32):
%b = arith.addi %a, %a : i32
cf.br ^bb1(%b: i32)
^bb1(%c: i32):
%d = arith.addi %c, %c : i32
"test.return"(%d) : (i32) -> ()
}) : () -> ((i32) -> i32)

%0 = call_indirect %fn(%arg0) : (i32) -> i32
return %0 : i32
}
1 change: 1 addition & 0 deletions mlir/test/lib/Transforms/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@ add_mlir_library(MLIRTestTransforms
TestConstantFold.cpp
TestControlFlowSink.cpp
TestInlining.cpp
TestInliningCallback.cpp
TestMakeIsolatedFromAbove.cpp
TestTransformsOps.cpp
${MLIRTestTransformsPDLSrc}
Expand Down
152 changes: 152 additions & 0 deletions mlir/test/lib/Transforms/TestInliningCallback.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,152 @@
//===- TestInliningCallback.cpp - Pass to inline calls in the test dialect
//--------===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
// This file implements a pass to test inlining callbacks including
// canHandleMultipleBlocks and doClone.
//===----------------------------------------------------------------------===//

#include "TestDialect.h"
#include "TestOps.h"
#include "mlir/Analysis/CallGraph.h"
#include "mlir/Dialect/Func/IR/FuncOps.h"
#include "mlir/Dialect/SCF/IR/SCF.h"
#include "mlir/IR/BuiltinOps.h"
#include "mlir/IR/IRMapping.h"
#include "mlir/Pass/Pass.h"
#include "mlir/Transforms/Inliner.h"
#include "mlir/Transforms/InliningUtils.h"
#include "llvm/ADT/StringSet.h"

using namespace mlir;
using namespace test;

namespace {
struct InlinerCallback
: public PassWrapper<InlinerCallback, OperationPass<func::FuncOp>> {
MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(InlinerCallback)

StringRef getArgument() const final { return "test-inline-callback"; }
StringRef getDescription() const final {
return "Test inlining region calls with call back functions";
}

void getDependentDialects(DialectRegistry &registry) const override {
registry.insert<scf::SCFDialect>();
}

static LogicalResult runPipelineHelper(Pass &pass, OpPassManager &pipeline,
Operation *op) {
return mlir::cast<InlinerCallback>(pass).runPipeline(pipeline, op);
}

// Customize the implementation of Inliner::doClone
// Wrap the callee into scf.execute_region operation
static void testDoClone(OpBuilder &builder, Region *src, Block *inlineBlock,
Block *postInsertBlock, IRMapping &mapper,
bool shouldCloneInlinedRegion) {
// Create a new scf.execute_region operation
mlir::Operation &call = inlineBlock->back();
builder.setInsertionPointAfter(&call);

auto executeRegionOp = builder.create<mlir::scf::ExecuteRegionOp>(
call.getLoc(), call.getResultTypes());
mlir::Region &region = executeRegionOp.getRegion();

// Move the inlined blocks into the region
src->cloneInto(&region, mapper);

// Split block before scf operation.
Block *continueBlock =
inlineBlock->splitBlock(executeRegionOp.getOperation());

// Replace all test.return with scf.yield
for (mlir::Block &block : region) {

for (mlir::Operation &op : llvm::make_early_inc_range(block)) {
if (test::TestReturnOp returnOp =
llvm::dyn_cast<test::TestReturnOp>(&op)) {
mlir::OpBuilder returnBuilder(returnOp);
returnBuilder.create<mlir::scf::YieldOp>(returnOp.getLoc(),
returnOp.getOperands());
returnOp.erase();
}
}
}

// Add test.return after scf.execute_region
builder.setInsertionPointAfter(executeRegionOp);
builder.create<test::TestReturnOp>(executeRegionOp.getLoc(),
executeRegionOp.getResults());
}

void runOnOperation() override {
InlinerConfig config;
CallGraph &cg = getAnalysis<CallGraph>();

auto function = getOperation();

// By default, assume that any inlining is profitable.
auto profitabilityCb = [&](const mlir::Inliner::ResolvedCall &call) {
return true;
};

// This customized inliner can turn multiple blocks into a single block.
auto canHandleMultipleBlocksCb = [&]() { return true; };

// Get an instance of the inliner.
Inliner inliner(function, cg, *this, getAnalysisManager(),
runPipelineHelper, config, profitabilityCb,
canHandleMultipleBlocksCb);

// Customize the implementation of Inliner::doClone
inliner.setCloneCallback([](OpBuilder &builder, Region *src,
Block *inlineBlock, Block *postInsertBlock,
IRMapping &mapper,
bool shouldCloneInlinedRegion) {
return testDoClone(builder, src, inlineBlock, postInsertBlock, mapper,
shouldCloneInlinedRegion);
});

// Collect each of the direct function calls within the module.
SmallVector<func::CallIndirectOp, 16> callers;
function.walk(
[&](func::CallIndirectOp caller) { callers.push_back(caller); });

// Build the inliner interface.
InlinerInterface interface(&getContext());

// Try to inline each of the call operations.
for (auto caller : callers) {
auto callee = dyn_cast_or_null<FunctionalRegionOp>(
caller.getCallee().getDefiningOp());
if (!callee)
continue;

// Inline the functional region operation, but only clone the internal
// region if there is more than one use.
if (failed(inlineRegion(
interface, &callee.getBody(), caller, caller.getArgOperands(),
caller.getResults(), caller.getLoc(),
/*shouldCloneInlinedRegion=*/!callee.getResult().hasOneUse())))
continue;

// If the inlining was successful then erase the call and callee if
// possible.
caller.erase();
if (callee.use_empty())
callee.erase();
}
}
};
} // namespace

namespace mlir {
namespace test {
void registerInlinerCallback() { PassRegistration<InlinerCallback>(); }
} // namespace test
} // namespace mlir
2 changes: 2 additions & 0 deletions mlir/tools/mlir-opt/mlir-opt.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,7 @@ void registerCommutativityUtils();
void registerConvertCallOpPass();
void registerConvertFuncOpPass();
void registerInliner();
void registerInlinerCallback();
void registerMemRefBoundCheck();
void registerPatternsTestPass();
void registerSimpleParametricTilingPass();
Expand Down Expand Up @@ -215,6 +216,7 @@ void registerTestPasses() {
mlir::test::registerConvertCallOpPass();
mlir::test::registerConvertFuncOpPass();
mlir::test::registerInliner();
mlir::test::registerInlinerCallback();
mlir::test::registerMemRefBoundCheck();
mlir::test::registerPatternsTestPass();
mlir::test::registerSimpleParametricTilingPass();
Expand Down