11from __future__ import annotations
22
33import dataclasses
4+ from datetime import datetime
45import json
56import logging
7+ from pathlib import PurePosixPath
68import re
79import textwrap
810import time
911from typing import Match , Sequence
1012
1113import git
1214
13- from .bots import Bot , Goal , OperationHook
14- from .common import random_id
15+ from .bots import Bot , Goal
16+ from .common import JSONObject , random_id
1517from .prompt import PromptRenderer , TemplatedPrompt
1618from .store import Store , sql
17- from .toolbox import StagingToolbox
19+ from .toolbox import StagingToolbox , ToolVisitor
1820
1921
2022_logger = logging .getLogger (__name__ )
@@ -52,37 +54,26 @@ def new_suffix():
5254class Drafter :
5355 """Draft state orchestrator"""
5456
55- def __init__ (
56- self , store : Store , repo : git .Repo , hook : OperationHook | None = None
57- ) -> None :
57+ def __init__ (self , store : Store , repo : git .Repo ) -> None :
5858 with store .cursor () as cursor :
5959 cursor .executescript (sql ("create-tables" ))
6060 self ._store = store
6161 self ._repo = repo
62- self ._operation_hook = hook
6362
6463 @classmethod
65- def create (
66- cls ,
67- store : Store ,
68- path : str | None = None ,
69- operation_hook : OperationHook | None = None ,
70- ) -> Drafter :
71- return cls (
72- store ,
73- git .Repo (path , search_parent_directories = True ),
74- operation_hook ,
75- )
64+ def create (cls , store : Store , path : str | None = None ) -> Drafter :
65+ return cls (store , git .Repo (path , search_parent_directories = True ))
7666
7767 def generate_draft (
7868 self ,
7969 prompt : str | TemplatedPrompt ,
8070 bot : Bot ,
71+ tool_visitors : Sequence [ToolVisitor ] | None = None ,
8172 checkout : bool = False ,
8273 reset : bool = False ,
8374 sync : bool = False ,
8475 timeout : float | None = None ,
85- ) -> None :
76+ ) -> str :
8677 if isinstance (prompt , str ) and not prompt .strip ():
8778 raise ValueError ("Empty prompt" )
8879 if self ._repo .is_dirty (working_tree = False ):
@@ -98,7 +89,9 @@ def generate_draft(
9889 branch = self ._create_branch (sync )
9990 _logger .debug ("Created branch %s." , branch )
10091
101- toolbox = StagingToolbox (self ._repo , self ._operation_hook )
92+ operation_recorder = _OperationRecorder ()
93+ tool_visitors = [operation_recorder ] + list (tool_visitors or [])
94+ toolbox = StagingToolbox (self ._repo , tool_visitors )
10295 if isinstance (prompt , TemplatedPrompt ):
10396 renderer = PromptRenderer .for_toolbox (toolbox )
10497 prompt_contents = renderer .render (prompt )
@@ -118,6 +111,7 @@ def generate_draft(
118111 goal = Goal (prompt_contents , timeout )
119112 action = bot .act (goal , toolbox )
120113 end_time = time .perf_counter ()
114+ walltime = end_time - start_time
121115
122116 toolbox .trim_index ()
123117 title = action .title
@@ -134,7 +128,7 @@ def generate_draft(
134128 {
135129 "commit_sha" : commit .hexsha ,
136130 "prompt_id" : prompt_id ,
137- "walltime" : end_time - start_time ,
131+ "walltime" : walltime ,
138132 },
139133 )
140134 cursor .executemany (
@@ -147,13 +141,14 @@ def generate_draft(
147141 "details" : json .dumps (o .details ),
148142 "started_at" : o .start ,
149143 }
150- for o in toolbox .operations
144+ for o in operation_recorder .operations
151145 ],
152146 )
153- _logger .info ("Generated draft." )
154147
148+ _logger .info ("Generated draft." )
155149 if checkout :
156150 self ._repo .git .checkout ("--" , "." )
151+ return str (branch )
157152
158153 def finalize_draft (self , delete = False ) -> str :
159154 return self ._exit_draft (revert = False , delete = delete )
@@ -243,5 +238,48 @@ def _changed_files(self, spec) -> Sequence[str]:
243238 return self ._repo .git .diff (spec , name_only = True ).splitlines ()
244239
245240
241+ class _OperationRecorder (ToolVisitor ):
242+ def __init__ (self ) -> None :
243+ self .operations = list [_Operation ]()
244+
245+ def on_list_files (
246+ self , paths : Sequence [PurePosixPath ], reason : str | None
247+ ) -> None :
248+ self ._record (reason , "list_files" , count = len (paths ))
249+
250+ def on_read_file (
251+ self , path : PurePosixPath , contents : str | None , reason : str | None
252+ ) -> None :
253+ self ._record (
254+ reason ,
255+ "read_file" ,
256+ path = str (path ),
257+ size = - 1 if contents is None else len (contents ),
258+ )
259+
260+ def on_write_file (
261+ self , path : PurePosixPath , contents : str , reason : str | None
262+ ) -> None :
263+ self ._record (reason , "write_file" , path = str (path ), size = len (contents ))
264+
265+ def on_delete_file (self , path : PurePosixPath , reason : str | None ) -> None :
266+ self ._record (reason , "delete_file" , path = str (path ))
267+
268+ def _record (self , reason : str | None , tool : str , ** kwargs ) -> None :
269+ self .operations .append (
270+ _Operation (
271+ tool = tool , details = kwargs , reason = reason , start = datetime .now ()
272+ )
273+ )
274+
275+
276+ @dataclasses .dataclass (frozen = True )
277+ class _Operation :
278+ tool : str
279+ details : JSONObject
280+ reason : str | None
281+ start : datetime
282+
283+
246284def _default_title (prompt : str ) -> str :
247285 return textwrap .shorten (prompt , break_on_hyphens = False , width = 72 )
0 commit comments