2121import openai
2222
2323from ..common import JSONObject , UnreachableError , config_string , reindent
24- from ..feedback import Feedback
25- from ..toolbox import Toolbox
26- from .common import Action , Bot , Goal
24+ from .common import Action , Bot , Goal , UserFeedback , WorkTree
2725
2826
2927_logger = logging .getLogger (__name__ )
@@ -175,8 +173,8 @@ def params(self) -> Sequence[openai.types.chat.ChatCompletionToolParam]:
175173
176174
177175class _ToolHandler [V ]:
178- def __init__ (self , toolbox : Toolbox ) -> None :
179- self ._toolbox = toolbox
176+ def __init__ (self , tree : WorkTree ) -> None :
177+ self ._tree = tree
180178 self .question : str | None = None
181179
182180 def _on_ask_user (self ) -> V :
@@ -209,23 +207,23 @@ def handle_function(self, function: Any) -> V:
209207 return self ._on_ask_user ()
210208 case "read_file" :
211209 path = PurePosixPath (inputs ["path" ])
212- return self ._on_read_file (path , self ._toolbox .read_file (path ))
210+ return self ._on_read_file (path , self ._tree .read_file (path ))
213211 case "write_file" :
214212 path = PurePosixPath (inputs ["path" ])
215213 contents = inputs ["contents" ]
216- self ._toolbox .write_file (path , contents )
214+ self ._tree .write_file (path , contents )
217215 return self ._on_write_file (path )
218216 case "delete_file" :
219217 path = PurePosixPath (inputs ["path" ])
220- self ._toolbox .delete_file (path )
218+ self ._tree .delete_file (path )
221219 return self ._on_delete_file (path )
222220 case "rename_file" :
223221 src_path = PurePosixPath (inputs ["src_path" ])
224222 dst_path = PurePosixPath (inputs ["dst_path" ])
225- self ._toolbox .rename_file (src_path , dst_path )
223+ self ._tree .rename_file (src_path , dst_path )
226224 return self ._on_rename_file (src_path , dst_path )
227225 case "list_files" :
228- paths = self ._toolbox .list_files ()
226+ paths = self ._tree .list_files ()
229227 return self ._on_list_files (paths )
230228 case _ as name :
231229 raise UnreachableError (f"Unexpected function: { name } " )
@@ -237,10 +235,10 @@ def __init__(self, client: openai.OpenAI, model: str) -> None:
237235 self ._model = model
238236
239237 async def act (
240- self , goal : Goal , toolbox : Toolbox , _feedback : Feedback
238+ self , goal : Goal , tree : WorkTree , _feedback : UserFeedback
241239 ) -> Action :
242240 tools = _ToolsFactory (strict = False ).params ()
243- tool_handler = _CompletionsToolHandler (toolbox )
241+ tool_handler = _CompletionsToolHandler (tree )
244242
245243 messages : list [openai .types .chat .ChatCompletionMessageParam ] = [
246244 {"role" : "system" , "content" : reindent (_INSTRUCTIONS )},
@@ -323,7 +321,7 @@ def _load_assistant_id(self) -> str:
323321 return assistant_id
324322
325323 async def act (
326- self , goal : Goal , toolbox : Toolbox , _feedback : Feedback
324+ self , goal : Goal , tree : WorkTree , _feedback : UserFeedback
327325 ) -> Action :
328326 assistant_id = self ._load_assistant_id ()
329327
@@ -340,24 +338,24 @@ async def act(
340338 with self ._client .beta .threads .runs .stream (
341339 thread_id = thread .id ,
342340 assistant_id = assistant_id ,
343- event_handler = _EventHandler (self ._client , toolbox , action ),
341+ event_handler = _EventHandler (self ._client , tree , action ),
344342 ) as stream :
345343 stream .until_done ()
346344 return action
347345
348346
349347class _EventHandler (openai .AssistantEventHandler ):
350348 def __init__ (
351- self , client : openai .Client , toolbox : Toolbox , action : Action
349+ self , client : openai .Client , tree : WorkTree , action : Action
352350 ) -> None :
353351 super ().__init__ ()
354352 self ._client = client
355- self ._toolbox = toolbox
353+ self ._tree = tree
356354 self ._action = action
357355 self ._action .increment_request_count ()
358356
359357 def _clone (self ) -> Self :
360- return self .__class__ (self ._client , self ._toolbox , self ._action )
358+ return self .__class__ (self ._client , self ._tree , self ._action )
361359
362360 @override
363361 def on_event (self , event : openai .types .beta .AssistantStreamEvent ) -> None :
@@ -383,7 +381,7 @@ def on_run_step_done(
383381 def _handle_action (self , _run_id : str , data : Any ) -> None :
384382 tool_outputs = list [Any ]()
385383 for tool in data .required_action .submit_tool_outputs .tool_calls :
386- handler = _ThreadToolHandler (self ._toolbox , tool .id )
384+ handler = _ThreadToolHandler (self ._tree , tool .id )
387385 tool_outputs .append (handler .handle_function (tool .function ))
388386 if handler .question :
389387 assert not self ._action .question
@@ -406,8 +404,8 @@ class _ToolOutput(TypedDict):
406404
407405
408406class _ThreadToolHandler (_ToolHandler [_ToolOutput ]):
409- def __init__ (self , toolbox : Toolbox , call_id : str ) -> None :
410- super ().__init__ (toolbox )
407+ def __init__ (self , tree : WorkTree , call_id : str ) -> None :
408+ super ().__init__ (tree )
411409 self ._call_id = call_id
412410
413411 def _wrap (self , output : str ) -> _ToolOutput :
0 commit comments