Skip to content

Commit a3a1fa6

Browse files
authored
feat: improve CLI operation logging (#45)
1 parent ddbe7a6 commit a3a1fa6

File tree

5 files changed

+134
-121
lines changed

5 files changed

+134
-121
lines changed

src/git_draft/__main__.py

Lines changed: 31 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -5,14 +5,17 @@
55
import importlib.metadata
66
import logging
77
import optparse
8+
from pathlib import PurePosixPath
89
import sys
10+
from typing import Sequence
911

10-
from .bots import Operation, load_bot
12+
from .bots import load_bot
1113
from .common import PROGRAM, Config, UnreachableError, ensure_state_home
1214
from .drafter import Drafter
1315
from .editor import open_editor
1416
from .prompt import TemplatedPrompt
1517
from .store import Store
18+
from .toolbox import ToolVisitor
1619

1720

1821
_logger = logging.getLogger(__name__)
@@ -93,8 +96,24 @@ def callback(_option, _opt, _value, parser) -> None:
9396
return parser
9497

9598

96-
def print_operation(op: Operation) -> None:
97-
print(op)
99+
class _ToolPrinter(ToolVisitor):
100+
def on_list_files(
101+
self, _paths: Sequence[PurePosixPath], _reason: str | None
102+
) -> None:
103+
print("Listing available files...")
104+
105+
def on_read_file(
106+
self, path: PurePosixPath, _contents: str | None, _reason: str | None
107+
) -> None:
108+
print(f"Reading {path}...")
109+
110+
def on_write_file(
111+
self, path: PurePosixPath, _contents: str, _reason: str | None
112+
) -> None:
113+
print(f"Updated {path}.")
114+
115+
def on_delete_file(self, path: PurePosixPath, _reason: str | None) -> None:
116+
print(f"Deleted {path}.")
98117

99118

100119
def main() -> None:
@@ -110,7 +129,6 @@ def main() -> None:
110129
drafter = Drafter.create(
111130
store=Store.persistent(),
112131
path=opts.root,
113-
operation_hook=print_operation,
114132
)
115133
command = getattr(opts, "command", "generate")
116134
if command == "generate":
@@ -133,15 +151,20 @@ def main() -> None:
133151
else:
134152
prompt = sys.stdin.read()
135153

136-
drafter.generate_draft(
137-
prompt, bot, checkout=opts.checkout, reset=opts.reset
154+
name = drafter.generate_draft(
155+
prompt,
156+
bot,
157+
tool_visitors=[_ToolPrinter()],
158+
checkout=opts.checkout,
159+
reset=opts.reset,
138160
)
161+
print(f"Generated {name}.")
139162
elif command == "finalize":
140163
name = drafter.finalize_draft(delete=opts.delete)
141-
print(f"Finalized {name}")
164+
print(f"Finalized {name}.")
142165
elif command == "revert":
143166
name = drafter.revert_draft(delete=opts.delete)
144-
print(f"Reverted {name}")
167+
print(f"Reverted {name}.")
145168
else:
146169
raise UnreachableError()
147170

src/git_draft/bots/__init__.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -8,16 +8,14 @@
88
import sys
99

1010
from ..common import BotConfig, reindent
11-
from ..toolbox import Operation, OperationHook, Toolbox
11+
from ..toolbox import Toolbox
1212
from .common import Action, Bot, Goal
1313

1414

1515
__all__ = [
1616
"Action",
1717
"Bot",
1818
"Goal",
19-
"Operation",
20-
"OperationHook",
2119
"Toolbox",
2220
]
2321

src/git_draft/drafter.py

Lines changed: 61 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -1,20 +1,22 @@
11
from __future__ import annotations
22

33
import dataclasses
4+
from datetime import datetime
45
import json
56
import logging
7+
from pathlib import PurePosixPath
68
import re
79
import textwrap
810
import time
911
from typing import Match, Sequence
1012

1113
import git
1214

13-
from .bots import Bot, Goal, OperationHook
14-
from .common import random_id
15+
from .bots import Bot, Goal
16+
from .common import JSONObject, random_id
1517
from .prompt import PromptRenderer, TemplatedPrompt
1618
from .store import Store, sql
17-
from .toolbox import StagingToolbox
19+
from .toolbox import StagingToolbox, ToolVisitor
1820

1921

2022
_logger = logging.getLogger(__name__)
@@ -52,37 +54,26 @@ def new_suffix():
5254
class Drafter:
5355
"""Draft state orchestrator"""
5456

55-
def __init__(
56-
self, store: Store, repo: git.Repo, hook: OperationHook | None = None
57-
) -> None:
57+
def __init__(self, store: Store, repo: git.Repo) -> None:
5858
with store.cursor() as cursor:
5959
cursor.executescript(sql("create-tables"))
6060
self._store = store
6161
self._repo = repo
62-
self._operation_hook = hook
6362

6463
@classmethod
65-
def create(
66-
cls,
67-
store: Store,
68-
path: str | None = None,
69-
operation_hook: OperationHook | None = None,
70-
) -> Drafter:
71-
return cls(
72-
store,
73-
git.Repo(path, search_parent_directories=True),
74-
operation_hook,
75-
)
64+
def create(cls, store: Store, path: str | None = None) -> Drafter:
65+
return cls(store, git.Repo(path, search_parent_directories=True))
7666

7767
def generate_draft(
7868
self,
7969
prompt: str | TemplatedPrompt,
8070
bot: Bot,
71+
tool_visitors: Sequence[ToolVisitor] | None = None,
8172
checkout: bool = False,
8273
reset: bool = False,
8374
sync: bool = False,
8475
timeout: float | None = None,
85-
) -> None:
76+
) -> str:
8677
if isinstance(prompt, str) and not prompt.strip():
8778
raise ValueError("Empty prompt")
8879
if self._repo.is_dirty(working_tree=False):
@@ -98,7 +89,9 @@ def generate_draft(
9889
branch = self._create_branch(sync)
9990
_logger.debug("Created branch %s.", branch)
10091

101-
toolbox = StagingToolbox(self._repo, self._operation_hook)
92+
operation_recorder = _OperationRecorder()
93+
tool_visitors = [operation_recorder] + list(tool_visitors or [])
94+
toolbox = StagingToolbox(self._repo, tool_visitors)
10295
if isinstance(prompt, TemplatedPrompt):
10396
renderer = PromptRenderer.for_toolbox(toolbox)
10497
prompt_contents = renderer.render(prompt)
@@ -118,6 +111,7 @@ def generate_draft(
118111
goal = Goal(prompt_contents, timeout)
119112
action = bot.act(goal, toolbox)
120113
end_time = time.perf_counter()
114+
walltime = end_time - start_time
121115

122116
toolbox.trim_index()
123117
title = action.title
@@ -134,7 +128,7 @@ def generate_draft(
134128
{
135129
"commit_sha": commit.hexsha,
136130
"prompt_id": prompt_id,
137-
"walltime": end_time - start_time,
131+
"walltime": walltime,
138132
},
139133
)
140134
cursor.executemany(
@@ -147,13 +141,14 @@ def generate_draft(
147141
"details": json.dumps(o.details),
148142
"started_at": o.start,
149143
}
150-
for o in toolbox.operations
144+
for o in operation_recorder.operations
151145
],
152146
)
153-
_logger.info("Generated draft.")
154147

148+
_logger.info("Generated draft.")
155149
if checkout:
156150
self._repo.git.checkout("--", ".")
151+
return str(branch)
157152

158153
def finalize_draft(self, delete=False) -> str:
159154
return self._exit_draft(revert=False, delete=delete)
@@ -243,5 +238,48 @@ def _changed_files(self, spec) -> Sequence[str]:
243238
return self._repo.git.diff(spec, name_only=True).splitlines()
244239

245240

241+
class _OperationRecorder(ToolVisitor):
242+
def __init__(self) -> None:
243+
self.operations = list[_Operation]()
244+
245+
def on_list_files(
246+
self, paths: Sequence[PurePosixPath], reason: str | None
247+
) -> None:
248+
self._record(reason, "list_files", count=len(paths))
249+
250+
def on_read_file(
251+
self, path: PurePosixPath, contents: str | None, reason: str | None
252+
) -> None:
253+
self._record(
254+
reason,
255+
"read_file",
256+
path=str(path),
257+
size=-1 if contents is None else len(contents),
258+
)
259+
260+
def on_write_file(
261+
self, path: PurePosixPath, contents: str, reason: str | None
262+
) -> None:
263+
self._record(reason, "write_file", path=str(path), size=len(contents))
264+
265+
def on_delete_file(self, path: PurePosixPath, reason: str | None) -> None:
266+
self._record(reason, "delete_file", path=str(path))
267+
268+
def _record(self, reason: str | None, tool: str, **kwargs) -> None:
269+
self.operations.append(
270+
_Operation(
271+
tool=tool, details=kwargs, reason=reason, start=datetime.now()
272+
)
273+
)
274+
275+
276+
@dataclasses.dataclass(frozen=True)
277+
class _Operation:
278+
tool: str
279+
details: JSONObject
280+
reason: str | None
281+
start: datetime
282+
283+
246284
def _default_title(prompt: str) -> str:
247285
return textwrap.shorten(prompt, break_on_hyphens=False, width=72)

0 commit comments

Comments
 (0)