44from datetime import datetime
55import json
66import logging
7+ import os
8+ import os .path as osp
79from pathlib import PurePosixPath
810import re
911import textwrap
@@ -69,7 +71,6 @@ def generate_draft(
6971 prompt : str | TemplatedPrompt ,
7072 bot : Bot ,
7173 tool_visitors : Sequence [ToolVisitor ] | None = None ,
72- checkout : bool = False ,
7374 reset : bool = False ,
7475 sync : bool = False ,
7576 timeout : float | None = None ,
@@ -107,11 +108,13 @@ def generate_draft(
107108 },
108109 )
109110
111+ _logger .debug ("Running bot... [bot=%s]" , bot )
110112 start_time = time .perf_counter ()
111113 goal = Goal (prompt_contents , timeout )
112114 action = bot .act (goal , toolbox )
113115 end_time = time .perf_counter ()
114116 walltime = end_time - start_time
117+ _logger .info ("Completed bot action. [action=%s]" , action )
115118
116119 toolbox .trim_index ()
117120 title = action .title
@@ -145,16 +148,18 @@ def generate_draft(
145148 ],
146149 )
147150
148- _logger .info ("Generated draft." )
149- if checkout :
150- self ._repo .git .checkout ("--" , "." )
151+ _logger .info ("Generated %s." , branch )
151152 return str (branch )
152153
153- def finalize_draft (self , delete = False ) -> str :
154- return self ._exit_draft (revert = False , delete = delete )
154+ def finalize_draft (self , clean = False , delete = False ) -> str :
155+ name = self ._exit_draft (revert = False , clean = clean , delete = delete )
156+ _logger .info ("Finalized %s." , name )
157+ return name
155158
156159 def revert_draft (self , delete = False ) -> str :
157- return self ._exit_draft (revert = True , delete = delete )
160+ name = self ._exit_draft (revert = True , clean = False , delete = delete )
161+ _logger .info ("Reverted %s." , name )
162+ return name
158163
159164 def _create_branch (self , sync : bool ) -> _Branch :
160165 if self ._repo .head .is_detached :
@@ -190,7 +195,7 @@ def _stage_changes(self, sync: bool) -> str | None:
190195 ref = self ._repo .index .commit ("draft! sync" )
191196 return ref .hexsha
192197
193- def _exit_draft (self , * , revert : bool , delete : bool ) -> str :
198+ def _exit_draft (self , * , revert : bool , clean : bool , delete : bool ) -> str :
194199 branch = _Branch .active (self ._repo )
195200 if not branch :
196201 raise RuntimeError ("Not currently on a draft branch" )
@@ -200,15 +205,24 @@ def _exit_draft(self, *, revert: bool, delete: bool) -> str:
200205 sql ("get-branch-by-suffix" ), {"suffix" : branch .suffix }
201206 )
202207 if not rows :
203- raise RuntimeError ("Unrecognized branch" )
208+ raise RuntimeError ("Unrecognized draft branch" )
204209 [(origin_branch , origin_sha , sync_sha )] = rows
205210
206211 if (
207212 revert
208213 and sync_sha
209214 and self ._repo .commit (origin_branch ).hexsha != origin_sha
210215 ):
211- raise RuntimeError ("Parent branch has moved, please rebase" )
216+ raise RuntimeError ("Parent branch has moved, please rebase first" )
217+
218+ if clean :
219+ # We delete files which have been deleted in the draft manually,
220+ # otherwise they would still show up as untracked.
221+ origin_delta = self ._delta (f"{ origin_branch } ..{ branch } " )
222+ deleted = self ._untracked () & origin_delta .deleted
223+ for path in deleted :
224+ os .remove (osp .join (self ._repo .working_dir , path ))
225+ _logger .info ("Cleaned up files. [deleted=%s]" , deleted )
212226
213227 # We do a small dance to move back to the original branch, keeping the
214228 # draft branch untouched. See https://stackoverflow.com/a/15993574 for
@@ -217,25 +231,60 @@ def _exit_draft(self, *, revert: bool, delete: bool) -> str:
217231 self ._repo .git .reset ("-N" , origin_branch )
218232 self ._repo .git .checkout (origin_branch )
219233
220- # Next, we revert the relevant files if needed. If a sync commit had
221- # been created, we simply revert to it. Otherwise we compute which
222- # files have changed due to draft commits and revert only those.
223234 if revert :
235+ # We revert the relevant files if needed. If a sync commit had been
236+ # created, we simply revert to it. Otherwise we compute which files
237+ # have changed due to draft commits and revert only those.
224238 if sync_sha :
225- self ._repo .git .checkout (sync_sha , "--" , "." )
239+ delta = self ._delta (sync_sha )
240+ if delta .changed :
241+ self ._repo .git .checkout (sync_sha , "--" , "." )
242+ _logger .info ("Reverted to sync commit. [sha=%s]" , sync_sha )
226243 else :
227- diffed = set (self ._changed_files (f"{ origin_branch } ..{ branch } " ))
228- dirty = [p for p in self ._changed_files ("HEAD" ) if p in diffed ]
229- if dirty :
230- self ._repo .git .checkout ("--" , * dirty )
244+ origin_delta = self ._delta (f"{ origin_branch } ..{ branch } " )
245+ head_delta = self ._delta ("HEAD" )
246+ changed = head_delta .touched & origin_delta .changed
247+ if changed :
248+ self ._repo .git .checkout ("--" , * changed )
249+ deleted = head_delta .touched & origin_delta .deleted
250+ if deleted :
251+ self ._repo .git .rm ("--" , * deleted )
252+ _logger .info (
253+ "Reverted touched files. [changed=%s, deleted=%s]" ,
254+ changed ,
255+ deleted ,
256+ )
231257
232258 if delete :
233259 self ._repo .git .branch ("-D" , branch .name )
260+ _logger .debug ("Deleted branch %s." , branch )
234261
235262 return branch .name
236263
237- def _changed_files (self , spec ) -> Sequence [str ]:
238- return self ._repo .git .diff (spec , name_only = True ).splitlines ()
264+ def _untracked (self ) -> frozenset [str ]:
265+ text = self ._repo .git .ls_files (exclude_standard = True , others = True )
266+ return frozenset (text .splitlines ())
267+
268+ def _delta (self , spec ) -> _Delta :
269+ changed = list [str ]()
270+ deleted = list [str ]()
271+ for line in self ._repo .git .diff (spec , name_status = True ).splitlines ():
272+ state , name = line .split (None , 1 )
273+ if state == "D" :
274+ deleted .append (name )
275+ else :
276+ changed .append (name )
277+ return _Delta (changed = frozenset (changed ), deleted = frozenset (deleted ))
278+
279+
280+ @dataclasses .dataclass (frozen = True )
281+ class _Delta :
282+ changed : frozenset [str ]
283+ deleted : frozenset [str ]
284+
285+ @property
286+ def touched (self ) -> frozenset [str ]:
287+ return self .changed | self .deleted
239288
240289
241290class _OperationRecorder (ToolVisitor ):
@@ -266,11 +315,11 @@ def on_delete_file(self, path: PurePosixPath, reason: str | None) -> None:
266315 self ._record (reason , "delete_file" , path = str (path ))
267316
268317 def _record (self , reason : str | None , tool : str , ** kwargs ) -> None :
269- self .operations .append (
270- _Operation (
271- tool = tool , details = kwargs , reason = reason , start = datetime .now ()
272- )
318+ op = _Operation (
319+ tool = tool , details = kwargs , reason = reason , start = datetime .now ()
273320 )
321+ _logger .debug ("Recorded operation. [op=%s]" , op )
322+ self .operations .append (op )
274323
275324
276325@dataclasses .dataclass (frozen = True )
0 commit comments