Skip to content

Commit e780e3d

Browse files
balloobclaudeCopilot
authored
Add chat log subscription endpoint (home-assistant#155287)
Co-authored-by: Claude <[email protected]> Co-authored-by: Copilot <[email protected]>
1 parent 4ed2efa commit e780e3d

File tree

15 files changed

+874
-9
lines changed

15 files changed

+874
-9
lines changed

homeassistant/components/conversation/chat_log.py

Lines changed: 171 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
from contextlib import contextmanager
88
from contextvars import ContextVar
99
from dataclasses import asdict, dataclass, field, replace
10+
from datetime import datetime
1011
import logging
1112
from pathlib import Path
1213
from typing import Any, Literal, TypedDict, cast
@@ -16,21 +17,59 @@
1617
from homeassistant.core import HomeAssistant, callback
1718
from homeassistant.exceptions import HomeAssistantError, TemplateError
1819
from homeassistant.helpers import chat_session, frame, intent, llm, template
20+
from homeassistant.util.dt import utcnow
1921
from homeassistant.util.hass_dict import HassKey
2022
from homeassistant.util.json import JsonObjectType
2123

2224
from . import trace
25+
from .const import ChatLogEventType
2326
from .models import ConversationInput, ConversationResult
2427

2528
DATA_CHAT_LOGS: HassKey[dict[str, ChatLog]] = HassKey("conversation_chat_logs")
26-
29+
DATA_SUBSCRIPTIONS: HassKey[
30+
list[Callable[[str, ChatLogEventType, dict[str, Any]], None]]
31+
] = HassKey("conversation_chat_log_subscriptions")
2732
LOGGER = logging.getLogger(__name__)
2833

2934
current_chat_log: ContextVar[ChatLog | None] = ContextVar(
3035
"current_chat_log", default=None
3136
)
3237

3338

39+
@callback
40+
def async_subscribe_chat_logs(
41+
hass: HomeAssistant,
42+
callback_func: Callable[[str, ChatLogEventType, dict[str, Any]], None],
43+
) -> Callable[[], None]:
44+
"""Subscribe to all chat logs."""
45+
subscriptions = hass.data.get(DATA_SUBSCRIPTIONS)
46+
if subscriptions is None:
47+
subscriptions = []
48+
hass.data[DATA_SUBSCRIPTIONS] = subscriptions
49+
50+
subscriptions.append(callback_func)
51+
52+
@callback
53+
def unsubscribe() -> None:
54+
"""Unsubscribe from chat logs."""
55+
subscriptions.remove(callback_func)
56+
57+
return unsubscribe
58+
59+
60+
@callback
61+
def _async_notify_subscribers(
62+
hass: HomeAssistant,
63+
conversation_id: str,
64+
event_type: ChatLogEventType,
65+
data: dict[str, Any],
66+
) -> None:
67+
"""Notify subscribers of a chat log event."""
68+
if subscriptions := hass.data.get(DATA_SUBSCRIPTIONS):
69+
for callback_func in subscriptions:
70+
callback_func(conversation_id, event_type, data)
71+
72+
3473
@contextmanager
3574
def async_get_chat_log(
3675
hass: HomeAssistant,
@@ -63,6 +102,8 @@ def async_get_chat_log(
63102
all_chat_logs = {}
64103
hass.data[DATA_CHAT_LOGS] = all_chat_logs
65104

105+
is_new_log = session.conversation_id not in all_chat_logs
106+
66107
if chat_log := all_chat_logs.get(session.conversation_id):
67108
chat_log = replace(chat_log, content=chat_log.content.copy())
68109
else:
@@ -71,6 +112,15 @@ def async_get_chat_log(
71112
if chat_log_delta_listener:
72113
chat_log.delta_listener = chat_log_delta_listener
73114

115+
# Fire CREATED event for new chat logs before any content is added
116+
if is_new_log:
117+
_async_notify_subscribers(
118+
hass,
119+
session.conversation_id,
120+
ChatLogEventType.CREATED,
121+
{"chat_log": chat_log.as_dict()},
122+
)
123+
74124
if user_input is not None:
75125
chat_log.async_add_user_content(UserContent(content=user_input.text))
76126

@@ -84,14 +134,28 @@ def async_get_chat_log(
84134
LOGGER.debug(
85135
"Chat Log opened but no assistant message was added, ignoring update"
86136
)
137+
# If this was a new log but nothing was added, fire DELETED to clean up
138+
if is_new_log:
139+
_async_notify_subscribers(
140+
hass,
141+
session.conversation_id,
142+
ChatLogEventType.DELETED,
143+
{},
144+
)
87145
return
88146

89-
if session.conversation_id not in all_chat_logs:
147+
if is_new_log:
90148

91149
@callback
92150
def do_cleanup() -> None:
93151
"""Handle cleanup."""
94152
all_chat_logs.pop(session.conversation_id)
153+
_async_notify_subscribers(
154+
hass,
155+
session.conversation_id,
156+
ChatLogEventType.DELETED,
157+
{},
158+
)
95159

96160
session.async_on_cleanup(do_cleanup)
97161

@@ -100,6 +164,16 @@ def do_cleanup() -> None:
100164

101165
all_chat_logs[session.conversation_id] = chat_log
102166

167+
# For new logs, CREATED was already fired before content was added
168+
# For existing logs, fire UPDATED
169+
if not is_new_log:
170+
_async_notify_subscribers(
171+
hass,
172+
session.conversation_id,
173+
ChatLogEventType.UPDATED,
174+
{"chat_log": chat_log.as_dict()},
175+
)
176+
103177

104178
class ConverseError(HomeAssistantError):
105179
"""Error during initialization of conversation.
@@ -129,6 +203,15 @@ class SystemContent:
129203

130204
role: Literal["system"] = field(init=False, default="system")
131205
content: str
206+
created: datetime = field(init=False, default_factory=utcnow)
207+
208+
def as_dict(self) -> dict[str, Any]:
209+
"""Return a dictionary representation of the content."""
210+
return {
211+
"role": self.role,
212+
"content": self.content,
213+
"created": self.created,
214+
}
132215

133216

134217
@dataclass(frozen=True)
@@ -138,6 +221,20 @@ class UserContent:
138221
role: Literal["user"] = field(init=False, default="user")
139222
content: str
140223
attachments: list[Attachment] | None = field(default=None)
224+
created: datetime = field(init=False, default_factory=utcnow)
225+
226+
def as_dict(self) -> dict[str, Any]:
227+
"""Return a dictionary representation of the content."""
228+
result: dict[str, Any] = {
229+
"role": self.role,
230+
"content": self.content,
231+
"created": self.created,
232+
}
233+
if self.attachments:
234+
result["attachments"] = [
235+
attachment.as_dict() for attachment in self.attachments
236+
]
237+
return result
141238

142239

143240
@dataclass(frozen=True)
@@ -153,6 +250,14 @@ class Attachment:
153250
path: Path
154251
"""Path to the attachment on disk."""
155252

253+
def as_dict(self) -> dict[str, Any]:
254+
"""Return a dictionary representation of the attachment."""
255+
return {
256+
"media_content_id": self.media_content_id,
257+
"mime_type": self.mime_type,
258+
"path": str(self.path),
259+
}
260+
156261

157262
@dataclass(frozen=True)
158263
class AssistantContent:
@@ -164,6 +269,22 @@ class AssistantContent:
164269
thinking_content: str | None = None
165270
tool_calls: list[llm.ToolInput] | None = None
166271
native: Any = None
272+
created: datetime = field(init=False, default_factory=utcnow)
273+
274+
def as_dict(self) -> dict[str, Any]:
275+
"""Return a dictionary representation of the content."""
276+
result: dict[str, Any] = {
277+
"role": self.role,
278+
"agent_id": self.agent_id,
279+
"created": self.created,
280+
}
281+
if self.content:
282+
result["content"] = self.content
283+
if self.thinking_content:
284+
result["thinking_content"] = self.thinking_content
285+
if self.tool_calls:
286+
result["tool_calls"] = self.tool_calls
287+
return result
167288

168289

169290
@dataclass(frozen=True)
@@ -175,6 +296,18 @@ class ToolResultContent:
175296
tool_call_id: str
176297
tool_name: str
177298
tool_result: JsonObjectType
299+
created: datetime = field(init=False, default_factory=utcnow)
300+
301+
def as_dict(self) -> dict[str, Any]:
302+
"""Return a dictionary representation of the content."""
303+
return {
304+
"role": self.role,
305+
"agent_id": self.agent_id,
306+
"tool_call_id": self.tool_call_id,
307+
"tool_name": self.tool_name,
308+
"tool_result": self.tool_result,
309+
"created": self.created,
310+
}
178311

179312

180313
type Content = SystemContent | UserContent | AssistantContent | ToolResultContent
@@ -210,6 +343,16 @@ class ChatLog:
210343
llm_api: llm.APIInstance | None = None
211344
delta_listener: Callable[[ChatLog, dict], None] | None = None
212345
llm_input_provided_index = 0
346+
created: datetime = field(init=False, default_factory=utcnow)
347+
348+
def as_dict(self) -> dict[str, Any]:
349+
"""Return a dictionary representation of the chat log."""
350+
return {
351+
"conversation_id": self.conversation_id,
352+
"continue_conversation": self.continue_conversation,
353+
"content": [c.as_dict() for c in self.content],
354+
"created": self.created,
355+
}
213356

214357
@property
215358
def continue_conversation(self) -> bool:
@@ -241,6 +384,12 @@ def async_add_user_content(self, content: UserContent) -> None:
241384
"""Add user content to the log."""
242385
LOGGER.debug("Adding user content: %s", content)
243386
self.content.append(content)
387+
_async_notify_subscribers(
388+
self.hass,
389+
self.conversation_id,
390+
ChatLogEventType.CONTENT_ADDED,
391+
{"content": content.as_dict()},
392+
)
244393

245394
@callback
246395
def async_add_assistant_content_without_tools(
@@ -259,6 +408,12 @@ def async_add_assistant_content_without_tools(
259408
):
260409
raise ValueError("Non-external tool calls not allowed")
261410
self.content.append(content)
411+
_async_notify_subscribers(
412+
self.hass,
413+
self.conversation_id,
414+
ChatLogEventType.CONTENT_ADDED,
415+
{"content": content.as_dict()},
416+
)
262417

263418
async def async_add_assistant_content(
264419
self,
@@ -317,6 +472,14 @@ async def async_add_assistant_content(
317472
tool_result=tool_result,
318473
)
319474
self.content.append(response_content)
475+
_async_notify_subscribers(
476+
self.hass,
477+
self.conversation_id,
478+
ChatLogEventType.CONTENT_ADDED,
479+
{
480+
"content": response_content.as_dict(),
481+
},
482+
)
320483
yield response_content
321484

322485
async def async_add_delta_content_stream(
@@ -593,6 +756,12 @@ async def async_provide_llm_data(
593756
self.llm_api = llm_api
594757
self.extra_system_prompt = extra_system_prompt
595758
self.content[0] = SystemContent(content=prompt)
759+
_async_notify_subscribers(
760+
self.hass,
761+
self.conversation_id,
762+
ChatLogEventType.UPDATED,
763+
{"chat_log": self.as_dict()},
764+
)
596765

597766
LOGGER.debug("Prompt: %s", self.content)
598767
LOGGER.debug("Tools: %s", self.llm_api.tools if self.llm_api else None)

homeassistant/components/conversation/const.py

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22

33
from __future__ import annotations
44

5-
from enum import IntFlag
5+
from enum import IntFlag, StrEnum
66
from typing import TYPE_CHECKING
77

88
from homeassistant.util.hass_dict import HassKey
@@ -34,3 +34,13 @@ class ConversationEntityFeature(IntFlag):
3434

3535
METADATA_CUSTOM_SENTENCE = "hass_custom_sentence"
3636
METADATA_CUSTOM_FILE = "hass_custom_file"
37+
38+
39+
class ChatLogEventType(StrEnum):
40+
"""Chat log event type."""
41+
42+
INITIAL_STATE = "initial_state"
43+
CREATED = "created"
44+
UPDATED = "updated"
45+
DELETED = "deleted"
46+
CONTENT_ADDED = "content_added"

0 commit comments

Comments
 (0)