11"""Conversation support for OpenAI."""
22
3- from collections .abc import AsyncGenerator , Callable
4- import json
5- from typing import Any , Literal , cast
6-
7- import openai
8- from openai ._streaming import AsyncStream
9- from openai .types .responses import (
10- EasyInputMessageParam ,
11- FunctionToolParam ,
12- ResponseCompletedEvent ,
13- ResponseErrorEvent ,
14- ResponseFailedEvent ,
15- ResponseFunctionCallArgumentsDeltaEvent ,
16- ResponseFunctionCallArgumentsDoneEvent ,
17- ResponseFunctionToolCall ,
18- ResponseFunctionToolCallParam ,
19- ResponseIncompleteEvent ,
20- ResponseInputParam ,
21- ResponseOutputItemAddedEvent ,
22- ResponseOutputItemDoneEvent ,
23- ResponseOutputMessage ,
24- ResponseOutputMessageParam ,
25- ResponseReasoningItem ,
26- ResponseReasoningItemParam ,
27- ResponseStreamEvent ,
28- ResponseTextDeltaEvent ,
29- ToolParam ,
30- WebSearchToolParam ,
31- )
32- from openai .types .responses .response_input_param import FunctionCallOutput
33- from openai .types .responses .web_search_tool_param import UserLocation
34- from voluptuous_openapi import convert
3+ from typing import Literal
354
365from homeassistant .components import assist_pipeline , conversation
376from homeassistant .config_entries import ConfigEntry , ConfigSubentry
387from homeassistant .const import CONF_LLM_HASS_API , MATCH_ALL
398from homeassistant .core import HomeAssistant
40- from homeassistant .exceptions import HomeAssistantError
41- from homeassistant .helpers import device_registry as dr , intent , llm
9+ from homeassistant .helpers import intent
4210from homeassistant .helpers .entity_platform import AddConfigEntryEntitiesCallback
4311
4412from . import OpenAIConfigEntry
45- from .const import (
46- CONF_CHAT_MODEL ,
47- CONF_MAX_TOKENS ,
48- CONF_PROMPT ,
49- CONF_REASONING_EFFORT ,
50- CONF_TEMPERATURE ,
51- CONF_TOP_P ,
52- CONF_WEB_SEARCH ,
53- CONF_WEB_SEARCH_CITY ,
54- CONF_WEB_SEARCH_CONTEXT_SIZE ,
55- CONF_WEB_SEARCH_COUNTRY ,
56- CONF_WEB_SEARCH_REGION ,
57- CONF_WEB_SEARCH_TIMEZONE ,
58- CONF_WEB_SEARCH_USER_LOCATION ,
59- DOMAIN ,
60- LOGGER ,
61- RECOMMENDED_CHAT_MODEL ,
62- RECOMMENDED_MAX_TOKENS ,
63- RECOMMENDED_REASONING_EFFORT ,
64- RECOMMENDED_TEMPERATURE ,
65- RECOMMENDED_TOP_P ,
66- RECOMMENDED_WEB_SEARCH_CONTEXT_SIZE ,
67- )
13+ from .const import CONF_PROMPT , DOMAIN
14+ from .entity import OpenAIBaseLLMEntity
6815
6916# Max number of back and forth with the LLM to generate a response
70- MAX_TOOL_ITERATIONS = 10
7117
7218
7319async def async_setup_entry (
@@ -86,170 +32,18 @@ async def async_setup_entry(
8632 )
8733
8834
89- def _format_tool (
90- tool : llm .Tool , custom_serializer : Callable [[Any ], Any ] | None
91- ) -> FunctionToolParam :
92- """Format tool specification."""
93- return FunctionToolParam (
94- type = "function" ,
95- name = tool .name ,
96- parameters = convert (tool .parameters , custom_serializer = custom_serializer ),
97- description = tool .description ,
98- strict = False ,
99- )
100-
101-
102- def _convert_content_to_param (
103- content : conversation .Content ,
104- ) -> ResponseInputParam :
105- """Convert any native chat message for this agent to the native format."""
106- messages : ResponseInputParam = []
107- if isinstance (content , conversation .ToolResultContent ):
108- return [
109- FunctionCallOutput (
110- type = "function_call_output" ,
111- call_id = content .tool_call_id ,
112- output = json .dumps (content .tool_result ),
113- )
114- ]
115-
116- if content .content :
117- role : Literal ["user" , "assistant" , "system" , "developer" ] = content .role
118- if role == "system" :
119- role = "developer"
120- messages .append (
121- EasyInputMessageParam (type = "message" , role = role , content = content .content )
122- )
123-
124- if isinstance (content , conversation .AssistantContent ) and content .tool_calls :
125- messages .extend (
126- ResponseFunctionToolCallParam (
127- type = "function_call" ,
128- name = tool_call .tool_name ,
129- arguments = json .dumps (tool_call .tool_args ),
130- call_id = tool_call .id ,
131- )
132- for tool_call in content .tool_calls
133- )
134- return messages
135-
136-
137- async def _transform_stream (
138- chat_log : conversation .ChatLog ,
139- result : AsyncStream [ResponseStreamEvent ],
140- messages : ResponseInputParam ,
141- ) -> AsyncGenerator [conversation .AssistantContentDeltaDict ]:
142- """Transform an OpenAI delta stream into HA format."""
143- async for event in result :
144- LOGGER .debug ("Received event: %s" , event )
145-
146- if isinstance (event , ResponseOutputItemAddedEvent ):
147- if isinstance (event .item , ResponseOutputMessage ):
148- yield {"role" : event .item .role }
149- elif isinstance (event .item , ResponseFunctionToolCall ):
150- # OpenAI has tool calls as individual events
151- # while HA puts tool calls inside the assistant message.
152- # We turn them into individual assistant content for HA
153- # to ensure that tools are called as soon as possible.
154- yield {"role" : "assistant" }
155- current_tool_call = event .item
156- elif isinstance (event , ResponseOutputItemDoneEvent ):
157- item = event .item .model_dump ()
158- item .pop ("status" , None )
159- if isinstance (event .item , ResponseReasoningItem ):
160- messages .append (cast (ResponseReasoningItemParam , item ))
161- elif isinstance (event .item , ResponseOutputMessage ):
162- messages .append (cast (ResponseOutputMessageParam , item ))
163- elif isinstance (event .item , ResponseFunctionToolCall ):
164- messages .append (cast (ResponseFunctionToolCallParam , item ))
165- elif isinstance (event , ResponseTextDeltaEvent ):
166- yield {"content" : event .delta }
167- elif isinstance (event , ResponseFunctionCallArgumentsDeltaEvent ):
168- current_tool_call .arguments += event .delta
169- elif isinstance (event , ResponseFunctionCallArgumentsDoneEvent ):
170- current_tool_call .status = "completed"
171- yield {
172- "tool_calls" : [
173- llm .ToolInput (
174- id = current_tool_call .call_id ,
175- tool_name = current_tool_call .name ,
176- tool_args = json .loads (current_tool_call .arguments ),
177- )
178- ]
179- }
180- elif isinstance (event , ResponseCompletedEvent ):
181- if event .response .usage is not None :
182- chat_log .async_trace (
183- {
184- "stats" : {
185- "input_tokens" : event .response .usage .input_tokens ,
186- "output_tokens" : event .response .usage .output_tokens ,
187- }
188- }
189- )
190- elif isinstance (event , ResponseIncompleteEvent ):
191- if event .response .usage is not None :
192- chat_log .async_trace (
193- {
194- "stats" : {
195- "input_tokens" : event .response .usage .input_tokens ,
196- "output_tokens" : event .response .usage .output_tokens ,
197- }
198- }
199- )
200-
201- if (
202- event .response .incomplete_details
203- and event .response .incomplete_details .reason
204- ):
205- reason : str = event .response .incomplete_details .reason
206- else :
207- reason = "unknown reason"
208-
209- if reason == "max_output_tokens" :
210- reason = "max output tokens reached"
211- elif reason == "content_filter" :
212- reason = "content filter triggered"
213-
214- raise HomeAssistantError (f"OpenAI response incomplete: { reason } " )
215- elif isinstance (event , ResponseFailedEvent ):
216- if event .response .usage is not None :
217- chat_log .async_trace (
218- {
219- "stats" : {
220- "input_tokens" : event .response .usage .input_tokens ,
221- "output_tokens" : event .response .usage .output_tokens ,
222- }
223- }
224- )
225- reason = "unknown reason"
226- if event .response .error is not None :
227- reason = event .response .error .message
228- raise HomeAssistantError (f"OpenAI response failed: { reason } " )
229- elif isinstance (event , ResponseErrorEvent ):
230- raise HomeAssistantError (f"OpenAI response error: { event .message } " )
231-
232-
23335class OpenAIConversationEntity (
234- conversation .ConversationEntity , conversation .AbstractConversationAgent
36+ conversation .ConversationEntity ,
37+ conversation .AbstractConversationAgent ,
38+ OpenAIBaseLLMEntity ,
23539):
23640 """OpenAI conversation agent."""
23741
23842 _attr_supports_streaming = True
23943
24044 def __init__ (self , entry : OpenAIConfigEntry , subentry : ConfigSubentry ) -> None :
24145 """Initialize the agent."""
242- self .entry = entry
243- self .subentry = subentry
244- self ._attr_name = subentry .title
245- self ._attr_unique_id = subentry .subentry_id
246- self ._attr_device_info = dr .DeviceInfo (
247- identifiers = {(DOMAIN , subentry .subentry_id )},
248- name = subentry .title ,
249- manufacturer = "OpenAI" ,
250- model = subentry .data .get (CONF_CHAT_MODEL , RECOMMENDED_CHAT_MODEL ),
251- entry_type = dr .DeviceEntryType .SERVICE ,
252- )
46+ super ().__init__ (entry , subentry )
25347 if self .subentry .data .get (CONF_LLM_HASS_API ):
25448 self ._attr_supported_features = (
25549 conversation .ConversationEntityFeature .CONTROL
@@ -305,91 +99,6 @@ async def _async_handle_message(
30599 continue_conversation = chat_log .continue_conversation ,
306100 )
307101
308- async def _async_handle_chat_log (
309- self ,
310- chat_log : conversation .ChatLog ,
311- ) -> None :
312- """Generate an answer for the chat log."""
313- options = self .subentry .data
314-
315- tools : list [ToolParam ] | None = None
316- if chat_log .llm_api :
317- tools = [
318- _format_tool (tool , chat_log .llm_api .custom_serializer )
319- for tool in chat_log .llm_api .tools
320- ]
321-
322- if options .get (CONF_WEB_SEARCH ):
323- web_search = WebSearchToolParam (
324- type = "web_search_preview" ,
325- search_context_size = options .get (
326- CONF_WEB_SEARCH_CONTEXT_SIZE , RECOMMENDED_WEB_SEARCH_CONTEXT_SIZE
327- ),
328- )
329- if options .get (CONF_WEB_SEARCH_USER_LOCATION ):
330- web_search ["user_location" ] = UserLocation (
331- type = "approximate" ,
332- city = options .get (CONF_WEB_SEARCH_CITY , "" ),
333- region = options .get (CONF_WEB_SEARCH_REGION , "" ),
334- country = options .get (CONF_WEB_SEARCH_COUNTRY , "" ),
335- timezone = options .get (CONF_WEB_SEARCH_TIMEZONE , "" ),
336- )
337- if tools is None :
338- tools = []
339- tools .append (web_search )
340-
341- model = options .get (CONF_CHAT_MODEL , RECOMMENDED_CHAT_MODEL )
342- messages = [
343- m
344- for content in chat_log .content
345- for m in _convert_content_to_param (content )
346- ]
347-
348- client = self .entry .runtime_data
349-
350- # To prevent infinite loops, we limit the number of iterations
351- for _iteration in range (MAX_TOOL_ITERATIONS ):
352- model_args = {
353- "model" : model ,
354- "input" : messages ,
355- "max_output_tokens" : options .get (
356- CONF_MAX_TOKENS , RECOMMENDED_MAX_TOKENS
357- ),
358- "top_p" : options .get (CONF_TOP_P , RECOMMENDED_TOP_P ),
359- "temperature" : options .get (CONF_TEMPERATURE , RECOMMENDED_TEMPERATURE ),
360- "user" : chat_log .conversation_id ,
361- "stream" : True ,
362- }
363- if tools :
364- model_args ["tools" ] = tools
365-
366- if model .startswith ("o" ):
367- model_args ["reasoning" ] = {
368- "effort" : options .get (
369- CONF_REASONING_EFFORT , RECOMMENDED_REASONING_EFFORT
370- )
371- }
372- else :
373- model_args ["store" ] = False
374-
375- try :
376- result = await client .responses .create (** model_args )
377- except openai .RateLimitError as err :
378- LOGGER .error ("Rate limited by OpenAI: %s" , err )
379- raise HomeAssistantError ("Rate limited or insufficient funds" ) from err
380- except openai .OpenAIError as err :
381- LOGGER .error ("Error talking to OpenAI: %s" , err )
382- raise HomeAssistantError ("Error talking to OpenAI" ) from err
383-
384- async for content in chat_log .async_add_delta_content_stream (
385- self .entity_id , _transform_stream (chat_log , result , messages )
386- ):
387- if not isinstance (content , conversation .AssistantContent ):
388- messages .extend (_convert_content_to_param (content ))
389-
390- if not chat_log .unresponded_tool_results :
391- break
392-
393102 async def _async_entry_update_listener (
394103 self , hass : HomeAssistant , entry : ConfigEntry
395104 ) -> None :
0 commit comments