Skip to content

Commit ddbe7a6

Browse files
authored
refactor: add toolbox module (#44)
This will make it easier to extend (e.g. with diff support).
1 parent becd123 commit ddbe7a6

File tree

15 files changed

+245
-198
lines changed

15 files changed

+245
-198
lines changed

pyproject.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -97,6 +97,7 @@ ignore = ["E203", "E501", "E704", "W503"]
9797
[tool.isort]
9898
profile = "black"
9999
force_sort_within_sections = true
100+
lines_after_imports = 2
100101

101102
[tool.mypy]
102103
disable_error_code = "import-untyped"

src/git_draft/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22

33
from .bots import Action, Bot, Toolbox
44

5+
56
__all__ = [
67
"Action",
78
"Bot",

src/git_draft/__main__.py

Lines changed: 13 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,9 @@
1515
from .store import Store
1616

1717

18+
_logger = logging.getLogger(__name__)
19+
20+
1821
def new_parser() -> optparse.OptionParser:
1922
parser = optparse.OptionParser(
2023
prog=PROGRAM,
@@ -134,12 +137,19 @@ def main() -> None:
134137
prompt, bot, checkout=opts.checkout, reset=opts.reset
135138
)
136139
elif command == "finalize":
137-
drafter.finalize_draft(delete=opts.delete)
140+
name = drafter.finalize_draft(delete=opts.delete)
141+
print(f"Finalized {name}")
138142
elif command == "revert":
139-
drafter.revert_draft(delete=opts.delete)
143+
name = drafter.revert_draft(delete=opts.delete)
144+
print(f"Reverted {name}")
140145
else:
141146
raise UnreachableError()
142147

143148

144149
if __name__ == "__main__":
145-
main()
150+
try:
151+
main()
152+
except Exception as err:
153+
_logger.exception("Program failed.")
154+
print(f"Error: {err}", file=sys.stderr)
155+
sys.exit(1)

src/git_draft/bots/__init__.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,9 @@
88
import sys
99

1010
from ..common import BotConfig, reindent
11-
from .common import Action, Bot, Goal, Operation, OperationHook, Toolbox
11+
from ..toolbox import Operation, OperationHook, Toolbox
12+
from .common import Action, Bot, Goal
13+
1214

1315
__all__ = [
1416
"Action",

src/git_draft/bots/common.py

Lines changed: 3 additions & 87 deletions
Original file line numberDiff line numberDiff line change
@@ -1,94 +1,10 @@
11
from __future__ import annotations
22

33
import dataclasses
4-
from datetime import datetime
5-
from pathlib import Path, PurePosixPath
6-
from typing import Callable, Sequence
4+
from pathlib import Path
75

8-
from ..common import JSONObject, ensure_state_home
9-
10-
11-
class Toolbox:
12-
"""File-system intermediary
13-
14-
Note that the toolbox is not thread-safe. Concurrent operations should be
15-
serialized by the caller.
16-
"""
17-
18-
# TODO: Something similar to https://aider.chat/docs/repomap.html,
19-
# including inferring the most important files, and allowing returning
20-
# signature-only versions.
21-
22-
# TODO: Support a diff-based edit method.
23-
# https://gist.github.com/noporpoise/16e731849eb1231e86d78f9dfeca3abc
24-
25-
def __init__(self, hook: OperationHook | None = None) -> None:
26-
self.operations = list[Operation]()
27-
self._operation_hook = hook
28-
29-
def _record_operation(
30-
self, reason: str | None, tool: str, **kwargs
31-
) -> None:
32-
op = Operation(
33-
tool=tool, details=kwargs, reason=reason, start=datetime.now()
34-
)
35-
self.operations.append(op)
36-
if self._operation_hook:
37-
self._operation_hook(op)
38-
39-
def list_files(self, reason: str | None = None) -> Sequence[PurePosixPath]:
40-
self._record_operation(reason, "list_files")
41-
return self._list()
42-
43-
def read_file(
44-
self,
45-
path: PurePosixPath,
46-
reason: str | None = None,
47-
) -> str:
48-
self._record_operation(reason, "read_file", path=str(path))
49-
return self._read(path)
50-
51-
def write_file(
52-
self,
53-
path: PurePosixPath,
54-
contents: str,
55-
reason: str | None = None,
56-
) -> None:
57-
self._record_operation(
58-
reason, "write_file", path=str(path), size=len(contents)
59-
)
60-
return self._write(path, contents)
61-
62-
def delete_file(
63-
self,
64-
path: PurePosixPath,
65-
reason: str | None = None,
66-
) -> None:
67-
self._record_operation(reason, "delete_file", path=str(path))
68-
return self._delete(path)
69-
70-
def _list(self) -> Sequence[PurePosixPath]:
71-
raise NotImplementedError()
72-
73-
def _read(self, path: PurePosixPath) -> str:
74-
raise NotImplementedError()
75-
76-
def _write(self, path: PurePosixPath, contents: str) -> None:
77-
raise NotImplementedError()
78-
79-
def _delete(self, path: PurePosixPath) -> None:
80-
raise NotImplementedError()
81-
82-
83-
@dataclasses.dataclass(frozen=True)
84-
class Operation:
85-
tool: str
86-
details: JSONObject
87-
reason: str | None
88-
start: datetime
89-
90-
91-
type OperationHook = Callable[[Operation], None]
6+
from ..common import ensure_state_home
7+
from ..toolbox import Toolbox
928

939

9410
@dataclasses.dataclass(frozen=True)

src/git_draft/bots/openai.py

Lines changed: 10 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@
2323
from ..common import JSONObject, reindent
2424
from .common import Action, Bot, Goal, Toolbox
2525

26+
2627
_logger = logging.getLogger(__name__)
2728

2829

@@ -132,7 +133,7 @@ class _ToolHandler[V]:
132133
def __init__(self, toolbox: Toolbox) -> None:
133134
self._toolbox = toolbox
134135

135-
def _on_read_file(self, path: PurePosixPath, contents: str) -> V:
136+
def _on_read_file(self, path: PurePosixPath, contents: str | None) -> V:
136137
raise NotImplementedError()
137138

138139
def _on_write_file(self, path: PurePosixPath) -> V:
@@ -196,10 +197,10 @@ def act(self, goal: Goal, toolbox: Toolbox) -> Action:
196197

197198

198199
class _CompletionsToolHandler(_ToolHandler[str | None]):
199-
def _on_read_file(self, path: PurePosixPath, contents: str) -> str:
200-
return (
201-
f"Here are the contents of `{path}`:\n\n```\n{contents}\n```\n" ""
202-
)
200+
def _on_read_file(self, path: PurePosixPath, contents: str | None) -> str:
201+
if contents is None:
202+
return f"`{path}` does not exist."
203+
return f"The contents of `{path}` are:\n\n```\n{contents}\n```\n"
203204

204205
def _on_write_file(self, path: PurePosixPath) -> None:
205206
return None
@@ -303,8 +304,10 @@ def __init__(self, toolbox: Toolbox, call_id: str) -> None:
303304
def _wrap(self, output: str) -> _ToolOutput:
304305
return _ToolOutput(tool_call_id=self._call_id, output=output)
305306

306-
def _on_read_file(self, path: PurePosixPath, contents: str) -> _ToolOutput:
307-
return self._wrap(contents)
307+
def _on_read_file(
308+
self, path: PurePosixPath, contents: str | None
309+
) -> _ToolOutput:
310+
return self._wrap(contents or "")
308311

309312
def _on_write_file(self, path: PurePosixPath) -> _ToolOutput:
310313
return self._wrap("OK")

src/git_draft/common.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414

1515
import xdg_base_dirs
1616

17+
1718
PROGRAM = "git-draft"
1819

1920

src/git_draft/drafter.py

Lines changed: 16 additions & 60 deletions
Original file line numberDiff line numberDiff line change
@@ -3,19 +3,19 @@
33
import dataclasses
44
import json
55
import logging
6-
from pathlib import PurePosixPath
76
import re
8-
import tempfile
97
import textwrap
108
import time
11-
from typing import Match, Sequence, override
9+
from typing import Match, Sequence
1210

1311
import git
1412

15-
from .bots import Bot, Goal, OperationHook, Toolbox
13+
from .bots import Bot, Goal, OperationHook
1614
from .common import random_id
1715
from .prompt import PromptRenderer, TemplatedPrompt
1816
from .store import Store, sql
17+
from .toolbox import StagingToolbox
18+
1919

2020
_logger = logging.getLogger(__name__)
2121

@@ -49,53 +49,6 @@ def new_suffix():
4949
return random_id(9)
5050

5151

52-
class _Toolbox(Toolbox):
53-
"""Git-index backed toolbox
54-
55-
All files are directly read from and written to the index. This allows
56-
concurrent editing without interference.
57-
"""
58-
59-
def __init__(self, repo: git.Repo, hook: OperationHook | None) -> None:
60-
super().__init__(hook)
61-
self._repo = repo
62-
self._written = set[str]()
63-
64-
@override
65-
def _list(self) -> Sequence[PurePosixPath]:
66-
# Show staged files.
67-
return self._repo.git.ls_files().splitlines()
68-
69-
@override
70-
def _read(self, path: PurePosixPath) -> str:
71-
# Read the file from the index.
72-
return self._repo.git.show(f":{path}")
73-
74-
@override
75-
def _write(self, path: PurePosixPath, contents: str) -> None:
76-
self._written.add(str(path))
77-
# Update the index without touching the worktree.
78-
# https://stackoverflow.com/a/25352119
79-
with tempfile.NamedTemporaryFile(delete_on_close=False) as temp:
80-
temp.write(contents.encode("utf8"))
81-
temp.close()
82-
sha = self._repo.git.hash_object("-w", temp.name, path=path)
83-
mode = 644 # TODO: Read from original file if it exists.
84-
self._repo.git.update_index(
85-
f"{mode},{sha},{path}", add=True, cacheinfo=True
86-
)
87-
88-
def trim_index(self) -> None:
89-
diff = self._repo.git.diff(name_only=True, cached=True)
90-
untouched = [
91-
path
92-
for path in diff.splitlines()
93-
if path and path not in self._written
94-
]
95-
if untouched:
96-
self._repo.git.reset("--", *untouched)
97-
98-
9952
class Drafter:
10053
"""Draft state orchestrator"""
10154

@@ -139,17 +92,19 @@ def generate_draft(
13992

14093
branch = _Branch.active(self._repo)
14194
if branch:
142-
_logger.debug("Reusing active branch %s.", branch)
14395
self._stage_changes(sync)
96+
_logger.debug("Reusing active branch %s.", branch)
14497
else:
14598
branch = self._create_branch(sync)
14699
_logger.debug("Created branch %s.", branch)
147100

101+
toolbox = StagingToolbox(self._repo, self._operation_hook)
148102
if isinstance(prompt, TemplatedPrompt):
149-
renderer = PromptRenderer.for_repo(self._repo)
103+
renderer = PromptRenderer.for_toolbox(toolbox)
150104
prompt_contents = renderer.render(prompt)
151105
else:
152106
prompt_contents = prompt
107+
153108
with self._store.cursor() as cursor:
154109
[(prompt_id,)] = cursor.execute(
155110
sql("add-prompt"),
@@ -161,7 +116,6 @@ def generate_draft(
161116

162117
start_time = time.perf_counter()
163118
goal = Goal(prompt_contents, timeout)
164-
toolbox = _Toolbox(self._repo, self._operation_hook)
165119
action = bot.act(goal, toolbox)
166120
end_time = time.perf_counter()
167121

@@ -201,11 +155,11 @@ def generate_draft(
201155
if checkout:
202156
self._repo.git.checkout("--", ".")
203157

204-
def finalize_draft(self, delete=False) -> None:
205-
self._exit_draft(revert=False, delete=delete)
158+
def finalize_draft(self, delete=False) -> str:
159+
return self._exit_draft(revert=False, delete=delete)
206160

207-
def revert_draft(self, delete=False) -> None:
208-
self._exit_draft(revert=True, delete=delete)
161+
def revert_draft(self, delete=False) -> str:
162+
return self._exit_draft(revert=True, delete=delete)
209163

210164
def _create_branch(self, sync: bool) -> _Branch:
211165
if self._repo.head.is_detached:
@@ -241,7 +195,7 @@ def _stage_changes(self, sync: bool) -> str | None:
241195
ref = self._repo.index.commit("draft! sync")
242196
return ref.hexsha
243197

244-
def _exit_draft(self, *, revert: bool, delete: bool) -> None:
198+
def _exit_draft(self, *, revert: bool, delete: bool) -> str:
245199
branch = _Branch.active(self._repo)
246200
if not branch:
247201
raise RuntimeError("Not currently on a draft branch")
@@ -268,7 +222,7 @@ def _exit_draft(self, *, revert: bool, delete: bool) -> None:
268222
self._repo.git.reset("-N", origin_branch)
269223
self._repo.git.checkout(origin_branch)
270224

271-
# Finally, we revert the relevant files if needed. If a sync commit had
225+
# Next, we revert the relevant files if needed. If a sync commit had
272226
# been created, we simply revert to it. Otherwise we compute which
273227
# files have changed due to draft commits and revert only those.
274228
if revert:
@@ -283,6 +237,8 @@ def _exit_draft(self, *, revert: bool, delete: bool) -> None:
283237
if delete:
284238
self._repo.git.branch("-D", branch.name)
285239

240+
return branch.name
241+
286242
def _changed_files(self, spec) -> Sequence[str]:
287243
return self._repo.git.diff(spec, name_only=True).splitlines()
288244

src/git_draft/editor.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
import sys
77
import tempfile
88

9+
910
_default_editors = ["vim", "emacs", "nano"]
1011

1112

0 commit comments

Comments
 (0)