44
55import dataclasses
66from datetime import datetime
7+ import enum
78import json
89import logging
910import os
@@ -57,6 +58,14 @@ def new_suffix():
5758 return random_id (9 )
5859
5960
61+ class Accept (enum .Enum ):
62+ """Valid accept modes"""
63+
64+ MANUAL = enum .auto ()
65+ CHECKOUT = enum .auto ()
66+ FINALIZE = enum .auto ()
67+
68+
6069class Drafter :
6170 """Draft state orchestrator"""
6271
@@ -77,12 +86,13 @@ def generate_draft( # noqa: PLR0913
7786 self ,
7887 prompt : str | TemplatedPrompt ,
7988 bot : Bot ,
89+ accept : Accept = Accept .MANUAL ,
8090 bot_name : str | None = None ,
81- tool_visitors : Sequence [ToolVisitor ] | None = None ,
8291 prompt_transform : Callable [[str ], str ] | None = None ,
8392 reset : bool = False ,
8493 sync : bool = False ,
8594 timeout : float | None = None ,
95+ tool_visitors : Sequence [ToolVisitor ] | None = None ,
8696 ) -> str :
8797 if timeout is not None :
8898 raise NotImplementedError () # TODO
@@ -172,7 +182,17 @@ def generate_draft( # noqa: PLR0913
172182 )
173183
174184 _logger .info ("Completed generation for %s." , branch )
175- return str (branch )
185+ if accept == Accept .MANUAL :
186+ return str (branch )
187+
188+ # Check out files from the index. Since we assume that users do not
189+ # manually update the index in draft branches, this is equivalent to
190+ # checking out the files from the latest (generated, here) commit.
191+ # delta = self._delta(
192+ self ._repo .git .checkout ("." , theirs = True )
193+ if accept == Accept .CHECKOUT :
194+ return str (branch )
195+ return self .exit_draft (revert = False , clean = accept == Accept .CLEAN )
176196
177197 def exit_draft (self , * , revert : bool , clean = False , delete = False ) -> str :
178198 branch = _Branch .active (self ._repo )
@@ -195,9 +215,10 @@ def exit_draft(self, *, revert: bool, clean=False, delete=False) -> str:
195215 raise RuntimeError ("Parent branch has moved, please rebase first" )
196216
197217 if clean and not revert :
198- # We delete files which have been deleted in the draft manually,
218+ _logger .debug ("Cleaning up files." )
219+ # We manually delete files which have been deleted in the draft,
199220 # otherwise they would still show up as untracked.
200- origin_delta = self ._delta (f" { origin_branch } .. { branch } " )
221+ origin_delta = self ._delta (start = origin_branch , end = str ( branch ) )
201222 deleted = self ._untracked () & origin_delta .deleted
202223 for path in deleted :
203224 os .remove (osp .join (self ._repo .working_dir , path ))
@@ -211,17 +232,18 @@ def exit_draft(self, *, revert: bool, clean=False, delete=False) -> str:
211232 self ._repo .git .checkout (origin_branch )
212233
213234 if revert :
235+ _logger .debug ("Reverting changes... [sync_sha=%s]" , sync_sha )
214236 # We revert the relevant files if needed. If a sync commit had been
215237 # created, we simply revert to it. Otherwise we compute which files
216238 # have changed due to draft commits and revert only those.
217239 if sync_sha :
218- delta = self ._delta (sync_sha )
219- if delta .changed :
220- self ._repo .git .checkout (sync_sha , "--" , "." )
240+ self ._repo .git .checkout ("-f" , sync_sha )
221241 _logger .info ("Reverted to sync commit. [sha=%s]" , sync_sha )
222242 else :
223- origin_delta = self ._delta (f"{ origin_branch } ..{ branch } " )
224- head_delta = self ._delta ("HEAD" )
243+ origin_delta = self ._delta (
244+ start = origin_branch , end = str (branch )
245+ )
246+ head_delta = self ._delta (end = "HEAD" )
225247 changed = head_delta .touched & origin_delta .changed
226248 if changed :
227249 self ._repo .git .checkout ("--" , * changed )
@@ -304,15 +326,18 @@ def _create_branch(self, sync: bool) -> _Branch:
304326 def _stage_changes (self , sync : bool ) -> str | None :
305327 self ._repo .git .add (all = True )
306328 if not sync or not self ._repo .is_dirty (untracked_files = True ):
329+ _logger .debug ("Skipped sync commit creation. [sync=%s]" , sync )
307330 return None
308331 ref = self ._repo .index .commit ("draft! sync" )
332+ _logger .debug ("Created sync commit. [sha=%s]" , ref .hexsha )
309333 return ref .hexsha
310334
311335 def _untracked (self ) -> frozenset [str ]:
312336 text = self ._repo .git .ls_files (exclude_standard = True , others = True )
313337 return frozenset (text .splitlines ())
314338
315- def _delta (self , spec ) -> _Delta :
339+ def _delta (self , * , start : str | None = None , end : str ) -> _Delta :
340+ spec = f"{ start } ..{ end } " if start else end
316341 changed = list [str ]()
317342 deleted = list [str ]()
318343 for line in self ._repo .git .diff (spec , name_status = True ).splitlines ():
@@ -321,7 +346,9 @@ def _delta(self, spec) -> _Delta:
321346 deleted .append (name )
322347 else :
323348 changed .append (name )
324- return _Delta (changed = frozenset (changed ), deleted = frozenset (deleted ))
349+ delta = _Delta (changed = frozenset (changed ), deleted = frozenset (deleted ))
350+ _logger .debug ("Computed delta for %s: %s" , spec , delta )
351+ return delta
325352
326353
327354@dataclasses .dataclass (frozen = True )
0 commit comments