33import dataclasses
44import json
55import logging
6- from pathlib import PurePosixPath
76import re
8- import tempfile
97import textwrap
108import time
11- from typing import Match , Sequence , override
9+ from typing import Match , Sequence
1210
1311import git
1412
15- from .bots import Bot , Goal , OperationHook , Toolbox
13+ from .bots import Bot , Goal , OperationHook
1614from .common import random_id
1715from .prompt import PromptRenderer , TemplatedPrompt
1816from .store import Store , sql
17+ from .toolbox import StagingToolbox
18+
1919
2020_logger = logging .getLogger (__name__ )
2121
@@ -49,53 +49,6 @@ def new_suffix():
4949 return random_id (9 )
5050
5151
52- class _Toolbox (Toolbox ):
53- """Git-index backed toolbox
54-
55- All files are directly read from and written to the index. This allows
56- concurrent editing without interference.
57- """
58-
59- def __init__ (self , repo : git .Repo , hook : OperationHook | None ) -> None :
60- super ().__init__ (hook )
61- self ._repo = repo
62- self ._written = set [str ]()
63-
64- @override
65- def _list (self ) -> Sequence [PurePosixPath ]:
66- # Show staged files.
67- return self ._repo .git .ls_files ().splitlines ()
68-
69- @override
70- def _read (self , path : PurePosixPath ) -> str :
71- # Read the file from the index.
72- return self ._repo .git .show (f":{ path } " )
73-
74- @override
75- def _write (self , path : PurePosixPath , contents : str ) -> None :
76- self ._written .add (str (path ))
77- # Update the index without touching the worktree.
78- # https://stackoverflow.com/a/25352119
79- with tempfile .NamedTemporaryFile (delete_on_close = False ) as temp :
80- temp .write (contents .encode ("utf8" ))
81- temp .close ()
82- sha = self ._repo .git .hash_object ("-w" , temp .name , path = path )
83- mode = 644 # TODO: Read from original file if it exists.
84- self ._repo .git .update_index (
85- f"{ mode } ,{ sha } ,{ path } " , add = True , cacheinfo = True
86- )
87-
88- def trim_index (self ) -> None :
89- diff = self ._repo .git .diff (name_only = True , cached = True )
90- untouched = [
91- path
92- for path in diff .splitlines ()
93- if path and path not in self ._written
94- ]
95- if untouched :
96- self ._repo .git .reset ("--" , * untouched )
97-
98-
9952class Drafter :
10053 """Draft state orchestrator"""
10154
@@ -139,17 +92,19 @@ def generate_draft(
13992
14093 branch = _Branch .active (self ._repo )
14194 if branch :
142- _logger .debug ("Reusing active branch %s." , branch )
14395 self ._stage_changes (sync )
96+ _logger .debug ("Reusing active branch %s." , branch )
14497 else :
14598 branch = self ._create_branch (sync )
14699 _logger .debug ("Created branch %s." , branch )
147100
101+ toolbox = StagingToolbox (self ._repo , self ._operation_hook )
148102 if isinstance (prompt , TemplatedPrompt ):
149- renderer = PromptRenderer .for_repo ( self . _repo )
103+ renderer = PromptRenderer .for_toolbox ( toolbox )
150104 prompt_contents = renderer .render (prompt )
151105 else :
152106 prompt_contents = prompt
107+
153108 with self ._store .cursor () as cursor :
154109 [(prompt_id ,)] = cursor .execute (
155110 sql ("add-prompt" ),
@@ -161,7 +116,6 @@ def generate_draft(
161116
162117 start_time = time .perf_counter ()
163118 goal = Goal (prompt_contents , timeout )
164- toolbox = _Toolbox (self ._repo , self ._operation_hook )
165119 action = bot .act (goal , toolbox )
166120 end_time = time .perf_counter ()
167121
@@ -201,11 +155,11 @@ def generate_draft(
201155 if checkout :
202156 self ._repo .git .checkout ("--" , "." )
203157
204- def finalize_draft (self , delete = False ) -> None :
205- self ._exit_draft (revert = False , delete = delete )
158+ def finalize_draft (self , delete = False ) -> str :
159+ return self ._exit_draft (revert = False , delete = delete )
206160
207- def revert_draft (self , delete = False ) -> None :
208- self ._exit_draft (revert = True , delete = delete )
161+ def revert_draft (self , delete = False ) -> str :
162+ return self ._exit_draft (revert = True , delete = delete )
209163
210164 def _create_branch (self , sync : bool ) -> _Branch :
211165 if self ._repo .head .is_detached :
@@ -241,7 +195,7 @@ def _stage_changes(self, sync: bool) -> str | None:
241195 ref = self ._repo .index .commit ("draft! sync" )
242196 return ref .hexsha
243197
244- def _exit_draft (self , * , revert : bool , delete : bool ) -> None :
198+ def _exit_draft (self , * , revert : bool , delete : bool ) -> str :
245199 branch = _Branch .active (self ._repo )
246200 if not branch :
247201 raise RuntimeError ("Not currently on a draft branch" )
@@ -268,7 +222,7 @@ def _exit_draft(self, *, revert: bool, delete: bool) -> None:
268222 self ._repo .git .reset ("-N" , origin_branch )
269223 self ._repo .git .checkout (origin_branch )
270224
271- # Finally , we revert the relevant files if needed. If a sync commit had
225+ # Next , we revert the relevant files if needed. If a sync commit had
272226 # been created, we simply revert to it. Otherwise we compute which
273227 # files have changed due to draft commits and revert only those.
274228 if revert :
@@ -283,6 +237,8 @@ def _exit_draft(self, *, revert: bool, delete: bool) -> None:
283237 if delete :
284238 self ._repo .git .branch ("-D" , branch .name )
285239
240+ return branch .name
241+
286242 def _changed_files (self , spec ) -> Sequence [str ]:
287243 return self ._repo .git .diff (spec , name_only = True ).splitlines ()
288244
0 commit comments