@@ -173,11 +173,11 @@ def params(self) -> Sequence[openai.types.chat.ChatCompletionToolParam]:
173173
174174
175175class _ToolHandler [V ]:
176- def __init__ (self , tree : Worktree ) -> None :
176+ def __init__ (self , tree : Worktree , feedback : UserFeedback ) -> None :
177177 self ._tree = tree
178- self .question : str | None = None
178+ self ._feedback = feedback
179179
180- def _on_ask_user (self ) -> V :
180+ def _on_ask_user (self , response : str ) -> V :
181181 raise NotImplementedError ()
182182
183183 def _on_read_file (self , path : PurePosixPath , contents : str | None ) -> V :
@@ -202,9 +202,9 @@ def handle_function(self, function: Any) -> V:
202202 _logger .info ("Requested function: %s" , function )
203203 match function .name :
204204 case "ask_user" :
205- assert not self . question
206- self . question = inputs [ " question" ]
207- return self ._on_ask_user ()
205+ question = inputs [ " question" ]
206+ response = self . _feedback . ask ( question )
207+ return self ._on_ask_user (response )
208208 case "read_file" :
209209 path = PurePosixPath (inputs ["path" ])
210210 return self ._on_read_file (path , self ._tree .read_file (path ))
@@ -235,10 +235,10 @@ def __init__(self, client: openai.OpenAI, model: str) -> None:
235235 self ._model = model
236236
237237 async def act (
238- self , goal : Goal , tree : Worktree , _feedback : UserFeedback
238+ self , goal : Goal , tree : Worktree , feedback : UserFeedback
239239 ) -> Action :
240240 tools = _ToolsFactory (strict = False ).params ()
241- tool_handler = _CompletionsToolHandler (tree )
241+ tool_handler = _CompletionsToolHandler (tree , feedback )
242242
243243 messages : list [openai .types .chat .ChatCompletionMessageParam ] = [
244244 {"role" : "system" , "content" : reindent (_INSTRUCTIONS )},
@@ -266,15 +266,12 @@ async def act(
266266 if done :
267267 break
268268
269- return Action (
270- request_count = request_count ,
271- question = tool_handler .question ,
272- )
269+ return Action (request_count = request_count )
273270
274271
275272class _CompletionsToolHandler (_ToolHandler [str | None ]):
276- def _on_ask_user (self ) -> None :
277- return None
273+ def _on_ask_user (self , response : str ) -> str :
274+ return response
278275
279276 def _on_read_file (self , path : PurePosixPath , contents : str | None ) -> str :
280277 if contents is None :
@@ -321,7 +318,7 @@ def _load_assistant_id(self) -> str:
321318 return assistant_id
322319
323320 async def act (
324- self , goal : Goal , tree : Worktree , _feedback : UserFeedback
321+ self , goal : Goal , tree : Worktree , feedback : UserFeedback
325322 ) -> Action :
326323 assistant_id = self ._load_assistant_id ()
327324
@@ -338,24 +335,29 @@ async def act(
338335 with self ._client .beta .threads .runs .stream (
339336 thread_id = thread .id ,
340337 assistant_id = assistant_id ,
341- event_handler = _EventHandler (self ._client , tree , action ),
338+ event_handler = _EventHandler (self ._client , tree , feedback , action ),
342339 ) as stream :
343340 stream .until_done ()
344341 return action
345342
346343
347344class _EventHandler (openai .AssistantEventHandler ):
348345 def __init__ (
349- self , client : openai .Client , tree : Worktree , action : Action
346+ self , client : openai .Client , tree : Worktree ,
347+ feedback : UserFeedback ,
348+ action : Action ,
350349 ) -> None :
351350 super ().__init__ ()
352351 self ._client = client
353352 self ._tree = tree
353+ self ._feedback = feedback
354354 self ._action = action
355355 self ._action .increment_request_count ()
356356
357357 def _clone (self ) -> Self :
358- return self .__class__ (self ._client , self ._tree , self ._action )
358+ return self .__class__ (
359+ self ._client , self ._tree , self ._feedback , self ._action
360+ )
359361
360362 @override
361363 def on_event (self , event : openai .types .beta .AssistantStreamEvent ) -> None :
@@ -381,11 +383,8 @@ def on_run_step_done(
381383 def _handle_action (self , _run_id : str , data : Any ) -> None :
382384 tool_outputs = list [Any ]()
383385 for tool in data .required_action .submit_tool_outputs .tool_calls :
384- handler = _ThreadToolHandler (self ._tree , tool .id )
386+ handler = _ThreadToolHandler (self ._tree , self . _feedback , tool .id )
385387 tool_outputs .append (handler .handle_function (tool .function ))
386- if handler .question :
387- assert not self ._action .question
388- self ._action .question = handler .question
389388
390389 run = self .current_run
391390 assert run , "No ongoing run"
@@ -404,15 +403,17 @@ class _ToolOutput(TypedDict):
404403
405404
406405class _ThreadToolHandler (_ToolHandler [_ToolOutput ]):
407- def __init__ (self , tree : Worktree , call_id : str ) -> None :
408- super ().__init__ (tree )
406+ def __init__ (
407+ self , tree : Worktree , feedback : UserFeedback , call_id : str
408+ ) -> None :
409+ super ().__init__ (tree , feedback )
409410 self ._call_id = call_id
410411
411412 def _wrap (self , output : str ) -> _ToolOutput :
412413 return _ToolOutput (tool_call_id = self ._call_id , output = output )
413414
414- def _on_ask_user (self ) -> _ToolOutput :
415- return self ._wrap ("OK" )
415+ def _on_ask_user (self , response : str ) -> _ToolOutput :
416+ return self ._wrap (response )
416417
417418 def _on_read_file (
418419 self , _path : PurePosixPath , contents : str | None
0 commit comments