11"""Conversation support for OpenRouter."""
22
3- from typing import Literal
3+ from collections .abc import AsyncGenerator , Callable
4+ import json
5+ from typing import Any , Literal
46
57import openai
8+ from openai import NOT_GIVEN
69from openai .types .chat import (
710 ChatCompletionAssistantMessageParam ,
11+ ChatCompletionMessage ,
812 ChatCompletionMessageParam ,
13+ ChatCompletionMessageToolCallParam ,
914 ChatCompletionSystemMessageParam ,
15+ ChatCompletionToolMessageParam ,
16+ ChatCompletionToolParam ,
1017 ChatCompletionUserMessageParam ,
1118)
19+ from openai .types .chat .chat_completion_message_tool_call_param import Function
20+ from openai .types .shared_params import FunctionDefinition
21+ from voluptuous_openapi import convert
1222
1323from homeassistant .components import conversation
1424from homeassistant .config_entries import ConfigSubentry
1525from homeassistant .const import CONF_LLM_HASS_API , CONF_MODEL , MATCH_ALL
1626from homeassistant .core import HomeAssistant
1727from homeassistant .exceptions import HomeAssistantError
28+ from homeassistant .helpers import llm
1829from homeassistant .helpers .device_registry import DeviceEntryType , DeviceInfo
1930from homeassistant .helpers .entity_platform import AddConfigEntryEntitiesCallback
2031
2132from . import OpenRouterConfigEntry
22- from .const import DOMAIN , LOGGER
33+ from .const import CONF_PROMPT , DOMAIN , LOGGER
34+
35+ # Max number of back and forth with the LLM to generate a response
36+ MAX_TOOL_ITERATIONS = 10
2337
2438
2539async def async_setup_entry (
@@ -35,13 +49,31 @@ async def async_setup_entry(
3549 )
3650
3751
52+ def _format_tool (
53+ tool : llm .Tool ,
54+ custom_serializer : Callable [[Any ], Any ] | None ,
55+ ) -> ChatCompletionToolParam :
56+ """Format tool specification."""
57+ tool_spec = FunctionDefinition (
58+ name = tool .name ,
59+ parameters = convert (tool .parameters , custom_serializer = custom_serializer ),
60+ )
61+ if tool .description :
62+ tool_spec ["description" ] = tool .description
63+ return ChatCompletionToolParam (type = "function" , function = tool_spec )
64+
65+
3866def _convert_content_to_chat_message (
3967 content : conversation .Content ,
4068) -> ChatCompletionMessageParam | None :
4169 """Convert any native chat message for this agent to the native format."""
4270 LOGGER .debug ("_convert_content_to_chat_message=%s" , content )
4371 if isinstance (content , conversation .ToolResultContent ):
44- return None
72+ return ChatCompletionToolMessageParam (
73+ role = "tool" ,
74+ tool_call_id = content .tool_call_id ,
75+ content = json .dumps (content .tool_result ),
76+ )
4577
4678 role : Literal ["user" , "assistant" , "system" ] = content .role
4779 if role == "system" and content .content :
@@ -51,13 +83,55 @@ def _convert_content_to_chat_message(
5183 return ChatCompletionUserMessageParam (role = "user" , content = content .content )
5284
5385 if role == "assistant" :
54- return ChatCompletionAssistantMessageParam (
55- role = "assistant" , content = content .content
86+ param = ChatCompletionAssistantMessageParam (
87+ role = "assistant" ,
88+ content = content .content ,
5689 )
90+ if isinstance (content , conversation .AssistantContent ) and content .tool_calls :
91+ param ["tool_calls" ] = [
92+ ChatCompletionMessageToolCallParam (
93+ type = "function" ,
94+ id = tool_call .id ,
95+ function = Function (
96+ arguments = json .dumps (tool_call .tool_args ),
97+ name = tool_call .tool_name ,
98+ ),
99+ )
100+ for tool_call in content .tool_calls
101+ ]
102+ return param
57103 LOGGER .warning ("Could not convert message to Completions API: %s" , content )
58104 return None
59105
60106
107+ def _decode_tool_arguments (arguments : str ) -> Any :
108+ """Decode tool call arguments."""
109+ try :
110+ return json .loads (arguments )
111+ except json .JSONDecodeError as err :
112+ raise HomeAssistantError (f"Unexpected tool argument response: { err } " ) from err
113+
114+
115+ async def _transform_response (
116+ message : ChatCompletionMessage ,
117+ ) -> AsyncGenerator [conversation .AssistantContentDeltaDict ]:
118+ """Transform the OpenRouter message to a ChatLog format."""
119+ data : conversation .AssistantContentDeltaDict = {
120+ "role" : message .role ,
121+ "content" : message .content ,
122+ }
123+ if message .tool_calls :
124+ data ["tool_calls" ] = [
125+ llm .ToolInput (
126+ id = tool_call .id ,
127+ tool_name = tool_call .function .name ,
128+ tool_args = _decode_tool_arguments (tool_call .function .arguments ),
129+ )
130+ for tool_call in message .tool_calls
131+ ]
132+ yield data
133+
134+
61135class OpenRouterConversationEntity (conversation .ConversationEntity ):
62136 """OpenRouter conversation agent."""
63137
@@ -75,6 +149,10 @@ def __init__(self, entry: OpenRouterConfigEntry, subentry: ConfigSubentry) -> No
75149 name = subentry .title ,
76150 entry_type = DeviceEntryType .SERVICE ,
77151 )
152+ if self .subentry .data .get (CONF_LLM_HASS_API ):
153+ self ._attr_supported_features = (
154+ conversation .ConversationEntityFeature .CONTROL
155+ )
78156
79157 @property
80158 def supported_languages (self ) -> list [str ] | Literal ["*" ]:
@@ -93,12 +171,19 @@ async def _async_handle_message(
93171 await chat_log .async_provide_llm_data (
94172 user_input .as_llm_context (DOMAIN ),
95173 options .get (CONF_LLM_HASS_API ),
96- None ,
174+ options . get ( CONF_PROMPT ) ,
97175 user_input .extra_system_prompt ,
98176 )
99177 except conversation .ConverseError as err :
100178 return err .as_conversation_result ()
101179
180+ tools : list [ChatCompletionToolParam ] | None = None
181+ if chat_log .llm_api :
182+ tools = [
183+ _format_tool (tool , chat_log .llm_api .custom_serializer )
184+ for tool in chat_log .llm_api .tools
185+ ]
186+
102187 messages = [
103188 m
104189 for content in chat_log .content
@@ -107,27 +192,34 @@ async def _async_handle_message(
107192
108193 client = self .entry .runtime_data
109194
110- try :
111- result = await client .chat .completions .create (
112- model = self .model ,
113- messages = messages ,
114- user = chat_log .conversation_id ,
115- extra_headers = {
116- "X-Title" : "Home Assistant" ,
117- "HTTP-Referer" : "https://www.home-assistant.io/integrations/open_router" ,
118- },
195+ for _iteration in range (MAX_TOOL_ITERATIONS ):
196+ try :
197+ result = await client .chat .completions .create (
198+ model = self .model ,
199+ messages = messages ,
200+ tools = tools or NOT_GIVEN ,
201+ user = chat_log .conversation_id ,
202+ extra_headers = {
203+ "X-Title" : "Home Assistant" ,
204+ "HTTP-Referer" : "https://www.home-assistant.io/integrations/open_router" ,
205+ },
206+ )
207+ except openai .OpenAIError as err :
208+ LOGGER .error ("Error talking to API: %s" , err )
209+ raise HomeAssistantError ("Error talking to API" ) from err
210+
211+ result_message = result .choices [0 ].message
212+
213+ messages .extend (
214+ [
215+ msg
216+ async for content in chat_log .async_add_delta_content_stream (
217+ user_input .agent_id , _transform_response (result_message )
218+ )
219+ if (msg := _convert_content_to_chat_message (content ))
220+ ]
119221 )
120- except openai .OpenAIError as err :
121- LOGGER .error ("Error talking to API: %s" , err )
122- raise HomeAssistantError ("Error talking to API" ) from err
123-
124- result_message = result .choices [0 ].message
125-
126- chat_log .async_add_assistant_content_without_tools (
127- conversation .AssistantContent (
128- agent_id = user_input .agent_id ,
129- content = result_message .content ,
130- )
131- )
222+ if not chat_log .unresponded_tool_results :
223+ break
132224
133225 return conversation .async_get_result_from_chat_log (user_input , chat_log )
0 commit comments