Skip to content

Commit 9797d39

Browse files
authored
OpenAI external tools (home-assistant#150599)
1 parent e68df66 commit 9797d39

File tree

4 files changed

+237
-48
lines changed

4 files changed

+237
-48
lines changed

homeassistant/components/openai_conversation/entity.py

Lines changed: 133 additions & 46 deletions
Original file line numberDiff line numberDiff line change
@@ -14,13 +14,16 @@
1414
from openai.types.responses import (
1515
EasyInputMessageParam,
1616
FunctionToolParam,
17+
ResponseCodeInterpreterToolCall,
1718
ResponseCompletedEvent,
1819
ResponseErrorEvent,
1920
ResponseFailedEvent,
2021
ResponseFunctionCallArgumentsDeltaEvent,
2122
ResponseFunctionCallArgumentsDoneEvent,
2223
ResponseFunctionToolCall,
2324
ResponseFunctionToolCallParam,
25+
ResponseFunctionWebSearch,
26+
ResponseFunctionWebSearchParam,
2427
ResponseIncompleteEvent,
2528
ResponseInputFileParam,
2629
ResponseInputImageParam,
@@ -149,16 +152,27 @@ def _convert_content_to_param(
149152
"""Convert any native chat message for this agent to the native format."""
150153
messages: ResponseInputParam = []
151154
reasoning_summary: list[str] = []
155+
web_search_calls: dict[str, ResponseFunctionWebSearchParam] = {}
152156

153157
for content in chat_content:
154158
if isinstance(content, conversation.ToolResultContent):
155-
messages.append(
156-
FunctionCallOutput(
157-
type="function_call_output",
158-
call_id=content.tool_call_id,
159-
output=json.dumps(content.tool_result),
159+
if (
160+
content.tool_name == "web_search_call"
161+
and content.tool_call_id in web_search_calls
162+
):
163+
web_search_call = web_search_calls.pop(content.tool_call_id)
164+
web_search_call["status"] = content.tool_result.get( # type: ignore[typeddict-item]
165+
"status", "completed"
166+
)
167+
messages.append(web_search_call)
168+
else:
169+
messages.append(
170+
FunctionCallOutput(
171+
type="function_call_output",
172+
call_id=content.tool_call_id,
173+
output=json.dumps(content.tool_result),
174+
)
160175
)
161-
)
162176
continue
163177

164178
if content.content:
@@ -173,15 +187,27 @@ def _convert_content_to_param(
173187

174188
if isinstance(content, conversation.AssistantContent):
175189
if content.tool_calls:
176-
messages.extend(
177-
ResponseFunctionToolCallParam(
178-
type="function_call",
179-
name=tool_call.tool_name,
180-
arguments=json.dumps(tool_call.tool_args),
181-
call_id=tool_call.id,
182-
)
183-
for tool_call in content.tool_calls
184-
)
190+
for tool_call in content.tool_calls:
191+
if (
192+
tool_call.external
193+
and tool_call.tool_name == "web_search_call"
194+
and "action" in tool_call.tool_args
195+
):
196+
web_search_calls[tool_call.id] = ResponseFunctionWebSearchParam(
197+
type="web_search_call",
198+
id=tool_call.id,
199+
action=tool_call.tool_args["action"],
200+
status="completed",
201+
)
202+
else:
203+
messages.append(
204+
ResponseFunctionToolCallParam(
205+
type="function_call",
206+
name=tool_call.tool_name,
207+
arguments=json.dumps(tool_call.tool_args),
208+
call_id=tool_call.id,
209+
)
210+
)
185211

186212
if content.thinking_content:
187213
reasoning_summary.append(content.thinking_content)
@@ -211,25 +237,37 @@ def _convert_content_to_param(
211237
async def _transform_stream(
212238
chat_log: conversation.ChatLog,
213239
stream: AsyncStream[ResponseStreamEvent],
214-
) -> AsyncGenerator[conversation.AssistantContentDeltaDict]:
240+
) -> AsyncGenerator[
241+
conversation.AssistantContentDeltaDict | conversation.ToolResultContentDeltaDict
242+
]:
215243
"""Transform an OpenAI delta stream into HA format."""
216244
last_summary_index = None
245+
last_role: Literal["assistant", "tool_result"] | None = None
217246

218247
async for event in stream:
219248
LOGGER.debug("Received event: %s", event)
220249

221250
if isinstance(event, ResponseOutputItemAddedEvent):
222-
if isinstance(event.item, ResponseOutputMessage):
223-
yield {"role": event.item.role}
224-
last_summary_index = None
225-
elif isinstance(event.item, ResponseFunctionToolCall):
251+
if isinstance(event.item, ResponseFunctionToolCall):
226252
# OpenAI has tool calls as individual events
227253
# while HA puts tool calls inside the assistant message.
228254
# We turn them into individual assistant content for HA
229255
# to ensure that tools are called as soon as possible.
230256
yield {"role": "assistant"}
257+
last_role = "assistant"
231258
last_summary_index = None
232259
current_tool_call = event.item
260+
elif (
261+
isinstance(event.item, ResponseOutputMessage)
262+
or (
263+
isinstance(event.item, ResponseReasoningItem)
264+
and last_summary_index is not None
265+
) # Subsequent ResponseReasoningItem
266+
or last_role != "assistant"
267+
):
268+
yield {"role": "assistant"}
269+
last_role = "assistant"
270+
last_summary_index = None
233271
elif isinstance(event, ResponseOutputItemDoneEvent):
234272
if isinstance(event.item, ResponseReasoningItem):
235273
yield {
@@ -240,6 +278,52 @@ async def _transform_stream(
240278
encrypted_content=event.item.encrypted_content,
241279
)
242280
}
281+
last_summary_index = len(event.item.summary) - 1
282+
elif isinstance(event.item, ResponseCodeInterpreterToolCall):
283+
yield {
284+
"tool_calls": [
285+
llm.ToolInput(
286+
id=event.item.id,
287+
tool_name="code_interpreter",
288+
tool_args={
289+
"code": event.item.code,
290+
"container": event.item.container_id,
291+
},
292+
external=True,
293+
)
294+
]
295+
}
296+
yield {
297+
"role": "tool_result",
298+
"tool_call_id": event.item.id,
299+
"tool_name": "code_interpreter",
300+
"tool_result": {
301+
"output": [output.to_dict() for output in event.item.outputs] # type: ignore[misc]
302+
if event.item.outputs is not None
303+
else None
304+
},
305+
}
306+
last_role = "tool_result"
307+
elif isinstance(event.item, ResponseFunctionWebSearch):
308+
yield {
309+
"tool_calls": [
310+
llm.ToolInput(
311+
id=event.item.id,
312+
tool_name="web_search_call",
313+
tool_args={
314+
"action": event.item.action.to_dict(),
315+
},
316+
external=True,
317+
)
318+
]
319+
}
320+
yield {
321+
"role": "tool_result",
322+
"tool_call_id": event.item.id,
323+
"tool_name": "web_search_call",
324+
"tool_result": {"status": event.item.status},
325+
}
326+
last_role = "tool_result"
243327
elif isinstance(event, ResponseTextDeltaEvent):
244328
yield {"content": event.delta}
245329
elif isinstance(event, ResponseReasoningSummaryTextDeltaEvent):
@@ -252,6 +336,7 @@ async def _transform_stream(
252336
and event.summary_index != last_summary_index
253337
):
254338
yield {"role": "assistant"}
339+
last_role = "assistant"
255340
last_summary_index = event.summary_index
256341
yield {"thinking_content": event.delta}
257342
elif isinstance(event, ResponseFunctionCallArgumentsDeltaEvent):
@@ -348,6 +433,33 @@ async def _async_handle_chat_log(
348433
"""Generate an answer for the chat log."""
349434
options = self.subentry.data
350435

436+
messages = _convert_content_to_param(chat_log.content)
437+
438+
model_args = ResponseCreateParamsStreaming(
439+
model=options.get(CONF_CHAT_MODEL, RECOMMENDED_CHAT_MODEL),
440+
input=messages,
441+
max_output_tokens=options.get(CONF_MAX_TOKENS, RECOMMENDED_MAX_TOKENS),
442+
top_p=options.get(CONF_TOP_P, RECOMMENDED_TOP_P),
443+
temperature=options.get(CONF_TEMPERATURE, RECOMMENDED_TEMPERATURE),
444+
user=chat_log.conversation_id,
445+
store=False,
446+
stream=True,
447+
)
448+
449+
if model_args["model"].startswith(("o", "gpt-5")):
450+
model_args["reasoning"] = {
451+
"effort": options.get(
452+
CONF_REASONING_EFFORT, RECOMMENDED_REASONING_EFFORT
453+
),
454+
"summary": "auto",
455+
}
456+
model_args["include"] = ["reasoning.encrypted_content"]
457+
458+
if model_args["model"].startswith("gpt-5"):
459+
model_args["text"] = {
460+
"verbosity": options.get(CONF_VERBOSITY, RECOMMENDED_VERBOSITY)
461+
}
462+
351463
tools: list[ToolParam] = []
352464
if chat_log.llm_api:
353465
tools = [
@@ -381,36 +493,11 @@ async def _async_handle_chat_log(
381493
),
382494
)
383495
)
496+
model_args.setdefault("include", []).append("code_interpreter_call.outputs") # type: ignore[union-attr]
384497

385-
messages = _convert_content_to_param(chat_log.content)
386-
387-
model_args = ResponseCreateParamsStreaming(
388-
model=options.get(CONF_CHAT_MODEL, RECOMMENDED_CHAT_MODEL),
389-
input=messages,
390-
max_output_tokens=options.get(CONF_MAX_TOKENS, RECOMMENDED_MAX_TOKENS),
391-
top_p=options.get(CONF_TOP_P, RECOMMENDED_TOP_P),
392-
temperature=options.get(CONF_TEMPERATURE, RECOMMENDED_TEMPERATURE),
393-
user=chat_log.conversation_id,
394-
store=False,
395-
stream=True,
396-
)
397498
if tools:
398499
model_args["tools"] = tools
399500

400-
if model_args["model"].startswith(("o", "gpt-5")):
401-
model_args["reasoning"] = {
402-
"effort": options.get(
403-
CONF_REASONING_EFFORT, RECOMMENDED_REASONING_EFFORT
404-
),
405-
"summary": "auto",
406-
}
407-
model_args["include"] = ["reasoning.encrypted_content"]
408-
409-
if model_args["model"].startswith("gpt-5"):
410-
model_args["text"] = {
411-
"verbosity": options.get(CONF_VERBOSITY, RECOMMENDED_VERBOSITY)
412-
}
413-
414501
last_content = chat_log.content[-1]
415502

416503
# Handle attachments by adding them to the last user message

tests/components/openai_conversation/__init__.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,7 @@
2929
ResponseWebSearchCallInProgressEvent,
3030
ResponseWebSearchCallSearchingEvent,
3131
)
32+
from openai.types.responses.response_code_interpreter_tool_call import OutputLogs
3233
from openai.types.responses.response_function_web_search import ActionSearch
3334
from openai.types.responses.response_reasoning_item import Summary
3435

@@ -320,7 +321,7 @@ def create_web_search_item(id: str, output_index: int) -> list[ResponseStreamEve
320321

321322

322323
def create_code_interpreter_item(
323-
id: str, code: str | list[str], output_index: int
324+
id: str, code: str | list[str], output_index: int, logs: str | None = None
324325
) -> list[ResponseStreamEvent]:
325326
"""Create a message item."""
326327
if isinstance(code, str):
@@ -388,7 +389,7 @@ def create_code_interpreter_item(
388389
id=id,
389390
code=code,
390391
container_id=container_id,
391-
outputs=None,
392+
outputs=[OutputLogs(type="logs", logs=logs)] if logs else None,
392393
status="completed",
393394
type="code_interpreter_call",
394395
),

tests/components/openai_conversation/snapshots/test_conversation.ambr

Lines changed: 68 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,39 @@
11
# serializer version: 1
2+
# name: test_code_interpreter
3+
list([
4+
dict({
5+
'content': 'Please use the python tool to calculate square root of 55555',
6+
'role': 'user',
7+
'type': 'message',
8+
}),
9+
dict({
10+
'arguments': '{"code": "import math\\nmath.sqrt(55555)", "container": "cntr_A"}',
11+
'call_id': 'ci_A',
12+
'name': 'code_interpreter',
13+
'type': 'function_call',
14+
}),
15+
dict({
16+
'call_id': 'ci_A',
17+
'output': '{"output": [{"logs": "235.70108188126758\\n", "type": "logs"}]}',
18+
'type': 'function_call_output',
19+
}),
20+
dict({
21+
'content': 'I’ve calculated it with Python: the square root of 55555 is approximately 235.70108188126758.',
22+
'role': 'assistant',
23+
'type': 'message',
24+
}),
25+
dict({
26+
'content': 'Thank you!',
27+
'role': 'user',
28+
'type': 'message',
29+
}),
30+
dict({
31+
'content': 'You are welcome!',
32+
'role': 'assistant',
33+
'type': 'message',
34+
}),
35+
])
36+
# ---
237
# name: test_function_call
338
list([
439
dict({
@@ -172,3 +207,36 @@
172207
}),
173208
])
174209
# ---
210+
# name: test_web_search
211+
list([
212+
dict({
213+
'content': "What's on the latest news?",
214+
'role': 'user',
215+
'type': 'message',
216+
}),
217+
dict({
218+
'action': dict({
219+
'query': 'query',
220+
'type': 'search',
221+
}),
222+
'id': 'ws_A',
223+
'status': 'completed',
224+
'type': 'web_search_call',
225+
}),
226+
dict({
227+
'content': 'Home Assistant now supports ChatGPT Search in Assist',
228+
'role': 'assistant',
229+
'type': 'message',
230+
}),
231+
dict({
232+
'content': 'Thank you!',
233+
'role': 'user',
234+
'type': 'message',
235+
}),
236+
dict({
237+
'content': 'You are welcome!',
238+
'role': 'assistant',
239+
'type': 'message',
240+
}),
241+
])
242+
# ---

0 commit comments

Comments
 (0)