Skip to content

Commit 70856bd

Browse files
authored
Split OpenAI entity (home-assistant#147771)
1 parent be6b624 commit 70856bd

File tree

2 files changed

+322
-299
lines changed

2 files changed

+322
-299
lines changed

homeassistant/components/openai_conversation/conversation.py

Lines changed: 8 additions & 299 deletions
Original file line numberDiff line numberDiff line change
@@ -1,73 +1,19 @@
11
"""Conversation support for OpenAI."""
22

3-
from collections.abc import AsyncGenerator, Callable
4-
import json
5-
from typing import Any, Literal, cast
6-
7-
import openai
8-
from openai._streaming import AsyncStream
9-
from openai.types.responses import (
10-
EasyInputMessageParam,
11-
FunctionToolParam,
12-
ResponseCompletedEvent,
13-
ResponseErrorEvent,
14-
ResponseFailedEvent,
15-
ResponseFunctionCallArgumentsDeltaEvent,
16-
ResponseFunctionCallArgumentsDoneEvent,
17-
ResponseFunctionToolCall,
18-
ResponseFunctionToolCallParam,
19-
ResponseIncompleteEvent,
20-
ResponseInputParam,
21-
ResponseOutputItemAddedEvent,
22-
ResponseOutputItemDoneEvent,
23-
ResponseOutputMessage,
24-
ResponseOutputMessageParam,
25-
ResponseReasoningItem,
26-
ResponseReasoningItemParam,
27-
ResponseStreamEvent,
28-
ResponseTextDeltaEvent,
29-
ToolParam,
30-
WebSearchToolParam,
31-
)
32-
from openai.types.responses.response_input_param import FunctionCallOutput
33-
from openai.types.responses.web_search_tool_param import UserLocation
34-
from voluptuous_openapi import convert
3+
from typing import Literal
354

365
from homeassistant.components import assist_pipeline, conversation
376
from homeassistant.config_entries import ConfigEntry, ConfigSubentry
387
from homeassistant.const import CONF_LLM_HASS_API, MATCH_ALL
398
from homeassistant.core import HomeAssistant
40-
from homeassistant.exceptions import HomeAssistantError
41-
from homeassistant.helpers import device_registry as dr, intent, llm
9+
from homeassistant.helpers import intent
4210
from homeassistant.helpers.entity_platform import AddConfigEntryEntitiesCallback
4311

4412
from . import OpenAIConfigEntry
45-
from .const import (
46-
CONF_CHAT_MODEL,
47-
CONF_MAX_TOKENS,
48-
CONF_PROMPT,
49-
CONF_REASONING_EFFORT,
50-
CONF_TEMPERATURE,
51-
CONF_TOP_P,
52-
CONF_WEB_SEARCH,
53-
CONF_WEB_SEARCH_CITY,
54-
CONF_WEB_SEARCH_CONTEXT_SIZE,
55-
CONF_WEB_SEARCH_COUNTRY,
56-
CONF_WEB_SEARCH_REGION,
57-
CONF_WEB_SEARCH_TIMEZONE,
58-
CONF_WEB_SEARCH_USER_LOCATION,
59-
DOMAIN,
60-
LOGGER,
61-
RECOMMENDED_CHAT_MODEL,
62-
RECOMMENDED_MAX_TOKENS,
63-
RECOMMENDED_REASONING_EFFORT,
64-
RECOMMENDED_TEMPERATURE,
65-
RECOMMENDED_TOP_P,
66-
RECOMMENDED_WEB_SEARCH_CONTEXT_SIZE,
67-
)
13+
from .const import CONF_PROMPT, DOMAIN
14+
from .entity import OpenAIBaseLLMEntity
6815

6916
# Max number of back and forth with the LLM to generate a response
70-
MAX_TOOL_ITERATIONS = 10
7117

7218

7319
async def async_setup_entry(
@@ -86,170 +32,18 @@ async def async_setup_entry(
8632
)
8733

8834

89-
def _format_tool(
90-
tool: llm.Tool, custom_serializer: Callable[[Any], Any] | None
91-
) -> FunctionToolParam:
92-
"""Format tool specification."""
93-
return FunctionToolParam(
94-
type="function",
95-
name=tool.name,
96-
parameters=convert(tool.parameters, custom_serializer=custom_serializer),
97-
description=tool.description,
98-
strict=False,
99-
)
100-
101-
102-
def _convert_content_to_param(
103-
content: conversation.Content,
104-
) -> ResponseInputParam:
105-
"""Convert any native chat message for this agent to the native format."""
106-
messages: ResponseInputParam = []
107-
if isinstance(content, conversation.ToolResultContent):
108-
return [
109-
FunctionCallOutput(
110-
type="function_call_output",
111-
call_id=content.tool_call_id,
112-
output=json.dumps(content.tool_result),
113-
)
114-
]
115-
116-
if content.content:
117-
role: Literal["user", "assistant", "system", "developer"] = content.role
118-
if role == "system":
119-
role = "developer"
120-
messages.append(
121-
EasyInputMessageParam(type="message", role=role, content=content.content)
122-
)
123-
124-
if isinstance(content, conversation.AssistantContent) and content.tool_calls:
125-
messages.extend(
126-
ResponseFunctionToolCallParam(
127-
type="function_call",
128-
name=tool_call.tool_name,
129-
arguments=json.dumps(tool_call.tool_args),
130-
call_id=tool_call.id,
131-
)
132-
for tool_call in content.tool_calls
133-
)
134-
return messages
135-
136-
137-
async def _transform_stream(
138-
chat_log: conversation.ChatLog,
139-
result: AsyncStream[ResponseStreamEvent],
140-
messages: ResponseInputParam,
141-
) -> AsyncGenerator[conversation.AssistantContentDeltaDict]:
142-
"""Transform an OpenAI delta stream into HA format."""
143-
async for event in result:
144-
LOGGER.debug("Received event: %s", event)
145-
146-
if isinstance(event, ResponseOutputItemAddedEvent):
147-
if isinstance(event.item, ResponseOutputMessage):
148-
yield {"role": event.item.role}
149-
elif isinstance(event.item, ResponseFunctionToolCall):
150-
# OpenAI has tool calls as individual events
151-
# while HA puts tool calls inside the assistant message.
152-
# We turn them into individual assistant content for HA
153-
# to ensure that tools are called as soon as possible.
154-
yield {"role": "assistant"}
155-
current_tool_call = event.item
156-
elif isinstance(event, ResponseOutputItemDoneEvent):
157-
item = event.item.model_dump()
158-
item.pop("status", None)
159-
if isinstance(event.item, ResponseReasoningItem):
160-
messages.append(cast(ResponseReasoningItemParam, item))
161-
elif isinstance(event.item, ResponseOutputMessage):
162-
messages.append(cast(ResponseOutputMessageParam, item))
163-
elif isinstance(event.item, ResponseFunctionToolCall):
164-
messages.append(cast(ResponseFunctionToolCallParam, item))
165-
elif isinstance(event, ResponseTextDeltaEvent):
166-
yield {"content": event.delta}
167-
elif isinstance(event, ResponseFunctionCallArgumentsDeltaEvent):
168-
current_tool_call.arguments += event.delta
169-
elif isinstance(event, ResponseFunctionCallArgumentsDoneEvent):
170-
current_tool_call.status = "completed"
171-
yield {
172-
"tool_calls": [
173-
llm.ToolInput(
174-
id=current_tool_call.call_id,
175-
tool_name=current_tool_call.name,
176-
tool_args=json.loads(current_tool_call.arguments),
177-
)
178-
]
179-
}
180-
elif isinstance(event, ResponseCompletedEvent):
181-
if event.response.usage is not None:
182-
chat_log.async_trace(
183-
{
184-
"stats": {
185-
"input_tokens": event.response.usage.input_tokens,
186-
"output_tokens": event.response.usage.output_tokens,
187-
}
188-
}
189-
)
190-
elif isinstance(event, ResponseIncompleteEvent):
191-
if event.response.usage is not None:
192-
chat_log.async_trace(
193-
{
194-
"stats": {
195-
"input_tokens": event.response.usage.input_tokens,
196-
"output_tokens": event.response.usage.output_tokens,
197-
}
198-
}
199-
)
200-
201-
if (
202-
event.response.incomplete_details
203-
and event.response.incomplete_details.reason
204-
):
205-
reason: str = event.response.incomplete_details.reason
206-
else:
207-
reason = "unknown reason"
208-
209-
if reason == "max_output_tokens":
210-
reason = "max output tokens reached"
211-
elif reason == "content_filter":
212-
reason = "content filter triggered"
213-
214-
raise HomeAssistantError(f"OpenAI response incomplete: {reason}")
215-
elif isinstance(event, ResponseFailedEvent):
216-
if event.response.usage is not None:
217-
chat_log.async_trace(
218-
{
219-
"stats": {
220-
"input_tokens": event.response.usage.input_tokens,
221-
"output_tokens": event.response.usage.output_tokens,
222-
}
223-
}
224-
)
225-
reason = "unknown reason"
226-
if event.response.error is not None:
227-
reason = event.response.error.message
228-
raise HomeAssistantError(f"OpenAI response failed: {reason}")
229-
elif isinstance(event, ResponseErrorEvent):
230-
raise HomeAssistantError(f"OpenAI response error: {event.message}")
231-
232-
23335
class OpenAIConversationEntity(
234-
conversation.ConversationEntity, conversation.AbstractConversationAgent
36+
conversation.ConversationEntity,
37+
conversation.AbstractConversationAgent,
38+
OpenAIBaseLLMEntity,
23539
):
23640
"""OpenAI conversation agent."""
23741

23842
_attr_supports_streaming = True
23943

24044
def __init__(self, entry: OpenAIConfigEntry, subentry: ConfigSubentry) -> None:
24145
"""Initialize the agent."""
242-
self.entry = entry
243-
self.subentry = subentry
244-
self._attr_name = subentry.title
245-
self._attr_unique_id = subentry.subentry_id
246-
self._attr_device_info = dr.DeviceInfo(
247-
identifiers={(DOMAIN, subentry.subentry_id)},
248-
name=subentry.title,
249-
manufacturer="OpenAI",
250-
model=subentry.data.get(CONF_CHAT_MODEL, RECOMMENDED_CHAT_MODEL),
251-
entry_type=dr.DeviceEntryType.SERVICE,
252-
)
46+
super().__init__(entry, subentry)
25347
if self.subentry.data.get(CONF_LLM_HASS_API):
25448
self._attr_supported_features = (
25549
conversation.ConversationEntityFeature.CONTROL
@@ -305,91 +99,6 @@ async def _async_handle_message(
30599
continue_conversation=chat_log.continue_conversation,
306100
)
307101

308-
async def _async_handle_chat_log(
309-
self,
310-
chat_log: conversation.ChatLog,
311-
) -> None:
312-
"""Generate an answer for the chat log."""
313-
options = self.subentry.data
314-
315-
tools: list[ToolParam] | None = None
316-
if chat_log.llm_api:
317-
tools = [
318-
_format_tool(tool, chat_log.llm_api.custom_serializer)
319-
for tool in chat_log.llm_api.tools
320-
]
321-
322-
if options.get(CONF_WEB_SEARCH):
323-
web_search = WebSearchToolParam(
324-
type="web_search_preview",
325-
search_context_size=options.get(
326-
CONF_WEB_SEARCH_CONTEXT_SIZE, RECOMMENDED_WEB_SEARCH_CONTEXT_SIZE
327-
),
328-
)
329-
if options.get(CONF_WEB_SEARCH_USER_LOCATION):
330-
web_search["user_location"] = UserLocation(
331-
type="approximate",
332-
city=options.get(CONF_WEB_SEARCH_CITY, ""),
333-
region=options.get(CONF_WEB_SEARCH_REGION, ""),
334-
country=options.get(CONF_WEB_SEARCH_COUNTRY, ""),
335-
timezone=options.get(CONF_WEB_SEARCH_TIMEZONE, ""),
336-
)
337-
if tools is None:
338-
tools = []
339-
tools.append(web_search)
340-
341-
model = options.get(CONF_CHAT_MODEL, RECOMMENDED_CHAT_MODEL)
342-
messages = [
343-
m
344-
for content in chat_log.content
345-
for m in _convert_content_to_param(content)
346-
]
347-
348-
client = self.entry.runtime_data
349-
350-
# To prevent infinite loops, we limit the number of iterations
351-
for _iteration in range(MAX_TOOL_ITERATIONS):
352-
model_args = {
353-
"model": model,
354-
"input": messages,
355-
"max_output_tokens": options.get(
356-
CONF_MAX_TOKENS, RECOMMENDED_MAX_TOKENS
357-
),
358-
"top_p": options.get(CONF_TOP_P, RECOMMENDED_TOP_P),
359-
"temperature": options.get(CONF_TEMPERATURE, RECOMMENDED_TEMPERATURE),
360-
"user": chat_log.conversation_id,
361-
"stream": True,
362-
}
363-
if tools:
364-
model_args["tools"] = tools
365-
366-
if model.startswith("o"):
367-
model_args["reasoning"] = {
368-
"effort": options.get(
369-
CONF_REASONING_EFFORT, RECOMMENDED_REASONING_EFFORT
370-
)
371-
}
372-
else:
373-
model_args["store"] = False
374-
375-
try:
376-
result = await client.responses.create(**model_args)
377-
except openai.RateLimitError as err:
378-
LOGGER.error("Rate limited by OpenAI: %s", err)
379-
raise HomeAssistantError("Rate limited or insufficient funds") from err
380-
except openai.OpenAIError as err:
381-
LOGGER.error("Error talking to OpenAI: %s", err)
382-
raise HomeAssistantError("Error talking to OpenAI") from err
383-
384-
async for content in chat_log.async_add_delta_content_stream(
385-
self.entity_id, _transform_stream(chat_log, result, messages)
386-
):
387-
if not isinstance(content, conversation.AssistantContent):
388-
messages.extend(_convert_content_to_param(content))
389-
390-
if not chat_log.unresponded_tool_results:
391-
break
392-
393102
async def _async_entry_update_listener(
394103
self, hass: HomeAssistant, entry: ConfigEntry
395104
) -> None:

0 commit comments

Comments
 (0)