-
Notifications
You must be signed in to change notification settings - Fork 15k
[MLIR][RemoveDeadValues] Privatize public function with NonLive arguments before RDV. #162038
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
52cf8cb
be516b7
6c62398
3f4d69f
4904fb9
8f92fee
443dff2
12322aa
f6a9b62
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change | ||||||||||
|---|---|---|---|---|---|---|---|---|---|---|---|---|
|
|
@@ -46,6 +46,7 @@ | |||||||||||
| #include "mlir/Interfaces/ControlFlowInterfaces.h" | ||||||||||||
| #include "mlir/Interfaces/FunctionInterfaces.h" | ||||||||||||
| #include "mlir/Interfaces/SideEffectInterfaces.h" | ||||||||||||
| #include "mlir/Pass/AnalysisManager.h" | ||||||||||||
| #include "mlir/Pass/Pass.h" | ||||||||||||
| #include "mlir/Support/LLVM.h" | ||||||||||||
| #include "mlir/Transforms/FoldUtils.h" | ||||||||||||
|
|
@@ -869,10 +870,151 @@ struct RemoveDeadValues : public impl::RemoveDeadValuesBase<RemoveDeadValues> { | |||||||||||
| }; | ||||||||||||
| } // namespace | ||||||||||||
|
|
||||||||||||
| /// If the target of CallOp is a public function and at least one argument is | ||||||||||||
| /// NonLive, privatize the function. Our strategy here is separation interface | ||||||||||||
| /// and implementation. eg. | ||||||||||||
| /// | ||||||||||||
| /// public void foo(int unused){...} | ||||||||||||
| /// => | ||||||||||||
| /// public void foo(int unused) { // old function, interface | ||||||||||||
| /// return __foo_privatized(unused); | ||||||||||||
| /// } | ||||||||||||
| /// | ||||||||||||
| /// private void __foo_privatized(int unused) { // the new private function, or | ||||||||||||
| /// implementation. | ||||||||||||
| /// ... // the function body of the | ||||||||||||
| /// original function. | ||||||||||||
| /// } | ||||||||||||
| /// | ||||||||||||
| /// Returns true if any IR changes were made, false otherwise. | ||||||||||||
| static bool processCallOp(CallOpInterface callOp, ModuleOp moduleOp, | ||||||||||||
| RunLivenessAnalysis &la) { | ||||||||||||
| Operation *callableOp = callOp.resolveCallable(); | ||||||||||||
| auto funcOp = dyn_cast<FunctionOpInterface>(callableOp); | ||||||||||||
| if (!funcOp || !funcOp.isPublic()) | ||||||||||||
| return false; | ||||||||||||
|
|
||||||||||||
| LDBG() << "Processing callOp " << callOp << " target is a public function: " | ||||||||||||
| << funcOp.getOperation()->getName(); | ||||||||||||
|
|
||||||||||||
| // Get the list of unnecessary (non-live) arguments in `nonLiveArgs`. | ||||||||||||
| SmallVector<Value> arguments(callOp.getArgOperands()); | ||||||||||||
| BitVector nonLiveArgs = markLives(arguments, DenseSet<Value>(), la); | ||||||||||||
| nonLiveArgs = nonLiveArgs.flip(); | ||||||||||||
|
|
||||||||||||
| if (nonLiveArgs.count() > 0) { | ||||||||||||
| OpBuilder rewriter(moduleOp.getContext()); | ||||||||||||
|
|
||||||||||||
| // Clone function and create private version | ||||||||||||
| FunctionOpInterface clonedFunc = cast<FunctionOpInterface>(funcOp.clone()); | ||||||||||||
|
|
||||||||||||
| // Set visibility = 'private' and a new name for the cloned function | ||||||||||||
| SymbolTable::setSymbolVisibility(clonedFunc, | ||||||||||||
| SymbolTable::Visibility::Private); | ||||||||||||
| std::string newName = "__" + funcOp.getName().str() + "_privatized"; | ||||||||||||
| clonedFunc.setName(newName); | ||||||||||||
|
|
||||||||||||
| // Insert the cloned function into the module | ||||||||||||
| rewriter.setInsertionPointAfter(funcOp); | ||||||||||||
| rewriter.insert(clonedFunc); | ||||||||||||
|
|
||||||||||||
| // Replace ALL callsites of the original function to call the cloned | ||||||||||||
| // function directly | ||||||||||||
| LogicalResult result = SymbolTable::replaceAllSymbolUses( | ||||||||||||
| funcOp, clonedFunc.getNameAttr(), moduleOp); | ||||||||||||
|
|
||||||||||||
| if (result.failed()) { | ||||||||||||
| LDBG() << "Failed to replace all symbol uses for " << funcOp.getName(); | ||||||||||||
| return false; | ||||||||||||
| } | ||||||||||||
|
|
||||||||||||
| LDBG() << "Redirected all callsites from " << funcOp.getName() << " to " | ||||||||||||
| << newName; | ||||||||||||
|
|
||||||||||||
| // Transform the original funcOp into a wrapper that calls the cloned | ||||||||||||
| // function | ||||||||||||
| Region &funcBody = funcOp.getFunctionBody(); | ||||||||||||
|
|
||||||||||||
| // Clean the original function body | ||||||||||||
| funcBody.dropAllReferences(); | ||||||||||||
| funcBody.getBlocks().clear(); | ||||||||||||
|
|
||||||||||||
| // Create a new entry block for the wrapper function | ||||||||||||
| Block *wrapperBlock = rewriter.createBlock(&funcBody); | ||||||||||||
|
|
||||||||||||
| // Add block arguments that match the function signature | ||||||||||||
| for (Type argType : funcOp.getArgumentTypes()) { | ||||||||||||
| wrapperBlock->addArgument(argType, funcOp.getLoc()); | ||||||||||||
| } | ||||||||||||
|
|
||||||||||||
| // Set insertion point to the new block | ||||||||||||
| rewriter.setInsertionPointToStart(wrapperBlock); | ||||||||||||
|
|
||||||||||||
| // Clone the original call operation and update its callee | ||||||||||||
| auto clonedCallOp = cast<CallOpInterface>(callOp->clone()); | ||||||||||||
| // Update the callee symbol reference to point to the new private function | ||||||||||||
| auto symbolRef = | ||||||||||||
| SymbolRefAttr::get(funcOp.getContext(), clonedFunc.getName()); | ||||||||||||
| clonedCallOp.setCalleeFromCallable(symbolRef); | ||||||||||||
| // Set the call arguments to use the wrapper block's arguments | ||||||||||||
| clonedCallOp->setOperands(wrapperBlock->getArguments()); | ||||||||||||
| rewriter.insert(clonedCallOp); | ||||||||||||
|
|
||||||||||||
| // Create return operation of the same type as the original function's | ||||||||||||
| // return | ||||||||||||
| Operation *returnOp = nullptr; | ||||||||||||
| for (Block &block : clonedFunc.getFunctionBody()) { | ||||||||||||
| if (block.getNumSuccessors() > 0) | ||||||||||||
| continue; | ||||||||||||
|
|
||||||||||||
| Operation *terminator = block.getTerminator(); | ||||||||||||
| if (terminator && terminator->hasTrait<OpTrait::ReturnLike>()) { | ||||||||||||
| returnOp = terminator; | ||||||||||||
| break; // Use first return as template | ||||||||||||
| } | ||||||||||||
| } | ||||||||||||
|
|
||||||||||||
| if (returnOp) { | ||||||||||||
| Operation *newReturnOp = returnOp->clone(); | ||||||||||||
| newReturnOp->setOperands(clonedCallOp->getResults()); | ||||||||||||
| newReturnOp->setLoc(funcOp.getLoc()); | ||||||||||||
| rewriter.insert(newReturnOp); | ||||||||||||
| } | ||||||||||||
| return true; // Changes were made | ||||||||||||
| } | ||||||||||||
|
|
||||||||||||
| return false; | ||||||||||||
| } | ||||||||||||
|
|
||||||||||||
| void RemoveDeadValues::runOnOperation() { | ||||||||||||
| auto &la = getAnalysis<RunLivenessAnalysis>(); | ||||||||||||
| AnalysisManager am = getAnalysisManager(); | ||||||||||||
| RunLivenessAnalysis *la = &am.getAnalysis<RunLivenessAnalysis>(); | ||||||||||||
| Operation *module = getOperation(); | ||||||||||||
|
|
||||||||||||
| // Only privatize public functions if liveness analysis is inter-procedural. | ||||||||||||
| if (la->getSolverConfig().isInterprocedural()) { | ||||||||||||
| bool changed = false; | ||||||||||||
| module->walk([&](CallOpInterface callOp) { | ||||||||||||
| if (processCallOp(callOp, cast<ModuleOp>(module), *la)) { | ||||||||||||
|
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This cast does not seem safe to me: the pass isn't a modulePass right now. |
||||||||||||
| changed = true; | ||||||||||||
| } | ||||||||||||
| }); | ||||||||||||
|
|
||||||||||||
| if (changed) { | ||||||||||||
| LDBG() << "IR has changed, invalidate RunLivenessAnalysis only"; | ||||||||||||
| auto &pa = getPassState().preservedAnalyses; | ||||||||||||
| bool preserved = pa.isPreserved<RunLivenessAnalysis>(); | ||||||||||||
| la->invalidate(); | ||||||||||||
| am.invalidate(pa); | ||||||||||||
| la = &am.getAnalysis<RunLivenessAnalysis>(); | ||||||||||||
| // If RunLivenessAnalysis was previously preserved, preserved the updated | ||||||||||||
| // results. | ||||||||||||
| if (preserved) { | ||||||||||||
| pa.preserve<RunLivenessAnalysis>(); | ||||||||||||
| } | ||||||||||||
|
Comment on lines
+1012
to
+1014
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Suggested change
Nit: no-trivial-braces in MLIR |
||||||||||||
| } | ||||||||||||
| } | ||||||||||||
|
|
||||||||||||
| // Tracks values eligible for erasure - complements liveness analysis to | ||||||||||||
| // identify "droppable" values. | ||||||||||||
| DenseSet<Value> deadVals; | ||||||||||||
|
|
@@ -883,19 +1025,19 @@ void RemoveDeadValues::runOnOperation() { | |||||||||||
|
|
||||||||||||
| module->walk([&](Operation *op) { | ||||||||||||
| if (auto funcOp = dyn_cast<FunctionOpInterface>(op)) { | ||||||||||||
| processFuncOp(funcOp, module, la, deadVals, finalCleanupList); | ||||||||||||
| processFuncOp(funcOp, module, *la, deadVals, finalCleanupList); | ||||||||||||
| } else if (auto regionBranchOp = dyn_cast<RegionBranchOpInterface>(op)) { | ||||||||||||
| processRegionBranchOp(regionBranchOp, la, deadVals, finalCleanupList); | ||||||||||||
| processRegionBranchOp(regionBranchOp, *la, deadVals, finalCleanupList); | ||||||||||||
| } else if (auto branchOp = dyn_cast<BranchOpInterface>(op)) { | ||||||||||||
| processBranchOp(branchOp, la, deadVals, finalCleanupList); | ||||||||||||
| processBranchOp(branchOp, *la, deadVals, finalCleanupList); | ||||||||||||
| } else if (op->hasTrait<::mlir::OpTrait::IsTerminator>()) { | ||||||||||||
| // Nothing to do here because this is a terminator op and it should be | ||||||||||||
| // honored with respect to its parent | ||||||||||||
| } else if (isa<CallOpInterface>(op)) { | ||||||||||||
| // Nothing to do because this op is associated with a function op and gets | ||||||||||||
| // cleaned when the latter is cleaned. | ||||||||||||
| } else { | ||||||||||||
| processSimpleOp(op, la, deadVals, finalCleanupList); | ||||||||||||
| processSimpleOp(op, *la, deadVals, finalCleanupList); | ||||||||||||
| } | ||||||||||||
| }); | ||||||||||||
|
|
||||||||||||
|
|
||||||||||||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -571,6 +571,54 @@ module @return_void_with_unused_argument { | |
| } | ||
| } | ||
|
|
||
| // check that public functions with non-live arguments correctly. | ||
| module @public_function_with_nonlive_arguments { | ||
| // the function signature is immutable because it is public. | ||
| func.func public @public_fn_with_unused_argument(%unused: i32) -> () { | ||
| return | ||
| } | ||
| // CHECK-LABEL: func.func @main | ||
| // CHECK: call @__public_fn_with_unused_argument_privatized() : () -> () | ||
| func.func @main() -> () { | ||
| %zero = arith.constant 0 : i32 | ||
| call @public_fn_with_unused_argument(%zero) : (i32) -> () | ||
| return | ||
| } | ||
|
|
||
| // CHECK-LABEL: func.func @main2 | ||
| // CHECK: call @__public_fn_with_unused_argument_privatized() : () -> () | ||
| func.func @main2(%arg0: i1) { | ||
| %0 = scf.if %arg0 -> (i32) { | ||
| %c1_i32 = arith.constant 1 : i32 | ||
| scf.yield %c1_i32 : i32 | ||
| } else { | ||
| %c0_i32 = arith.constant 0 : i32 | ||
| scf.yield %c0_i32 : i32 | ||
| } | ||
|
|
||
| call @public_fn_with_unused_argument(%0) : (i32) -> () | ||
| return | ||
| } | ||
|
|
||
| func.func public @fn_return_multiple(%arg0: i32) -> (i32, i32, i32) { | ||
| %one = arith.constant 1 : i32 | ||
| %two = arith.constant 2 : i32 | ||
| %three = arith.constant 4 : i32 | ||
|
|
||
| return %one, %two, %three: i32, i32, i32 | ||
| } | ||
|
|
||
| // CHECK-LABEL: func.func @main3 | ||
| // CHECK: call @__fn_return_multiple_privatized() : () -> (i32, i32, i32) | ||
| func.func @main3(%arg: i32) -> () { | ||
| %one = arith.constant 1 : i32 | ||
| %scalar = arith.addi %arg, %one: i32 | ||
|
|
||
| call @fn_return_multiple(%scalar) : (i32) -> (i32, i32, i32) | ||
| return | ||
| } | ||
| } | ||
|
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Can you add two CFG examples where the blocks are listed in different order to ensure you're not sensitive to the order the blocks are in-memory. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. hi, @joker-eph Here is a testcase only for RegionBranchOpinterface. From %0 is live at line 11, we need to mark %0 is live at line 2. After that, we need to mark %c1_i32 at line 4 and c0_i32 at line 7 live as well. In order words, we need to walk function @main3 preorder + backward. I manage to fix this in propagateBackward. It pretty much redo what liveness analysis has done. TBH, I don't think this is the right way to proceed. RemoveDeadValues should keep its own single responsibility. I take a step back and think about why we end up here. The very reason we try to propagate liveness in it because:
How about we just introduce a new pass: 'privatize-public-function' right before 'remove-dead-values'.
Here is a demo what this pass transforms. This is my prototype here. do you think it's more feasible solution? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Seems reasonable, but this needs to be callable from RemoveDeadValues itself (the pass can't crash itself here) There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I figure out how to invalidate an analysis in analysis-manager, so I can combine them. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. update this branch. could you take a look at the new implementation? |
||
|
|
||
| // ----- | ||
|
|
||
| // CHECK-LABEL: module @dynamically_unreachable | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I suspect this is expensive: can we thread a SymbolTable somehow?