Skip to content

Commit 6ef1b77

Browse files
committed
fixup! 8c6613e
1 parent 8c6613e commit 6ef1b77

File tree

14 files changed

+309
-296
lines changed

14 files changed

+309
-296
lines changed

src/git_draft/__init__.py

Lines changed: 3 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -2,17 +2,15 @@
22

33
import logging
44

5-
from .bots import Action, Bot, Goal
6-
from .feedback import Feedback
7-
from .toolbox import Toolbox
5+
from .bots import Action, Bot, Goal, UserFeedback, WorkTree
86

97

108
__all__ = [
119
"Action",
1210
"Bot",
13-
"Feedback",
1411
"Goal",
15-
"Toolbox",
12+
"UserFeedback",
13+
"WorkTree",
1614
]
1715

1816
logging.getLogger(__name__).addHandler(logging.NullHandler())

src/git_draft/__main__.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,6 @@
2020
)
2121
from .drafter import Drafter, DraftMergeStrategy
2222
from .editor import open_editor
23-
from .feedback import LiveFeedback
2423
from .git import Repo
2524
from .prompt import (
2625
PromptMetadata,
@@ -29,6 +28,7 @@
2928
templates_table,
3029
)
3130
from .store import Store
31+
from .user_feedbacks import InteractiveUserFeedback
3232

3333

3434
_logger = logging.getLogger(__name__)
@@ -170,7 +170,7 @@ async def run() -> None: # noqa: PLR0912 PLR0915
170170

171171
if sys.stdin.isatty():
172172
progress = Progress.dynamic()
173-
feedback = LiveFeedback()
173+
feedback = InteractiveUserFeedback()
174174
else:
175175
progress = Progress.static()
176176
feedback = None

src/git_draft/bots/__init__.py

Lines changed: 4 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,22 +1,19 @@
1-
"""Bot interfaces and built-in implementations
2-
3-
* https://aider.chat/docs/leaderboards/
4-
"""
1+
"""Bot interfaces and built-in implementations"""
52

63
import importlib
74
import os
85
import sys
96

107
from ..common import BotConfig, reindent
11-
from ..toolbox import Toolbox
12-
from .common import Action, Bot, Goal
8+
from .common import Action, Bot, Goal, UserFeedback, WorkTree
139

1410

1511
__all__ = [
1612
"Action",
1713
"Bot",
1814
"Goal",
19-
"Toolbox",
15+
"UserFeedback",
16+
"WorkTree",
2017
]
2118

2219

src/git_draft/bots/common.py

Lines changed: 50 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -2,20 +2,64 @@
22

33
from __future__ import annotations
44

5+
from collections.abc import Sequence
6+
import contextlib
57
import dataclasses
6-
from pathlib import Path
8+
from pathlib import Path, PurePosixPath
9+
from typing import Protocol
710

811
from ..common import ensure_state_home, qualified_class_name
9-
from ..feedback import Feedback
10-
from ..toolbox import Toolbox
1112

1213

1314
@dataclasses.dataclass(frozen=True)
1415
class Goal:
1516
"""Bot request"""
1617

1718
prompt: str
18-
# TODO: Add timeout here.
19+
# TODO: Add timeout.
20+
21+
22+
class WorkTree(Protocol):
23+
"""File operations
24+
25+
Implementations may not be thread-safe. Concurrent operations should be
26+
serialized by the caller.
27+
"""
28+
29+
def list_files(self) -> Sequence[PurePosixPath]:
30+
"""List all files"""
31+
32+
def read_file(self, path: PurePosixPath) -> str | None:
33+
"""Get a file's contents"""
34+
35+
def write_file(self, path: PurePosixPath, contents: str) -> None:
36+
"""Update a file's contents"""
37+
38+
def delete_file(self, path: PurePosixPath) -> None:
39+
"""Remove a file"""
40+
41+
def rename_file(
42+
self, src_path: PurePosixPath, dst_path: PurePosixPath
43+
) -> None:
44+
"""Move a file"""
45+
46+
def edit_files(self) -> contextlib.AbstractContextManager[Path]:
47+
"""Return path to a temporary folder with editable copies of all files
48+
49+
Any updates are synced back to the work tree when the context exits.
50+
Other operations should not be performed concurrently as they may be
51+
stale or lost.
52+
"""
53+
54+
55+
class UserFeedback(Protocol):
56+
"""User interactions"""
57+
58+
def notify(self, update: str) -> None:
59+
"""Report progress to the user"""
60+
61+
def ask(self, question: str) -> str:
62+
"""Request additional information from the user"""
1963

2064

2165
@dataclasses.dataclass
@@ -29,7 +73,6 @@ class Action:
2973
title: str | None = None
3074
request_count: int | None = None
3175
token_count: int | None = None
32-
question: str | None = None
3376

3477
def increment_request_count(self, n: int = 1, init: bool = False) -> None:
3578
self._increment("request_count", n, init)
@@ -68,7 +111,7 @@ def state_folder_path(cls, ensure_exists: bool = False) -> Path:
68111
return path
69112

70113
async def act(
71-
self, goal: Goal, toolbox: Toolbox, feedback: Feedback
114+
self, goal: Goal, tree: WorkTree, feedback: UserFeedback
72115
) -> Action:
73-
"""Runs the bot, striving to achieve the goal with the given toolbox"""
116+
"""Runs the bot, striving to achieve the goal"""
74117
raise NotImplementedError()

src/git_draft/bots/openai.py

Lines changed: 18 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -21,9 +21,7 @@
2121
import openai
2222

2323
from ..common import JSONObject, UnreachableError, config_string, reindent
24-
from ..feedback import Feedback
25-
from ..toolbox import Toolbox
26-
from .common import Action, Bot, Goal
24+
from .common import Action, Bot, Goal, UserFeedback, WorkTree
2725

2826

2927
_logger = logging.getLogger(__name__)
@@ -175,8 +173,8 @@ def params(self) -> Sequence[openai.types.chat.ChatCompletionToolParam]:
175173

176174

177175
class _ToolHandler[V]:
178-
def __init__(self, toolbox: Toolbox) -> None:
179-
self._toolbox = toolbox
176+
def __init__(self, tree: WorkTree) -> None:
177+
self._tree = tree
180178
self.question: str | None = None
181179

182180
def _on_ask_user(self) -> V:
@@ -209,23 +207,23 @@ def handle_function(self, function: Any) -> V:
209207
return self._on_ask_user()
210208
case "read_file":
211209
path = PurePosixPath(inputs["path"])
212-
return self._on_read_file(path, self._toolbox.read_file(path))
210+
return self._on_read_file(path, self._tree.read_file(path))
213211
case "write_file":
214212
path = PurePosixPath(inputs["path"])
215213
contents = inputs["contents"]
216-
self._toolbox.write_file(path, contents)
214+
self._tree.write_file(path, contents)
217215
return self._on_write_file(path)
218216
case "delete_file":
219217
path = PurePosixPath(inputs["path"])
220-
self._toolbox.delete_file(path)
218+
self._tree.delete_file(path)
221219
return self._on_delete_file(path)
222220
case "rename_file":
223221
src_path = PurePosixPath(inputs["src_path"])
224222
dst_path = PurePosixPath(inputs["dst_path"])
225-
self._toolbox.rename_file(src_path, dst_path)
223+
self._tree.rename_file(src_path, dst_path)
226224
return self._on_rename_file(src_path, dst_path)
227225
case "list_files":
228-
paths = self._toolbox.list_files()
226+
paths = self._tree.list_files()
229227
return self._on_list_files(paths)
230228
case _ as name:
231229
raise UnreachableError(f"Unexpected function: {name}")
@@ -237,10 +235,10 @@ def __init__(self, client: openai.OpenAI, model: str) -> None:
237235
self._model = model
238236

239237
async def act(
240-
self, goal: Goal, toolbox: Toolbox, _feedback: Feedback
238+
self, goal: Goal, tree: WorkTree, _feedback: UserFeedback
241239
) -> Action:
242240
tools = _ToolsFactory(strict=False).params()
243-
tool_handler = _CompletionsToolHandler(toolbox)
241+
tool_handler = _CompletionsToolHandler(tree)
244242

245243
messages: list[openai.types.chat.ChatCompletionMessageParam] = [
246244
{"role": "system", "content": reindent(_INSTRUCTIONS)},
@@ -323,7 +321,7 @@ def _load_assistant_id(self) -> str:
323321
return assistant_id
324322

325323
async def act(
326-
self, goal: Goal, toolbox: Toolbox, _feedback: Feedback
324+
self, goal: Goal, tree: WorkTree, _feedback: UserFeedback
327325
) -> Action:
328326
assistant_id = self._load_assistant_id()
329327

@@ -340,24 +338,24 @@ async def act(
340338
with self._client.beta.threads.runs.stream(
341339
thread_id=thread.id,
342340
assistant_id=assistant_id,
343-
event_handler=_EventHandler(self._client, toolbox, action),
341+
event_handler=_EventHandler(self._client, tree, action),
344342
) as stream:
345343
stream.until_done()
346344
return action
347345

348346

349347
class _EventHandler(openai.AssistantEventHandler):
350348
def __init__(
351-
self, client: openai.Client, toolbox: Toolbox, action: Action
349+
self, client: openai.Client, tree: WorkTree, action: Action
352350
) -> None:
353351
super().__init__()
354352
self._client = client
355-
self._toolbox = toolbox
353+
self._tree = tree
356354
self._action = action
357355
self._action.increment_request_count()
358356

359357
def _clone(self) -> Self:
360-
return self.__class__(self._client, self._toolbox, self._action)
358+
return self.__class__(self._client, self._tree, self._action)
361359

362360
@override
363361
def on_event(self, event: openai.types.beta.AssistantStreamEvent) -> None:
@@ -383,7 +381,7 @@ def on_run_step_done(
383381
def _handle_action(self, _run_id: str, data: Any) -> None:
384382
tool_outputs = list[Any]()
385383
for tool in data.required_action.submit_tool_outputs.tool_calls:
386-
handler = _ThreadToolHandler(self._toolbox, tool.id)
384+
handler = _ThreadToolHandler(self._tree, tool.id)
387385
tool_outputs.append(handler.handle_function(tool.function))
388386
if handler.question:
389387
assert not self._action.question
@@ -406,8 +404,8 @@ class _ToolOutput(TypedDict):
406404

407405

408406
class _ThreadToolHandler(_ToolHandler[_ToolOutput]):
409-
def __init__(self, toolbox: Toolbox, call_id: str) -> None:
410-
super().__init__(toolbox)
407+
def __init__(self, tree: WorkTree, call_id: str) -> None:
408+
super().__init__(tree)
411409
self._call_id = call_id
412410

413411
def _wrap(self, output: str) -> _ToolOutput:

0 commit comments

Comments
 (0)