Skip to content

Commit 8cb9cad

Browse files
balloobCopilotallenporter
authored
Extract files_to_prompt from Gemini action (home-assistant#148203)
Co-authored-by: Copilot <[email protected]> Co-authored-by: Allen Porter <[email protected]>
1 parent 075efb4 commit 8cb9cad

File tree

4 files changed

+92
-54
lines changed

4 files changed

+92
-54
lines changed

homeassistant/components/google_generative_ai_conversation/__init__.py

Lines changed: 9 additions & 48 deletions
Original file line numberDiff line numberDiff line change
@@ -2,15 +2,12 @@
22

33
from __future__ import annotations
44

5-
import asyncio
65
from functools import partial
7-
import mimetypes
86
from pathlib import Path
97
from types import MappingProxyType
108

119
from google.genai import Client
1210
from google.genai.errors import APIError, ClientError
13-
from google.genai.types import File, FileState
1411
from requests.exceptions import Timeout
1512
import voluptuous as vol
1613

@@ -42,13 +39,13 @@
4239
DEFAULT_TITLE,
4340
DEFAULT_TTS_NAME,
4441
DOMAIN,
45-
FILE_POLLING_INTERVAL_SECONDS,
4642
LOGGER,
4743
RECOMMENDED_AI_TASK_OPTIONS,
4844
RECOMMENDED_CHAT_MODEL,
4945
RECOMMENDED_TTS_OPTIONS,
5046
TIMEOUT_MILLIS,
5147
)
48+
from .entity import async_prepare_files_for_prompt
5249

5350
SERVICE_GENERATE_CONTENT = "generate_content"
5451
CONF_IMAGE_FILENAME = "image_filename"
@@ -92,58 +89,22 @@ async def generate_content(call: ServiceCall) -> ServiceResponse:
9289

9390
client = config_entry.runtime_data
9491

95-
def append_files_to_prompt():
96-
image_filenames = call.data[CONF_IMAGE_FILENAME]
97-
filenames = call.data[CONF_FILENAMES]
98-
for filename in set(image_filenames + filenames):
92+
files = call.data[CONF_IMAGE_FILENAME] + call.data[CONF_FILENAMES]
93+
94+
if files:
95+
for filename in files:
9996
if not hass.config.is_allowed_path(filename):
10097
raise HomeAssistantError(
10198
f"Cannot read `{filename}`, no access to path; "
10299
"`allowlist_external_dirs` may need to be adjusted in "
103100
"`configuration.yaml`"
104101
)
105-
if not Path(filename).exists():
106-
raise HomeAssistantError(f"`{filename}` does not exist")
107-
mimetype = mimetypes.guess_type(filename)[0]
108-
with open(filename, "rb") as file:
109-
uploaded_file = client.files.upload(
110-
file=file, config={"mime_type": mimetype}
111-
)
112-
prompt_parts.append(uploaded_file)
113-
114-
async def wait_for_file_processing(uploaded_file: File) -> None:
115-
"""Wait for file processing to complete."""
116-
while True:
117-
uploaded_file = await client.aio.files.get(
118-
name=uploaded_file.name,
119-
config={"http_options": {"timeout": TIMEOUT_MILLIS}},
120-
)
121-
if uploaded_file.state not in (
122-
FileState.STATE_UNSPECIFIED,
123-
FileState.PROCESSING,
124-
):
125-
break
126-
LOGGER.debug(
127-
"Waiting for file `%s` to be processed, current state: %s",
128-
uploaded_file.name,
129-
uploaded_file.state,
130-
)
131-
await asyncio.sleep(FILE_POLLING_INTERVAL_SECONDS)
132102

133-
if uploaded_file.state == FileState.FAILED:
134-
raise HomeAssistantError(
135-
f"File `{uploaded_file.name}` processing failed, reason: {uploaded_file.error.message}"
103+
prompt_parts.extend(
104+
await async_prepare_files_for_prompt(
105+
hass, client, [Path(filename) for filename in files]
136106
)
137-
138-
await hass.async_add_executor_job(append_files_to_prompt)
139-
140-
tasks = [
141-
asyncio.create_task(wait_for_file_processing(part))
142-
for part in prompt_parts
143-
if isinstance(part, File) and part.state != FileState.ACTIVE
144-
]
145-
async with asyncio.timeout(TIMEOUT_MILLIS / 1000):
146-
await asyncio.gather(*tasks)
107+
)
147108

148109
try:
149110
response = await client.aio.models.generate_content(

homeassistant/components/google_generative_ai_conversation/entity.py

Lines changed: 74 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,15 +2,21 @@
22

33
from __future__ import annotations
44

5+
import asyncio
56
import codecs
67
from collections.abc import AsyncGenerator, Callable
78
from dataclasses import replace
9+
import mimetypes
10+
from pathlib import Path
811
from typing import Any, cast
912

13+
from google.genai import Client
1014
from google.genai.errors import APIError, ClientError
1115
from google.genai.types import (
1216
AutomaticFunctionCallingConfig,
1317
Content,
18+
File,
19+
FileState,
1420
FunctionDeclaration,
1521
GenerateContentConfig,
1622
GenerateContentResponse,
@@ -26,6 +32,7 @@
2632

2733
from homeassistant.components import conversation
2834
from homeassistant.config_entries import ConfigEntry, ConfigSubentry
35+
from homeassistant.core import HomeAssistant
2936
from homeassistant.exceptions import HomeAssistantError
3037
from homeassistant.helpers import device_registry as dr, llm
3138
from homeassistant.helpers.entity import Entity
@@ -42,13 +49,15 @@
4249
CONF_TOP_P,
4350
CONF_USE_GOOGLE_SEARCH_TOOL,
4451
DOMAIN,
52+
FILE_POLLING_INTERVAL_SECONDS,
4553
LOGGER,
4654
RECOMMENDED_CHAT_MODEL,
4755
RECOMMENDED_HARM_BLOCK_THRESHOLD,
4856
RECOMMENDED_MAX_TOKENS,
4957
RECOMMENDED_TEMPERATURE,
5058
RECOMMENDED_TOP_K,
5159
RECOMMENDED_TOP_P,
60+
TIMEOUT_MILLIS,
5261
)
5362

5463
# Max number of back and forth with the LLM to generate a response
@@ -494,3 +503,68 @@ def create_generate_content_config(self) -> GenerateContentConfig:
494503
),
495504
],
496505
)
506+
507+
508+
async def async_prepare_files_for_prompt(
509+
hass: HomeAssistant, client: Client, files: list[Path]
510+
) -> list[File]:
511+
"""Append files to a prompt.
512+
513+
Caller needs to ensure that the files are allowed.
514+
"""
515+
516+
def upload_files() -> list[File]:
517+
prompt_parts: list[File] = []
518+
for filename in files:
519+
if not filename.exists():
520+
raise HomeAssistantError(f"`{filename}` does not exist")
521+
mimetype = mimetypes.guess_type(filename)[0]
522+
prompt_parts.append(
523+
client.files.upload(
524+
file=filename,
525+
config={
526+
"mime_type": mimetype,
527+
"display_name": filename.name,
528+
},
529+
)
530+
)
531+
return prompt_parts
532+
533+
async def wait_for_file_processing(uploaded_file: File) -> None:
534+
"""Wait for file processing to complete."""
535+
first = True
536+
while uploaded_file.state in (
537+
FileState.STATE_UNSPECIFIED,
538+
FileState.PROCESSING,
539+
):
540+
if first:
541+
first = False
542+
else:
543+
LOGGER.debug(
544+
"Waiting for file `%s` to be processed, current state: %s",
545+
uploaded_file.name,
546+
uploaded_file.state,
547+
)
548+
await asyncio.sleep(FILE_POLLING_INTERVAL_SECONDS)
549+
550+
uploaded_file = await client.aio.files.get(
551+
name=uploaded_file.name,
552+
config={"http_options": {"timeout": TIMEOUT_MILLIS}},
553+
)
554+
555+
if uploaded_file.state == FileState.FAILED:
556+
raise HomeAssistantError(
557+
f"File `{uploaded_file.name}` processing failed, reason: {uploaded_file.error.message}"
558+
)
559+
560+
prompt_parts = await hass.async_add_executor_job(upload_files)
561+
562+
tasks = [
563+
asyncio.create_task(wait_for_file_processing(part))
564+
for part in prompt_parts
565+
if part.state != FileState.ACTIVE
566+
]
567+
async with asyncio.timeout(TIMEOUT_MILLIS / 1000):
568+
await asyncio.gather(*tasks)
569+
570+
return prompt_parts

tests/components/google_generative_ai_conversation/snapshots/test_init.ambr

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -122,8 +122,8 @@
122122
dict({
123123
'contents': list([
124124
'Describe this image from my doorbell camera',
125-
b'some file',
126-
b'some file',
125+
File(name='doorbell_snapshot.jpg', display_name=None, mime_type=None, size_bytes=None, create_time=None, expiration_time=None, update_time=None, sha256_hash=None, uri=None, download_uri=None, state=<FileState.ACTIVE: 'ACTIVE'>, source=None, video_metadata=None, error=None),
126+
File(name='context.txt', display_name=None, mime_type=None, size_bytes=None, create_time=None, expiration_time=None, update_time=None, sha256_hash=None, uri=None, download_uri=None, state=<FileState.ACTIVE: 'ACTIVE'>, source=None, video_metadata=None, error=None),
127127
]),
128128
'model': 'models/gemini-2.5-flash',
129129
}),

tests/components/google_generative_ai_conversation/test_init.py

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -80,7 +80,10 @@ async def test_generate_content_service_with_image(
8080
) as mock_generate,
8181
patch(
8282
"google.genai.files.Files.upload",
83-
return_value=b"some file",
83+
side_effect=[
84+
File(name="doorbell_snapshot.jpg", state=FileState.ACTIVE),
85+
File(name="context.txt", state=FileState.ACTIVE),
86+
],
8487
),
8588
patch("pathlib.Path.exists", return_value=True),
8689
patch.object(hass.config, "is_allowed_path", return_value=True),
@@ -92,7 +95,7 @@ async def test_generate_content_service_with_image(
9295
"generate_content",
9396
{
9497
"prompt": "Describe this image from my doorbell camera",
95-
"filenames": ["doorbell_snapshot.jpg", "context.txt", "context.txt"],
98+
"filenames": ["doorbell_snapshot.jpg", "context.txt"],
9699
},
97100
blocking=True,
98101
return_response=True,
@@ -146,7 +149,7 @@ async def test_generate_content_file_processing_succeeds(
146149
"generate_content",
147150
{
148151
"prompt": "Describe this image from my doorbell camera",
149-
"filenames": ["doorbell_snapshot.jpg", "context.txt", "context.txt"],
152+
"filenames": ["doorbell_snapshot.jpg", "context.txt"],
150153
},
151154
blocking=True,
152155
return_response=True,
@@ -208,7 +211,7 @@ async def test_generate_content_file_processing_fails(
208211
"generate_content",
209212
{
210213
"prompt": "Describe this image from my doorbell camera",
211-
"filenames": ["doorbell_snapshot.jpg", "context.txt", "context.txt"],
214+
"filenames": ["doorbell_snapshot.jpg", "context.txt"],
212215
},
213216
blocking=True,
214217
return_response=True,

0 commit comments

Comments
 (0)