1010import re
1111import textwrap
1212import time
13- from typing import Match , Sequence
13+ from typing import Callable , Match , Sequence
1414
1515import git
1616
@@ -77,17 +77,17 @@ def generate_draft(
7777 bot : Bot ,
7878 bot_name : str | None = None ,
7979 tool_visitors : Sequence [ToolVisitor ] | None = None ,
80+ prompt_transform : Callable [[str ], str ] | None = None ,
8081 reset : bool = False ,
8182 sync : bool = False ,
8283 timeout : float | None = None ,
8384 ) -> str :
84- if isinstance (prompt , str ) and not prompt .strip ():
85- raise ValueError ("Empty prompt" )
8685 if self ._repo .is_dirty (working_tree = False ):
8786 if not reset :
8887 raise ValueError ("Please commit or reset any staged changes" )
8988 self ._repo .index .reset ()
9089
90+ # Ensure that we are on a draft branch.
9191 branch = _Branch .active (self ._repo )
9292 if branch :
9393 self ._stage_changes (sync )
@@ -96,17 +96,18 @@ def generate_draft(
9696 branch = self ._create_branch (sync )
9797 _logger .debug ("Created branch %s." , branch )
9898
99- operation_recorder = _OperationRecorder ()
100- tool_visitors = [operation_recorder ] + list (tool_visitors or [])
101- toolbox = StagingToolbox (self ._repo , tool_visitors )
99+ # Handle prompt templating and editing.
102100 if isinstance (prompt , TemplatedPrompt ):
103101 template : str | None = prompt .template
104- renderer = PromptRenderer .for_toolbox (toolbox )
102+ renderer = PromptRenderer .for_toolbox (StagingToolbox ( self . _repo ) )
105103 prompt_contents = renderer .render (prompt )
106104 else :
107105 template = None
108106 prompt_contents = prompt
109-
107+ if prompt_transform :
108+ prompt_contents = prompt_transform (prompt_contents )
109+ if not prompt_contents .strip ():
110+ raise ValueError ("Aborting: empty prompt" )
110111 with self ._store .cursor () as cursor :
111112 [(prompt_id ,)] = cursor .execute (
112113 sql ("add-prompt" ),
@@ -117,14 +118,19 @@ def generate_draft(
117118 },
118119 )
119120
121+ # Trigger code generation.
120122 _logger .debug ("Running bot... [bot=%s]" , bot )
123+ operation_recorder = _OperationRecorder ()
124+ tool_visitors = [operation_recorder ] + list (tool_visitors or [])
125+ toolbox = StagingToolbox (self ._repo , tool_visitors )
121126 start_time = time .perf_counter ()
122127 goal = Goal (prompt_contents , timeout )
123128 action = bot .act (goal , toolbox )
124129 end_time = time .perf_counter ()
125130 walltime = end_time - start_time
126131 _logger .info ("Completed bot action. [action=%s]" , action )
127132
133+ # Generate an appropriate commit and update our database.
128134 toolbox .trim_index ()
129135 title = action .title
130136 if not title :
@@ -133,7 +139,6 @@ def generate_draft(
133139 f"draft! { title } \n \n { prompt_contents } " ,
134140 skip_hooks = True ,
135141 )
136-
137142 with self ._store .cursor () as cursor :
138143 cursor .execute (
139144 sql ("add-action" ),
@@ -159,7 +164,7 @@ def generate_draft(
159164 ],
160165 )
161166
162- _logger .info ("Generated %s." , branch )
167+ _logger .info ("Completed generation for %s." , branch )
163168 return str (branch )
164169
165170 def exit_draft (self , * , revert : bool , clean = False , delete = False ) -> str :
@@ -232,22 +237,35 @@ def exit_draft(self, *, revert: bool, clean=False, delete=False) -> str:
232237 def history_table (self , branch_name : str | None = None ) -> Table :
233238 path = self ._repo .working_dir
234239 branch = _Branch .active (self ._repo , branch_name )
235- if branch :
236- with self . _store . cursor () as cursor :
240+ with self . _store . cursor () as cursor :
241+ if branch :
237242 results = cursor .execute (
238243 sql ("list-prompts" ),
239244 {
240245 "repo_path" : path ,
241246 "branch_suffix" : branch .suffix ,
242247 },
243248 )
244- return Table .from_cursor (results )
245- else :
246- with self ._store .cursor () as cursor :
249+ else :
247250 results = cursor .execute (
248251 sql ("list-drafts" ), {"repo_path" : path }
249252 )
250- return Table .from_cursor (results )
253+ return Table .from_cursor (results )
254+
255+ def latest_draft_prompt (self ) -> str | None :
256+ """Returns the latest prompt for the current draft"""
257+ branch = _Branch .active (self ._repo )
258+ if not branch :
259+ return None
260+ with self ._store .cursor () as cursor :
261+ result = cursor .execute (
262+ sql ("get-latest-prompt" ),
263+ {
264+ "repo_path" : self ._repo .working_dir ,
265+ "branch_suffix" : branch .suffix ,
266+ },
267+ ).fetchone ()
268+ return result [0 ] if result else None
251269
252270 def _create_branch (self , sync : bool ) -> _Branch :
253271 if self ._repo .head .is_detached :
0 commit comments