Skip to content

Commit 50688bb

Browse files
authored
Add support for calling tools in Open Router (home-assistant#148881)
1 parent 073ea81 commit 50688bb

File tree

8 files changed

+499
-53
lines changed

8 files changed

+499
-53
lines changed

homeassistant/components/open_router/config_flow.py

Lines changed: 27 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -16,18 +16,20 @@
1616
ConfigSubentryFlow,
1717
SubentryFlowResult,
1818
)
19-
from homeassistant.const import CONF_API_KEY, CONF_MODEL
19+
from homeassistant.const import CONF_API_KEY, CONF_LLM_HASS_API, CONF_MODEL
2020
from homeassistant.core import callback
21+
from homeassistant.helpers import llm
2122
from homeassistant.helpers.aiohttp_client import async_get_clientsession
2223
from homeassistant.helpers.httpx_client import get_async_client
2324
from homeassistant.helpers.selector import (
2425
SelectOptionDict,
2526
SelectSelector,
2627
SelectSelectorConfig,
2728
SelectSelectorMode,
29+
TemplateSelector,
2830
)
2931

30-
from .const import DOMAIN
32+
from .const import CONF_PROMPT, DOMAIN, RECOMMENDED_CONVERSATION_OPTIONS
3133

3234
_LOGGER = logging.getLogger(__name__)
3335

@@ -90,6 +92,8 @@ async def async_step_user(
9092
) -> SubentryFlowResult:
9193
"""User flow to create a sensor subentry."""
9294
if user_input is not None:
95+
if not user_input.get(CONF_LLM_HASS_API):
96+
user_input.pop(CONF_LLM_HASS_API, None)
9397
return self.async_create_entry(
9498
title=self.options[user_input[CONF_MODEL]], data=user_input
9599
)
@@ -99,11 +103,17 @@ async def async_step_user(
99103
api_key=entry.data[CONF_API_KEY],
100104
http_client=get_async_client(self.hass),
101105
)
106+
hass_apis: list[SelectOptionDict] = [
107+
SelectOptionDict(
108+
label=api.name,
109+
value=api.id,
110+
)
111+
for api in llm.async_get_apis(self.hass)
112+
]
102113
options = []
103114
async for model in client.with_options(timeout=10.0).models.list():
104115
options.append(SelectOptionDict(value=model.id, label=model.name)) # type: ignore[attr-defined]
105116
self.options[model.id] = model.name # type: ignore[attr-defined]
106-
107117
return self.async_show_form(
108118
step_id="user",
109119
data_schema=vol.Schema(
@@ -113,6 +123,20 @@ async def async_step_user(
113123
options=options, mode=SelectSelectorMode.DROPDOWN, sort=True
114124
),
115125
),
126+
vol.Optional(
127+
CONF_PROMPT,
128+
description={
129+
"suggested_value": RECOMMENDED_CONVERSATION_OPTIONS[
130+
CONF_PROMPT
131+
]
132+
},
133+
): TemplateSelector(),
134+
vol.Optional(
135+
CONF_LLM_HASS_API,
136+
default=RECOMMENDED_CONVERSATION_OPTIONS[CONF_LLM_HASS_API],
137+
): SelectSelector(
138+
SelectSelectorConfig(options=hass_apis, multiple=True)
139+
),
116140
}
117141
),
118142
)

homeassistant/components/open_router/const.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,5 +2,17 @@
22

33
import logging
44

5+
from homeassistant.const import CONF_LLM_HASS_API
6+
from homeassistant.helpers import llm
7+
58
DOMAIN = "open_router"
69
LOGGER = logging.getLogger(__package__)
10+
11+
CONF_PROMPT = "prompt"
12+
CONF_RECOMMENDED = "recommended"
13+
14+
RECOMMENDED_CONVERSATION_OPTIONS = {
15+
CONF_RECOMMENDED: True,
16+
CONF_LLM_HASS_API: [llm.LLM_API_ASSIST],
17+
CONF_PROMPT: llm.DEFAULT_INSTRUCTIONS_PROMPT,
18+
}

homeassistant/components/open_router/conversation.py

Lines changed: 119 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -1,25 +1,39 @@
11
"""Conversation support for OpenRouter."""
22

3-
from typing import Literal
3+
from collections.abc import AsyncGenerator, Callable
4+
import json
5+
from typing import Any, Literal
46

57
import openai
8+
from openai import NOT_GIVEN
69
from openai.types.chat import (
710
ChatCompletionAssistantMessageParam,
11+
ChatCompletionMessage,
812
ChatCompletionMessageParam,
13+
ChatCompletionMessageToolCallParam,
914
ChatCompletionSystemMessageParam,
15+
ChatCompletionToolMessageParam,
16+
ChatCompletionToolParam,
1017
ChatCompletionUserMessageParam,
1118
)
19+
from openai.types.chat.chat_completion_message_tool_call_param import Function
20+
from openai.types.shared_params import FunctionDefinition
21+
from voluptuous_openapi import convert
1222

1323
from homeassistant.components import conversation
1424
from homeassistant.config_entries import ConfigSubentry
1525
from homeassistant.const import CONF_LLM_HASS_API, CONF_MODEL, MATCH_ALL
1626
from homeassistant.core import HomeAssistant
1727
from homeassistant.exceptions import HomeAssistantError
28+
from homeassistant.helpers import llm
1829
from homeassistant.helpers.device_registry import DeviceEntryType, DeviceInfo
1930
from homeassistant.helpers.entity_platform import AddConfigEntryEntitiesCallback
2031

2132
from . import OpenRouterConfigEntry
22-
from .const import DOMAIN, LOGGER
33+
from .const import CONF_PROMPT, DOMAIN, LOGGER
34+
35+
# Max number of back and forth with the LLM to generate a response
36+
MAX_TOOL_ITERATIONS = 10
2337

2438

2539
async def async_setup_entry(
@@ -35,13 +49,31 @@ async def async_setup_entry(
3549
)
3650

3751

52+
def _format_tool(
53+
tool: llm.Tool,
54+
custom_serializer: Callable[[Any], Any] | None,
55+
) -> ChatCompletionToolParam:
56+
"""Format tool specification."""
57+
tool_spec = FunctionDefinition(
58+
name=tool.name,
59+
parameters=convert(tool.parameters, custom_serializer=custom_serializer),
60+
)
61+
if tool.description:
62+
tool_spec["description"] = tool.description
63+
return ChatCompletionToolParam(type="function", function=tool_spec)
64+
65+
3866
def _convert_content_to_chat_message(
3967
content: conversation.Content,
4068
) -> ChatCompletionMessageParam | None:
4169
"""Convert any native chat message for this agent to the native format."""
4270
LOGGER.debug("_convert_content_to_chat_message=%s", content)
4371
if isinstance(content, conversation.ToolResultContent):
44-
return None
72+
return ChatCompletionToolMessageParam(
73+
role="tool",
74+
tool_call_id=content.tool_call_id,
75+
content=json.dumps(content.tool_result),
76+
)
4577

4678
role: Literal["user", "assistant", "system"] = content.role
4779
if role == "system" and content.content:
@@ -51,13 +83,55 @@ def _convert_content_to_chat_message(
5183
return ChatCompletionUserMessageParam(role="user", content=content.content)
5284

5385
if role == "assistant":
54-
return ChatCompletionAssistantMessageParam(
55-
role="assistant", content=content.content
86+
param = ChatCompletionAssistantMessageParam(
87+
role="assistant",
88+
content=content.content,
5689
)
90+
if isinstance(content, conversation.AssistantContent) and content.tool_calls:
91+
param["tool_calls"] = [
92+
ChatCompletionMessageToolCallParam(
93+
type="function",
94+
id=tool_call.id,
95+
function=Function(
96+
arguments=json.dumps(tool_call.tool_args),
97+
name=tool_call.tool_name,
98+
),
99+
)
100+
for tool_call in content.tool_calls
101+
]
102+
return param
57103
LOGGER.warning("Could not convert message to Completions API: %s", content)
58104
return None
59105

60106

107+
def _decode_tool_arguments(arguments: str) -> Any:
108+
"""Decode tool call arguments."""
109+
try:
110+
return json.loads(arguments)
111+
except json.JSONDecodeError as err:
112+
raise HomeAssistantError(f"Unexpected tool argument response: {err}") from err
113+
114+
115+
async def _transform_response(
116+
message: ChatCompletionMessage,
117+
) -> AsyncGenerator[conversation.AssistantContentDeltaDict]:
118+
"""Transform the OpenRouter message to a ChatLog format."""
119+
data: conversation.AssistantContentDeltaDict = {
120+
"role": message.role,
121+
"content": message.content,
122+
}
123+
if message.tool_calls:
124+
data["tool_calls"] = [
125+
llm.ToolInput(
126+
id=tool_call.id,
127+
tool_name=tool_call.function.name,
128+
tool_args=_decode_tool_arguments(tool_call.function.arguments),
129+
)
130+
for tool_call in message.tool_calls
131+
]
132+
yield data
133+
134+
61135
class OpenRouterConversationEntity(conversation.ConversationEntity):
62136
"""OpenRouter conversation agent."""
63137

@@ -75,6 +149,10 @@ def __init__(self, entry: OpenRouterConfigEntry, subentry: ConfigSubentry) -> No
75149
name=subentry.title,
76150
entry_type=DeviceEntryType.SERVICE,
77151
)
152+
if self.subentry.data.get(CONF_LLM_HASS_API):
153+
self._attr_supported_features = (
154+
conversation.ConversationEntityFeature.CONTROL
155+
)
78156

79157
@property
80158
def supported_languages(self) -> list[str] | Literal["*"]:
@@ -93,12 +171,19 @@ async def _async_handle_message(
93171
await chat_log.async_provide_llm_data(
94172
user_input.as_llm_context(DOMAIN),
95173
options.get(CONF_LLM_HASS_API),
96-
None,
174+
options.get(CONF_PROMPT),
97175
user_input.extra_system_prompt,
98176
)
99177
except conversation.ConverseError as err:
100178
return err.as_conversation_result()
101179

180+
tools: list[ChatCompletionToolParam] | None = None
181+
if chat_log.llm_api:
182+
tools = [
183+
_format_tool(tool, chat_log.llm_api.custom_serializer)
184+
for tool in chat_log.llm_api.tools
185+
]
186+
102187
messages = [
103188
m
104189
for content in chat_log.content
@@ -107,27 +192,34 @@ async def _async_handle_message(
107192

108193
client = self.entry.runtime_data
109194

110-
try:
111-
result = await client.chat.completions.create(
112-
model=self.model,
113-
messages=messages,
114-
user=chat_log.conversation_id,
115-
extra_headers={
116-
"X-Title": "Home Assistant",
117-
"HTTP-Referer": "https://www.home-assistant.io/integrations/open_router",
118-
},
195+
for _iteration in range(MAX_TOOL_ITERATIONS):
196+
try:
197+
result = await client.chat.completions.create(
198+
model=self.model,
199+
messages=messages,
200+
tools=tools or NOT_GIVEN,
201+
user=chat_log.conversation_id,
202+
extra_headers={
203+
"X-Title": "Home Assistant",
204+
"HTTP-Referer": "https://www.home-assistant.io/integrations/open_router",
205+
},
206+
)
207+
except openai.OpenAIError as err:
208+
LOGGER.error("Error talking to API: %s", err)
209+
raise HomeAssistantError("Error talking to API") from err
210+
211+
result_message = result.choices[0].message
212+
213+
messages.extend(
214+
[
215+
msg
216+
async for content in chat_log.async_add_delta_content_stream(
217+
user_input.agent_id, _transform_response(result_message)
218+
)
219+
if (msg := _convert_content_to_chat_message(content))
220+
]
119221
)
120-
except openai.OpenAIError as err:
121-
LOGGER.error("Error talking to API: %s", err)
122-
raise HomeAssistantError("Error talking to API") from err
123-
124-
result_message = result.choices[0].message
125-
126-
chat_log.async_add_assistant_content_without_tools(
127-
conversation.AssistantContent(
128-
agent_id=user_input.agent_id,
129-
content=result_message.content,
130-
)
131-
)
222+
if not chat_log.unresponded_tool_results:
223+
break
132224

133225
return conversation.async_get_result_from_chat_log(user_input, chat_log)

homeassistant/components/open_router/strings.json

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,13 @@
2424
"user": {
2525
"description": "Configure the new conversation agent",
2626
"data": {
27-
"model": "Model"
27+
"model": "Model",
28+
"prompt": "Instructions",
29+
"llm_hass_api": "[%key:common::config_flow::data::llm_hass_api%]"
30+
},
31+
"data_description": {
32+
"model": "The model to use for the conversation agent",
33+
"prompt": "Instruct how the LLM should respond. This can be a template."
2834
}
2935
}
3036
},

tests/components/open_router/conftest.py

Lines changed: 26 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -2,17 +2,19 @@
22

33
from collections.abc import AsyncGenerator, Generator
44
from dataclasses import dataclass
5+
from typing import Any
56
from unittest.mock import AsyncMock, MagicMock, patch
67

78
from openai.types import CompletionUsage
89
from openai.types.chat import ChatCompletion, ChatCompletionMessage
910
from openai.types.chat.chat_completion import Choice
1011
import pytest
1112

12-
from homeassistant.components.open_router.const import DOMAIN
13+
from homeassistant.components.open_router.const import CONF_PROMPT, DOMAIN
1314
from homeassistant.config_entries import ConfigSubentryData
14-
from homeassistant.const import CONF_API_KEY, CONF_MODEL
15+
from homeassistant.const import CONF_API_KEY, CONF_LLM_HASS_API, CONF_MODEL
1516
from homeassistant.core import HomeAssistant
17+
from homeassistant.helpers import llm
1618
from homeassistant.setup import async_setup_component
1719

1820
from tests.common import MockConfigEntry
@@ -29,7 +31,27 @@ def mock_setup_entry() -> Generator[AsyncMock]:
2931

3032

3133
@pytest.fixture
32-
def mock_config_entry(hass: HomeAssistant) -> MockConfigEntry:
34+
def enable_assist() -> bool:
35+
"""Mock conversation subentry data."""
36+
return False
37+
38+
39+
@pytest.fixture
40+
def conversation_subentry_data(enable_assist: bool) -> dict[str, Any]:
41+
"""Mock conversation subentry data."""
42+
res: dict[str, Any] = {
43+
CONF_MODEL: "gpt-3.5-turbo",
44+
CONF_PROMPT: "You are a helpful assistant.",
45+
}
46+
if enable_assist:
47+
res[CONF_LLM_HASS_API] = [llm.LLM_API_ASSIST]
48+
return res
49+
50+
51+
@pytest.fixture
52+
def mock_config_entry(
53+
hass: HomeAssistant, conversation_subentry_data: dict[str, Any]
54+
) -> MockConfigEntry:
3355
"""Mock a config entry."""
3456
return MockConfigEntry(
3557
title="OpenRouter",
@@ -39,7 +61,7 @@ def mock_config_entry(hass: HomeAssistant) -> MockConfigEntry:
3961
},
4062
subentries_data=[
4163
ConfigSubentryData(
42-
data={CONF_MODEL: "gpt-3.5-turbo"},
64+
data=conversation_subentry_data,
4365
subentry_id="ABCDEF",
4466
subentry_type="conversation",
4567
title="GPT-3.5 Turbo",

0 commit comments

Comments
 (0)