Skip to content

Commit 87e7fe6

Browse files
authored
Add custom (external) wake words (home-assistant#152919)
1 parent c782489 commit 87e7fe6

File tree

11 files changed

+238
-3
lines changed

11 files changed

+238
-3
lines changed

homeassistant/components/esphome/__init__.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@
1717
from homeassistant.helpers.issue_registry import async_delete_issue
1818
from homeassistant.helpers.typing import ConfigType
1919

20-
from . import dashboard, ffmpeg_proxy
20+
from . import assist_satellite, dashboard, ffmpeg_proxy
2121
from .const import CONF_BLUETOOTH_MAC_ADDRESS, CONF_NOISE_PSK, DOMAIN
2222
from .domain_data import DomainData
2323
from .entry_data import ESPHomeConfigEntry, RuntimeEntryData
@@ -31,6 +31,7 @@
3131
async def async_setup(hass: HomeAssistant, config: ConfigType) -> bool:
3232
"""Set up the esphome component."""
3333
ffmpeg_proxy.async_setup(hass)
34+
await assist_satellite.async_setup(hass)
3435
await dashboard.async_setup(hass)
3536
return True
3637

homeassistant/components/esphome/assist_satellite.py

Lines changed: 102 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -5,9 +5,12 @@
55
import asyncio
66
from collections.abc import AsyncIterable
77
from functools import partial
8+
import hashlib
89
import io
910
from itertools import chain
11+
import json
1012
import logging
13+
from pathlib import Path
1114
import socket
1215
from typing import Any, cast
1316
import wave
@@ -19,16 +22,20 @@
1922
VoiceAssistantAudioSettings,
2023
VoiceAssistantCommandFlag,
2124
VoiceAssistantEventType,
25+
VoiceAssistantExternalWakeWord,
2226
VoiceAssistantFeature,
2327
VoiceAssistantTimerEventType,
2428
)
29+
import voluptuous as vol
30+
from voluptuous.humanize import humanize_error
2531

2632
from homeassistant.components import assist_satellite, tts
2733
from homeassistant.components.assist_pipeline import (
2834
PipelineEvent,
2935
PipelineEventType,
3036
PipelineStage,
3137
)
38+
from homeassistant.components.http import StaticPathConfig
3239
from homeassistant.components.intent import (
3340
TimerEventType,
3441
TimerInfo,
@@ -39,8 +46,11 @@
3946
from homeassistant.core import HomeAssistant, callback
4047
from homeassistant.helpers import entity_registry as er
4148
from homeassistant.helpers.entity_platform import AddConfigEntryEntitiesCallback
49+
from homeassistant.helpers.network import get_url
50+
from homeassistant.helpers.singleton import singleton
51+
from homeassistant.util.hass_dict import HassKey
4252

43-
from .const import DOMAIN
53+
from .const import DOMAIN, WAKE_WORDS_API_PATH, WAKE_WORDS_DIR_NAME
4454
from .entity import EsphomeAssistEntity, convert_api_error_ha_error
4555
from .entry_data import ESPHomeConfigEntry
4656
from .enum_mapper import EsphomeEnumMapper
@@ -84,6 +94,16 @@
8494

8595
_ANNOUNCEMENT_TIMEOUT_SEC = 5 * 60 # 5 minutes
8696
_CONFIG_TIMEOUT_SEC = 5
97+
_WAKE_WORD_CONFIG_SCHEMA = vol.Schema(
98+
{
99+
vol.Required("type"): str,
100+
vol.Required("wake_word"): str,
101+
},
102+
extra=vol.ALLOW_EXTRA,
103+
)
104+
_DATA_WAKE_WORDS: HassKey[dict[str, VoiceAssistantExternalWakeWord]] = HassKey(
105+
"wake_word_cache"
106+
)
87107

88108

89109
async def async_setup_entry(
@@ -182,9 +202,14 @@ async def async_set_configuration(
182202

183203
async def _update_satellite_config(self) -> None:
184204
"""Get the latest satellite configuration from the device."""
205+
wake_words = await async_get_custom_wake_words(self.hass)
206+
if wake_words:
207+
_LOGGER.debug("Found custom wake words: %s", sorted(wake_words.keys()))
208+
185209
try:
186210
config = await self.cli.get_voice_assistant_configuration(
187-
_CONFIG_TIMEOUT_SEC
211+
_CONFIG_TIMEOUT_SEC,
212+
external_wake_words=list(wake_words.values()),
188213
)
189214
except TimeoutError:
190215
# Placeholder config will be used
@@ -784,3 +809,78 @@ def send_audio_bytes(self, data: bytes) -> None:
784809
return
785810

786811
self.transport.sendto(data, self.remote_addr)
812+
813+
814+
async def async_get_custom_wake_words(
815+
hass: HomeAssistant,
816+
) -> dict[str, VoiceAssistantExternalWakeWord]:
817+
"""Get available custom wake words."""
818+
return await hass.async_add_executor_job(_get_custom_wake_words, hass)
819+
820+
821+
@singleton(_DATA_WAKE_WORDS)
822+
def _get_custom_wake_words(
823+
hass: HomeAssistant,
824+
) -> dict[str, VoiceAssistantExternalWakeWord]:
825+
"""Get available custom wake words (singleton)."""
826+
wake_words_dir = Path(hass.config.path(WAKE_WORDS_DIR_NAME))
827+
wake_words: dict[str, VoiceAssistantExternalWakeWord] = {}
828+
829+
# Look for config/model files
830+
for config_path in wake_words_dir.glob("*.json"):
831+
wake_word_id = config_path.stem
832+
model_path = config_path.with_suffix(".tflite")
833+
if not model_path.exists():
834+
# Missing model file
835+
continue
836+
837+
with open(config_path, encoding="utf-8") as config_file:
838+
config_dict = json.load(config_file)
839+
try:
840+
config = _WAKE_WORD_CONFIG_SCHEMA(config_dict)
841+
except vol.Invalid as err:
842+
# Invalid config
843+
_LOGGER.debug(
844+
"Invalid wake word config: path=%s, error=%s",
845+
config_path,
846+
humanize_error(config_dict, err),
847+
)
848+
continue
849+
850+
with open(model_path, "rb") as model_file:
851+
model_hash = hashlib.sha256(model_file.read()).hexdigest()
852+
853+
model_size = model_path.stat().st_size
854+
config_rel_path = config_path.relative_to(wake_words_dir)
855+
856+
# Only intended for the internal network
857+
base_url = get_url(hass, prefer_external=False, allow_cloud=False)
858+
859+
wake_words[wake_word_id] = VoiceAssistantExternalWakeWord.from_dict(
860+
{
861+
"id": wake_word_id,
862+
"wake_word": config["wake_word"],
863+
"trained_languages": config_dict.get("trained_languages", []),
864+
"model_type": config["type"],
865+
"model_size": model_size,
866+
"model_hash": model_hash,
867+
"url": f"{base_url}{WAKE_WORDS_API_PATH}/{config_rel_path}",
868+
}
869+
)
870+
871+
return wake_words
872+
873+
874+
async def async_setup(hass: HomeAssistant) -> None:
875+
"""Set up the satellite."""
876+
wake_words_dir = Path(hass.config.path(WAKE_WORDS_DIR_NAME))
877+
878+
# Satellites will pull model files over HTTP
879+
await hass.http.async_register_static_paths(
880+
[
881+
StaticPathConfig(
882+
url_path=WAKE_WORDS_API_PATH,
883+
path=str(wake_words_dir),
884+
)
885+
]
886+
)

homeassistant/components/esphome/const.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -27,3 +27,6 @@
2727
DEFAULT_URL = f"https://esphome.io/changelog/{STABLE_BLE_URL_VERSION}.html"
2828

2929
NO_WAKE_WORD: Final[str] = "no_wake_word"
30+
31+
WAKE_WORDS_DIR_NAME = "custom_wake_words"
32+
WAKE_WORDS_API_PATH = "/api/esphome/wake_words"

tests/components/esphome/test_assist_satellite.py

Lines changed: 79 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22

33
import asyncio
44
from dataclasses import replace
5+
from http import HTTPStatus
56
import io
67
import socket
78
from unittest.mock import ANY, AsyncMock, Mock, patch
@@ -55,6 +56,7 @@
5556
from .conftest import MockESPHomeDeviceType
5657

5758
from tests.components.tts.common import MockResultStream
59+
from tests.typing import ClientSessionGenerator
5860

5961

6062
@pytest.fixture
@@ -2087,3 +2089,80 @@ async def get_pipeline(wake_word_phrase):
20872089

20882090
# Primary pipeline should be restored after
20892091
assert (await get_pipeline(None)) == "Primary Pipeline"
2092+
2093+
2094+
async def test_custom_wake_words(
2095+
hass: HomeAssistant,
2096+
mock_client: APIClient,
2097+
mock_esphome_device: MockESPHomeDeviceType,
2098+
hass_client: ClientSessionGenerator,
2099+
) -> None:
2100+
"""Test exposing custom wake word models.
2101+
2102+
Expects 2 models in testing_config/custom_wake_words:
2103+
- hey_home_assistant
2104+
- choo_choo_homie
2105+
"""
2106+
http_client = await hass_client()
2107+
expected_config = AssistSatelliteConfiguration(
2108+
available_wake_words=[
2109+
AssistSatelliteWakeWord("1234", "okay nabu", ["en"]),
2110+
],
2111+
active_wake_words=["1234"],
2112+
max_active_wake_words=1,
2113+
)
2114+
gvac = mock_client.get_voice_assistant_configuration
2115+
gvac.return_value = expected_config
2116+
2117+
mock_device = await mock_esphome_device(
2118+
mock_client=mock_client,
2119+
device_info={
2120+
"voice_assistant_feature_flags": VoiceAssistantFeature.VOICE_ASSISTANT
2121+
| VoiceAssistantFeature.ANNOUNCE
2122+
},
2123+
)
2124+
await hass.async_block_till_done()
2125+
2126+
satellite = get_satellite_entity(hass, mock_device.device_info.mac_address)
2127+
assert satellite is not None
2128+
2129+
# Models should be present in testing_config/custom_wake_words
2130+
gvac.assert_called_once()
2131+
external_wake_words = gvac.call_args_list[0].kwargs["external_wake_words"]
2132+
assert len(external_wake_words) == 2
2133+
2134+
assert {external_wake_words[0].id, external_wake_words[1].id} == {
2135+
"hey_home_assistant",
2136+
"choo_choo_homie",
2137+
}
2138+
2139+
# Verify details
2140+
for eww in external_wake_words:
2141+
if eww.id == "hey_home_assistant":
2142+
assert eww.wake_word == "Hey Home Assistant"
2143+
else:
2144+
assert eww.wake_word == "Choo Choo Homie"
2145+
2146+
assert eww.model_type == "micro"
2147+
assert eww.model_size == 4 # tflite files contain "test"
2148+
assert (
2149+
eww.model_hash
2150+
== "9f86d081884c7d659a2feaa0c55ad015a3bf4f1b2b0b822cd15d6c15b0f00a08"
2151+
)
2152+
assert eww.trained_languages == ["en"]
2153+
2154+
# GET config
2155+
config_url = eww.url[eww.url.find("/api") :]
2156+
req = await http_client.get(config_url)
2157+
assert req.status == HTTPStatus.OK
2158+
config_dict = await req.json()
2159+
2160+
# GET model
2161+
model = config_dict["model"]
2162+
model_url = config_url[: config_url.rfind("/")] + f"/{model}"
2163+
req = await http_client.get(model_url)
2164+
assert req.status == HTTPStatus.OK
2165+
2166+
# Check non-existent wake word
2167+
req = await http_client.get("/api/esphome/wake_words/wrong_wake_word.json")
2168+
assert req.status == HTTPStatus.NOT_FOUND
Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
{}
Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
test
Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,16 @@
1+
{
2+
"type": "micro",
3+
"wake_word": "Choo Choo Homie",
4+
"author": "Michael Hansen",
5+
"website": "https://www.home-assistant.io",
6+
"model": "choo_choo_homie.tflite",
7+
"trained_languages": ["en"],
8+
"version": 2,
9+
"micro": {
10+
"probability_cutoff": 0.97,
11+
"feature_step_size": 10,
12+
"sliding_window_size": 5,
13+
"tensor_arena_size": 30000,
14+
"minimum_esphome_version": "2024.7.0"
15+
}
16+
}
Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
test
Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,16 @@
1+
{
2+
"type": "micro",
3+
"wake_word": "Hey Home Assistant",
4+
"author": "Michael Hansen",
5+
"website": "https://www.home-assistant.io",
6+
"model": "hey_home_assistant.tflite",
7+
"trained_languages": ["en"],
8+
"version": 2,
9+
"micro": {
10+
"probability_cutoff": 0.97,
11+
"feature_step_size": 10,
12+
"sliding_window_size": 5,
13+
"tensor_arena_size": 30000,
14+
"minimum_esphome_version": "2024.7.0"
15+
}
16+
}
Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
test

0 commit comments

Comments
 (0)