44from typing import List
55
66from openai import OpenAI
7+ from openai .types .responses import EasyInputMessageParam , Response , ResponseFunctionToolCallParam , \
8+ ResponseFunctionToolCall
9+ from openai .types .responses .response_input_param import Message , ResponseInputParam , FunctionCallOutput
710
811from discovery .agent_support .tool import Tool
912
@@ -23,51 +26,55 @@ class AgentResult:
2326
2427
2528class Agent :
26- def __init__ (self , client : OpenAI , model : str , instructions : str , tools : List [Tool ]):
29+ def __init__ (self , client : OpenAI , model : str , system_instructions : str , tools : List [Tool ]):
2730 self .client = client
31+ self .model = model
32+ self .instructions = system_instructions
2833 self .tools = tools
29- self .assistant_id = client .beta .assistants .create (
30- instructions = instructions ,
31- model = model ,
32- tools = [tool .schema () for tool in tools ]
33- ).id
34+ self .tool_params = [tool .tool_param () for tool in tools ]
3435
3536 def answer (self , question : str ) -> AgentResult :
36- thread = self .client .beta .threads .create ()
37- tool_calls = []
38- self .client .beta .threads .messages .create (
39- thread_id = thread .id ,
40- role = "user" ,
41- content = question ,
42- )
43-
44- run = self .client .beta .threads .runs .create_and_poll (thread_id = thread .id , assistant_id = self .assistant_id )
45- while run .status != "completed" :
46- logger .debug (f"status %s" , run .status )
47- tool_outputs = []
48- for tool_call in run .required_action .submit_tool_outputs .tool_calls :
49- tool_name = tool_call .function .name
50- arguments = json .loads (tool_call .function .arguments )
51- logger .debug (f"calling %s with args %s" , tool_name , arguments )
52-
53- tool = next ((tool for tool in self .tools if tool .name == tool_name ), None )
54- if tool is None :
55- raise Exception (f"No tool found with name { tool_name } " )
56-
57- tool_calls .append (ToolCall (name = tool_name , arguments = arguments ))
58- tool_outputs .append ({
59- "tool_call_id" : tool_call .id ,
60- "output" : tool .action (** arguments ),
61- })
62-
63- run = self .client .beta .threads .runs .submit_tool_outputs_and_poll (
64- thread_id = thread .id ,
65- run_id = run .id ,
66- tool_outputs = tool_outputs ,
67- )
37+ messages : ResponseInputParam = [
38+ EasyInputMessageParam (role = "system" , content = self .instructions ),
39+ EasyInputMessageParam (role = "user" , content = question )
40+ ]
41+
42+ response : Response = self .client .responses .create (model = self .model , input = messages , tools = self .tool_params )
43+
44+ while response .output_text == "" :
45+ for tool_call in response .output :
46+ if not isinstance (tool_call , ResponseFunctionToolCall ):
47+ continue
48+ new_messages = self .invoke_tool (tool_call )
49+ messages .extend (new_messages )
50+
51+ response = self .client .responses .create (model = self .model , input = messages , tools = self .tool_params )
6852
69- messages = self .client .beta .threads .messages .list (thread_id = thread .id )
70- return AgentResult (
71- response = messages .data [0 ].content [0 ].text .value ,
72- tool_calls = tool_calls ,
73- )
53+ tool_calls = [
54+ ToolCall (name = message ["name" ], arguments = json .loads (message ["arguments" ]))
55+ for message in messages if "type" in message and message ["type" ] == "function_call"
56+ ]
57+
58+ return AgentResult (response = response .output_text , tool_calls = tool_calls )
59+
60+ def invoke_tool (self , tool_call : ResponseFunctionToolCall ) -> ResponseInputParam :
61+ arguments = json .loads (tool_call .arguments )
62+ logger .debug (f"calling %s with args %s" , tool_call .name , arguments )
63+ tool = next ((tool for tool in self .tools if tool .name == tool_call .name ), None )
64+ if tool is None :
65+ raise Exception (f"No tool found with name { tool_call .name } " )
66+
67+ return [
68+ ResponseFunctionToolCallParam (
69+ id = tool_call .id ,
70+ arguments = tool_call .arguments ,
71+ call_id = tool_call .call_id ,
72+ name = tool_call .name ,
73+ type = "function_call" ,
74+ ),
75+ FunctionCallOutput (
76+ call_id = tool_call .call_id ,
77+ output = tool .invoke (** arguments ),
78+ type = "function_call_output" ,
79+ )
80+ ]
0 commit comments