Skip to content

Commit f103980

Browse files
committed
Use responses API
1 parent f10fb26 commit f103980

File tree

6 files changed

+102
-94
lines changed

6 files changed

+102
-94
lines changed

discovery/agent_support/agent.py

Lines changed: 50 additions & 43 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,9 @@
44
from typing import List
55

66
from openai import OpenAI
7+
from openai.types.responses import EasyInputMessageParam, Response, ResponseFunctionToolCallParam, \
8+
ResponseFunctionToolCall
9+
from openai.types.responses.response_input_param import Message, ResponseInputParam, FunctionCallOutput
710

811
from discovery.agent_support.tool import Tool
912

@@ -23,51 +26,55 @@ class AgentResult:
2326

2427

2528
class Agent:
26-
def __init__(self, client: OpenAI, model: str, instructions: str, tools: List[Tool]):
29+
def __init__(self, client: OpenAI, model: str, system_instructions: str, tools: List[Tool]):
2730
self.client = client
31+
self.model = model
32+
self.instructions = system_instructions
2833
self.tools = tools
29-
self.assistant_id = client.beta.assistants.create(
30-
instructions=instructions,
31-
model=model,
32-
tools=[tool.schema() for tool in tools]
33-
).id
34+
self.tool_params = [tool.tool_param() for tool in tools]
3435

3536
def answer(self, question: str) -> AgentResult:
36-
thread = self.client.beta.threads.create()
37-
tool_calls = []
38-
self.client.beta.threads.messages.create(
39-
thread_id=thread.id,
40-
role="user",
41-
content=question,
42-
)
43-
44-
run = self.client.beta.threads.runs.create_and_poll(thread_id=thread.id, assistant_id=self.assistant_id)
45-
while run.status != "completed":
46-
logger.debug(f"status %s", run.status)
47-
tool_outputs = []
48-
for tool_call in run.required_action.submit_tool_outputs.tool_calls:
49-
tool_name = tool_call.function.name
50-
arguments = json.loads(tool_call.function.arguments)
51-
logger.debug(f"calling %s with args %s", tool_name, arguments)
52-
53-
tool = next((tool for tool in self.tools if tool.name == tool_name), None)
54-
if tool is None:
55-
raise Exception(f"No tool found with name {tool_name}")
56-
57-
tool_calls.append(ToolCall(name=tool_name, arguments=arguments))
58-
tool_outputs.append({
59-
"tool_call_id": tool_call.id,
60-
"output": tool.action(**arguments),
61-
})
62-
63-
run = self.client.beta.threads.runs.submit_tool_outputs_and_poll(
64-
thread_id=thread.id,
65-
run_id=run.id,
66-
tool_outputs=tool_outputs,
67-
)
37+
messages: ResponseInputParam = [
38+
EasyInputMessageParam(role="system", content=self.instructions),
39+
EasyInputMessageParam(role="user", content=question)
40+
]
41+
42+
response: Response = self.client.responses.create(model=self.model, input=messages, tools=self.tool_params)
43+
44+
while response.output_text == "":
45+
for tool_call in response.output:
46+
if not isinstance(tool_call, ResponseFunctionToolCall):
47+
continue
48+
new_messages = self.invoke_tool(tool_call)
49+
messages.extend(new_messages)
50+
51+
response = self.client.responses.create(model=self.model, input=messages, tools=self.tool_params)
6852

69-
messages = self.client.beta.threads.messages.list(thread_id=thread.id)
70-
return AgentResult(
71-
response=messages.data[0].content[0].text.value,
72-
tool_calls=tool_calls,
73-
)
53+
tool_calls = [
54+
ToolCall(name=message["name"], arguments=json.loads(message["arguments"]))
55+
for message in messages if "type" in message and message["type"] == "function_call"
56+
]
57+
58+
return AgentResult(response=response.output_text, tool_calls=tool_calls)
59+
60+
def invoke_tool(self, tool_call: ResponseFunctionToolCall) -> ResponseInputParam:
61+
arguments = json.loads(tool_call.arguments)
62+
logger.debug(f"calling %s with args %s", tool_call.name, arguments)
63+
tool = next((tool for tool in self.tools if tool.name == tool_call.name), None)
64+
if tool is None:
65+
raise Exception(f"No tool found with name {tool_call.name}")
66+
67+
return [
68+
ResponseFunctionToolCallParam(
69+
id=tool_call.id,
70+
arguments=tool_call.arguments,
71+
call_id=tool_call.call_id,
72+
name=tool_call.name,
73+
type="function_call",
74+
),
75+
FunctionCallOutput(
76+
call_id=tool_call.call_id,
77+
output=tool.invoke(**arguments),
78+
type="function_call_output",
79+
)
80+
]

discovery/agent_support/tool.py

Lines changed: 20 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,8 @@
22
from typing import Callable, List
33
from inspect import signature, Parameter
44

5+
from openai.types.responses import FunctionToolParam
6+
57

68
@dataclass
79
class Argument:
@@ -14,25 +16,25 @@ class Argument:
1416
class Tool:
1517
name: str
1618
description: str
17-
action: Callable
19+
invoke: Callable
1820
arguments: List[Argument]
1921

20-
def schema(self) -> dict:
21-
return {
22-
"type": "function",
23-
"function": {
24-
"name": self.name,
25-
"description": self.description,
26-
"parameters": {
27-
"type": "object",
28-
"properties": {
29-
argument.name: {"type": argument.type}
30-
for argument in self.arguments
31-
},
32-
"required": [argument.name for argument in self.arguments if argument.required],
33-
}
34-
}
35-
}
22+
def tool_param(self) -> FunctionToolParam:
23+
return FunctionToolParam(
24+
name=self.name,
25+
parameters={
26+
"type": "object",
27+
"properties": {
28+
argument.name: {"type": argument.type}
29+
for argument in self.arguments
30+
},
31+
"required": [argument.name for argument in self.arguments if argument.required],
32+
"additionalProperties": False,
33+
},
34+
strict=False,
35+
type="function",
36+
description=self.description,
37+
)
3638

3739

3840
def json_type(parameter: Parameter) -> str:
@@ -67,7 +69,7 @@ def wrapper(action: Callable) -> Tool:
6769
return Tool(
6870
name=action.__name__,
6971
description=action.__doc__,
70-
action=action,
72+
invoke=action,
7173
arguments=[argument_from_parameter(parameter) for parameter in parameters]
7274
)
7375

discovery/repository_agent/repository_agent.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@ def repository_agent_creator(open_ai_client: OpenAI) -> Callable[[GithubClient],
1111
return lambda github_client: Agent(
1212
client=open_ai_client,
1313
model="gpt-4o-mini",
14-
instructions="""
14+
system_instructions="""
1515
You are a helpful assistant that can answer a user's questions about GitHub Repositories.
1616
Use the provided functions to answer the user's questions.
1717
When possible, prefer to use search functions over list functions to find lists of repositories matching certain

tests/agent_support/test_agent.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99

1010

1111
@tool()
12-
def get_temperature(city: str) -> str:
12+
def get_temperature(city: str, unrelated: str = "") -> str:
1313
"""Gets the temperature for a given city"""
1414
return "86"
1515

@@ -19,7 +19,7 @@ def test_answer(self):
1919
agent = Agent(
2020
client=OpenAI(api_key=require_env("OPEN_AI_KEY")),
2121
model="gpt-4o",
22-
instructions="You are a helpful assistant that can answer questions about weather. "
22+
system_instructions="You are a helpful assistant that can answer questions about weather. "
2323
"Use the only the functions provided to answer the user's question."
2424
"You must always use the provided function.",
2525
tools=[get_temperature],

tests/agent_support/test_tool.py

Lines changed: 19 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,7 @@
11
import unittest
22

3+
from openai.types.responses import FunctionToolParam
4+
35
from discovery.agent_support.tool import tool, Argument
46

57

@@ -8,31 +10,32 @@ def some_tool(username: str, count: int, done: bool = False) -> str:
810
"""A helpful description"""
911
return f"Hello {username}, {count}, {done}"
1012

13+
1114
class TestTool(unittest.TestCase):
1215
def test_tool_attributes(self):
1316
self.assertEqual("some_tool", some_tool.name)
1417
self.assertEqual("A helpful description", some_tool.description)
15-
self.assertEqual("Hello fred, 3, True", some_tool.action("fred", 3, True))
18+
self.assertEqual("Hello fred, 3, True", some_tool.invoke("fred", 3, True))
1619
self.assertEqual([
1720
Argument(name="username", type="string", required=True),
1821
Argument(name="count", type="number", required=True),
1922
Argument(name="done", type="boolean", required=False),
2023
], some_tool.arguments)
2124

2225
def test_tool_schema(self):
23-
self.assertEqual({
24-
"type": "function",
25-
"function": {
26-
"name": "some_tool",
27-
"description": "A helpful description",
28-
"parameters": {
29-
"type": "object",
30-
"properties": {
31-
"username": {"type": "string"},
32-
"count": {"type": "number"},
33-
"done": {"type": "boolean"},
34-
},
35-
"required": ["username", "count"],
36-
}
26+
self.assertEqual(FunctionToolParam(
27+
type="function",
28+
name="some_tool",
29+
strict=False,
30+
description="A helpful description",
31+
parameters={
32+
"type": "object",
33+
"properties": {
34+
"username": {"type": "string"},
35+
"count": {"type": "number"},
36+
"done": {"type": "boolean"},
37+
},
38+
"required": ["username", "count"],
39+
'additionalProperties': False,
3740
}
38-
}, some_tool.schema())
41+
), some_tool.tool_param())

tests/repository_agent/test_repository_agent.py

Lines changed: 10 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -115,9 +115,9 @@ def test_list_repository_languages(self):
115115
responses.GET,
116116
"https://api.github.com/repos/pickles_org/pickles_repo/languages",
117117
json={
118-
"python": 100,
119-
"html": 50,
120-
"css": 50,
118+
"Python": 100,
119+
"HTML": 50,
120+
"CSS": 50,
121121
},
122122
status=200,
123123
)
@@ -129,10 +129,7 @@ def test_list_repository_languages(self):
129129
self.assertIn("CSS", result.response)
130130
self.assertNotIn("java", result.response)
131131
self.assertEqual(
132-
[ToolCall(
133-
name='list_repository_languages',
134-
arguments={'full_name': 'pickles_org/pickles_repo'}
135-
)],
132+
[ToolCall(name='list_repository_languages', arguments={'full_name': 'pickles_org/pickles_repo'})],
136133
result.tool_calls
137134
)
138135

@@ -143,19 +140,18 @@ def test_list_repository_contributors(self):
143140
responses.GET,
144141
"https://api.github.com/repos/pickles_org/pickles_repo/contributors",
145142
json=[
146-
{"login": "fred"},
147-
{"login": "mary"},
148-
{"login": "kate"},
143+
{"login": "Fred"},
144+
{"login": "Mary"},
145+
{"login": "Kate"},
149146
],
150147
status=200,
151148
)
152149

153150
result = self.agent.answer("Who contributes to the pickles_repo within the pickles_org?")
154151

155-
self.assertIn("fred", result.response)
156-
self.assertIn("mary", result.response)
157-
self.assertIn("kate", result.response)
158-
self.assertNotIn("chuck", result.response)
152+
self.assertIn("Fred", result.response)
153+
self.assertIn("Mary", result.response)
154+
self.assertIn("Kate", result.response)
159155
self.assertEqual(
160156
[ToolCall(
161157
name='list_repository_contributors',

0 commit comments

Comments
 (0)