Skip to content

Commit 718faa0

Browse files
authored
Add combined TreeManager stats endpoint (#816)
1 parent acaa56e commit 718faa0

File tree

2 files changed

+87
-1
lines changed

2 files changed

+87
-1
lines changed

backend/oasst_backend/api/v1/stats.py

Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
from oasst_backend.api import deps
33
from oasst_backend.models import ApiClient
44
from oasst_backend.prompt_repository import PromptRepository
5+
from oasst_backend.tree_manager import TreeManager, TreeManagerStats, TreeMessageCountStats
56
from oasst_shared.schemas import protocol
67
from sqlmodel import Session
78

@@ -15,3 +16,34 @@ def get_message_stats(
1516
):
1617
pr = PromptRepository(db, api_client)
1718
return pr.get_stats()
19+
20+
21+
@router.get("/tree_manager/state_counts", response_model=dict[str, int])
22+
def get_tree_manager__state_counts(
23+
db: Session = Depends(deps.get_db),
24+
api_client: ApiClient = Depends(deps.get_trusted_api_client),
25+
):
26+
pr = PromptRepository(db, api_client)
27+
tm = TreeManager(db, pr)
28+
return tm.tree_counts_by_state()
29+
30+
31+
@router.get("/tree_manager/message_counts", response_model=list[TreeMessageCountStats])
32+
def get_tree_manager__message_counts(
33+
only_active: bool = True,
34+
db: Session = Depends(deps.get_db),
35+
api_client: ApiClient = Depends(deps.get_trusted_api_client),
36+
):
37+
pr = PromptRepository(db, api_client)
38+
tm = TreeManager(db, pr)
39+
return tm.tree_message_count_stats(only_active=only_active)
40+
41+
42+
@router.get("/tree_manager", response_model=TreeManagerStats)
43+
def get_tree_manager__stats(
44+
db: Session = Depends(deps.get_db),
45+
api_client: ApiClient = Depends(deps.get_trusted_api_client),
46+
):
47+
pr = PromptRepository(db, api_client)
48+
tm = TreeManager(db, pr)
49+
return tm.stats()

backend/oasst_backend/tree_manager.py

Lines changed: 55 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
import random
2+
from datetime import datetime
23
from enum import Enum
34
from http import HTTPStatus
45
from typing import Any, Dict, List, Optional, Tuple
@@ -69,6 +70,25 @@ class Config:
6970
orm_mode = True
7071

7172

73+
class TreeMessageCountStats(pydantic.BaseModel):
74+
message_tree_id: UUID
75+
state: str
76+
depth: int
77+
oldest: datetime
78+
youngest: datetime
79+
count: int
80+
goal_tree_size: int
81+
82+
@property
83+
def completed(self) -> int:
84+
return self.count / self.goal_tree_size
85+
86+
87+
class TreeManagerStats(pydantic.BaseModel):
88+
state_counts: dict[str, int]
89+
message_counts: list[TreeMessageCountStats]
90+
91+
7292
class TreeManager:
7393
_all_text_labels = list(map(lambda x: x.value, protocol_schema.TextLabel))
7494

@@ -924,6 +944,40 @@ def _insert_default_state(
924944
active=True,
925945
)
926946

947+
def tree_counts_by_state(self) -> dict[str, int]:
948+
qry = self.db.query(
949+
MessageTreeState.state, func.count(MessageTreeState.message_tree_id).label("count")
950+
).group_by(MessageTreeState.state)
951+
return {x["state"]: x["count"] for x in qry}
952+
953+
def tree_message_count_stats(self, only_active: bool = True) -> list[TreeMessageCountStats]:
954+
qry = (
955+
self.db.query(
956+
MessageTreeState.message_tree_id,
957+
func.max(Message.depth).label("depth"),
958+
func.min(Message.created_date).label("oldest"),
959+
func.max(Message.created_date).label("youngest"),
960+
func.count(Message.id).label("count"),
961+
MessageTreeState.goal_tree_size,
962+
MessageTreeState.state,
963+
)
964+
.select_from(MessageTreeState)
965+
.join(Message, MessageTreeState.message_tree_id == Message.message_tree_id)
966+
.filter(not_(Message.deleted))
967+
.group_by(MessageTreeState.message_tree_id)
968+
)
969+
970+
if only_active:
971+
qry.filter(MessageTreeState.active)
972+
973+
return [TreeMessageCountStats(**x) for x in qry]
974+
975+
def stats(self) -> TreeManagerStats:
976+
return TreeManagerStats(
977+
state_counts=self.tree_counts_by_state(),
978+
message_counts=self.tree_message_count_stats(only_active=True),
979+
)
980+
927981

928982
if __name__ == "__main__":
929983
from oasst_backend.api.deps import api_auth
@@ -942,7 +996,7 @@ def _insert_default_state(
942996

943997
# print("query_num_active_trees", tm.query_num_active_trees())
944998
# print("query_incomplete_rankings", tm.query_incomplete_rankings())
945-
print("query_replies_need_review", tm.query_replies_need_review())
999+
# print("query_replies_need_review", tm.query_replies_need_review())
9461000
# print("query_incomplete_initial_prompt_reviews", tm.query_prompts_need_review())
9471001
# print("query_extendible_trees", tm.query_extendible_trees())
9481002
# print("query_extendible_parents", tm.query_extendible_parents())

0 commit comments

Comments
 (0)