Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions homeassistant/components/anthropic/conversation.py
Original file line number Diff line number Diff line change
Expand Up @@ -366,11 +366,11 @@ async def _async_handle_message(
options = self.entry.options

try:
await chat_log.async_update_llm_data(
DOMAIN,
user_input,
await chat_log.async_provide_llm_data(
user_input.as_llm_context(DOMAIN),
options.get(CONF_LLM_HASS_API),
options.get(CONF_PROMPT),
user_input.extra_system_prompt,
)
except conversation.ConverseError as err:
return err.as_conversation_result()
Expand Down
48 changes: 30 additions & 18 deletions homeassistant/components/conversation/chat_log.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,12 +14,11 @@

from homeassistant.core import HomeAssistant, callback
from homeassistant.exceptions import HomeAssistantError, TemplateError
from homeassistant.helpers import chat_session, intent, llm, template
from homeassistant.helpers import chat_session, frame, intent, llm, template
from homeassistant.util.hass_dict import HassKey
from homeassistant.util.json import JsonObjectType

from . import trace
from .const import DOMAIN
from .models import ConversationInput, ConversationResult

DATA_CHAT_LOGS: HassKey[dict[str, ChatLog]] = HassKey("conversation_chat_logs")
Expand Down Expand Up @@ -359,7 +358,7 @@ async def _async_expand_prompt_template(
self,
llm_context: llm.LLMContext,
prompt: str,
language: str,
language: str | None,
user_name: str | None = None,
) -> str:
try:
Expand All @@ -373,7 +372,7 @@ async def _async_expand_prompt_template(
)
except TemplateError as err:
LOGGER.error("Error rendering prompt: %s", err)
intent_response = intent.IntentResponse(language=language)
intent_response = intent.IntentResponse(language=language or "")
intent_response.async_set_error(
intent.IntentResponseErrorCode.UNKNOWN,
"Sorry, I had a problem with my template",
Expand All @@ -392,14 +391,25 @@ async def async_update_llm_data(
user_llm_prompt: str | None = None,
) -> None:
"""Set the LLM system prompt."""
llm_context = llm.LLMContext(
platform=conversing_domain,
context=user_input.context,
language=user_input.language,
assistant=DOMAIN,
device_id=user_input.device_id,
frame.report_usage(
"ChatLog.async_update_llm_data",
breaks_in_ha_version="2026.1",
)
return await self.async_provide_llm_data(
llm_context=user_input.as_llm_context(conversing_domain),
user_llm_hass_api=user_llm_hass_api,
user_llm_prompt=user_llm_prompt,
user_extra_system_prompt=user_input.extra_system_prompt,
)

async def async_provide_llm_data(
self,
llm_context: llm.LLMContext,
user_llm_hass_api: str | list[str] | None = None,
user_llm_prompt: str | None = None,
user_extra_system_prompt: str | None = None,
) -> None:
"""Set the LLM system prompt."""
llm_api: llm.APIInstance | None = None

if user_llm_hass_api:
Expand All @@ -413,10 +423,12 @@ async def async_update_llm_data(
LOGGER.error(
"Error getting LLM API %s for %s: %s",
user_llm_hass_api,
conversing_domain,
llm_context.platform,
err,
)
intent_response = intent.IntentResponse(language=user_input.language)
intent_response = intent.IntentResponse(
language=llm_context.language or ""
)
intent_response.async_set_error(
intent.IntentResponseErrorCode.UNKNOWN,
"Error preparing LLM API",
Expand All @@ -430,10 +442,10 @@ async def async_update_llm_data(
user_name: str | None = None

if (
user_input.context
and user_input.context.user_id
llm_context.context
and llm_context.context.user_id
and (
user := await self.hass.auth.async_get_user(user_input.context.user_id)
user := await self.hass.auth.async_get_user(llm_context.context.user_id)
)
):
user_name = user.name
Expand All @@ -443,7 +455,7 @@ async def async_update_llm_data(
await self._async_expand_prompt_template(
llm_context,
(user_llm_prompt or llm.DEFAULT_INSTRUCTIONS_PROMPT),
user_input.language,
llm_context.language,
user_name,
)
)
Expand All @@ -455,14 +467,14 @@ async def async_update_llm_data(
await self._async_expand_prompt_template(
llm_context,
llm.BASE_PROMPT,
user_input.language,
llm_context.language,
user_name,
)
)

if extra_system_prompt := (
# Take new system prompt if one was given
user_input.extra_system_prompt or self.extra_system_prompt
user_extra_system_prompt or self.extra_system_prompt
):
prompt_parts.append(extra_system_prompt)

Expand Down
14 changes: 13 additions & 1 deletion homeassistant/components/conversation/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,9 @@
from typing import Any, Literal

from homeassistant.core import Context
from homeassistant.helpers import intent
from homeassistant.helpers import intent, llm

from .const import DOMAIN


@dataclass(frozen=True)
Expand Down Expand Up @@ -56,6 +58,16 @@ def as_dict(self) -> dict[str, Any]:
"extra_system_prompt": self.extra_system_prompt,
}

def as_llm_context(self, conversing_domain: str) -> llm.LLMContext:
"""Return input as an LLM context."""
return llm.LLMContext(
platform=conversing_domain,
context=self.context,
language=self.language,
assistant=DOMAIN,
device_id=self.device_id,
)


@dataclass(slots=True)
class ConversationResult:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -73,11 +73,11 @@ async def _async_handle_message(
options = self.entry.options

try:
await chat_log.async_update_llm_data(
DOMAIN,
user_input,
await chat_log.async_provide_llm_data(
user_input.as_llm_context(DOMAIN),
options.get(CONF_LLM_HASS_API),
options.get(CONF_PROMPT),
user_input.extra_system_prompt,
)
except conversation.ConverseError as err:
return err.as_conversation_result()
Expand Down
6 changes: 3 additions & 3 deletions homeassistant/components/ollama/conversation.py
Original file line number Diff line number Diff line change
Expand Up @@ -219,11 +219,11 @@ async def _async_handle_message(
settings = {**self.entry.data, **self.entry.options}

try:
await chat_log.async_update_llm_data(
DOMAIN,
user_input,
await chat_log.async_provide_llm_data(
user_input.as_llm_context(DOMAIN),
settings.get(CONF_LLM_HASS_API),
settings.get(CONF_PROMPT),
user_input.extra_system_prompt,
)
except conversation.ConverseError as err:
return err.as_conversation_result()
Expand Down
6 changes: 3 additions & 3 deletions homeassistant/components/openai_conversation/conversation.py
Original file line number Diff line number Diff line change
Expand Up @@ -279,11 +279,11 @@ async def _async_handle_message(
options = self.entry.options

try:
await chat_log.async_update_llm_data(
DOMAIN,
user_input,
await chat_log.async_provide_llm_data(
user_input.as_llm_context(DOMAIN),
options.get(CONF_LLM_HASS_API),
options.get(CONF_PROMPT),
user_input.extra_system_prompt,
)
except conversation.ConverseError as err:
return err.as_conversation_result()
Expand Down
6 changes: 3 additions & 3 deletions tests/components/assist_pipeline/test_pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -1779,11 +1779,11 @@ async def stream_llm_response():
conversation_input,
) as chat_log,
):
await chat_log.async_update_llm_data(
conversing_domain="test",
user_input=conversation_input,
await chat_log.async_provide_llm_data(
conversation_input.as_llm_context("test"),
user_llm_hass_api="assist",
user_llm_prompt=None,
user_extra_system_prompt=conversation_input.extra_system_prompt,
)
async for _content in chat_log.async_add_delta_content_stream(
agent_id, stream_llm_response()
Expand Down
64 changes: 28 additions & 36 deletions tests/components/conversation/test_chat_log.py
Original file line number Diff line number Diff line change
Expand Up @@ -106,9 +106,8 @@ async def test_llm_api(
chat_session.async_get_chat_session(hass) as session,
async_get_chat_log(hass, session, mock_conversation_input) as chat_log,
):
await chat_log.async_update_llm_data(
conversing_domain="test",
user_input=mock_conversation_input,
await chat_log.async_provide_llm_data(
mock_conversation_input.as_llm_context("test"),
user_llm_hass_api="assist",
user_llm_prompt=None,
)
Expand All @@ -128,9 +127,8 @@ async def test_unknown_llm_api(
async_get_chat_log(hass, session, mock_conversation_input) as chat_log,
pytest.raises(ConverseError) as exc_info,
):
await chat_log.async_update_llm_data(
conversing_domain="test",
user_input=mock_conversation_input,
await chat_log.async_provide_llm_data(
mock_conversation_input.as_llm_context("test"),
user_llm_hass_api="unknown-api",
user_llm_prompt=None,
)
Expand Down Expand Up @@ -170,9 +168,8 @@ async def async_get_api_instance(
chat_session.async_get_chat_session(hass) as session,
async_get_chat_log(hass, session, mock_conversation_input) as chat_log,
):
await chat_log.async_update_llm_data(
conversing_domain="test",
user_input=mock_conversation_input,
await chat_log.async_provide_llm_data(
mock_conversation_input.as_llm_context("test"),
user_llm_hass_api=["assist", "my-api"],
user_llm_prompt=None,
)
Expand All @@ -192,9 +189,8 @@ async def test_template_error(
async_get_chat_log(hass, session, mock_conversation_input) as chat_log,
pytest.raises(ConverseError) as exc_info,
):
await chat_log.async_update_llm_data(
conversing_domain="test",
user_input=mock_conversation_input,
await chat_log.async_provide_llm_data(
mock_conversation_input.as_llm_context("test"),
user_llm_hass_api=None,
user_llm_prompt="{{ invalid_syntax",
)
Expand All @@ -217,9 +213,8 @@ async def test_template_variables(
async_get_chat_log(hass, session, mock_conversation_input) as chat_log,
patch("homeassistant.auth.AuthManager.async_get_user", return_value=mock_user),
):
await chat_log.async_update_llm_data(
conversing_domain="test",
user_input=mock_conversation_input,
await chat_log.async_provide_llm_data(
mock_conversation_input.as_llm_context("test"),
user_llm_hass_api=None,
user_llm_prompt=(
"The instance name is {{ ha_name }}. "
Expand Down Expand Up @@ -249,11 +244,11 @@ async def test_extra_systen_prompt(
chat_session.async_get_chat_session(hass) as session,
async_get_chat_log(hass, session, mock_conversation_input) as chat_log,
):
await chat_log.async_update_llm_data(
conversing_domain="test",
user_input=mock_conversation_input,
await chat_log.async_provide_llm_data(
mock_conversation_input.as_llm_context("test"),
user_llm_hass_api=None,
user_llm_prompt=None,
user_extra_system_prompt=mock_conversation_input.extra_system_prompt,
)
chat_log.async_add_assistant_content_without_tools(
AssistantContent(
Expand All @@ -273,11 +268,11 @@ async def test_extra_systen_prompt(
chat_session.async_get_chat_session(hass, conversation_id) as session,
async_get_chat_log(hass, session, mock_conversation_input) as chat_log,
):
await chat_log.async_update_llm_data(
conversing_domain="test",
user_input=mock_conversation_input,
await chat_log.async_provide_llm_data(
mock_conversation_input.as_llm_context("test"),
user_llm_hass_api=None,
user_llm_prompt=None,
user_extra_system_prompt=mock_conversation_input.extra_system_prompt,
)

assert chat_log.extra_system_prompt == extra_system_prompt
Expand All @@ -290,11 +285,11 @@ async def test_extra_systen_prompt(
chat_session.async_get_chat_session(hass, conversation_id) as session,
async_get_chat_log(hass, session, mock_conversation_input) as chat_log,
):
await chat_log.async_update_llm_data(
conversing_domain="test",
user_input=mock_conversation_input,
await chat_log.async_provide_llm_data(
mock_conversation_input.as_llm_context("test"),
user_llm_hass_api=None,
user_llm_prompt=None,
user_extra_system_prompt=mock_conversation_input.extra_system_prompt,
)
chat_log.async_add_assistant_content_without_tools(
AssistantContent(
Expand All @@ -314,11 +309,11 @@ async def test_extra_systen_prompt(
chat_session.async_get_chat_session(hass, conversation_id) as session,
async_get_chat_log(hass, session, mock_conversation_input) as chat_log,
):
await chat_log.async_update_llm_data(
conversing_domain="test",
user_input=mock_conversation_input,
await chat_log.async_provide_llm_data(
mock_conversation_input.as_llm_context("test"),
user_llm_hass_api=None,
user_llm_prompt=None,
user_extra_system_prompt=mock_conversation_input.extra_system_prompt,
)

assert chat_log.extra_system_prompt == extra_system_prompt2
Expand Down Expand Up @@ -357,9 +352,8 @@ async def test_tool_call(
chat_session.async_get_chat_session(hass) as session,
async_get_chat_log(hass, session, mock_conversation_input) as chat_log,
):
await chat_log.async_update_llm_data(
conversing_domain="test",
user_input=mock_conversation_input,
await chat_log.async_provide_llm_data(
mock_conversation_input.as_llm_context("test"),
user_llm_hass_api="assist",
user_llm_prompt=None,
)
Expand Down Expand Up @@ -434,9 +428,8 @@ async def test_tool_call_exception(
async_get_chat_log(hass, session, mock_conversation_input) as chat_log,
):
mock_get_tools.return_value = [mock_tool]
await chat_log.async_update_llm_data(
conversing_domain="test",
user_input=mock_conversation_input,
await chat_log.async_provide_llm_data(
mock_conversation_input.as_llm_context("test"),
user_llm_hass_api="assist",
user_llm_prompt=None,
)
Expand Down Expand Up @@ -595,9 +588,8 @@ async def stream():
) as chat_log,
):
mock_get_tools.return_value = [mock_tool]
await chat_log.async_update_llm_data(
conversing_domain="test",
user_input=mock_conversation_input,
await chat_log.async_provide_llm_data(
mock_conversation_input.as_llm_context("test"),
user_llm_hass_api="assist",
user_llm_prompt=None,
)
Expand Down
Loading
Loading