1313
1414from .bots import Bot , OperationHook , Toolbox
1515from .common import Store , random_id , sql
16+ from .prompt import PromptRenderer , TemplatedPrompt
1617
1718
1819_logger = logging .getLogger (__name__ )
@@ -135,13 +136,13 @@ def _create_branch(self, sync: bool) -> _Branch:
135136
136137 def generate_draft (
137138 self ,
138- prompt : str ,
139+ prompt : str | TemplatedPrompt ,
139140 bot : Bot ,
140141 checkout = False ,
141142 reset = False ,
142143 sync = False ,
143144 ) -> None :
144- if not prompt .strip ():
145+ if isinstance ( prompt , str ) and not prompt .strip ():
145146 raise ValueError ("Empty prompt" )
146147 if self ._repo .is_dirty (working_tree = False ):
147148 if not reset :
@@ -157,24 +158,31 @@ def generate_draft(
157158 branch = self ._create_branch (sync )
158159 _logger .debug ("Created branch %s." , branch )
159160
161+ if isinstance (prompt , TemplatedPrompt ):
162+ renderer = PromptRenderer .for_repo (self ._repo )
163+ prompt_contents = renderer .render (prompt )
164+ else :
165+ prompt_contents = prompt
160166 with self ._store .cursor () as cursor :
161167 [(prompt_id ,)] = cursor .execute (
162168 sql ("add-prompt" ),
163169 {
164170 "branch_suffix" : branch .suffix ,
165- "contents" : prompt ,
171+ "contents" : prompt_contents ,
166172 },
167173 )
168174
169175 start_time = time .perf_counter ()
170176 toolbox = _Toolbox (self ._repo , self ._operation_hook )
171- action = bot .act (prompt , toolbox )
177+ action = bot .act (prompt_contents , toolbox )
172178 end_time = time .perf_counter ()
173179
174180 title = action .title
175181 if not title :
176- title = textwrap .shorten (prompt , break_on_hyphens = False , width = 72 )
177- commit = self ._repo .index .commit (f"draft! { title } \n \n { prompt } " )
182+ title = _default_title (prompt_contents )
183+ commit = self ._repo .index .commit (
184+ f"draft! { title } \n \n { prompt_contents } "
185+ )
178186
179187 with self ._store .cursor () as cursor :
180188 cursor .execute (
@@ -247,3 +255,7 @@ def _exit_draft(self, apply: bool, delete=False) -> None:
247255 self ._repo .git .checkout (sync_sha , "--" , "." )
248256 if delete :
249257 self ._repo .git .branch ("-D" , branch .name )
258+
259+
260+ def _default_title (prompt : str ) -> str :
261+ return textwrap .shorten (prompt , break_on_hyphens = False , width = 72 )
0 commit comments