@@ -61,6 +61,9 @@ static cl::opt<SkipMLPolicyCriteria> SkipPolicy(
6161static cl::opt<std::string> ModelSelector (" ml-inliner-model-selector" ,
6262 cl::Hidden, cl::init(" " ));
6363
64+ static cl::opt<bool > StopImmediatelyForTest (" ml-inliner-stop-immediately" ,
65+ cl::Hidden);
66+
6467#if defined(LLVM_HAVE_TF_AOT_INLINERSIZEMODEL)
6568// codegen-ed file
6669#include " InlinerSizeModel.h" // NOLINT
@@ -214,6 +217,7 @@ MLInlineAdvisor::MLInlineAdvisor(
214217 return ;
215218 }
216219 ModelRunner->switchContext (" " );
220+ ForceStop = StopImmediatelyForTest;
217221}
218222
219223unsigned MLInlineAdvisor::getInitialFunctionLevel (const Function &F) const {
@@ -379,9 +383,17 @@ std::unique_ptr<InlineAdvice> MLInlineAdvisor::getAdviceImpl(CallBase &CB) {
379383 auto &ORE = FAM.getResult <OptimizationRemarkEmitterAnalysis>(Caller);
380384
381385 if (SkipPolicy == SkipMLPolicyCriteria::IfCallerIsNotCold) {
382- if (!PSI.isFunctionEntryCold (&Caller))
383- return std::make_unique<InlineAdvice>(this , CB, ORE,
384- GetDefaultAdvice (CB));
386+ if (!PSI.isFunctionEntryCold (&Caller)) {
387+ // Return a MLInlineAdvice, despite delegating to the default advice,
388+ // because we need to keep track of the internal state. This is different
389+ // from the other instances where we return a "default" InlineAdvice,
390+ // which happen at points we won't come back to the MLAdvisor for
391+ // decisions requiring that state.
392+ return ForceStop ? std::make_unique<InlineAdvice>(this , CB, ORE,
393+ GetDefaultAdvice (CB))
394+ : std::make_unique<MLInlineAdvice>(this , CB, ORE,
395+ GetDefaultAdvice (CB));
396+ }
385397 }
386398 auto MandatoryKind = InlineAdvisor::getMandatoryKind (CB, FAM, ORE);
387399 // If this is a "never inline" case, there won't be any changes to internal
0 commit comments