Skip to content

Commit 008e2a3

Browse files
authored
Add attachment support to AI task (home-assistant#148120)
1 parent 699c60f commit 008e2a3

File tree

9 files changed

+147
-21
lines changed

9 files changed

+147
-21
lines changed

homeassistant/components/ai_task/__init__.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020
from homeassistant.helpers.typing import UNDEFINED, ConfigType, UndefinedType
2121

2222
from .const import (
23+
ATTR_ATTACHMENTS,
2324
ATTR_INSTRUCTIONS,
2425
ATTR_REQUIRED,
2526
ATTR_STRUCTURE,
@@ -32,14 +33,15 @@
3233
)
3334
from .entity import AITaskEntity
3435
from .http import async_setup as async_setup_http
35-
from .task import GenDataTask, GenDataTaskResult, async_generate_data
36+
from .task import GenDataTask, GenDataTaskResult, PlayMediaWithId, async_generate_data
3637

3738
__all__ = [
3839
"DOMAIN",
3940
"AITaskEntity",
4041
"AITaskEntityFeature",
4142
"GenDataTask",
4243
"GenDataTaskResult",
44+
"PlayMediaWithId",
4345
"async_generate_data",
4446
"async_setup",
4547
"async_setup_entry",
@@ -92,6 +94,9 @@ async def async_setup(hass: HomeAssistant, config: ConfigType) -> bool:
9294
vol.Schema({str: STRUCTURE_FIELD_SCHEMA}),
9395
_validate_structure_fields,
9496
),
97+
vol.Optional(ATTR_ATTACHMENTS): vol.All(
98+
cv.ensure_list, [selector.MediaSelector({"accept": ["*/*"]})]
99+
),
95100
}
96101
),
97102
supports_response=SupportsResponse.ONLY,

homeassistant/components/ai_task/const.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@
2323
ATTR_TASK_NAME: Final = "task_name"
2424
ATTR_STRUCTURE: Final = "structure"
2525
ATTR_REQUIRED: Final = "required"
26+
ATTR_ATTACHMENTS: Final = "attachments"
2627

2728
DEFAULT_SYSTEM_PROMPT = (
2829
"You are a Home Assistant expert and help users with their tasks."
@@ -34,3 +35,6 @@ class AITaskEntityFeature(IntFlag):
3435

3536
GENERATE_DATA = 1
3637
"""Generate data based on instructions."""
38+
39+
SUPPORT_ATTACHMENTS = 2
40+
"""Support attachments with generate data."""

homeassistant/components/ai_task/manifest.json

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22
"domain": "ai_task",
33
"name": "AI Task",
44
"codeowners": ["@home-assistant/core"],
5-
"dependencies": ["conversation"],
5+
"dependencies": ["conversation", "media_source"],
66
"documentation": "https://www.home-assistant.io/integrations/ai_task",
77
"integration_type": "system",
88
"quality_scale": "internal"

homeassistant/components/ai_task/services.yaml

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,3 +23,9 @@ generate_data:
2323
example: '{ "name": { "selector": { "text": }, "description": "Name of the user", "required": "True" } } }, "age": { "selector": { "number": }, "description": "Age of the user" } }'
2424
selector:
2525
object:
26+
attachments:
27+
required: false
28+
selector:
29+
media:
30+
accept:
31+
- "*"

homeassistant/components/ai_task/strings.json

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,10 @@
1919
"structure": {
2020
"name": "Structured output",
2121
"description": "When set, the AI Task will output fields with this in structure. The structure is a dictionary where the keys are the field names and the values contain a 'description', a 'selector', and an optional 'required' field."
22+
},
23+
"attachments": {
24+
"name": "Attachments",
25+
"description": "List of files to attach for multi-modal AI analysis."
2226
}
2327
}
2428
}

homeassistant/components/ai_task/task.py

Lines changed: 44 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,24 +2,38 @@
22

33
from __future__ import annotations
44

5-
from dataclasses import dataclass
5+
from dataclasses import dataclass, fields
66
from typing import Any
77

88
import voluptuous as vol
99

10+
from homeassistant.components import media_source
1011
from homeassistant.core import HomeAssistant
1112
from homeassistant.exceptions import HomeAssistantError
1213

1314
from .const import DATA_COMPONENT, DATA_PREFERENCES, AITaskEntityFeature
1415

1516

17+
@dataclass(slots=True)
18+
class PlayMediaWithId(media_source.PlayMedia):
19+
"""Play media with a media content ID."""
20+
21+
media_content_id: str
22+
"""Media source ID to play."""
23+
24+
def __str__(self) -> str:
25+
"""Return media source ID as a string."""
26+
return f"<PlayMediaWithId {self.media_content_id}>"
27+
28+
1629
async def async_generate_data(
1730
hass: HomeAssistant,
1831
*,
1932
task_name: str,
2033
entity_id: str | None = None,
2134
instructions: str,
2235
structure: vol.Schema | None = None,
36+
attachments: list[dict] | None = None,
2337
) -> GenDataTaskResult:
2438
"""Run a task in the AI Task integration."""
2539
if entity_id is None:
@@ -37,11 +51,37 @@ async def async_generate_data(
3751
f"AI Task entity {entity_id} does not support generating data"
3852
)
3953

54+
# Resolve attachments
55+
resolved_attachments: list[PlayMediaWithId] | None = None
56+
57+
if attachments:
58+
if AITaskEntityFeature.SUPPORT_ATTACHMENTS not in entity.supported_features:
59+
raise HomeAssistantError(
60+
f"AI Task entity {entity_id} does not support attachments"
61+
)
62+
63+
resolved_attachments = []
64+
65+
for attachment in attachments:
66+
media = await media_source.async_resolve_media(
67+
hass, attachment["media_content_id"], None
68+
)
69+
resolved_attachments.append(
70+
PlayMediaWithId(
71+
**{
72+
field.name: getattr(media, field.name)
73+
for field in fields(media)
74+
},
75+
media_content_id=attachment["media_content_id"],
76+
)
77+
)
78+
4079
return await entity.internal_async_generate_data(
4180
GenDataTask(
4281
name=task_name,
4382
instructions=instructions,
4483
structure=structure,
84+
attachments=resolved_attachments,
4585
)
4686
)
4787

@@ -59,6 +99,9 @@ class GenDataTask:
5999
structure: vol.Schema | None = None
60100
"""Optional structure for the data to be generated."""
61101

102+
attachments: list[PlayMediaWithId] | None = None
103+
"""List of attachments to go along the instructions."""
104+
62105
def __str__(self) -> str:
63106
"""Return task as a string."""
64107
return f"<GenDataTask {self.name}: {id(self)}>"

tests/components/ai_task/conftest.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,9 @@ class MockAITaskEntity(AITaskEntity):
3535
"""Mock AI Task entity for testing."""
3636

3737
_attr_name = "Test Task Entity"
38-
_attr_supported_features = AITaskEntityFeature.GENERATE_DATA
38+
_attr_supported_features = (
39+
AITaskEntityFeature.GENERATE_DATA | AITaskEntityFeature.SUPPORT_ATTACHMENTS
40+
)
3941

4042
def __init__(self) -> None:
4143
"""Initialize the mock entity."""

tests/components/ai_task/test_init.py

Lines changed: 47 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,13 @@
11
"""Test initialization of the AI Task component."""
22

33
from typing import Any
4+
from unittest.mock import patch
45

56
from freezegun.api import FrozenDateTimeFactory
67
import pytest
78
import voluptuous as vol
89

10+
from homeassistant.components import media_source
911
from homeassistant.components.ai_task import AITaskPreferences
1012
from homeassistant.components.ai_task.const import DATA_PREFERENCES
1113
from homeassistant.core import HomeAssistant
@@ -58,7 +60,15 @@ async def test_preferences_storage_load(
5860
),
5961
(
6062
{},
61-
{"entity_id": TEST_ENTITY_ID},
63+
{
64+
"entity_id": TEST_ENTITY_ID,
65+
"attachments": [
66+
{
67+
"media_content_id": "media-source://mock/blah_blah_blah.mp4",
68+
"media_content_type": "video/mp4",
69+
}
70+
],
71+
},
6272
),
6373
],
6474
)
@@ -68,25 +78,50 @@ async def test_generate_data_service(
6878
freezer: FrozenDateTimeFactory,
6979
set_preferences: dict[str, str | None],
7080
msg_extra: dict[str, str],
81+
mock_ai_task_entity: MockAITaskEntity,
7182
) -> None:
7283
"""Test the generate data service."""
7384
preferences = hass.data[DATA_PREFERENCES]
7485
preferences.async_set_preferences(**set_preferences)
7586

76-
result = await hass.services.async_call(
77-
"ai_task",
78-
"generate_data",
79-
{
80-
"task_name": "Test Name",
81-
"instructions": "Test prompt",
82-
}
83-
| msg_extra,
84-
blocking=True,
85-
return_response=True,
86-
)
87+
with patch(
88+
"homeassistant.components.media_source.async_resolve_media",
89+
return_value=media_source.PlayMedia(
90+
url="http://example.com/media.mp4",
91+
mime_type="video/mp4",
92+
),
93+
):
94+
result = await hass.services.async_call(
95+
"ai_task",
96+
"generate_data",
97+
{
98+
"task_name": "Test Name",
99+
"instructions": "Test prompt",
100+
}
101+
| msg_extra,
102+
blocking=True,
103+
return_response=True,
104+
)
87105

88106
assert result["data"] == "Mock result"
89107

108+
assert len(mock_ai_task_entity.mock_generate_data_tasks) == 1
109+
task = mock_ai_task_entity.mock_generate_data_tasks[0]
110+
111+
assert len(task.attachments or []) == len(
112+
msg_attachments := msg_extra.get("attachments", [])
113+
)
114+
115+
for msg_attachment, attachment in zip(
116+
msg_attachments, task.attachments or [], strict=False
117+
):
118+
assert attachment.url == "http://example.com/media.mp4"
119+
assert attachment.mime_type == "video/mp4"
120+
assert attachment.media_content_id == msg_attachment["media_content_id"]
121+
assert (
122+
str(attachment) == f"<PlayMediaWithId {msg_attachment['media_content_id']}>"
123+
)
124+
90125

91126
async def test_generate_data_service_structure_fields(
92127
hass: HomeAssistant,

tests/components/ai_task/test_task.py

Lines changed: 32 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -16,13 +16,13 @@
1616
from tests.typing import WebSocketGenerator
1717

1818

19-
async def test_run_task_preferred_entity(
19+
async def test_generate_data_preferred_entity(
2020
hass: HomeAssistant,
2121
init_components: None,
2222
mock_ai_task_entity: MockAITaskEntity,
2323
hass_ws_client: WebSocketGenerator,
2424
) -> None:
25-
"""Test running a task with an unknown entity."""
25+
"""Test generating data with entity via preferences."""
2626
client = await hass_ws_client(hass)
2727

2828
with pytest.raises(
@@ -90,11 +90,11 @@ async def test_run_task_preferred_entity(
9090
)
9191

9292

93-
async def test_run_data_task_unknown_entity(
93+
async def test_generate_data_unknown_entity(
9494
hass: HomeAssistant,
9595
init_components: None,
9696
) -> None:
97-
"""Test running a data task with an unknown entity."""
97+
"""Test generating data with an unknown entity."""
9898

9999
with pytest.raises(
100100
HomeAssistantError, match="AI Task entity ai_task.unknown_entity not found"
@@ -113,7 +113,7 @@ async def test_run_data_task_updates_chat_log(
113113
init_components: None,
114114
snapshot: SnapshotAssertion,
115115
) -> None:
116-
"""Test that running a data task updates the chat log."""
116+
"""Test that generating data updates the chat log."""
117117
result = await async_generate_data(
118118
hass,
119119
task_name="Test Task",
@@ -127,3 +127,30 @@ async def test_run_data_task_updates_chat_log(
127127
async_get_chat_log(hass, session) as chat_log,
128128
):
129129
assert chat_log.content == snapshot
130+
131+
132+
async def test_generate_data_attachments_not_supported(
133+
hass: HomeAssistant,
134+
init_components: None,
135+
mock_ai_task_entity: MockAITaskEntity,
136+
) -> None:
137+
"""Test generating data with attachments when entity doesn't support them."""
138+
# Remove attachment support from the entity
139+
mock_ai_task_entity._attr_supported_features = AITaskEntityFeature.GENERATE_DATA
140+
141+
with pytest.raises(
142+
HomeAssistantError,
143+
match="AI Task entity ai_task.test_task_entity does not support attachments",
144+
):
145+
await async_generate_data(
146+
hass,
147+
task_name="Test Task",
148+
entity_id=TEST_ENTITY_ID,
149+
instructions="Test prompt",
150+
attachments=[
151+
{
152+
"media_content_id": "media-source://mock/test.mp4",
153+
"media_content_type": "video/mp4",
154+
}
155+
],
156+
)

0 commit comments

Comments
 (0)