Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
29 changes: 16 additions & 13 deletions mlir/include/mlir/Transforms/InliningUtils.h
Original file line number Diff line number Diff line change
Expand Up @@ -255,16 +255,17 @@ class InlinerInterface
/// information. 'shouldCloneInlinedRegion' corresponds to whether the source
/// region should be cloned into the 'inlinePoint' or spliced directly.
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actually, the documentation the all the API changes should be updated, can you do it in a follow-up PR please?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes. I will.

LogicalResult inlineRegion(InlinerInterface &interface,
const InlinerConfig &config, Region *src,
Operation *inlinePoint, IRMapping &mapper,
ValueRange resultsToReplace,
InlinerConfig::CloneCallbackTy cloneCallback,
Region *src, Operation *inlinePoint,
IRMapping &mapper, ValueRange resultsToReplace,
TypeRange regionResultTypes,
std::optional<Location> inlineLoc = std::nullopt,
bool shouldCloneInlinedRegion = true);
LogicalResult inlineRegion(InlinerInterface &interface,
const InlinerConfig &config, Region *src,
Block *inlineBlock, Block::iterator inlinePoint,
IRMapping &mapper, ValueRange resultsToReplace,
InlinerConfig::CloneCallbackTy cloneCallback,
Region *src, Block *inlineBlock,
Block::iterator inlinePoint, IRMapping &mapper,
ValueRange resultsToReplace,
TypeRange regionResultTypes,
std::optional<Location> inlineLoc = std::nullopt,
bool shouldCloneInlinedRegion = true);
Expand All @@ -273,14 +274,16 @@ LogicalResult inlineRegion(InlinerInterface &interface,
/// providing the set of operands ('inlinedOperands') that should be used
/// in-favor of the region arguments when inlining.
LogicalResult inlineRegion(InlinerInterface &interface,
const InlinerConfig &config, Region *src,
Operation *inlinePoint, ValueRange inlinedOperands,
InlinerConfig::CloneCallbackTy cloneCallback,
Region *src, Operation *inlinePoint,
ValueRange inlinedOperands,
ValueRange resultsToReplace,
std::optional<Location> inlineLoc = std::nullopt,
bool shouldCloneInlinedRegion = true);
LogicalResult inlineRegion(InlinerInterface &interface,
const InlinerConfig &config, Region *src,
Block *inlineBlock, Block::iterator inlinePoint,
InlinerConfig::CloneCallbackTy cloneCallback,
Region *src, Block *inlineBlock,
Block::iterator inlinePoint,
ValueRange inlinedOperands,
ValueRange resultsToReplace,
std::optional<Location> inlineLoc = std::nullopt,
Expand All @@ -293,9 +296,9 @@ LogicalResult inlineRegion(InlinerInterface &interface,
/// corresponds to whether the source region should be cloned into the 'call' or
/// spliced directly.
LogicalResult inlineCall(InlinerInterface &interface,
const InlinerConfig &config, CallOpInterface call,
CallableOpInterface callable, Region *src,
bool shouldCloneInlinedRegion = true);
InlinerConfig::CloneCallbackTy cloneCallback,
CallOpInterface call, CallableOpInterface callable,
Region *src, bool shouldCloneInlinedRegion = true);

} // namespace mlir

Expand Down
2 changes: 1 addition & 1 deletion mlir/lib/Transforms/Utils/Inliner.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -651,7 +651,7 @@ Inliner::Impl::inlineCallsInSCC(InlinerInterfaceImpl &inlinerIface,
bool inlineInPlace = useList.hasOneUseAndDiscardable(it.targetNode);

LogicalResult inlineResult =
inlineCall(inlinerIface, inliner.config, call,
inlineCall(inlinerIface, inliner.config.getCloneCallback(), call,
cast<CallableOpInterface>(targetRegion->getParentOp()),
targetRegion, /*shouldCloneInlinedRegion=*/!inlineInPlace);
if (failed(inlineResult)) {
Expand Down
68 changes: 35 additions & 33 deletions mlir/lib/Transforms/Utils/InliningUtils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -246,8 +246,9 @@ static void handleResultImpl(InlinerInterface &interface, OpBuilder &builder,
}

static LogicalResult
inlineRegionImpl(InlinerInterface &interface, const InlinerConfig &config,
Region *src, Block *inlineBlock, Block::iterator inlinePoint,
inlineRegionImpl(InlinerInterface &interface,
InlinerConfig::CloneCallbackTy cloneCallback, Region *src,
Block *inlineBlock, Block::iterator inlinePoint,
IRMapping &mapper, ValueRange resultsToReplace,
TypeRange regionResultTypes, std::optional<Location> inlineLoc,
bool shouldCloneInlinedRegion, CallOpInterface call = {}) {
Expand Down Expand Up @@ -278,8 +279,8 @@ inlineRegionImpl(InlinerInterface &interface, const InlinerConfig &config,

// Clone the callee's source into the caller.
Block *postInsertBlock = inlineBlock->splitBlock(inlinePoint);
config.getCloneCallback()(builder, src, inlineBlock, postInsertBlock, mapper,
shouldCloneInlinedRegion);
cloneCallback(builder, src, inlineBlock, postInsertBlock, mapper,
shouldCloneInlinedRegion);

// Get the range of newly inserted blocks.
auto newBlocks = llvm::make_range(std::next(inlineBlock->getIterator()),
Expand Down Expand Up @@ -348,8 +349,9 @@ inlineRegionImpl(InlinerInterface &interface, const InlinerConfig &config,
}

static LogicalResult
inlineRegionImpl(InlinerInterface &interface, const InlinerConfig &config,
Region *src, Block *inlineBlock, Block::iterator inlinePoint,
inlineRegionImpl(InlinerInterface &interface,
InlinerConfig::CloneCallbackTy cloneCallback, Region *src,
Block *inlineBlock, Block::iterator inlinePoint,
ValueRange inlinedOperands, ValueRange resultsToReplace,
std::optional<Location> inlineLoc,
bool shouldCloneInlinedRegion, CallOpInterface call = {}) {
Expand All @@ -373,54 +375,54 @@ inlineRegionImpl(InlinerInterface &interface, const InlinerConfig &config,
}

// Call into the main region inliner function.
return inlineRegionImpl(interface, config, src, inlineBlock, inlinePoint,
mapper, resultsToReplace, resultsToReplace.getTypes(),
inlineLoc, shouldCloneInlinedRegion, call);
return inlineRegionImpl(interface, cloneCallback, src, inlineBlock,
inlinePoint, mapper, resultsToReplace,
resultsToReplace.getTypes(), inlineLoc,
shouldCloneInlinedRegion, call);
}

LogicalResult mlir::inlineRegion(InlinerInterface &interface,
const InlinerConfig &config, Region *src,
Operation *inlinePoint, IRMapping &mapper,
ValueRange resultsToReplace,
InlinerConfig::CloneCallbackTy cloneCallback,
Region *src, Operation *inlinePoint,
IRMapping &mapper, ValueRange resultsToReplace,
TypeRange regionResultTypes,
std::optional<Location> inlineLoc,
bool shouldCloneInlinedRegion) {
return inlineRegion(interface, config, src, inlinePoint->getBlock(),
return inlineRegion(interface, cloneCallback, src, inlinePoint->getBlock(),
++inlinePoint->getIterator(), mapper, resultsToReplace,
regionResultTypes, inlineLoc, shouldCloneInlinedRegion);
}

LogicalResult mlir::inlineRegion(
InlinerInterface &interface, const InlinerConfig &config, Region *src,
Block *inlineBlock, Block::iterator inlinePoint, IRMapping &mapper,
ValueRange resultsToReplace, TypeRange regionResultTypes,
InlinerInterface &interface, InlinerConfig::CloneCallbackTy cloneCallback,
Region *src, Block *inlineBlock, Block::iterator inlinePoint,
IRMapping &mapper, ValueRange resultsToReplace, TypeRange regionResultTypes,
std::optional<Location> inlineLoc, bool shouldCloneInlinedRegion) {
return inlineRegionImpl(interface, config, src, inlineBlock, inlinePoint,
mapper, resultsToReplace, regionResultTypes,
inlineLoc, shouldCloneInlinedRegion);
return inlineRegionImpl(
interface, cloneCallback, src, inlineBlock, inlinePoint, mapper,
resultsToReplace, regionResultTypes, inlineLoc, shouldCloneInlinedRegion);
}

LogicalResult mlir::inlineRegion(InlinerInterface &interface,
const InlinerConfig &config, Region *src,
Operation *inlinePoint,
InlinerConfig::CloneCallbackTy cloneCallback,
Region *src, Operation *inlinePoint,
ValueRange inlinedOperands,
ValueRange resultsToReplace,
std::optional<Location> inlineLoc,
bool shouldCloneInlinedRegion) {
return inlineRegion(interface, config, src, inlinePoint->getBlock(),
return inlineRegion(interface, cloneCallback, src, inlinePoint->getBlock(),
++inlinePoint->getIterator(), inlinedOperands,
resultsToReplace, inlineLoc, shouldCloneInlinedRegion);
}

LogicalResult
mlir::inlineRegion(InlinerInterface &interface, const InlinerConfig &config,
Region *src, Block *inlineBlock, Block::iterator inlinePoint,
ValueRange inlinedOperands, ValueRange resultsToReplace,
std::optional<Location> inlineLoc,
bool shouldCloneInlinedRegion) {
return inlineRegionImpl(interface, config, src, inlineBlock, inlinePoint,
inlinedOperands, resultsToReplace, inlineLoc,
shouldCloneInlinedRegion);
LogicalResult mlir::inlineRegion(
InlinerInterface &interface, InlinerConfig::CloneCallbackTy cloneCallback,
Region *src, Block *inlineBlock, Block::iterator inlinePoint,
ValueRange inlinedOperands, ValueRange resultsToReplace,
std::optional<Location> inlineLoc, bool shouldCloneInlinedRegion) {
return inlineRegionImpl(interface, cloneCallback, src, inlineBlock,
inlinePoint, inlinedOperands, resultsToReplace,
inlineLoc, shouldCloneInlinedRegion);
}

/// Utility function used to generate a cast operation from the given interface,
Expand Down Expand Up @@ -452,7 +454,7 @@ static Value materializeConversion(const DialectInlinerInterface *interface,
/// corresponds to whether the source region should be cloned into the 'call' or
/// spliced directly.
LogicalResult mlir::inlineCall(InlinerInterface &interface,
const InlinerConfig &config,
InlinerConfig::CloneCallbackTy cloneCallback,
CallOpInterface call,
CallableOpInterface callable, Region *src,
bool shouldCloneInlinedRegion) {
Expand Down Expand Up @@ -529,7 +531,7 @@ LogicalResult mlir::inlineCall(InlinerInterface &interface,
return cleanupState();

// Attempt to inline the call.
if (failed(inlineRegionImpl(interface, config, src, call->getBlock(),
if (failed(inlineRegionImpl(interface, cloneCallback, src, call->getBlock(),
++call->getIterator(), mapper, callResults,
callableResultTypes, call.getLoc(),
shouldCloneInlinedRegion, call)))
Expand Down
2 changes: 1 addition & 1 deletion mlir/test/lib/Transforms/TestInlining.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,7 @@ struct InlinerTest
// Inline the functional region operation, but only clone the internal
// region if there is more than one use.
if (failed(inlineRegion(
interface, config, &callee.getBody(), caller,
interface, config.getCloneCallback(), &callee.getBody(), caller,
caller.getArgOperands(), caller.getResults(), caller.getLoc(),
/*shouldCloneInlinedRegion=*/!callee.getResult().hasOneUse())))
continue;
Expand Down
2 changes: 1 addition & 1 deletion mlir/test/lib/Transforms/TestInliningCallback.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -129,7 +129,7 @@ struct InlinerCallback
// Inline the functional region operation, but only clone the internal
// region if there is more than one use.
if (failed(inlineRegion(
interface, config, &callee.getBody(), caller,
interface, config.getCloneCallback(), &callee.getBody(), caller,
caller.getArgOperands(), caller.getResults(), caller.getLoc(),
/*shouldCloneInlinedRegion=*/!callee.getResult().hasOneUse())))
continue;
Expand Down