11import random
2+ from datetime import datetime
23from enum import Enum
34from http import HTTPStatus
45from typing import Any , Dict , List , Optional , Tuple
@@ -69,6 +70,25 @@ class Config:
6970 orm_mode = True
7071
7172
73+ class TreeMessageCountStats (pydantic .BaseModel ):
74+ message_tree_id : UUID
75+ state : str
76+ depth : int
77+ oldest : datetime
78+ youngest : datetime
79+ count : int
80+ goal_tree_size : int
81+
82+ @property
83+ def completed (self ) -> int :
84+ return self .count / self .goal_tree_size
85+
86+
87+ class TreeManagerStats (pydantic .BaseModel ):
88+ state_counts : dict [str , int ]
89+ message_counts : list [TreeMessageCountStats ]
90+
91+
7292class TreeManager :
7393 _all_text_labels = list (map (lambda x : x .value , protocol_schema .TextLabel ))
7494
@@ -924,6 +944,40 @@ def _insert_default_state(
924944 active = True ,
925945 )
926946
947+ def tree_counts_by_state (self ) -> dict [str , int ]:
948+ qry = self .db .query (
949+ MessageTreeState .state , func .count (MessageTreeState .message_tree_id ).label ("count" )
950+ ).group_by (MessageTreeState .state )
951+ return {x ["state" ]: x ["count" ] for x in qry }
952+
953+ def tree_message_count_stats (self , only_active : bool = True ) -> list [TreeMessageCountStats ]:
954+ qry = (
955+ self .db .query (
956+ MessageTreeState .message_tree_id ,
957+ func .max (Message .depth ).label ("depth" ),
958+ func .min (Message .created_date ).label ("oldest" ),
959+ func .max (Message .created_date ).label ("youngest" ),
960+ func .count (Message .id ).label ("count" ),
961+ MessageTreeState .goal_tree_size ,
962+ MessageTreeState .state ,
963+ )
964+ .select_from (MessageTreeState )
965+ .join (Message , MessageTreeState .message_tree_id == Message .message_tree_id )
966+ .filter (not_ (Message .deleted ))
967+ .group_by (MessageTreeState .message_tree_id )
968+ )
969+
970+ if only_active :
971+ qry .filter (MessageTreeState .active )
972+
973+ return [TreeMessageCountStats (** x ) for x in qry ]
974+
975+ def stats (self ) -> TreeManagerStats :
976+ return TreeManagerStats (
977+ state_counts = self .tree_counts_by_state (),
978+ message_counts = self .tree_message_count_stats (only_active = True ),
979+ )
980+
927981
928982if __name__ == "__main__" :
929983 from oasst_backend .api .deps import api_auth
@@ -942,7 +996,7 @@ def _insert_default_state(
942996
943997 # print("query_num_active_trees", tm.query_num_active_trees())
944998 # print("query_incomplete_rankings", tm.query_incomplete_rankings())
945- print ("query_replies_need_review" , tm .query_replies_need_review ())
999+ # print("query_replies_need_review", tm.query_replies_need_review())
9461000 # print("query_incomplete_initial_prompt_reviews", tm.query_prompts_need_review())
9471001 # print("query_extendible_trees", tm.query_extendible_trees())
9481002 # print("query_extendible_parents", tm.query_extendible_parents())
0 commit comments