Skip to content

Commit 81e387c

Browse files
committed
Use SymbolTable and also don't clone function body.
one optimization I made is that I don't clone function body. The function body migrates to the new private function. After that, I create wrapper body for the original function.
1 parent 44ccd98 commit 81e387c

File tree

1 file changed

+22
-20
lines changed

1 file changed

+22
-20
lines changed

mlir/lib/Transforms/RemoveDeadValues.cpp

Lines changed: 22 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -887,11 +887,16 @@ struct RemoveDeadValues : public impl::RemoveDeadValuesBase<RemoveDeadValues> {
887887
/// original function.
888888
/// }
889889
///
890-
/// Returns true if any IR changes were made, false otherwise.
890+
/// changed = true if any IR changes were made.
891+
///
892+
/// Cloning has to be Interface-based because downstream projects may use their
893+
/// own func/call/return ops.
891894
static LogicalResult processCallOp(CallOpInterface callOp, ModuleOp moduleOp,
892-
RunLivenessAnalysis &la, bool &changed) {
893-
Operation *callableOp = callOp.resolveCallable();
894-
auto funcOp = dyn_cast<FunctionOpInterface>(callableOp);
895+
RunLivenessAnalysis &la,
896+
SymbolTableCollection *symbolTable,
897+
bool &changed) {
898+
Operation *callableOp = callOp.resolveCallableInTable(symbolTable);
899+
auto funcOp = dyn_cast_or_null<FunctionOpInterface>(callableOp);
895900
if (!funcOp || !funcOp.isPublic())
896901
return LogicalResult::success();
897902

@@ -907,7 +912,8 @@ static LogicalResult processCallOp(CallOpInterface callOp, ModuleOp moduleOp,
907912
OpBuilder rewriter(moduleOp.getContext());
908913

909914
// Clone function and create private version
910-
FunctionOpInterface clonedFunc = cast<FunctionOpInterface>(funcOp.clone());
915+
FunctionOpInterface clonedFunc =
916+
cast<FunctionOpInterface>(funcOp->cloneWithoutRegions());
911917

912918
// Set visibility = 'private' and a new name for the cloned function
913919
SymbolTable::setSymbolVisibility(clonedFunc,
@@ -930,20 +936,15 @@ static LogicalResult processCallOp(CallOpInterface callOp, ModuleOp moduleOp,
930936
<< funcOp.getName();
931937
return result;
932938
}
933-
934939
LDBG() << "Redirected all callsites from " << funcOp.getName() << " to "
935940
<< newName;
936941

937-
// Transform the original funcOp into a wrapper that calls the cloned
938-
// function
939-
Region &funcBody = funcOp.getFunctionBody();
942+
Region &clonedFuncBody = clonedFunc.getFunctionBody();
943+
// Move the body from funcOp to clonedFunc
944+
clonedFuncBody.takeBody(funcOp.getFunctionBody());
940945

941-
// Clean the original function body
942-
funcBody.dropAllReferences();
943-
funcBody.getBlocks().clear();
944-
945-
// Create a new entry block for the wrapper function
946-
Block *wrapperBlock = rewriter.createBlock(&funcBody);
946+
// Create a new entry block for the wrapper function in funcOp
947+
Block *wrapperBlock = rewriter.createBlock(&funcOp.getFunctionBody());
947948

948949
// Add block arguments that match the function signature
949950
for (Type argType : funcOp.getArgumentTypes()) {
@@ -964,9 +965,9 @@ static LogicalResult processCallOp(CallOpInterface callOp, ModuleOp moduleOp,
964965
rewriter.insert(clonedCallOp);
965966

966967
// Create return operation of the same type as the original function's
967-
// return
968+
// returnOp.
968969
Operation *returnOp = nullptr;
969-
for (Block &block : clonedFunc.getFunctionBody()) {
970+
for (Block &block : clonedFuncBody) {
970971
if (block.getNumSuccessors() > 0)
971972
continue;
972973

@@ -980,7 +981,7 @@ static LogicalResult processCallOp(CallOpInterface callOp, ModuleOp moduleOp,
980981
if (returnOp) {
981982
Operation *newReturnOp = returnOp->clone();
982983
newReturnOp->setOperands(clonedCallOp->getResults());
983-
newReturnOp->setLoc(funcOp.getLoc());
984+
newReturnOp->setLoc(returnOp->getLoc());
984985
rewriter.insert(newReturnOp);
985986
}
986987
changed = true; // Changes were made
@@ -998,11 +999,12 @@ void RemoveDeadValues::runOnOperation() {
998999
// inter-procedural.
9991000
if (la->getSolverConfig().isInterprocedural() && isa<ModuleOp>(module)) {
10001001
bool changed = false;
1002+
SymbolTableCollection symbolTable;
10011003
WalkResult walkResult =
10021004
module->walk([&](CallOpInterface callOp) -> WalkResult {
1003-
return processCallOp(callOp, cast<ModuleOp>(module), *la, changed);
1005+
return processCallOp(callOp, cast<ModuleOp>(module), *la,
1006+
&symbolTable, changed);
10041007
});
1005-
10061008
if (walkResult.wasInterrupted()) {
10071009
signalPassFailure();
10081010
return;

0 commit comments

Comments
 (0)