Skip to content

Commit 7c2741b

Browse files
Add home assistant cloud conversation (home-assistant#157090)
1 parent d6fb268 commit 7c2741b

File tree

6 files changed

+392
-74
lines changed

6 files changed

+392
-74
lines changed

homeassistant/components/cloud/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -80,6 +80,7 @@
8080
PLATFORMS = [
8181
Platform.AI_TASK,
8282
Platform.BINARY_SENSOR,
83+
Platform.CONVERSATION,
8384
Platform.STT,
8485
Platform.TTS,
8586
]

homeassistant/components/cloud/const.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -92,6 +92,7 @@
9292
STT_ENTITY_UNIQUE_ID = "cloud-speech-to-text"
9393
TTS_ENTITY_UNIQUE_ID = "cloud-text-to-speech"
9494
AI_TASK_ENTITY_UNIQUE_ID = "cloud-ai-task"
95+
CONVERSATION_ENTITY_UNIQUE_ID = "cloud-conversation-agent"
9596

9697
LOGIN_MFA_TIMEOUT = 60
9798

Lines changed: 75 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,75 @@
1+
"""Conversation support for Home Assistant Cloud."""
2+
3+
from __future__ import annotations
4+
5+
from typing import Literal
6+
7+
from hass_nabucasa.llm import LLMError
8+
9+
from homeassistant.components import conversation
10+
from homeassistant.config_entries import ConfigEntry
11+
from homeassistant.const import MATCH_ALL
12+
from homeassistant.core import HomeAssistant
13+
from homeassistant.helpers import llm
14+
from homeassistant.helpers.entity_platform import AddConfigEntryEntitiesCallback
15+
16+
from .const import CONVERSATION_ENTITY_UNIQUE_ID, DATA_CLOUD, DOMAIN
17+
from .entity import BaseCloudLLMEntity
18+
19+
20+
async def async_setup_entry(
21+
hass: HomeAssistant,
22+
config_entry: ConfigEntry,
23+
async_add_entities: AddConfigEntryEntitiesCallback,
24+
) -> None:
25+
"""Set up the Home Assistant Cloud conversation entity."""
26+
cloud = hass.data[DATA_CLOUD]
27+
try:
28+
await cloud.llm.async_ensure_token()
29+
except LLMError:
30+
return
31+
32+
async_add_entities([CloudConversationEntity(cloud, config_entry)])
33+
34+
35+
class CloudConversationEntity(
36+
conversation.ConversationEntity,
37+
BaseCloudLLMEntity,
38+
):
39+
"""Home Assistant Cloud conversation agent."""
40+
41+
_attr_has_entity_name = True
42+
_attr_name = "Home Assistant Cloud"
43+
_attr_translation_key = "cloud_conversation"
44+
_attr_unique_id = CONVERSATION_ENTITY_UNIQUE_ID
45+
_attr_supported_features = conversation.ConversationEntityFeature.CONTROL
46+
47+
@property
48+
def available(self) -> bool:
49+
"""Return if the entity is available."""
50+
return self._cloud.is_logged_in and self._cloud.valid_subscription
51+
52+
@property
53+
def supported_languages(self) -> list[str] | Literal["*"]:
54+
"""Return a list of supported languages."""
55+
return MATCH_ALL
56+
57+
async def _async_handle_message(
58+
self,
59+
user_input: conversation.ConversationInput,
60+
chat_log: conversation.ChatLog,
61+
) -> conversation.ConversationResult:
62+
"""Process a user input."""
63+
try:
64+
await chat_log.async_provide_llm_data(
65+
user_input.as_llm_context(DOMAIN),
66+
llm.LLM_API_ASSIST,
67+
None,
68+
user_input.extra_system_prompt,
69+
)
70+
except conversation.ConverseError as err:
71+
return err.as_conversation_result()
72+
73+
await self._async_handle_chat_log("conversation", chat_log)
74+
75+
return conversation.async_get_result_from_chat_log(user_input, chat_log)

homeassistant/components/cloud/entity.py

Lines changed: 120 additions & 48 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
"""Helpers for cloud LLM chat handling."""
22

33
import base64
4-
from collections.abc import AsyncGenerator, Callable
4+
from collections.abc import AsyncGenerator, Callable, Iterable
55
from enum import Enum
66
import json
77
import logging
@@ -16,13 +16,22 @@
1616
LLMResponseError,
1717
LLMServiceError,
1818
)
19-
from litellm import ResponseFunctionToolCall, ResponsesAPIStreamEvents
19+
from litellm import (
20+
ResponseFunctionToolCall,
21+
ResponseInputParam,
22+
ResponsesAPIStreamEvents,
23+
)
2024
from openai.types.responses import (
2125
FunctionToolParam,
26+
ResponseInputItemParam,
2227
ResponseReasoningItem,
2328
ToolParam,
2429
WebSearchToolParam,
2530
)
31+
from openai.types.responses.response_input_param import (
32+
ImageGenerationCall as ImageGenerationCallParam,
33+
)
34+
from openai.types.responses.response_output_item import ImageGenerationCall
2635
import voluptuous as vol
2736
from voluptuous_openapi import convert
2837

@@ -50,34 +59,97 @@ class ResponseItemType(str, Enum):
5059
IMAGE = "image"
5160

5261

53-
def _convert_content_to_chat_message(
54-
content: conversation.Content,
55-
) -> dict[str, Any] | None:
56-
"""Convert ChatLog content to a responses message."""
57-
if content.role not in ("user", "system", "tool_result", "assistant"):
58-
return None
62+
def _convert_content_to_param(
63+
chat_content: Iterable[conversation.Content],
64+
) -> ResponseInputParam:
65+
"""Convert any native chat message for this agent to the native format."""
66+
messages: ResponseInputParam = []
67+
reasoning_summary: list[str] = []
68+
web_search_calls: dict[str, dict[str, Any]] = {}
5969

60-
text_content = cast(
61-
conversation.SystemContent
62-
| conversation.UserContent
63-
| conversation.AssistantContent,
64-
content,
65-
)
70+
for content in chat_content:
71+
if isinstance(content, conversation.ToolResultContent):
72+
if (
73+
content.tool_name == "web_search_call"
74+
and content.tool_call_id in web_search_calls
75+
):
76+
web_search_call = web_search_calls.pop(content.tool_call_id)
77+
web_search_call["status"] = content.tool_result.get(
78+
"status", "completed"
79+
)
80+
messages.append(cast("ResponseInputItemParam", web_search_call))
81+
else:
82+
messages.append(
83+
{
84+
"type": "function_call_output",
85+
"call_id": content.tool_call_id,
86+
"output": json.dumps(content.tool_result),
87+
}
88+
)
89+
continue
90+
91+
if content.content:
92+
role: Literal["user", "assistant", "system", "developer"] = content.role
93+
if role == "system":
94+
role = "developer"
95+
messages.append(
96+
{"type": "message", "role": role, "content": content.content}
97+
)
98+
99+
if isinstance(content, conversation.AssistantContent):
100+
if content.tool_calls:
101+
for tool_call in content.tool_calls:
102+
if (
103+
tool_call.external
104+
and tool_call.tool_name == "web_search_call"
105+
and "action" in tool_call.tool_args
106+
):
107+
web_search_calls[tool_call.id] = {
108+
"type": "web_search_call",
109+
"id": tool_call.id,
110+
"action": tool_call.tool_args["action"],
111+
"status": "completed",
112+
}
113+
else:
114+
messages.append(
115+
{
116+
"type": "function_call",
117+
"name": tool_call.tool_name,
118+
"arguments": json.dumps(tool_call.tool_args),
119+
"call_id": tool_call.id,
120+
}
121+
)
66122

67-
if not text_content.content:
68-
return None
123+
if content.thinking_content:
124+
reasoning_summary.append(content.thinking_content)
69125

70-
content_type = "output_text" if text_content.role == "assistant" else "input_text"
126+
if isinstance(content.native, ResponseReasoningItem):
127+
messages.append(
128+
{
129+
"type": "reasoning",
130+
"id": content.native.id,
131+
"summary": (
132+
[
133+
{
134+
"type": "summary_text",
135+
"text": summary,
136+
}
137+
for summary in reasoning_summary
138+
]
139+
if content.thinking_content
140+
else []
141+
),
142+
"encrypted_content": content.native.encrypted_content,
143+
}
144+
)
145+
reasoning_summary = []
71146

72-
return {
73-
"role": text_content.role,
74-
"content": [
75-
{
76-
"type": content_type,
77-
"text": text_content.content,
78-
}
79-
],
80-
}
147+
elif isinstance(content.native, ImageGenerationCall):
148+
messages.append(
149+
cast(ImageGenerationCallParam, content.native.to_dict())
150+
)
151+
152+
return messages
81153

82154

83155
def _format_tool(
@@ -381,25 +453,16 @@ def __init__(self, cloud: Cloud[CloudClient], config_entry: ConfigEntry) -> None
381453
async def _prepare_chat_for_generation(
382454
self,
383455
chat_log: conversation.ChatLog,
456+
messages: ResponseInputParam,
384457
response_format: dict[str, Any] | None = None,
385458
) -> dict[str, Any]:
386459
"""Prepare kwargs for Cloud LLM from the chat log."""
387460

388-
messages = [
389-
message
390-
for content in chat_log.content
391-
if (message := _convert_content_to_chat_message(content))
392-
]
393-
394-
if not messages or messages[-1]["role"] != "user":
395-
raise HomeAssistantError("No user prompt found")
396-
397-
last_content = chat_log.content[-1]
461+
last_content: Any = chat_log.content[-1]
398462
if last_content.role == "user" and last_content.attachments:
399463
files = await self._async_prepare_files_for_prompt(last_content.attachments)
400-
user_message = messages[-1]
401-
current_content = user_message.get("content", [])
402-
user_message["content"] = [*(current_content or []), *files]
464+
current_content = last_content.content
465+
last_content = [*(current_content or []), *files]
403466

404467
tools: list[ToolParam] = []
405468
tool_choice: str | None = None
@@ -503,8 +566,11 @@ async def _async_handle_chat_log(
503566
},
504567
}
505568

569+
messages = _convert_content_to_param(chat_log.content)
570+
506571
response_kwargs = await self._prepare_chat_for_generation(
507572
chat_log,
573+
messages,
508574
response_format,
509575
)
510576

@@ -518,15 +584,21 @@ async def _async_handle_chat_log(
518584
**response_kwargs,
519585
)
520586

521-
async for _ in chat_log.async_add_delta_content_stream(
522-
agent_id=self.entity_id,
523-
stream=_transform_stream(
524-
chat_log,
525-
raw_stream,
526-
True,
527-
),
528-
):
529-
pass
587+
messages.extend(
588+
_convert_content_to_param(
589+
[
590+
content
591+
async for content in chat_log.async_add_delta_content_stream(
592+
self.entity_id,
593+
_transform_stream(
594+
chat_log,
595+
raw_stream,
596+
True,
597+
),
598+
)
599+
]
600+
)
601+
)
530602

531603
except LLMAuthenticationError as err:
532604
raise ConfigEntryAuthFailed("Cloud LLM authentication failed") from err

0 commit comments

Comments
 (0)