@@ -804,26 +804,51 @@ static BranchInst *getExpectedExitLoopLatchBranch(Loop *L) {
804
804
return LatchBR;
805
805
}
806
806
807
- // / Return the estimated trip count for any exiting branch which dominates
808
- // / the loop latch.
809
- static std::optional<unsigned > getEstimatedTripCount (BranchInst *ExitingBranch,
810
- Loop *L,
811
- uint64_t &OrigExitWeight) {
807
+ struct DbgLoop {
808
+ const Loop *L;
809
+ explicit DbgLoop (const Loop *L) : L(L) {}
810
+ };
811
+
812
+ #ifndef NDEBUG
813
+ static inline raw_ostream &operator <<(raw_ostream &OS, DbgLoop D) {
814
+ OS << " function " ;
815
+ D.L ->getHeader ()->getParent ()->printAsOperand (OS, /* PrintType=*/ false );
816
+ return OS << " " << *D.L ;
817
+ }
818
+ #endif // NDEBUG
819
+
820
+ static std::optional<unsigned > estimateLoopTripCount (Loop *L) {
821
+ // Currently we take the estimate exit count only from the loop latch,
822
+ // ignoring other exiting blocks. This can overestimate the trip count
823
+ // if we exit through another exit, but can never underestimate it.
824
+ // TODO: incorporate information from other exits
825
+ BranchInst *ExitingBranch = getExpectedExitLoopLatchBranch (L);
826
+ if (!ExitingBranch) {
827
+ LLVM_DEBUG (dbgs () << " estimateLoopTripCount: Failed to find exiting "
828
+ << " latch branch of required form in " << DbgLoop (L)
829
+ << " \n " );
830
+ return std::nullopt ;
831
+ }
832
+
812
833
// To estimate the number of times the loop body was executed, we want to
813
834
// know the number of times the backedge was taken, vs. the number of times
814
835
// we exited the loop.
815
836
uint64_t LoopWeight, ExitWeight;
816
- if (!extractBranchWeights (*ExitingBranch, LoopWeight, ExitWeight))
837
+ if (!extractBranchWeights (*ExitingBranch, LoopWeight, ExitWeight)) {
838
+ LLVM_DEBUG (dbgs () << " estimateLoopTripCount: Failed to extract branch "
839
+ << " weights for " << DbgLoop (L) << " \n " );
817
840
return std::nullopt ;
841
+ }
818
842
819
843
if (L->contains (ExitingBranch->getSuccessor (1 )))
820
844
std::swap (LoopWeight, ExitWeight);
821
845
822
- if (!ExitWeight)
846
+ if (!ExitWeight) {
823
847
// Don't have a way to return predicated infinite
848
+ LLVM_DEBUG (dbgs () << " estimateLoopTripCount: Failed because of zero exit "
849
+ << " probability for " << DbgLoop (L) << " \n " );
824
850
return std::nullopt ;
825
-
826
- OrigExitWeight = ExitWeight;
851
+ }
827
852
828
853
// Estimated exit count is a ratio of the loop weight by the weight of the
829
854
// edge exiting the loop, rounded to nearest.
@@ -834,43 +859,102 @@ static std::optional<unsigned> getEstimatedTripCount(BranchInst *ExitingBranch,
834
859
return std::numeric_limits<unsigned >::max ();
835
860
836
861
// Estimated trip count is one plus estimated exit count.
837
- return ExitCount + 1 ;
862
+ uint64_t TC = ExitCount + 1 ;
863
+ LLVM_DEBUG (dbgs () << " estimateLoopTripCount: Estimated trip count of " << TC
864
+ << " for " << DbgLoop (L) << " \n " );
865
+ return TC;
838
866
}
839
867
840
868
std::optional<unsigned >
841
869
llvm::getLoopEstimatedTripCount (Loop *L,
842
870
unsigned *EstimatedLoopInvocationWeight) {
843
- // Currently we take the estimate exit count only from the loop latch,
844
- // ignoring other exiting blocks. This can overestimate the trip count
845
- // if we exit through another exit, but can never underestimate it.
846
- // TODO: incorporate information from other exits
847
- if (BranchInst *LatchBranch = getExpectedExitLoopLatchBranch (L)) {
848
- uint64_t ExitWeight;
849
- if (std::optional<uint64_t > EstTripCount =
850
- getEstimatedTripCount (LatchBranch, L, ExitWeight)) {
851
- if (EstimatedLoopInvocationWeight)
852
- *EstimatedLoopInvocationWeight = ExitWeight;
853
- return *EstTripCount;
854
- }
871
+ // If EstimatedLoopInvocationWeight, we do not support this loop if
872
+ // getExpectedExitLoopLatchBranch returns nullptr.
873
+ //
874
+ // FIXME: Also, this is a stop-gap solution for nested loops. It avoids
875
+ // mistaking LLVMLoopEstimatedTripCount metadata to be for an outer loop when
876
+ // it was created for an inner loop. The problem is that loop metadata is
877
+ // attached to the branch instruction in the loop latch block, but that can be
878
+ // shared by the loops. A solution is to attach loop metadata to loop headers
879
+ // instead, but that would be a large change to LLVM.
880
+ //
881
+ // Until that happens, we work around the problem as follows.
882
+ // getExpectedExitLoopLatchBranch (which also guards
883
+ // setLoopEstimatedTripCount) returns nullptr for a loop unless the loop has
884
+ // one latch and that latch has exactly two successors one of which is an exit
885
+ // from the loop. If the latch is shared by nested loops, then that condition
886
+ // might hold for the inner loop but cannot hold for the outer loop:
887
+ // - Because the latch is shared, it must have at least two successors: the
888
+ // inner loop header and the outer loop header, which is also an exit for
889
+ // the inner loop. That satisifies the condition for the inner loop.
890
+ // - To satsify the condition for the outer loop, the latch must have a third
891
+ // successor that is an exit for the outer loop. But that violates the
892
+ // condition for both loops.
893
+ BranchInst *ExitingBranch = getExpectedExitLoopLatchBranch (L);
894
+ if (!ExitingBranch)
895
+ return std::nullopt ;
896
+
897
+ // If requested, either compute *EstimatedLoopInvocationWeight or return
898
+ // nullopt if cannot.
899
+ //
900
+ // TODO: Eventually, once all passes have migrated away from setting branch
901
+ // weights to indicate estimated trip counts, this function will drop the
902
+ // EstimatedLoopInvocationWeight parameter.
903
+ if (EstimatedLoopInvocationWeight) {
904
+ uint64_t LoopWeight = 0 , ExitWeight = 0 ; // Inits expected to be unused.
905
+ if (!extractBranchWeights (*ExitingBranch, LoopWeight, ExitWeight))
906
+ return std::nullopt ;
907
+ if (L->contains (ExitingBranch->getSuccessor (1 )))
908
+ std::swap (LoopWeight, ExitWeight);
909
+ if (!ExitWeight)
910
+ return std::nullopt ;
911
+ *EstimatedLoopInvocationWeight = ExitWeight;
855
912
}
856
- return std::nullopt ;
913
+
914
+ // Return the estimated trip count from metadata unless the metadata is
915
+ // missing or has no value.
916
+ if (auto TC = getOptionalIntLoopAttribute (L, LLVMLoopEstimatedTripCount)) {
917
+ LLVM_DEBUG (dbgs () << " getLoopEstimatedTripCount: "
918
+ << LLVMLoopEstimatedTripCount << " metadata has trip "
919
+ << " count of " << *TC << " for " << DbgLoop (L) << " \n " );
920
+ return TC;
921
+ }
922
+
923
+ // Estimate the trip count from latch branch weights.
924
+ return estimateLoopTripCount (L);
857
925
}
858
926
859
- bool llvm::setLoopEstimatedTripCount (Loop *L, unsigned EstimatedTripCount,
860
- unsigned EstimatedloopInvocationWeight) {
861
- // At the moment, we currently support changing the estimate trip count of
862
- // the latch branch only. We could extend this API to manipulate estimated
863
- // trip counts for any exit.
927
+ bool llvm::setLoopEstimatedTripCount (
928
+ Loop *L, unsigned EstimatedTripCount,
929
+ std::optional<unsigned > EstimatedloopInvocationWeight) {
930
+ // If EstimatedLoopInvocationWeight, we do not support this loop if
931
+ // getExpectedExitLoopLatchBranch returns nullptr.
932
+ //
933
+ // FIXME: See comments in getLoopEstimatedTripCount for why this is required
934
+ // here regardless of EstimatedLoopInvocationWeight.
864
935
BranchInst *LatchBranch = getExpectedExitLoopLatchBranch (L);
865
936
if (!LatchBranch)
866
937
return false ;
867
938
939
+ // Set the metadata.
940
+ addStringMetadataToLoop (L, LLVMLoopEstimatedTripCount, EstimatedTripCount);
941
+
942
+ // At the moment, we currently support changing the estimated trip count in
943
+ // the latch branch's branch weights only. We could extend this API to
944
+ // manipulate estimated trip counts for any exit.
945
+ //
946
+ // TODO: Eventually, once all passes have migrated away from setting branch
947
+ // weights to indicate estimated trip counts, we will not set branch weights
948
+ // here at all.
949
+ if (!EstimatedloopInvocationWeight)
950
+ return true ;
951
+
868
952
// Calculate taken and exit weights.
869
953
unsigned LatchExitWeight = 0 ;
870
954
unsigned BackedgeTakenWeight = 0 ;
871
955
872
- if (EstimatedTripCount > 0 ) {
873
- LatchExitWeight = EstimatedloopInvocationWeight;
956
+ if (EstimatedTripCount != 0 ) {
957
+ LatchExitWeight = * EstimatedloopInvocationWeight;
874
958
BackedgeTakenWeight = (EstimatedTripCount - 1 ) * LatchExitWeight;
875
959
}
876
960
0 commit comments