diff --git a/src/git_draft/bots/common.py b/src/git_draft/bots/common.py index 119324a..a85e334 100644 --- a/src/git_draft/bots/common.py +++ b/src/git_draft/bots/common.py @@ -9,18 +9,53 @@ @dataclasses.dataclass(frozen=True) class Goal: + """Bot request""" + prompt: str timeout: float | None -@dataclasses.dataclass(frozen=True) +@dataclasses.dataclass class Action: + """End-of-action statistics + + This dataclass is not frozen to allow bot implementors to populate its + fields incrementally. + """ + title: str | None = None + request_count: int | None = None + token_count: int | None = None + + def increment_request_count(self, n=1, init=False) -> None: + self._increment("request_count", n, init) + + def increment_token_count(self, n, init=False) -> None: + self._increment("token_count", n, init) + + def _increment(self, attr: str, count: int, init: bool) -> None: + if (value := getattr(self, attr)) is None: + if not init: + raise ValueError(f"Uninitialized action {attr}") + setattr(self, attr, count) + else: + setattr(self, attr, value + count) class Bot: + """Code assistant bot""" + @classmethod def state_folder_path(cls, ensure_exists=False) -> Path: + """Returns a path unique to this bot class + + The path can be used to store data specific to this bot implementation. + For example a bot interacting with a stateful API may wish to store IDs + between runs, and use this folder to do so. + + Args: + ensure_exists: Create the folder if it does not exist. + """ name = qualified_class_name(cls) path = ensure_state_home() / "bots" / name if ensure_exists: @@ -28,4 +63,5 @@ def state_folder_path(cls, ensure_exists=False) -> Path: return path def act(self, goal: Goal, toolbox: Toolbox) -> Action: + """Runs the bot, striving to achieve the goal with the given toolbox""" raise NotImplementedError() diff --git a/src/git_draft/bots/openai.py b/src/git_draft/bots/openai.py index 318de13..3b8f4db 100644 --- a/src/git_draft/bots/openai.py +++ b/src/git_draft/bots/openai.py @@ -195,6 +195,7 @@ def act(self, goal: Goal, toolbox: Toolbox) -> Action: {"role": "user", "content": goal.prompt}, ] + request_count = 0 while True: response = self._client.chat.completions.create( model=self._model, @@ -203,6 +204,7 @@ def act(self, goal: Goal, toolbox: Toolbox) -> Action: tool_choice="required", ) assert len(response.choices) == 1 + request_count += 1 done = True calls = response.choices[0].message.tool_calls @@ -214,7 +216,7 @@ def act(self, goal: Goal, toolbox: Toolbox) -> Action: if done: break - return Action() + return Action(request_count=request_count) class _CompletionsToolHandler(_ToolHandler[str | None]): @@ -262,41 +264,58 @@ def create(cls, client: openai.OpenAI, model: str) -> Self: return cls(client, assistant_id) def act(self, goal: Goal, toolbox: Toolbox) -> Action: - # TODO: Use timeout. thread = self._client.beta.threads.create() - self._client.beta.threads.messages.create( thread_id=thread.id, role="user", content=goal.prompt, ) + # We intentionally do not count the two requests above, to focus on + # "data requests" only. + action = Action(request_count=0, token_count=0) with self._client.beta.threads.runs.stream( thread_id=thread.id, assistant_id=self._assistant_id, - event_handler=_EventHandler(self._client, toolbox), + event_handler=_EventHandler(self._client, toolbox, action), ) as stream: stream.until_done() - - return Action() + return action class _EventHandler(openai.AssistantEventHandler): - def __init__(self, client: openai.Client, toolbox: Toolbox) -> None: + def __init__( + self, client: openai.Client, toolbox: Toolbox, action: Action + ) -> None: super().__init__() self._client = client self._toolbox = toolbox + self._action = action + self._action.increment_request_count() - def clone(self) -> Self: - return self.__class__(self._client, self._toolbox) + def _clone(self) -> Self: + return self.__class__(self._client, self._toolbox, self._action) @override - def on_event(self, event: Any) -> None: - _logger.debug("Event: %s", event) + def on_event(self, event: openai.types.beta.AssistantStreamEvent) -> None: if event.event == "thread.run.requires_action": run_id = event.data.id # Retrieve the run ID from the event data self._handle_action(run_id, event.data) - # TODO: Handle (log?) other events. + elif event.event == "thread.run.completed": + _logger.info("Threads run completed. [usage=%s]", event.data.usage) + else: + _logger.debug("Threads event: %s", event) + + @override + def on_run_step_done( + self, run_step: openai.types.beta.threads.runs.RunStep + ) -> None: + usage = run_step.usage + if usage: + _logger.debug("Threads run step usage: %s", usage) + self._action.increment_token_count(usage.total_tokens) + else: + _logger.warning("Missing usage in threads run step") def _handle_action(self, run_id: str, data: Any) -> None: tool_outputs = list[Any]() @@ -310,7 +329,7 @@ def _handle_action(self, run_id: str, data: Any) -> None: thread_id=run.thread_id, run_id=run.id, tool_outputs=tool_outputs, - event_handler=self.clone(), + event_handler=self._clone(), ) as stream: stream.until_done() diff --git a/src/git_draft/drafter.py b/src/git_draft/drafter.py index 7cd78d9..bfd9660 100644 --- a/src/git_draft/drafter.py +++ b/src/git_draft/drafter.py @@ -82,6 +82,9 @@ def generate_draft( sync: bool = False, timeout: float | None = None, ) -> str: + if timeout is not None: + raise NotImplementedError() # TODO + if self._repo.is_dirty(working_tree=False): if not reset: raise ValueError("Please commit or reset any staged changes") @@ -148,6 +151,8 @@ def generate_draft( "bot_name": bot_name, "bot_class": qualified_class_name(bot.__class__), "walltime": walltime, + "request_count": action.request_count, + "token_count": action.token_count, }, ) cursor.executemany( diff --git a/src/git_draft/queries/add-action.sql b/src/git_draft/queries/add-action.sql index 3e58859..2614c9e 100644 --- a/src/git_draft/queries/add-action.sql +++ b/src/git_draft/queries/add-action.sql @@ -1,2 +1,16 @@ -insert into actions (commit_sha, prompt_id, bot_name, bot_class, walltime) - values (:commit_sha, :prompt_id, :bot_name, :bot_class, :walltime); +insert into actions ( + commit_sha, + prompt_id, + bot_name, + bot_class, + walltime, + request_count, + token_count) + values ( + :commit_sha, + :prompt_id, + :bot_name, + :bot_class, + :walltime, + :request_count, + :token_count); diff --git a/src/git_draft/queries/create-tables.sql b/src/git_draft/queries/create-tables.sql index c0cf2da..8c85ae9 100644 --- a/src/git_draft/queries/create-tables.sql +++ b/src/git_draft/queries/create-tables.sql @@ -23,6 +23,8 @@ create table if not exists actions ( bot_name text, bot_class text not null, walltime real not null, + request_count int, + token_count int, foreign key (prompt_id) references prompts(id) on delete cascade ) without rowid; diff --git a/src/git_draft/queries/list-drafts.sql b/src/git_draft/queries/list-drafts.sql index 3e188fb..1b0747a 100644 --- a/src/git_draft/queries/list-drafts.sql +++ b/src/git_draft/queries/list-drafts.sql @@ -2,9 +2,11 @@ select datetime(min(p.created_at), 'localtime') as created, 'draft/' || b.suffix as branch, min(b.origin_branch) as origin, - count(p.id) as prompts + count(p.id) as prompts, + sum(a.token_count) as tokens from branches as b join prompts as p on b.suffix = p.branch_suffix + join actions as a on p.id = a.prompt_id where b.repo_path = :repo_path group by b.suffix order by created desc; diff --git a/tests/git_draft/bots/common_test.py b/tests/git_draft/bots/common_test.py index 3652388..69ac56c 100644 --- a/tests/git_draft/bots/common_test.py +++ b/tests/git_draft/bots/common_test.py @@ -1,3 +1,5 @@ +import pytest + import git_draft.bots.common as sut @@ -8,3 +10,23 @@ class FakeBot(sut.Bot): class TestBot: def test_state_folder_path(self) -> None: assert "bots.common_test.FakeBot" in str(FakeBot.state_folder_path()) + + +class TestAction: + def test_increment_noinit(self) -> None: + action = sut.Action() + with pytest.raises(ValueError): + action.increment_request_count() + + def test_increment_request_count(self) -> None: + action = sut.Action() + action.increment_request_count(init=True) + assert action.request_count == 1 + action.increment_request_count() + assert action.request_count == 2 + + def test_increment_token_count(self) -> None: + action = sut.Action() + action.increment_token_count(5, init=True) + action.increment_token_count(3) + assert action.token_count == 8