Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 6 additions & 1 deletion homeassistant/components/ai_task/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
from homeassistant.helpers.typing import UNDEFINED, ConfigType, UndefinedType

from .const import (
ATTR_ATTACHMENTS,
ATTR_INSTRUCTIONS,
ATTR_REQUIRED,
ATTR_STRUCTURE,
Expand All @@ -32,14 +33,15 @@
)
from .entity import AITaskEntity
from .http import async_setup as async_setup_http
from .task import GenDataTask, GenDataTaskResult, async_generate_data
from .task import GenDataTask, GenDataTaskResult, PlayMediaWithId, async_generate_data

__all__ = [
"DOMAIN",
"AITaskEntity",
"AITaskEntityFeature",
"GenDataTask",
"GenDataTaskResult",
"PlayMediaWithId",
"async_generate_data",
"async_setup",
"async_setup_entry",
Expand Down Expand Up @@ -92,6 +94,9 @@ async def async_setup(hass: HomeAssistant, config: ConfigType) -> bool:
vol.Schema({str: STRUCTURE_FIELD_SCHEMA}),
_validate_structure_fields,
),
vol.Optional(ATTR_ATTACHMENTS): vol.All(
cv.ensure_list, [selector.MediaSelector({"accept": ["*/*"]})]
),
}
),
supports_response=SupportsResponse.ONLY,
Expand Down
4 changes: 4 additions & 0 deletions homeassistant/components/ai_task/const.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
ATTR_TASK_NAME: Final = "task_name"
ATTR_STRUCTURE: Final = "structure"
ATTR_REQUIRED: Final = "required"
ATTR_ATTACHMENTS: Final = "attachments"

DEFAULT_SYSTEM_PROMPT = (
"You are a Home Assistant expert and help users with their tasks."
Expand All @@ -34,3 +35,6 @@ class AITaskEntityFeature(IntFlag):

GENERATE_DATA = 1
"""Generate data based on instructions."""

SUPPORT_ATTACHMENTS = 2
"""Support attachments with generate data."""
2 changes: 1 addition & 1 deletion homeassistant/components/ai_task/manifest.json
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
"domain": "ai_task",
"name": "AI Task",
"codeowners": ["@home-assistant/core"],
"dependencies": ["conversation"],
"dependencies": ["conversation", "media_source"],
"documentation": "https://www.home-assistant.io/integrations/ai_task",
"integration_type": "system",
"quality_scale": "internal"
Expand Down
6 changes: 6 additions & 0 deletions homeassistant/components/ai_task/services.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -23,3 +23,9 @@ generate_data:
example: '{ "name": { "selector": { "text": }, "description": "Name of the user", "required": "True" } } }, "age": { "selector": { "number": }, "description": "Age of the user" } }'
selector:
object:
attachments:
required: false
selector:
media:
accept:
- "*"
4 changes: 4 additions & 0 deletions homeassistant/components/ai_task/strings.json
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,10 @@
"structure": {
"name": "Structured output",
"description": "When set, the AI Task will output fields with this in structure. The structure is a dictionary where the keys are the field names and the values contain a 'description', a 'selector', and an optional 'required' field."
},
"attachments": {
"name": "Attachments",
"description": "List of files to attach for multi-modal AI analysis."
}
}
}
Expand Down
45 changes: 44 additions & 1 deletion homeassistant/components/ai_task/task.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,24 +2,38 @@

from __future__ import annotations

from dataclasses import dataclass
from dataclasses import dataclass, fields
from typing import Any

import voluptuous as vol

from homeassistant.components import media_source
from homeassistant.core import HomeAssistant
from homeassistant.exceptions import HomeAssistantError

from .const import DATA_COMPONENT, DATA_PREFERENCES, AITaskEntityFeature


@dataclass(slots=True)
class PlayMediaWithId(media_source.PlayMedia):
"""Play media with a media content ID."""

media_content_id: str
"""Media source ID to play."""

def __str__(self) -> str:
"""Return media source ID as a string."""
return f"<PlayMediaWithId {self.media_content_id}>"


async def async_generate_data(
hass: HomeAssistant,
*,
task_name: str,
entity_id: str | None = None,
instructions: str,
structure: vol.Schema | None = None,
attachments: list[dict] | None = None,
) -> GenDataTaskResult:
"""Run a task in the AI Task integration."""
if entity_id is None:
Expand All @@ -37,11 +51,37 @@ async def async_generate_data(
f"AI Task entity {entity_id} does not support generating data"
)

# Resolve attachments
resolved_attachments: list[PlayMediaWithId] | None = None

if attachments:
if AITaskEntityFeature.SUPPORT_ATTACHMENTS not in entity.supported_features:
raise HomeAssistantError(
f"AI Task entity {entity_id} does not support attachments"
)

resolved_attachments = []

for attachment in attachments:
media = await media_source.async_resolve_media(
hass, attachment["media_content_id"], None
)
resolved_attachments.append(
PlayMediaWithId(
**{
field.name: getattr(media, field.name)
for field in fields(media)
},
media_content_id=attachment["media_content_id"],
)
)

return await entity.internal_async_generate_data(
GenDataTask(
name=task_name,
instructions=instructions,
structure=structure,
attachments=resolved_attachments,
)
)

Expand All @@ -59,6 +99,9 @@ class GenDataTask:
structure: vol.Schema | None = None
"""Optional structure for the data to be generated."""

attachments: list[PlayMediaWithId] | None = None
"""List of attachments to go along the instructions."""

def __str__(self) -> str:
"""Return task as a string."""
return f"<GenDataTask {self.name}: {id(self)}>"
Expand Down
2 changes: 1 addition & 1 deletion homeassistant/components/rest/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@ def __init__(
# Convert auth tuple to aiohttp.BasicAuth if needed
if isinstance(auth, tuple) and len(auth) == 2:
self._auth: aiohttp.BasicAuth | aiohttp.DigestAuthMiddleware | None = (
aiohttp.BasicAuth(auth[0], auth[1])
aiohttp.BasicAuth(auth[0], auth[1], encoding="utf-8")
)
else:
self._auth = auth
Expand Down
1 change: 1 addition & 0 deletions homeassistant/components/utility_meter/config_flow.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,6 +94,7 @@ async def _validate_config(
max=28,
mode=selector.NumberSelectorMode.BOX,
unit_of_measurement="days",
translation_key=CONF_METER_OFFSET,
),
),
vol.Required(CONF_TARIFFS, default=[]): selector.SelectSelector(
Expand Down
5 changes: 5 additions & 0 deletions homeassistant/components/utility_meter/strings.json
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,11 @@
"quarterly": "Quarterly",
"yearly": "Yearly"
}
},
"offset": {
"unit_of_measurement": {
"days": "days"
}
}
},
"services": {
Expand Down
2 changes: 1 addition & 1 deletion homeassistant/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -532,7 +532,7 @@ async def async_start(self) -> None:

This method is a coroutine.
"""
_LOGGER.info("Starting Home Assistant")
_LOGGER.info("Starting Home Assistant %s", __version__)

self.set_state(CoreState.starting)
self.bus.async_fire_internal(EVENT_CORE_CONFIG_UPDATE)
Expand Down
4 changes: 3 additions & 1 deletion tests/components/ai_task/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,9 @@ class MockAITaskEntity(AITaskEntity):
"""Mock AI Task entity for testing."""

_attr_name = "Test Task Entity"
_attr_supported_features = AITaskEntityFeature.GENERATE_DATA
_attr_supported_features = (
AITaskEntityFeature.GENERATE_DATA | AITaskEntityFeature.SUPPORT_ATTACHMENTS
)

def __init__(self) -> None:
"""Initialize the mock entity."""
Expand Down
59 changes: 47 additions & 12 deletions tests/components/ai_task/test_init.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,13 @@
"""Test initialization of the AI Task component."""

from typing import Any
from unittest.mock import patch

from freezegun.api import FrozenDateTimeFactory
import pytest
import voluptuous as vol

from homeassistant.components import media_source
from homeassistant.components.ai_task import AITaskPreferences
from homeassistant.components.ai_task.const import DATA_PREFERENCES
from homeassistant.core import HomeAssistant
Expand Down Expand Up @@ -58,7 +60,15 @@ async def test_preferences_storage_load(
),
(
{},
{"entity_id": TEST_ENTITY_ID},
{
"entity_id": TEST_ENTITY_ID,
"attachments": [
{
"media_content_id": "media-source://mock/blah_blah_blah.mp4",
"media_content_type": "video/mp4",
}
],
},
),
],
)
Expand All @@ -68,25 +78,50 @@ async def test_generate_data_service(
freezer: FrozenDateTimeFactory,
set_preferences: dict[str, str | None],
msg_extra: dict[str, str],
mock_ai_task_entity: MockAITaskEntity,
) -> None:
"""Test the generate data service."""
preferences = hass.data[DATA_PREFERENCES]
preferences.async_set_preferences(**set_preferences)

result = await hass.services.async_call(
"ai_task",
"generate_data",
{
"task_name": "Test Name",
"instructions": "Test prompt",
}
| msg_extra,
blocking=True,
return_response=True,
)
with patch(
"homeassistant.components.media_source.async_resolve_media",
return_value=media_source.PlayMedia(
url="http://example.com/media.mp4",
mime_type="video/mp4",
),
):
result = await hass.services.async_call(
"ai_task",
"generate_data",
{
"task_name": "Test Name",
"instructions": "Test prompt",
}
| msg_extra,
blocking=True,
return_response=True,
)

assert result["data"] == "Mock result"

assert len(mock_ai_task_entity.mock_generate_data_tasks) == 1
task = mock_ai_task_entity.mock_generate_data_tasks[0]

assert len(task.attachments or []) == len(
msg_attachments := msg_extra.get("attachments", [])
)

for msg_attachment, attachment in zip(
msg_attachments, task.attachments or [], strict=False
):
assert attachment.url == "http://example.com/media.mp4"
assert attachment.mime_type == "video/mp4"
assert attachment.media_content_id == msg_attachment["media_content_id"]
assert (
str(attachment) == f"<PlayMediaWithId {msg_attachment['media_content_id']}>"
)


async def test_generate_data_service_structure_fields(
hass: HomeAssistant,
Expand Down
37 changes: 32 additions & 5 deletions tests/components/ai_task/test_task.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,13 +16,13 @@
from tests.typing import WebSocketGenerator


async def test_run_task_preferred_entity(
async def test_generate_data_preferred_entity(
hass: HomeAssistant,
init_components: None,
mock_ai_task_entity: MockAITaskEntity,
hass_ws_client: WebSocketGenerator,
) -> None:
"""Test running a task with an unknown entity."""
"""Test generating data with entity via preferences."""
client = await hass_ws_client(hass)

with pytest.raises(
Expand Down Expand Up @@ -90,11 +90,11 @@ async def test_run_task_preferred_entity(
)


async def test_run_data_task_unknown_entity(
async def test_generate_data_unknown_entity(
hass: HomeAssistant,
init_components: None,
) -> None:
"""Test running a data task with an unknown entity."""
"""Test generating data with an unknown entity."""

with pytest.raises(
HomeAssistantError, match="AI Task entity ai_task.unknown_entity not found"
Expand All @@ -113,7 +113,7 @@ async def test_run_data_task_updates_chat_log(
init_components: None,
snapshot: SnapshotAssertion,
) -> None:
"""Test that running a data task updates the chat log."""
"""Test that generating data updates the chat log."""
result = await async_generate_data(
hass,
task_name="Test Task",
Expand All @@ -127,3 +127,30 @@ async def test_run_data_task_updates_chat_log(
async_get_chat_log(hass, session) as chat_log,
):
assert chat_log.content == snapshot


async def test_generate_data_attachments_not_supported(
hass: HomeAssistant,
init_components: None,
mock_ai_task_entity: MockAITaskEntity,
) -> None:
"""Test generating data with attachments when entity doesn't support them."""
# Remove attachment support from the entity
mock_ai_task_entity._attr_supported_features = AITaskEntityFeature.GENERATE_DATA

with pytest.raises(
HomeAssistantError,
match="AI Task entity ai_task.test_task_entity does not support attachments",
):
await async_generate_data(
hass,
task_name="Test Task",
entity_id=TEST_ENTITY_ID,
instructions="Test prompt",
attachments=[
{
"media_content_id": "media-source://mock/test.mp4",
"media_content_type": "video/mp4",
}
],
)
Loading
Loading