Skip to content

Commit 8f1abb6

Browse files
Add HomeAssistant Cloud ai_task (home-assistant#157015)
1 parent 242c028 commit 8f1abb6

File tree

9 files changed

+1331
-3
lines changed

9 files changed

+1331
-3
lines changed

homeassistant/components/cloud/__init__.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -77,7 +77,12 @@
7777

7878
DEFAULT_MODE = MODE_PROD
7979

80-
PLATFORMS = [Platform.BINARY_SENSOR, Platform.STT, Platform.TTS]
80+
PLATFORMS = [
81+
Platform.AI_TASK,
82+
Platform.BINARY_SENSOR,
83+
Platform.STT,
84+
Platform.TTS,
85+
]
8186

8287
SERVICE_REMOTE_CONNECT = "remote_connect"
8388
SERVICE_REMOTE_DISCONNECT = "remote_disconnect"
Lines changed: 200 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,200 @@
1+
"""AI Task integration for Home Assistant Cloud."""
2+
3+
from __future__ import annotations
4+
5+
import io
6+
from json import JSONDecodeError
7+
import logging
8+
9+
from hass_nabucasa.llm import (
10+
LLMAuthenticationError,
11+
LLMError,
12+
LLMImageAttachment,
13+
LLMRateLimitError,
14+
LLMResponseError,
15+
LLMServiceError,
16+
)
17+
from PIL import Image
18+
19+
from homeassistant.components import ai_task, conversation
20+
from homeassistant.config_entries import ConfigEntry
21+
from homeassistant.core import HomeAssistant
22+
from homeassistant.exceptions import ConfigEntryAuthFailed, HomeAssistantError
23+
from homeassistant.helpers.entity_platform import AddConfigEntryEntitiesCallback
24+
from homeassistant.util.json import json_loads
25+
26+
from .const import AI_TASK_ENTITY_UNIQUE_ID, DATA_CLOUD
27+
from .entity import BaseCloudLLMEntity
28+
29+
_LOGGER = logging.getLogger(__name__)
30+
31+
32+
def _convert_image_for_editing(data: bytes) -> tuple[bytes, str]:
33+
"""Ensure the image data is in a format accepted by OpenAI image edits."""
34+
stream = io.BytesIO(data)
35+
with Image.open(stream) as img:
36+
mode = img.mode
37+
if mode not in ("RGBA", "LA", "L"):
38+
img = img.convert("RGBA")
39+
40+
output = io.BytesIO()
41+
if img.mode in ("RGBA", "LA", "L"):
42+
img.save(output, format="PNG")
43+
return output.getvalue(), "image/png"
44+
45+
img.save(output, format=img.format or "PNG")
46+
return output.getvalue(), f"image/{(img.format or 'png').lower()}"
47+
48+
49+
async def async_prepare_image_generation_attachments(
50+
hass: HomeAssistant, attachments: list[conversation.Attachment]
51+
) -> list[LLMImageAttachment]:
52+
"""Load attachment data for image generation."""
53+
54+
def prepare() -> list[LLMImageAttachment]:
55+
items: list[LLMImageAttachment] = []
56+
for attachment in attachments:
57+
if not attachment.mime_type or not attachment.mime_type.startswith(
58+
"image/"
59+
):
60+
raise HomeAssistantError(
61+
"Only image attachments are supported for image generation"
62+
)
63+
path = attachment.path
64+
if not path.exists():
65+
raise HomeAssistantError(f"`{path}` does not exist")
66+
67+
data = path.read_bytes()
68+
mime_type = attachment.mime_type
69+
70+
try:
71+
data, mime_type = _convert_image_for_editing(data)
72+
except HomeAssistantError:
73+
raise
74+
except Exception as err:
75+
raise HomeAssistantError("Failed to process image attachment") from err
76+
77+
items.append(
78+
LLMImageAttachment(
79+
filename=path.name,
80+
mime_type=mime_type,
81+
data=data,
82+
)
83+
)
84+
85+
return items
86+
87+
return await hass.async_add_executor_job(prepare)
88+
89+
90+
async def async_setup_entry(
91+
hass: HomeAssistant,
92+
config_entry: ConfigEntry,
93+
async_add_entities: AddConfigEntryEntitiesCallback,
94+
) -> None:
95+
"""Set up Home Assistant Cloud AI Task entity."""
96+
cloud = hass.data[DATA_CLOUD]
97+
try:
98+
await cloud.llm.async_ensure_token()
99+
except LLMError:
100+
return
101+
102+
async_add_entities([CloudLLMTaskEntity(cloud, config_entry)])
103+
104+
105+
class CloudLLMTaskEntity(ai_task.AITaskEntity, BaseCloudLLMEntity):
106+
"""Home Assistant Cloud AI Task entity."""
107+
108+
_attr_has_entity_name = True
109+
_attr_supported_features = (
110+
ai_task.AITaskEntityFeature.GENERATE_DATA
111+
| ai_task.AITaskEntityFeature.GENERATE_IMAGE
112+
| ai_task.AITaskEntityFeature.SUPPORT_ATTACHMENTS
113+
)
114+
_attr_translation_key = "cloud_ai"
115+
_attr_unique_id = AI_TASK_ENTITY_UNIQUE_ID
116+
117+
@property
118+
def available(self) -> bool:
119+
"""Return if the entity is available."""
120+
return self._cloud.is_logged_in and self._cloud.valid_subscription
121+
122+
async def _async_generate_data(
123+
self,
124+
task: ai_task.GenDataTask,
125+
chat_log: conversation.ChatLog,
126+
) -> ai_task.GenDataTaskResult:
127+
"""Handle a generate data task."""
128+
await self._async_handle_chat_log(
129+
"ai_task", chat_log, task.name, task.structure
130+
)
131+
132+
if not isinstance(chat_log.content[-1], conversation.AssistantContent):
133+
raise HomeAssistantError(
134+
"Last content in chat log is not an AssistantContent"
135+
)
136+
137+
text = chat_log.content[-1].content or ""
138+
139+
if not task.structure:
140+
return ai_task.GenDataTaskResult(
141+
conversation_id=chat_log.conversation_id,
142+
data=text,
143+
)
144+
try:
145+
data = json_loads(text)
146+
except JSONDecodeError as err:
147+
_LOGGER.error(
148+
"Failed to parse JSON response: %s. Response: %s",
149+
err,
150+
text,
151+
)
152+
raise HomeAssistantError("Error with OpenAI structured response") from err
153+
154+
return ai_task.GenDataTaskResult(
155+
conversation_id=chat_log.conversation_id,
156+
data=data,
157+
)
158+
159+
async def _async_generate_image(
160+
self,
161+
task: ai_task.GenImageTask,
162+
chat_log: conversation.ChatLog,
163+
) -> ai_task.GenImageTaskResult:
164+
"""Handle a generate image task."""
165+
attachments: list[LLMImageAttachment] | None = None
166+
if task.attachments:
167+
attachments = await async_prepare_image_generation_attachments(
168+
self.hass, task.attachments
169+
)
170+
171+
try:
172+
if attachments is None:
173+
image = await self._cloud.llm.async_generate_image(
174+
prompt=task.instructions,
175+
)
176+
else:
177+
image = await self._cloud.llm.async_edit_image(
178+
prompt=task.instructions,
179+
attachments=attachments,
180+
)
181+
except LLMAuthenticationError as err:
182+
raise ConfigEntryAuthFailed("Cloud LLM authentication failed") from err
183+
except LLMRateLimitError as err:
184+
raise HomeAssistantError("Cloud LLM is rate limited") from err
185+
except LLMResponseError as err:
186+
raise HomeAssistantError(str(err)) from err
187+
except LLMServiceError as err:
188+
raise HomeAssistantError("Error talking to Cloud LLM") from err
189+
except LLMError as err:
190+
raise HomeAssistantError(str(err)) from err
191+
192+
return ai_task.GenImageTaskResult(
193+
conversation_id=chat_log.conversation_id,
194+
mime_type=image["mime_type"],
195+
image_data=image["image_data"],
196+
model=image.get("model"),
197+
width=image.get("width"),
198+
height=image.get("height"),
199+
revised_prompt=image.get("revised_prompt"),
200+
)

homeassistant/components/cloud/const.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -91,6 +91,7 @@
9191

9292
STT_ENTITY_UNIQUE_ID = "cloud-speech-to-text"
9393
TTS_ENTITY_UNIQUE_ID = "cloud-text-to-speech"
94+
AI_TASK_ENTITY_UNIQUE_ID = "cloud-ai-task"
9495

9596
LOGIN_MFA_TIMEOUT = 60
9697

0 commit comments

Comments
 (0)