1111// ===----------------------------------------------------------------------===//
1212
1313#include " mlir/Transforms/InliningUtils.h"
14+ #include " mlir/Transforms/Inliner.h"
1415
1516#include " mlir/IR/Builders.h"
1617#include " mlir/IR/IRMapping.h"
@@ -266,10 +267,11 @@ static void handleResultImpl(InlinerInterface &interface, OpBuilder &builder,
266267}
267268
268269static LogicalResult
269- inlineRegionImpl (InlinerInterface &interface, Region *src, Block *inlineBlock,
270- Block::iterator inlinePoint, IRMapping &mapper,
271- ValueRange resultsToReplace, TypeRange regionResultTypes,
272- std::optional<Location> inlineLoc,
270+ inlineRegionImpl (InlinerInterface &interface,
271+ function_ref<InlinerConfig::CloneCallbackSigTy> cloneCallback,
272+ Region *src, Block *inlineBlock, Block::iterator inlinePoint,
273+ IRMapping &mapper, ValueRange resultsToReplace,
274+ TypeRange regionResultTypes, std::optional<Location> inlineLoc,
273275 bool shouldCloneInlinedRegion, CallOpInterface call = {}) {
274276 assert (resultsToReplace.size () == regionResultTypes.size ());
275277 // We expect the region to have at least one block.
@@ -296,16 +298,10 @@ inlineRegionImpl(InlinerInterface &interface, Region *src, Block *inlineBlock,
296298 if (call && callable)
297299 handleArgumentImpl (interface, builder, call, callable, mapper);
298300
299- // Check to see if the region is being cloned, or moved inline. In either
300- // case, move the new blocks after the 'insertBlock' to improve IR
301- // readability.
301+ // Clone the callee's source into the caller.
302302 Block *postInsertBlock = inlineBlock->splitBlock (inlinePoint);
303- if (shouldCloneInlinedRegion)
304- src->cloneInto (insertRegion, postInsertBlock->getIterator (), mapper);
305- else
306- insertRegion->getBlocks ().splice (postInsertBlock->getIterator (),
307- src->getBlocks (), src->begin (),
308- src->end ());
303+ cloneCallback (builder, src, inlineBlock, postInsertBlock, mapper,
304+ shouldCloneInlinedRegion);
309305
310306 // Get the range of newly inserted blocks.
311307 auto newBlocks = llvm::make_range (std::next (inlineBlock->getIterator ()),
@@ -374,9 +370,11 @@ inlineRegionImpl(InlinerInterface &interface, Region *src, Block *inlineBlock,
374370}
375371
376372static LogicalResult
377- inlineRegionImpl (InlinerInterface &interface, Region *src, Block *inlineBlock,
378- Block::iterator inlinePoint, ValueRange inlinedOperands,
379- ValueRange resultsToReplace, std::optional<Location> inlineLoc,
373+ inlineRegionImpl (InlinerInterface &interface,
374+ function_ref<InlinerConfig::CloneCallbackSigTy> cloneCallback,
375+ Region *src, Block *inlineBlock, Block::iterator inlinePoint,
376+ ValueRange inlinedOperands, ValueRange resultsToReplace,
377+ std::optional<Location> inlineLoc,
380378 bool shouldCloneInlinedRegion, CallOpInterface call = {}) {
381379 // We expect the region to have at least one block.
382380 if (src->empty ())
@@ -398,53 +396,54 @@ inlineRegionImpl(InlinerInterface &interface, Region *src, Block *inlineBlock,
398396 }
399397
400398 // Call into the main region inliner function.
401- return inlineRegionImpl (interface, src, inlineBlock, inlinePoint, mapper,
402- resultsToReplace, resultsToReplace.getTypes (),
403- inlineLoc, shouldCloneInlinedRegion, call);
399+ return inlineRegionImpl (interface, cloneCallback, src, inlineBlock,
400+ inlinePoint, mapper, resultsToReplace,
401+ resultsToReplace.getTypes (), inlineLoc,
402+ shouldCloneInlinedRegion, call);
404403}
405404
406- LogicalResult mlir::inlineRegion (InlinerInterface &interface, Region *src,
407- Operation *inlinePoint, IRMapping &mapper ,
408- ValueRange resultsToReplace ,
409- TypeRange regionResultTypes ,
410- std::optional<Location> inlineLoc,
411- bool shouldCloneInlinedRegion) {
412- return inlineRegion (interface, src, inlinePoint->getBlock (),
405+ LogicalResult mlir::inlineRegion (
406+ InlinerInterface &interface ,
407+ function_ref<InlinerConfig::CloneCallbackSigTy> cloneCallback, Region *src ,
408+ Operation *inlinePoint, IRMapping &mapper, ValueRange resultsToReplace ,
409+ TypeRange regionResultTypes, std::optional<Location> inlineLoc,
410+ bool shouldCloneInlinedRegion) {
411+ return inlineRegion (interface, cloneCallback, src, inlinePoint->getBlock (),
413412 ++inlinePoint->getIterator (), mapper, resultsToReplace,
414413 regionResultTypes, inlineLoc, shouldCloneInlinedRegion);
415414}
416- LogicalResult mlir::inlineRegion (InlinerInterface &interface, Region *src,
417- Block *inlineBlock,
418- Block::iterator inlinePoint, IRMapping &mapper ,
419- ValueRange resultsToReplace ,
420- TypeRange regionResultTypes ,
421- std::optional<Location> inlineLoc ,
422- bool shouldCloneInlinedRegion) {
423- return inlineRegionImpl (interface, src, inlineBlock, inlinePoint, mapper,
424- resultsToReplace, regionResultTypes, inlineLoc ,
425- shouldCloneInlinedRegion);
415+
416+ LogicalResult mlir::inlineRegion (
417+ InlinerInterface &interface ,
418+ function_ref<InlinerConfig::CloneCallbackSigTy> cloneCallback, Region *src ,
419+ Block *inlineBlock, Block::iterator inlinePoint, IRMapping &mapper ,
420+ ValueRange resultsToReplace, TypeRange regionResultTypes ,
421+ std::optional<Location> inlineLoc, bool shouldCloneInlinedRegion) {
422+ return inlineRegionImpl (
423+ interface, cloneCallback, src, inlineBlock, inlinePoint, mapper ,
424+ resultsToReplace, regionResultTypes, inlineLoc, shouldCloneInlinedRegion);
426425}
427426
428- LogicalResult mlir::inlineRegion (InlinerInterface &interface, Region *src,
429- Operation *inlinePoint ,
430- ValueRange inlinedOperands ,
431- ValueRange resultsToReplace ,
432- std::optional<Location> inlineLoc,
433- bool shouldCloneInlinedRegion) {
434- return inlineRegion (interface, src, inlinePoint->getBlock (),
427+ LogicalResult mlir::inlineRegion (
428+ InlinerInterface &interface ,
429+ function_ref<InlinerConfig::CloneCallbackSigTy> cloneCallback, Region *src ,
430+ Operation *inlinePoint, ValueRange inlinedOperands ,
431+ ValueRange resultsToReplace, std::optional<Location> inlineLoc,
432+ bool shouldCloneInlinedRegion) {
433+ return inlineRegion (interface, cloneCallback, src, inlinePoint->getBlock (),
435434 ++inlinePoint->getIterator (), inlinedOperands,
436435 resultsToReplace, inlineLoc, shouldCloneInlinedRegion);
437436}
438- LogicalResult mlir::inlineRegion (InlinerInterface &interface, Region *src,
439- Block *inlineBlock,
440- Block::iterator inlinePoint ,
441- ValueRange inlinedOperands ,
442- ValueRange resultsToReplace ,
443- std::optional<Location> inlineLoc,
444- bool shouldCloneInlinedRegion) {
445- return inlineRegionImpl (interface, src, inlineBlock, inlinePoint ,
446- inlinedOperands, resultsToReplace, inlineLoc ,
447- shouldCloneInlinedRegion);
437+
438+ LogicalResult mlir::inlineRegion (
439+ InlinerInterface &interface ,
440+ function_ref<InlinerConfig::CloneCallbackSigTy> cloneCallback, Region *src ,
441+ Block *inlineBlock, Block::iterator inlinePoint, ValueRange inlinedOperands ,
442+ ValueRange resultsToReplace, std::optional<Location> inlineLoc,
443+ bool shouldCloneInlinedRegion) {
444+ return inlineRegionImpl (interface, cloneCallback, src, inlineBlock ,
445+ inlinePoint, inlinedOperands, resultsToReplace ,
446+ inlineLoc, shouldCloneInlinedRegion);
448447}
449448
450449// / Utility function used to generate a cast operation from the given interface,
@@ -475,10 +474,11 @@ static Value materializeConversion(const DialectInlinerInterface *interface,
475474// / failure, no changes are made to the module. 'shouldCloneInlinedRegion'
476475// / corresponds to whether the source region should be cloned into the 'call' or
477476// / spliced directly.
478- LogicalResult mlir::inlineCall (InlinerInterface &interface,
479- CallOpInterface call,
480- CallableOpInterface callable, Region *src,
481- bool shouldCloneInlinedRegion) {
477+ LogicalResult
478+ mlir::inlineCall (InlinerInterface &interface,
479+ function_ref<InlinerConfig::CloneCallbackSigTy> cloneCallback,
480+ CallOpInterface call, CallableOpInterface callable,
481+ Region *src, bool shouldCloneInlinedRegion) {
482482 // We expect the region to have at least one block.
483483 if (src->empty ())
484484 return failure ();
@@ -552,7 +552,7 @@ LogicalResult mlir::inlineCall(InlinerInterface &interface,
552552 return cleanupState ();
553553
554554 // Attempt to inline the call.
555- if (failed (inlineRegionImpl (interface, src, call->getBlock (),
555+ if (failed (inlineRegionImpl (interface, cloneCallback, src, call->getBlock (),
556556 ++call->getIterator (), mapper, callResults,
557557 callableResultTypes, call.getLoc (),
558558 shouldCloneInlinedRegion, call)))
0 commit comments