Skip to content

Commit af07ab4

Browse files
authored
Allow passing an LLM API to AI Task generate data (#151081)
1 parent 74b7315 commit af07ab4

File tree

4 files changed

+26
-3
lines changed

4 files changed

+26
-3
lines changed

homeassistant/components/ai_task/entity.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -60,6 +60,10 @@ async def _async_get_ai_task_chat_log(
6060
task: GenDataTask | GenImageTask,
6161
) -> AsyncGenerator[ChatLog]:
6262
"""Context manager used to manage the ChatLog used during an AI Task."""
63+
user_llm_hass_api: llm.API | None = None
64+
if isinstance(task, GenDataTask):
65+
user_llm_hass_api = task.llm_api
66+
6367
# pylint: disable-next=contextmanager-generator-missing-cleanup
6468
with (
6569
async_get_chat_log(
@@ -77,6 +81,7 @@ async def _async_get_ai_task_chat_log(
7781
device_id=None,
7882
),
7983
user_llm_prompt=DEFAULT_SYSTEM_PROMPT,
84+
user_llm_hass_api=user_llm_hass_api,
8085
)
8186

8287
chat_log.async_add_user_content(

homeassistant/components/ai_task/task.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616
from homeassistant.components.http.auth import async_sign_path
1717
from homeassistant.core import HomeAssistant, ServiceResponse, callback
1818
from homeassistant.exceptions import HomeAssistantError
19+
from homeassistant.helpers import llm
1920
from homeassistant.helpers.chat_session import ChatSession, async_get_chat_session
2021
from homeassistant.helpers.event import async_call_later
2122
from homeassistant.helpers.network import get_url
@@ -116,6 +117,7 @@ async def async_generate_data(
116117
instructions: str,
117118
structure: vol.Schema | None = None,
118119
attachments: list[dict] | None = None,
120+
llm_api: llm.API | None = None,
119121
) -> GenDataTaskResult:
120122
"""Run a data generation task in the AI Task integration."""
121123
if entity_id is None:
@@ -151,6 +153,7 @@ async def async_generate_data(
151153
instructions=instructions,
152154
structure=structure,
153155
attachments=resolved_attachments or None,
156+
llm_api=llm_api,
154157
),
155158
)
156159

@@ -272,6 +275,9 @@ class GenDataTask:
272275
attachments: list[conversation.Attachment] | None = None
273276
"""List of attachments to go along the instructions."""
274277

278+
llm_api: llm.API | None = None
279+
"""API to provide to the LLM."""
280+
275281
def __str__(self) -> str:
276282
"""Return task as a string."""
277283
return f"<GenDataTask {self.name}: {id(self)}>"

homeassistant/components/conversation/chat_log.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -507,14 +507,18 @@ async def async_update_llm_data(
507507
async def async_provide_llm_data(
508508
self,
509509
llm_context: llm.LLMContext,
510-
user_llm_hass_api: str | list[str] | None = None,
510+
user_llm_hass_api: str | list[str] | llm.API | None = None,
511511
user_llm_prompt: str | None = None,
512512
user_extra_system_prompt: str | None = None,
513513
) -> None:
514514
"""Set the LLM system prompt."""
515515
llm_api: llm.APIInstance | None = None
516516

517-
if user_llm_hass_api:
517+
if user_llm_hass_api is None:
518+
pass
519+
elif isinstance(user_llm_hass_api, llm.API):
520+
llm_api = await user_llm_hass_api.async_get_api_instance(llm_context)
521+
else:
518522
try:
519523
llm_api = await llm.async_get_api(
520524
self.hass,

tests/components/ai_task/test_task.py

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@
2020
from homeassistant.const import STATE_UNKNOWN
2121
from homeassistant.core import HomeAssistant
2222
from homeassistant.exceptions import HomeAssistantError
23-
from homeassistant.helpers import chat_session
23+
from homeassistant.helpers import chat_session, llm
2424
from homeassistant.util import dt as dt_util
2525

2626
from .conftest import TEST_ENTITY_ID, MockAITaskEntity
@@ -78,10 +78,12 @@ async def test_generate_data_preferred_entity(
7878
assert state is not None
7979
assert state.state == STATE_UNKNOWN
8080

81+
llm_api = llm.AssistAPI(hass)
8182
result = await async_generate_data(
8283
hass,
8384
task_name="Test Task",
8485
instructions="Test prompt",
86+
llm_api=llm_api,
8587
)
8688
assert result.data == "Mock result"
8789
as_dict = result.as_dict()
@@ -91,6 +93,12 @@ async def test_generate_data_preferred_entity(
9193
assert state is not None
9294
assert state.state != STATE_UNKNOWN
9395

96+
with (
97+
chat_session.async_get_chat_session(hass, result.conversation_id) as session,
98+
async_get_chat_log(hass, session) as chat_log,
99+
):
100+
assert chat_log.llm_api.api is llm_api
101+
94102
mock_ai_task_entity.supported_features = AITaskEntityFeature(0)
95103
with pytest.raises(
96104
HomeAssistantError,

0 commit comments

Comments
 (0)