Skip to content

Commit 841c5a3

Browse files
authored
feat: expose request and token counts (#51)
1 parent 359566d commit 841c5a3

File tree

7 files changed

+117
-17
lines changed

7 files changed

+117
-17
lines changed

src/git_draft/bots/common.py

Lines changed: 37 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,23 +9,59 @@
99

1010
@dataclasses.dataclass(frozen=True)
1111
class Goal:
12+
"""Bot request"""
13+
1214
prompt: str
1315
timeout: float | None
1416

1517

16-
@dataclasses.dataclass(frozen=True)
18+
@dataclasses.dataclass
1719
class Action:
20+
"""End-of-action statistics
21+
22+
This dataclass is not frozen to allow bot implementors to populate its
23+
fields incrementally.
24+
"""
25+
1826
title: str | None = None
27+
request_count: int | None = None
28+
token_count: int | None = None
29+
30+
def increment_request_count(self, n=1, init=False) -> None:
31+
self._increment("request_count", n, init)
32+
33+
def increment_token_count(self, n, init=False) -> None:
34+
self._increment("token_count", n, init)
35+
36+
def _increment(self, attr: str, count: int, init: bool) -> None:
37+
if (value := getattr(self, attr)) is None:
38+
if not init:
39+
raise ValueError(f"Uninitialized action {attr}")
40+
setattr(self, attr, count)
41+
else:
42+
setattr(self, attr, value + count)
1943

2044

2145
class Bot:
46+
"""Code assistant bot"""
47+
2248
@classmethod
2349
def state_folder_path(cls, ensure_exists=False) -> Path:
50+
"""Returns a path unique to this bot class
51+
52+
The path can be used to store data specific to this bot implementation.
53+
For example a bot interacting with a stateful API may wish to store IDs
54+
between runs, and use this folder to do so.
55+
56+
Args:
57+
ensure_exists: Create the folder if it does not exist.
58+
"""
2459
name = qualified_class_name(cls)
2560
path = ensure_state_home() / "bots" / name
2661
if ensure_exists:
2762
path.mkdir(parents=True, exist_ok=True)
2863
return path
2964

3065
def act(self, goal: Goal, toolbox: Toolbox) -> Action:
66+
"""Runs the bot, striving to achieve the goal with the given toolbox"""
3167
raise NotImplementedError()

src/git_draft/bots/openai.py

Lines changed: 32 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -195,6 +195,7 @@ def act(self, goal: Goal, toolbox: Toolbox) -> Action:
195195
{"role": "user", "content": goal.prompt},
196196
]
197197

198+
request_count = 0
198199
while True:
199200
response = self._client.chat.completions.create(
200201
model=self._model,
@@ -203,6 +204,7 @@ def act(self, goal: Goal, toolbox: Toolbox) -> Action:
203204
tool_choice="required",
204205
)
205206
assert len(response.choices) == 1
207+
request_count += 1
206208

207209
done = True
208210
calls = response.choices[0].message.tool_calls
@@ -214,7 +216,7 @@ def act(self, goal: Goal, toolbox: Toolbox) -> Action:
214216
if done:
215217
break
216218

217-
return Action()
219+
return Action(request_count=request_count)
218220

219221

220222
class _CompletionsToolHandler(_ToolHandler[str | None]):
@@ -262,41 +264,58 @@ def create(cls, client: openai.OpenAI, model: str) -> Self:
262264
return cls(client, assistant_id)
263265

264266
def act(self, goal: Goal, toolbox: Toolbox) -> Action:
265-
# TODO: Use timeout.
266267
thread = self._client.beta.threads.create()
267-
268268
self._client.beta.threads.messages.create(
269269
thread_id=thread.id,
270270
role="user",
271271
content=goal.prompt,
272272
)
273273

274+
# We intentionally do not count the two requests above, to focus on
275+
# "data requests" only.
276+
action = Action(request_count=0, token_count=0)
274277
with self._client.beta.threads.runs.stream(
275278
thread_id=thread.id,
276279
assistant_id=self._assistant_id,
277-
event_handler=_EventHandler(self._client, toolbox),
280+
event_handler=_EventHandler(self._client, toolbox, action),
278281
) as stream:
279282
stream.until_done()
280-
281-
return Action()
283+
return action
282284

283285

284286
class _EventHandler(openai.AssistantEventHandler):
285-
def __init__(self, client: openai.Client, toolbox: Toolbox) -> None:
287+
def __init__(
288+
self, client: openai.Client, toolbox: Toolbox, action: Action
289+
) -> None:
286290
super().__init__()
287291
self._client = client
288292
self._toolbox = toolbox
293+
self._action = action
294+
self._action.increment_request_count()
289295

290-
def clone(self) -> Self:
291-
return self.__class__(self._client, self._toolbox)
296+
def _clone(self) -> Self:
297+
return self.__class__(self._client, self._toolbox, self._action)
292298

293299
@override
294-
def on_event(self, event: Any) -> None:
295-
_logger.debug("Event: %s", event)
300+
def on_event(self, event: openai.types.beta.AssistantStreamEvent) -> None:
296301
if event.event == "thread.run.requires_action":
297302
run_id = event.data.id # Retrieve the run ID from the event data
298303
self._handle_action(run_id, event.data)
299-
# TODO: Handle (log?) other events.
304+
elif event.event == "thread.run.completed":
305+
_logger.info("Threads run completed. [usage=%s]", event.data.usage)
306+
else:
307+
_logger.debug("Threads event: %s", event)
308+
309+
@override
310+
def on_run_step_done(
311+
self, run_step: openai.types.beta.threads.runs.RunStep
312+
) -> None:
313+
usage = run_step.usage
314+
if usage:
315+
_logger.debug("Threads run step usage: %s", usage)
316+
self._action.increment_token_count(usage.total_tokens)
317+
else:
318+
_logger.warning("Missing usage in threads run step")
300319

301320
def _handle_action(self, run_id: str, data: Any) -> None:
302321
tool_outputs = list[Any]()
@@ -310,7 +329,7 @@ def _handle_action(self, run_id: str, data: Any) -> None:
310329
thread_id=run.thread_id,
311330
run_id=run.id,
312331
tool_outputs=tool_outputs,
313-
event_handler=self.clone(),
332+
event_handler=self._clone(),
314333
) as stream:
315334
stream.until_done()
316335

src/git_draft/drafter.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -82,6 +82,9 @@ def generate_draft(
8282
sync: bool = False,
8383
timeout: float | None = None,
8484
) -> str:
85+
if timeout is not None:
86+
raise NotImplementedError() # TODO
87+
8588
if self._repo.is_dirty(working_tree=False):
8689
if not reset:
8790
raise ValueError("Please commit or reset any staged changes")
@@ -148,6 +151,8 @@ def generate_draft(
148151
"bot_name": bot_name,
149152
"bot_class": qualified_class_name(bot.__class__),
150153
"walltime": walltime,
154+
"request_count": action.request_count,
155+
"token_count": action.token_count,
151156
},
152157
)
153158
cursor.executemany(
Lines changed: 16 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,2 +1,16 @@
1-
insert into actions (commit_sha, prompt_id, bot_name, bot_class, walltime)
2-
values (:commit_sha, :prompt_id, :bot_name, :bot_class, :walltime);
1+
insert into actions (
2+
commit_sha,
3+
prompt_id,
4+
bot_name,
5+
bot_class,
6+
walltime,
7+
request_count,
8+
token_count)
9+
values (
10+
:commit_sha,
11+
:prompt_id,
12+
:bot_name,
13+
:bot_class,
14+
:walltime,
15+
:request_count,
16+
:token_count);

src/git_draft/queries/create-tables.sql

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,8 @@ create table if not exists actions (
2323
bot_name text,
2424
bot_class text not null,
2525
walltime real not null,
26+
request_count int,
27+
token_count int,
2628
foreign key (prompt_id) references prompts(id) on delete cascade
2729
) without rowid;
2830

src/git_draft/queries/list-drafts.sql

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,9 +2,11 @@ select
22
datetime(min(p.created_at), 'localtime') as created,
33
'draft/' || b.suffix as branch,
44
min(b.origin_branch) as origin,
5-
count(p.id) as prompts
5+
count(p.id) as prompts,
6+
sum(a.token_count) as tokens
67
from branches as b
78
join prompts as p on b.suffix = p.branch_suffix
9+
join actions as a on p.id = a.prompt_id
810
where b.repo_path = :repo_path
911
group by b.suffix
1012
order by created desc;

tests/git_draft/bots/common_test.py

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,5 @@
1+
import pytest
2+
13
import git_draft.bots.common as sut
24

35

@@ -8,3 +10,23 @@ class FakeBot(sut.Bot):
810
class TestBot:
911
def test_state_folder_path(self) -> None:
1012
assert "bots.common_test.FakeBot" in str(FakeBot.state_folder_path())
13+
14+
15+
class TestAction:
16+
def test_increment_noinit(self) -> None:
17+
action = sut.Action()
18+
with pytest.raises(ValueError):
19+
action.increment_request_count()
20+
21+
def test_increment_request_count(self) -> None:
22+
action = sut.Action()
23+
action.increment_request_count(init=True)
24+
assert action.request_count == 1
25+
action.increment_request_count()
26+
assert action.request_count == 2
27+
28+
def test_increment_token_count(self) -> None:
29+
action = sut.Action()
30+
action.increment_token_count(5, init=True)
31+
action.increment_token_count(3)
32+
assert action.token_count == 8

0 commit comments

Comments
 (0)