Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 4 additions & 3 deletions homeassistant/components/devolo_home_network/coordinator.py
Original file line number Diff line number Diff line change
Expand Up @@ -207,7 +207,7 @@ async def async_update_last_restart(self) -> int:


class DevoloWifiConnectedStationsGetCoordinator(
DevoloDataUpdateCoordinator[list[ConnectedStationInfo]]
DevoloDataUpdateCoordinator[dict[str, ConnectedStationInfo]]
):
"""Class to manage fetching data from the WifiGuestAccessGet endpoint."""

Expand All @@ -230,10 +230,11 @@ def __init__(
)
self.update_method = self.async_get_wifi_connected_station

async def async_get_wifi_connected_station(self) -> list[ConnectedStationInfo]:
async def async_get_wifi_connected_station(self) -> dict[str, ConnectedStationInfo]:
"""Fetch data from API endpoint."""
assert self.device.device
return await self.device.device.async_get_wifi_connected_station()
clients = await self.device.device.async_get_wifi_connected_station()
return {client.mac_address: client for client in clients}


class DevoloWifiGuestAccessGetCoordinator(
Expand Down
35 changes: 13 additions & 22 deletions homeassistant/components/devolo_home_network/device_tracker.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,26 +28,26 @@ async def async_setup_entry(
) -> None:
"""Get all devices and sensors and setup them via config entry."""
device = entry.runtime_data.device
coordinators: dict[str, DevoloDataUpdateCoordinator[list[ConnectedStationInfo]]] = (
entry.runtime_data.coordinators
)
coordinators: dict[
str, DevoloDataUpdateCoordinator[dict[str, ConnectedStationInfo]]
] = entry.runtime_data.coordinators
registry = er.async_get(hass)
tracked = set()

@callback
def new_device_callback() -> None:
"""Add new devices if needed."""
new_entities = []
for station in coordinators[CONNECTED_WIFI_CLIENTS].data:
if station.mac_address in tracked:
for mac_address in coordinators[CONNECTED_WIFI_CLIENTS].data:
if mac_address in tracked:
continue

new_entities.append(
DevoloScannerEntity(
coordinators[CONNECTED_WIFI_CLIENTS], device, station.mac_address
coordinators[CONNECTED_WIFI_CLIENTS], device, mac_address
)
)
tracked.add(station.mac_address)
tracked.add(mac_address)
async_add_entities(new_entities)

@callback
Expand Down Expand Up @@ -82,7 +82,7 @@ def restore_entities() -> None:

# The pylint disable is needed because of https://github.com/pylint-dev/pylint/issues/9138
class DevoloScannerEntity( # pylint: disable=hass-enforce-class-module
CoordinatorEntity[DevoloDataUpdateCoordinator[list[ConnectedStationInfo]]],
CoordinatorEntity[DevoloDataUpdateCoordinator[dict[str, ConnectedStationInfo]]],
ScannerEntity,
):
"""Representation of a devolo device tracker."""
Expand All @@ -92,7 +92,7 @@ class DevoloScannerEntity( # pylint: disable=hass-enforce-class-module

def __init__(
self,
coordinator: DevoloDataUpdateCoordinator[list[ConnectedStationInfo]],
coordinator: DevoloDataUpdateCoordinator[dict[str, ConnectedStationInfo]],
device: Device,
mac: str,
) -> None:
Expand All @@ -109,14 +109,8 @@ def extra_state_attributes(self) -> dict[str, str]:
if not self.coordinator.data:
return {}

station = next(
(
station
for station in self.coordinator.data
if station.mac_address == self.mac_address
),
None,
)
assert self.mac_address
station = self.coordinator.data.get(self.mac_address)
if station:
attrs["wifi"] = WIFI_APTYPE.get(station.vap_type, STATE_UNKNOWN)
attrs["band"] = (
Expand All @@ -129,11 +123,8 @@ def extra_state_attributes(self) -> dict[str, str]:
@property
def is_connected(self) -> bool:
"""Return true if the device is connected to the network."""
return any(
station
for station in self.coordinator.data
if station.mac_address == self.mac_address
)
assert self.mac_address
return self.coordinator.data.get(self.mac_address) is not None

@property
def unique_id(self) -> str:
Expand Down
2 changes: 1 addition & 1 deletion homeassistant/components/devolo_home_network/entity.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@
type _DataType = (
LogicalNetwork
| DataRate
| list[ConnectedStationInfo]
| dict[str, ConnectedStationInfo]
| list[NeighborAPInfo]
| WifiGuestAccessGet
| bool
Expand Down
8 changes: 6 additions & 2 deletions homeassistant/components/devolo_home_network/sensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,11 @@ def _last_restart(runtime: int) -> datetime:


type _CoordinatorDataType = (
LogicalNetwork | DataRate | list[ConnectedStationInfo] | list[NeighborAPInfo] | int
LogicalNetwork
| DataRate
| dict[str, ConnectedStationInfo]
| list[NeighborAPInfo]
| int
)
type _SensorDataType = int | float | datetime

Expand Down Expand Up @@ -79,7 +83,7 @@ class DevoloSensorEntityDescription[
),
),
CONNECTED_WIFI_CLIENTS: DevoloSensorEntityDescription[
list[ConnectedStationInfo], int
dict[str, ConnectedStationInfo], int
](
key=CONNECTED_WIFI_CLIENTS,
state_class=SensorStateClass.MEASUREMENT,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,11 +2,14 @@

from __future__ import annotations

from json import JSONDecodeError

from homeassistant.components import ai_task, conversation
from homeassistant.config_entries import ConfigEntry
from homeassistant.core import HomeAssistant
from homeassistant.exceptions import HomeAssistantError
from homeassistant.helpers.entity_platform import AddConfigEntryEntitiesCallback
from homeassistant.util.json import json_loads

from .const import LOGGER
from .entity import ERROR_GETTING_RESPONSE, GoogleGenerativeAILLMBaseEntity
Expand Down Expand Up @@ -42,7 +45,7 @@ async def _async_generate_data(
chat_log: conversation.ChatLog,
) -> ai_task.GenDataTaskResult:
"""Handle a generate data task."""
await self._async_handle_chat_log(chat_log)
await self._async_handle_chat_log(chat_log, task.structure)

if not isinstance(chat_log.content[-1], conversation.AssistantContent):
LOGGER.error(
Expand All @@ -51,7 +54,25 @@ async def _async_generate_data(
)
raise HomeAssistantError(ERROR_GETTING_RESPONSE)

text = chat_log.content[-1].content or ""

if not task.structure:
return ai_task.GenDataTaskResult(
conversation_id=chat_log.conversation_id,
data=text,
)

try:
data = json_loads(text)
except JSONDecodeError as err:
LOGGER.error(
"Failed to parse JSON response: %s. Response: %s",
err,
text,
)
raise HomeAssistantError(ERROR_GETTING_RESPONSE) from err

return ai_task.GenDataTaskResult(
conversation_id=chat_log.conversation_id,
data=chat_log.content[-1].content or "",
data=data,
)
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
Schema,
Tool,
)
import voluptuous as vol
from voluptuous_openapi import convert

from homeassistant.components import conversation
Expand Down Expand Up @@ -324,6 +325,7 @@ def __init__(
async def _async_handle_chat_log(
self,
chat_log: conversation.ChatLog,
structure: vol.Schema | None = None,
) -> None:
"""Generate an answer for the chat log."""
options = self.subentry.data
Expand Down Expand Up @@ -402,6 +404,18 @@ async def _async_handle_chat_log(
generateContentConfig.automatic_function_calling = (
AutomaticFunctionCallingConfig(disable=True, maximum_remote_calls=None)
)
if structure:
generateContentConfig.response_mime_type = "application/json"
generateContentConfig.response_schema = _format_schema(
convert(
structure,
custom_serializer=(
chat_log.llm_api.custom_serializer
if chat_log.llm_api
else llm.selector_serializer
),
)
)

if not supports_system_instruction:
messages = [
Expand Down
2 changes: 1 addition & 1 deletion homeassistant/components/miele/const.py
Original file line number Diff line number Diff line change
Expand Up @@ -1314,7 +1314,7 @@ class PlatePowerStep(MieleEnum):
plate_step_11 = 11
plate_step_12 = 12
plate_step_13 = 13
plate_step_14 = 4
plate_step_14 = 14
plate_step_15 = 15
plate_step_16 = 16
plate_step_17 = 17
Expand Down
13 changes: 8 additions & 5 deletions homeassistant/components/squeezebox/browse_media.py
Original file line number Diff line number Diff line change
Expand Up @@ -221,12 +221,16 @@ def _get_item_thumbnail(
) -> str | None:
"""Construct path to thumbnail image."""
item_thumbnail: str | None = None
if artwork_track_id := item.get("artwork_track_id"):
track_id = item.get("artwork_track_id") or (
item.get("id") if item_type == "track" else None
)

if track_id:
if internal_request:
item_thumbnail = player.generate_image_url_from_track_id(artwork_track_id)
item_thumbnail = player.generate_image_url_from_track_id(track_id)
elif item_type is not None:
item_thumbnail = entity.get_browse_image_url(
item_type, item["id"], artwork_track_id
item_type, item["id"], track_id
)

elif search_type in ["apps", "radios"]:
Expand Down Expand Up @@ -311,8 +315,7 @@ async def build_item_response(
title=item["title"],
media_content_type=item_type,
media_class=CONTENT_TYPE_MEDIA_CLASS[item_type]["item"],
can_expand=CONTENT_TYPE_MEDIA_CLASS[item_type]["children"]
is not None,
can_expand=bool(CONTENT_TYPE_MEDIA_CLASS[item_type]["children"]),
can_play=True,
)

Expand Down
10 changes: 5 additions & 5 deletions homeassistant/helpers/llm.py
Original file line number Diff line number Diff line change
Expand Up @@ -331,7 +331,7 @@ class NamespacedTool(Tool):
def __init__(self, namespace: str, tool: Tool) -> None:
"""Init the class."""
self.namespace = namespace
self.name = f"{namespace}.{tool.name}"
self.name = f"{namespace}__{tool.name}"
self.description = tool.description
self.parameters = tool.parameters
self.tool = tool
Expand Down Expand Up @@ -458,7 +458,7 @@ async def async_get_api_instance(self, llm_context: LLMContext) -> APIInstance:
api_prompt=self._async_get_api_prompt(llm_context, exposed_entities),
llm_context=llm_context,
tools=self._async_get_tools(llm_context, exposed_entities),
custom_serializer=_selector_serializer,
custom_serializer=selector_serializer,
)

@callback
Expand Down Expand Up @@ -701,7 +701,7 @@ def _get_exposed_entities(
return data


def _selector_serializer(schema: Any) -> Any: # noqa: C901
def selector_serializer(schema: Any) -> Any: # noqa: C901
"""Convert selectors into OpenAPI schema."""
if not isinstance(schema, selector.Selector):
return UNSUPPORTED
Expand Down Expand Up @@ -782,7 +782,7 @@ def _selector_serializer(schema: Any) -> Any: # noqa: C901
result["properties"] = {
field: convert(
selector.selector(field_schema["selector"]),
custom_serializer=_selector_serializer,
custom_serializer=selector_serializer,
)
for field, field_schema in fields.items()
}
Expand Down Expand Up @@ -915,7 +915,7 @@ def __init__(
"""Init the class."""
self._domain = domain
self._action = action
self.name = f"{domain}.{action}"
self.name = f"{domain}__{action}"
# Note: _get_cached_action_parameters only works for services which
# add their description directly to the service description cache.
# This is not the case for most services, but it is for scripts.
Expand Down
23 changes: 15 additions & 8 deletions tests/components/google_generative_ai_conversation/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -112,19 +112,26 @@ async def setup_ha(hass: HomeAssistant) -> None:


@pytest.fixture
def mock_send_message_stream() -> Generator[AsyncMock]:
def mock_chat_create() -> Generator[AsyncMock]:
"""Mock stream response."""

async def mock_generator(stream):
for value in stream:
yield value

mock_send_message_stream = AsyncMock()
mock_send_message_stream.side_effect = lambda **kwargs: mock_generator(
mock_send_message_stream.return_value.pop(0)
)

with patch(
"google.genai.chats.AsyncChat.send_message_stream",
AsyncMock(),
) as mock_send_message_stream:
mock_send_message_stream.side_effect = lambda **kwargs: mock_generator(
mock_send_message_stream.return_value.pop(0)
)
"google.genai.chats.AsyncChats.create",
return_value=AsyncMock(send_message_stream=mock_send_message_stream),
) as mock_create:
yield mock_create

yield mock_send_message_stream

@pytest.fixture
def mock_send_message_stream(mock_chat_create) -> Generator[AsyncMock]:
"""Mock stream response."""
return mock_chat_create.return_value.send_message_stream
Loading
Loading