- 
                Notifications
    
You must be signed in to change notification settings  - Fork 15.1k
 
[FunctionSpecialization] Preserve call counts of specialized functions #157768
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
A function that has been specialized will have its function entry counts preserved as follows: * Each specialization's count is the sum of each call site's basic block's number of entries as computed by `BlockFrequencyInfo`. * The original function's count will be decreased by the counts of its specializations. Tracking issue: llvm#147390
| 
          
 @llvm/pr-subscribers-function-specialization @llvm/pr-subscribers-llvm-transforms Author: Alan Zhao (alanzhao1) ChangesA function that has been specialized will have its function entry counts preserved as follows: 
 Tracking issue: #147390 Full diff: https://github.com/llvm/llvm-project/pull/157768.diff 2 Files Affected: 
 diff --git a/llvm/lib/Transforms/IPO/FunctionSpecialization.cpp b/llvm/lib/Transforms/IPO/FunctionSpecialization.cpp
index a459a9eddbcfc..324723c7942ab 100644
--- a/llvm/lib/Transforms/IPO/FunctionSpecialization.cpp
+++ b/llvm/lib/Transforms/IPO/FunctionSpecialization.cpp
@@ -784,9 +784,25 @@ bool FunctionSpecializer::run() {
 
     // Update the known call sites to call the clone.
     for (CallBase *Call : S.CallSites) {
+      Function *Clone = S.Clone;
       LLVM_DEBUG(dbgs() << "FnSpecialization: Redirecting " << *Call
-                        << " to call " << S.Clone->getName() << "\n");
+                        << " to call " << Clone->getName() << "\n");
       Call->setCalledFunction(S.Clone);
+      if (std::optional<uint64_t> Count =
+              GetBFI(*Call->getFunction())
+                  .getBlockProfileCount(Call->getParent())) {
+        uint64_t CallCount = *Count + Clone->getEntryCount()->getCount();
+        Clone->setEntryCount(CallCount);
+        if (std::optional<llvm::Function::ProfileCount> MaybeOriginalCount =
+                S.F->getEntryCount()) {
+          uint64_t OriginalCount = MaybeOriginalCount->getCount();
+          if (OriginalCount > CallCount) {
+            S.F->setEntryCount(OriginalCount - CallCount);
+          } else {
+            S.F->setEntryCount(0);
+          }
+        }
+      }
     }
 
     Clones.push_back(S.Clone);
@@ -1043,6 +1059,9 @@ Function *FunctionSpecializer::createSpecialization(Function *F,
   // clone must.
   Clone->setLinkage(GlobalValue::InternalLinkage);
 
+  if (F->getEntryCount())
+    Clone->setEntryCount(0);
+
   // Initialize the lattice state of the arguments of the function clone,
   // marking the argument on which we specialized the function constant
   // with the given value.
diff --git a/llvm/test/Transforms/FunctionSpecialization/profile-counts.ll b/llvm/test/Transforms/FunctionSpecialization/profile-counts.ll
new file mode 100644
index 0000000000000..4a2ad4ff9fe90
--- /dev/null
+++ b/llvm/test/Transforms/FunctionSpecialization/profile-counts.ll
@@ -0,0 +1,52 @@
+; RUN: opt -passes="ipsccp<func-spec>" -force-specialization -S < %s | FileCheck %s
+target datalayout = "e-m:e-i8:8:32-i16:16:32-i64:64-i128:128-n32:64-S128"
+
+@A = external dso_local constant i32, align 4
+@B = external dso_local constant i32, align 4
+
+; CHECK: define dso_local i32 @bar(i32 %x, i32 %y, ptr %z) !prof ![[BAR_PROF:[0-9]]] {
+define dso_local i32 @bar(i32 %x, i32 %y, ptr %z) !prof !0 {
+entry:
+  %tobool = icmp ne i32 %x, 0
+; CHECK: br i1 %tobool, label %if.then, label %if.else, !prof ![[BRANCH_PROF:[0-9]]]
+  br i1 %tobool, label %if.then, label %if.else, !prof !1
+
+if.then:
+; CHECK: if.then:
+; CHECK: call i32 @foo.specialized.1(i32 %x, ptr @A)
+  %call = call i32 @foo(i32 %x, ptr @A)
+  br label %return
+
+if.else:
+; CHECK: if.else:
+; CHECK: call i32 @foo.specialized.2(i32 %y, ptr @B)
+  %call1 = call i32 @foo(i32 %y, ptr @B)
+  br label %return
+
+; CHECK: return:
+; CHECK: %call2 = call i32 @foo(i32 %x, ptr %z)
+return:
+  %retval.0 = phi i32 [ %call, %if.then ], [ %call1, %if.else ]
+  %call2 = call i32 @foo(i32 %x, ptr %z);
+  %add = add i32 %retval.0, %call2
+  ret i32 %add
+}
+
+; CHECK: define internal i32 @foo(i32 %x, ptr %b) !prof ![[FOO_UNSPEC_PROF:[0-9]]]
+; CHECK: define internal i32 @foo.specialized.1(i32 %x, ptr %b) !prof ![[FOO_SPEC_1_PROF:[0-9]]]
+; CHECK: define internal i32 @foo.specialized.2(i32 %x, ptr %b) !prof ![[FOO_SPEC_2_PROF:[0-9]]]
+define internal i32 @foo(i32 %x, ptr %b) !prof !2 {
+entry:
+  %0 = load i32, ptr %b, align 4
+  %add = add nsw i32 %x, %0
+  ret i32 %add
+}
+
+; CHECK: ![[BAR_PROF]] = !{!"function_entry_count", i64 1000}
+; CHECK: ![[BRANCH_PROF]] = !{!"branch_weights", i32 1, i32 3}
+; CHECK: ![[FOO_UNSPEC_PROF]] =  !{!"function_entry_count", i64 234}
+; CHECK: ![[FOO_SPEC_1_PROF]] = !{!"function_entry_count", i64 250}
+; CHECK: ![[FOO_SPEC_2_PROF]] = !{!"function_entry_count", i64 750}
+!0 = !{!"function_entry_count", i64 1000}
+!1 = !{!"branch_weights", i32 1, i32 3}
+!2 = !{!"function_entry_count", i64 1234}
 | 
    
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
just nits. thanks!
The previous fix in llvm#157768 had a bug; instead of subtracting the original function's call count per call site of a specialization, we were subtracting the running total of the specialization's calls. Tracking issue: llvm#147390
| 
           @alanzhao1 this is causing clang crashes when building clang itself. Reduced test case: https://gcc.godbolt.org/z/h5qjdW4c9  | 
    
… to zero We were hitting an assert discovered in #157768 (comment)
          
 Should be fixed by 7d748a9  | 
    
…count equal to zero We were hitting an assert discovered in llvm/llvm-project#157768 (comment)
A function that has been specialized will have its function entry counts preserved as follows:
BlockFrequencyInfo.Tracking issue: #147390