Skip to content

Commit 69e9c62

Browse files
committed
fixup! 3e59ac4
1 parent 3e59ac4 commit 69e9c62

File tree

8 files changed

+53
-39
lines changed

8 files changed

+53
-39
lines changed

docs/git-draft.adoc

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -89,7 +89,7 @@ Otherwise if no template is specified and stdin is a TTY, `$EDITOR` will be open
8989
--quit::
9090
Go back to the draft's origin branch, keeping the working directory's current state.
9191
This will delete the draft branch and its upstream.
92-
Generated commits and the draft branch's final state remain available via `ref/drafts`.
92+
Generated commits and the draft branch's final state remain available via `refs/drafts`.
9393

9494
-T::
9595
--templates::

src/git_draft/bots/openai.py

Lines changed: 27 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -173,11 +173,11 @@ def params(self) -> Sequence[openai.types.chat.ChatCompletionToolParam]:
173173

174174

175175
class _ToolHandler[V]:
176-
def __init__(self, tree: Worktree) -> None:
176+
def __init__(self, tree: Worktree, feedback: UserFeedback) -> None:
177177
self._tree = tree
178-
self.question: str | None = None
178+
self._feedback = feedback
179179

180-
def _on_ask_user(self) -> V:
180+
def _on_ask_user(self, response: str) -> V:
181181
raise NotImplementedError()
182182

183183
def _on_read_file(self, path: PurePosixPath, contents: str | None) -> V:
@@ -202,9 +202,9 @@ def handle_function(self, function: Any) -> V:
202202
_logger.info("Requested function: %s", function)
203203
match function.name:
204204
case "ask_user":
205-
assert not self.question
206-
self.question = inputs["question"]
207-
return self._on_ask_user()
205+
question = inputs["question"]
206+
response = self._feedback.ask(question)
207+
return self._on_ask_user(response)
208208
case "read_file":
209209
path = PurePosixPath(inputs["path"])
210210
return self._on_read_file(path, self._tree.read_file(path))
@@ -235,10 +235,10 @@ def __init__(self, client: openai.OpenAI, model: str) -> None:
235235
self._model = model
236236

237237
async def act(
238-
self, goal: Goal, tree: Worktree, _feedback: UserFeedback
238+
self, goal: Goal, tree: Worktree, feedback: UserFeedback
239239
) -> Action:
240240
tools = _ToolsFactory(strict=False).params()
241-
tool_handler = _CompletionsToolHandler(tree)
241+
tool_handler = _CompletionsToolHandler(tree, feedback)
242242

243243
messages: list[openai.types.chat.ChatCompletionMessageParam] = [
244244
{"role": "system", "content": reindent(_INSTRUCTIONS)},
@@ -266,15 +266,12 @@ async def act(
266266
if done:
267267
break
268268

269-
return Action(
270-
request_count=request_count,
271-
question=tool_handler.question,
272-
)
269+
return Action(request_count=request_count)
273270

274271

275272
class _CompletionsToolHandler(_ToolHandler[str | None]):
276-
def _on_ask_user(self) -> None:
277-
return None
273+
def _on_ask_user(self, response: str) -> str:
274+
return response
278275

279276
def _on_read_file(self, path: PurePosixPath, contents: str | None) -> str:
280277
if contents is None:
@@ -321,7 +318,7 @@ def _load_assistant_id(self) -> str:
321318
return assistant_id
322319

323320
async def act(
324-
self, goal: Goal, tree: Worktree, _feedback: UserFeedback
321+
self, goal: Goal, tree: Worktree, feedback: UserFeedback
325322
) -> Action:
326323
assistant_id = self._load_assistant_id()
327324

@@ -338,24 +335,29 @@ async def act(
338335
with self._client.beta.threads.runs.stream(
339336
thread_id=thread.id,
340337
assistant_id=assistant_id,
341-
event_handler=_EventHandler(self._client, tree, action),
338+
event_handler=_EventHandler(self._client, tree, feedback, action),
342339
) as stream:
343340
stream.until_done()
344341
return action
345342

346343

347344
class _EventHandler(openai.AssistantEventHandler):
348345
def __init__(
349-
self, client: openai.Client, tree: Worktree, action: Action
346+
self, client: openai.Client, tree: Worktree,
347+
feedback: UserFeedback,
348+
action: Action,
350349
) -> None:
351350
super().__init__()
352351
self._client = client
353352
self._tree = tree
353+
self._feedback = feedback
354354
self._action = action
355355
self._action.increment_request_count()
356356

357357
def _clone(self) -> Self:
358-
return self.__class__(self._client, self._tree, self._action)
358+
return self.__class__(
359+
self._client, self._tree, self._feedback, self._action
360+
)
359361

360362
@override
361363
def on_event(self, event: openai.types.beta.AssistantStreamEvent) -> None:
@@ -381,11 +383,8 @@ def on_run_step_done(
381383
def _handle_action(self, _run_id: str, data: Any) -> None:
382384
tool_outputs = list[Any]()
383385
for tool in data.required_action.submit_tool_outputs.tool_calls:
384-
handler = _ThreadToolHandler(self._tree, tool.id)
386+
handler = _ThreadToolHandler(self._tree, self._feedback, tool.id)
385387
tool_outputs.append(handler.handle_function(tool.function))
386-
if handler.question:
387-
assert not self._action.question
388-
self._action.question = handler.question
389388

390389
run = self.current_run
391390
assert run, "No ongoing run"
@@ -404,15 +403,17 @@ class _ToolOutput(TypedDict):
404403

405404

406405
class _ThreadToolHandler(_ToolHandler[_ToolOutput]):
407-
def __init__(self, tree: Worktree, call_id: str) -> None:
408-
super().__init__(tree)
406+
def __init__(
407+
self, tree: Worktree, feedback: UserFeedback, call_id: str
408+
) -> None:
409+
super().__init__(tree, feedback)
409410
self._call_id = call_id
410411

411412
def _wrap(self, output: str) -> _ToolOutput:
412413
return _ToolOutput(tool_call_id=self._call_id, output=output)
413414

414-
def _on_ask_user(self) -> _ToolOutput:
415-
return self._wrap("OK")
415+
def _on_ask_user(self, response: str) -> _ToolOutput:
416+
return self._wrap(response)
416417

417418
def _on_read_file(
418419
self, _path: PurePosixPath, contents: str | None

src/git_draft/common.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -140,8 +140,12 @@ def from_cursor(cls, cursor: sqlite3.Cursor) -> Self:
140140

141141

142142
def _tagged(text: str, /, **kwargs) -> str:
143-
tags = [f"{key}={val}" for key, val in kwargs.items() if val is not None]
144-
return f"{text} [{', '.join(tags)}]" if tags else text
143+
if kwargs:
144+
tags = [
145+
f"{key}={val}" for key, val in kwargs.items() if val is not None
146+
]
147+
text = f"{text} [{', '.join(tags)}]" if tags else text
148+
return reindent(text)
145149

146150

147151
class Progress:

src/git_draft/drafter.py

Lines changed: 3 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,7 @@ class Draft:
3232
folio: Folio
3333
seqno: int
3434
is_noop: bool
35-
has_question: bool
35+
has_pending_question: bool
3636
walltime: timedelta
3737
token_count: int | None
3838

@@ -154,8 +154,6 @@ async def generate_draft(
154154
tree.with_hooks(operation_recorder),
155155
feedback,
156156
)
157-
if change.action.question:
158-
self._progress.report("Requested feedback.")
159157
spinner.update(
160158
"Completed bot run.",
161159
runtime=round(change.walltime.total_seconds(), 1),
@@ -167,7 +165,7 @@ async def generate_draft(
167165
folio=folio,
168166
seqno=seqno,
169167
is_noop=change.is_noop,
170-
has_question=change.action.question is not None,
168+
has_pending_question=change.action.question is not None,
171169
walltime=change.walltime,
172170
token_count=change.action.token_count,
173171
)
@@ -197,7 +195,7 @@ async def generate_draft(
197195
"walltime_seconds": change.walltime.total_seconds(),
198196
"request_count": change.action.request_count,
199197
"token_count": change.action.token_count,
200-
"question": change.action.question,
198+
"pending_question": change.action.question,
201199
},
202200
)
203201
cursor.executemany(

src/git_draft/queries/add-action.sql

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4,11 +4,11 @@ insert into actions (
44
walltime_seconds,
55
request_count,
66
token_count,
7-
question)
7+
pending_question)
88
values (
99
:prompt_id,
1010
:bot_class,
1111
:walltime_seconds,
1212
:request_count,
1313
:token_count,
14-
:question);
14+
:pending_question);

src/git_draft/queries/create-tables.sql

Lines changed: 12 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -26,10 +26,21 @@ create table if not exists actions (
2626
walltime_seconds real not null,
2727
request_count int,
2828
token_count int,
29-
question text,
29+
pending_question text,
3030
foreign key (prompt_id) references prompts (id) on delete cascade
3131
) without rowid;
3232

33+
create table if not exists notifications (
34+
id integer primary key,
35+
prompt_id integer not null,
36+
created_at timestamp default current_timestamp,
37+
status text,
38+
question text,
39+
answer text,
40+
foreign key (prompt_id) references actions (prompt_id) on delete cascade,
41+
check ((status is null != question is null) and (answer is null or question is not null))
42+
);
43+
3344
create table if not exists operations (
3445
id integer primary key,
3546
prompt_id integer not null,

src/git_draft/queries/get-latest-folio-prompt.sql

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
select p.contents, a.question
1+
select p.contents, a.pending_question
22
from prompts as p
33
join folios as f on p.folio_id = f.id
44
left join actions as a on p.id = a.prompt_id

src/git_draft/store.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@
1919
class Store:
2020
"""Lightweight sqlite wrapper"""
2121

22-
_name = "v4.sqlite3"
22+
_name = "v5.sqlite3"
2323

2424
def __init__(self, conn: sqlite3.Connection) -> None:
2525
self._connection = conn

0 commit comments

Comments
 (0)