Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
38 changes: 37 additions & 1 deletion src/git_draft/bots/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,23 +9,59 @@

@dataclasses.dataclass(frozen=True)
class Goal:
"""Bot request"""

prompt: str
timeout: float | None


@dataclasses.dataclass(frozen=True)
@dataclasses.dataclass
class Action:
"""End-of-action statistics

This dataclass is not frozen to allow bot implementors to populate its
fields incrementally.
"""

title: str | None = None
request_count: int | None = None
token_count: int | None = None

def increment_request_count(self, n=1, init=False) -> None:
self._increment("request_count", n, init)

def increment_token_count(self, n, init=False) -> None:
self._increment("token_count", n, init)

def _increment(self, attr: str, count: int, init: bool) -> None:
if (value := getattr(self, attr)) is None:
if not init:
raise ValueError(f"Uninitialized action {attr}")
setattr(self, attr, count)
else:
setattr(self, attr, value + count)


class Bot:
"""Code assistant bot"""

@classmethod
def state_folder_path(cls, ensure_exists=False) -> Path:
"""Returns a path unique to this bot class

The path can be used to store data specific to this bot implementation.
For example a bot interacting with a stateful API may wish to store IDs
between runs, and use this folder to do so.

Args:
ensure_exists: Create the folder if it does not exist.
"""
name = qualified_class_name(cls)
path = ensure_state_home() / "bots" / name
if ensure_exists:
path.mkdir(parents=True, exist_ok=True)
return path

def act(self, goal: Goal, toolbox: Toolbox) -> Action:
"""Runs the bot, striving to achieve the goal with the given toolbox"""
raise NotImplementedError()
45 changes: 32 additions & 13 deletions src/git_draft/bots/openai.py
Original file line number Diff line number Diff line change
Expand Up @@ -195,6 +195,7 @@ def act(self, goal: Goal, toolbox: Toolbox) -> Action:
{"role": "user", "content": goal.prompt},
]

request_count = 0
while True:
response = self._client.chat.completions.create(
model=self._model,
Expand All @@ -203,6 +204,7 @@ def act(self, goal: Goal, toolbox: Toolbox) -> Action:
tool_choice="required",
)
assert len(response.choices) == 1
request_count += 1

done = True
calls = response.choices[0].message.tool_calls
Expand All @@ -214,7 +216,7 @@ def act(self, goal: Goal, toolbox: Toolbox) -> Action:
if done:
break

return Action()
return Action(request_count=request_count)


class _CompletionsToolHandler(_ToolHandler[str | None]):
Expand Down Expand Up @@ -262,41 +264,58 @@ def create(cls, client: openai.OpenAI, model: str) -> Self:
return cls(client, assistant_id)

def act(self, goal: Goal, toolbox: Toolbox) -> Action:
# TODO: Use timeout.
thread = self._client.beta.threads.create()

self._client.beta.threads.messages.create(
thread_id=thread.id,
role="user",
content=goal.prompt,
)

# We intentionally do not count the two requests above, to focus on
# "data requests" only.
action = Action(request_count=0, token_count=0)
with self._client.beta.threads.runs.stream(
thread_id=thread.id,
assistant_id=self._assistant_id,
event_handler=_EventHandler(self._client, toolbox),
event_handler=_EventHandler(self._client, toolbox, action),
) as stream:
stream.until_done()

return Action()
return action


class _EventHandler(openai.AssistantEventHandler):
def __init__(self, client: openai.Client, toolbox: Toolbox) -> None:
def __init__(
self, client: openai.Client, toolbox: Toolbox, action: Action
) -> None:
super().__init__()
self._client = client
self._toolbox = toolbox
self._action = action
self._action.increment_request_count()

def clone(self) -> Self:
return self.__class__(self._client, self._toolbox)
def _clone(self) -> Self:
return self.__class__(self._client, self._toolbox, self._action)

@override
def on_event(self, event: Any) -> None:
_logger.debug("Event: %s", event)
def on_event(self, event: openai.types.beta.AssistantStreamEvent) -> None:
if event.event == "thread.run.requires_action":
run_id = event.data.id # Retrieve the run ID from the event data
self._handle_action(run_id, event.data)
# TODO: Handle (log?) other events.
elif event.event == "thread.run.completed":
_logger.info("Threads run completed. [usage=%s]", event.data.usage)
else:
_logger.debug("Threads event: %s", event)

@override
def on_run_step_done(
self, run_step: openai.types.beta.threads.runs.RunStep
) -> None:
usage = run_step.usage
if usage:
_logger.debug("Threads run step usage: %s", usage)
self._action.increment_token_count(usage.total_tokens)
else:
_logger.warning("Missing usage in threads run step")

def _handle_action(self, run_id: str, data: Any) -> None:
tool_outputs = list[Any]()
Expand All @@ -310,7 +329,7 @@ def _handle_action(self, run_id: str, data: Any) -> None:
thread_id=run.thread_id,
run_id=run.id,
tool_outputs=tool_outputs,
event_handler=self.clone(),
event_handler=self._clone(),
) as stream:
stream.until_done()

Expand Down
5 changes: 5 additions & 0 deletions src/git_draft/drafter.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,6 +82,9 @@ def generate_draft(
sync: bool = False,
timeout: float | None = None,
) -> str:
if timeout is not None:
raise NotImplementedError() # TODO

if self._repo.is_dirty(working_tree=False):
if not reset:
raise ValueError("Please commit or reset any staged changes")
Expand Down Expand Up @@ -148,6 +151,8 @@ def generate_draft(
"bot_name": bot_name,
"bot_class": qualified_class_name(bot.__class__),
"walltime": walltime,
"request_count": action.request_count,
"token_count": action.token_count,
},
)
cursor.executemany(
Expand Down
18 changes: 16 additions & 2 deletions src/git_draft/queries/add-action.sql
Original file line number Diff line number Diff line change
@@ -1,2 +1,16 @@
insert into actions (commit_sha, prompt_id, bot_name, bot_class, walltime)
values (:commit_sha, :prompt_id, :bot_name, :bot_class, :walltime);
insert into actions (
commit_sha,
prompt_id,
bot_name,
bot_class,
walltime,
request_count,
token_count)
values (
:commit_sha,
:prompt_id,
:bot_name,
:bot_class,
:walltime,
:request_count,
:token_count);
2 changes: 2 additions & 0 deletions src/git_draft/queries/create-tables.sql
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,8 @@ create table if not exists actions (
bot_name text,
bot_class text not null,
walltime real not null,
request_count int,
token_count int,
foreign key (prompt_id) references prompts(id) on delete cascade
) without rowid;

Expand Down
4 changes: 3 additions & 1 deletion src/git_draft/queries/list-drafts.sql
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,11 @@ select
datetime(min(p.created_at), 'localtime') as created,
'draft/' || b.suffix as branch,
min(b.origin_branch) as origin,
count(p.id) as prompts
count(p.id) as prompts,
sum(a.token_count) as tokens
from branches as b
join prompts as p on b.suffix = p.branch_suffix
join actions as a on p.id = a.prompt_id
where b.repo_path = :repo_path
group by b.suffix
order by created desc;
22 changes: 22 additions & 0 deletions tests/git_draft/bots/common_test.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
import pytest

import git_draft.bots.common as sut


Expand All @@ -8,3 +10,23 @@ class FakeBot(sut.Bot):
class TestBot:
def test_state_folder_path(self) -> None:
assert "bots.common_test.FakeBot" in str(FakeBot.state_folder_path())


class TestAction:
def test_increment_noinit(self) -> None:
action = sut.Action()
with pytest.raises(ValueError):
action.increment_request_count()

def test_increment_request_count(self) -> None:
action = sut.Action()
action.increment_request_count(init=True)
assert action.request_count == 1
action.increment_request_count()
assert action.request_count == 2

def test_increment_token_count(self) -> None:
action = sut.Action()
action.increment_token_count(5, init=True)
action.increment_token_count(3)
assert action.token_count == 8