Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
96 changes: 95 additions & 1 deletion homeassistant/components/assist_satellite/__init__.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,23 @@
"""Base class for assist satellite entities."""

from dataclasses import asdict
import logging
from pathlib import Path
from typing import Any

from hassil.util import (
PUNCTUATION_END,
PUNCTUATION_END_WORD,
PUNCTUATION_START,
PUNCTUATION_START_WORD,
)
import voluptuous as vol

from homeassistant.components.http import StaticPathConfig
from homeassistant.config_entries import ConfigEntry
from homeassistant.core import HomeAssistant
from homeassistant.const import ATTR_ENTITY_ID
from homeassistant.core import HomeAssistant, ServiceCall, SupportsResponse
from homeassistant.exceptions import HomeAssistantError
from homeassistant.helpers import config_validation as cv
from homeassistant.helpers.entity_component import EntityComponent
from homeassistant.helpers.typing import ConfigType
Expand All @@ -23,6 +33,7 @@
)
from .entity import (
AssistSatelliteAnnouncement,
AssistSatelliteAnswer,
AssistSatelliteConfiguration,
AssistSatelliteEntity,
AssistSatelliteEntityDescription,
Expand All @@ -34,6 +45,7 @@
__all__ = [
"DOMAIN",
"AssistSatelliteAnnouncement",
"AssistSatelliteAnswer",
"AssistSatelliteConfiguration",
"AssistSatelliteEntity",
"AssistSatelliteEntityDescription",
Expand Down Expand Up @@ -86,6 +98,62 @@ async def async_setup(hass: HomeAssistant, config: ConfigType) -> bool:
"async_internal_start_conversation",
[AssistSatelliteEntityFeature.START_CONVERSATION],
)

async def handle_ask_question(call: ServiceCall) -> dict[str, Any]:
"""Handle a Show View service call."""
satellite_entity_id: str = call.data[ATTR_ENTITY_ID]
satellite_entity: AssistSatelliteEntity | None = component.get_entity(
satellite_entity_id
)
if satellite_entity is None:
raise HomeAssistantError(
f"Invalid Assist satellite entity id: {satellite_entity_id}"
)

ask_question_args = {
"question": call.data.get("question"),
"question_media_id": call.data.get("question_media_id"),
"preannounce": call.data.get("preannounce", False),
"answers": call.data.get("answers"),
}

if preannounce_media_id := call.data.get("preannounce_media_id"):
ask_question_args["preannounce_media_id"] = preannounce_media_id

answer = await satellite_entity.async_internal_ask_question(**ask_question_args)

if answer is None:
raise HomeAssistantError("No answer from satellite")

return asdict(answer)

hass.services.async_register(
domain=DOMAIN,
service="ask_question",
service_func=handle_ask_question,
schema=vol.All(
{
vol.Required(ATTR_ENTITY_ID): cv.entity_id,
vol.Optional("question"): str,
vol.Optional("question_media_id"): str,
vol.Optional("preannounce"): bool,
vol.Optional("preannounce_media_id"): str,
vol.Optional("answers"): [
{
vol.Required("id"): str,
vol.Required("sentences"): vol.All(
cv.ensure_list,
[cv.string],
has_one_non_empty_item,
has_no_punctuation,
),
}
],
},
cv.has_at_least_one_key("question", "question_media_id"),
),
supports_response=SupportsResponse.ONLY,
)
hass.data[CONNECTION_TEST_DATA] = {}
async_register_websocket_api(hass)
hass.http.register_view(ConnectionTestView())
Expand All @@ -110,3 +178,29 @@ async def async_setup_entry(hass: HomeAssistant, entry: ConfigEntry) -> bool:
async def async_unload_entry(hass: HomeAssistant, entry: ConfigEntry) -> bool:
"""Unload a config entry."""
return await hass.data[DATA_COMPONENT].async_unload_entry(entry)


def has_no_punctuation(value: list[str]) -> list[str]:
"""Validate result does not contain punctuation."""
for sentence in value:
if (
PUNCTUATION_START.search(sentence)
or PUNCTUATION_END.search(sentence)
or PUNCTUATION_START_WORD.search(sentence)
or PUNCTUATION_END_WORD.search(sentence)
):
raise vol.Invalid("sentence should not contain punctuation")

return value


def has_one_non_empty_item(value: list[str]) -> list[str]:
"""Validate result has at least one item."""
if len(value) < 1:
raise vol.Invalid("at least one sentence is required")

for sentence in value:
if not sentence:
raise vol.Invalid("sentences cannot be empty")

return value
161 changes: 160 additions & 1 deletion homeassistant/components/assist_satellite/entity.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,12 +4,16 @@
import asyncio
from collections.abc import AsyncIterable
import contextlib
from dataclasses import dataclass
from dataclasses import dataclass, field
from enum import StrEnum
import logging
import time
from typing import Any, Literal, final

from hassil import Intents, recognize
from hassil.expression import Expression, ListReference, Sequence
from hassil.intents import WildcardSlotList

from homeassistant.components import conversation, media_source, stt, tts
from homeassistant.components.assist_pipeline import (
OPTION_PREFERRED,
Expand Down Expand Up @@ -105,6 +109,20 @@ class AssistSatelliteAnnouncement:
"""Media ID to be played before announcement."""


@dataclass
class AssistSatelliteAnswer:
"""Answer to a question."""

id: str | None
"""Matched answer id or None if no answer was matched."""

sentence: str
"""Raw sentence text from user response."""

slots: dict[str, Any] = field(default_factory=dict)
"""Matched slots from answer."""


class AssistSatelliteEntity(entity.Entity):
"""Entity encapsulating the state and functionality of an Assist satellite."""

Expand All @@ -120,8 +138,10 @@ class AssistSatelliteEntity(entity.Entity):
_is_announcing = False
_extra_system_prompt: str | None = None
_wake_word_intercept_future: asyncio.Future[str | None] | None = None
_stt_intercept_future: asyncio.Future[str | None] | None = None
_attr_tts_options: dict[str, Any] | None = None
_pipeline_task: asyncio.Task | None = None
_ask_question_future: asyncio.Future[str | None] | None = None

__assist_satellite_state = AssistSatelliteState.IDLE

Expand Down Expand Up @@ -309,6 +329,112 @@ async def async_start_conversation(
"""Start a conversation from the satellite."""
raise NotImplementedError

async def async_internal_ask_question(
self,
question: str | None = None,
question_media_id: str | None = None,
preannounce: bool = True,
preannounce_media_id: str = PREANNOUNCE_URL,
answers: list[dict[str, Any]] | None = None,
) -> AssistSatelliteAnswer | None:
"""Ask a question and get a user's response from the satellite.

If question_media_id is not provided, question is synthesized to audio
with the selected pipeline.

If question_media_id is provided, it is played directly. It is possible
to omit the message and the satellite will not show any text.

If preannounce is True, a sound is played before the start message or media.
If preannounce_media_id is provided, it overrides the default sound.

Calls async_start_conversation.
"""
await self._cancel_running_pipeline()

if question is None:
question = ""

announcement = await self._resolve_announcement_media_id(
question,
question_media_id,
preannounce_media_id=preannounce_media_id if preannounce else None,
)

if self._is_announcing:
raise SatelliteBusyError

self._is_announcing = True
self._set_state(AssistSatelliteState.RESPONDING)
self._ask_question_future = asyncio.Future()

try:
# Wait for announcement to finish
await self.async_start_conversation(announcement)

# Wait for response text
response_text = await self._ask_question_future
if response_text is None:
raise HomeAssistantError("No answer from question")

if not answers:
return AssistSatelliteAnswer(id=None, sentence=response_text)

return self._question_response_to_answer(response_text, answers)
finally:
self._is_announcing = False
self._set_state(AssistSatelliteState.IDLE)
self._ask_question_future = None

def _question_response_to_answer(
self, response_text: str, answers: list[dict[str, Any]]
) -> AssistSatelliteAnswer:
"""Match text to a pre-defined set of answers."""

# Build intents and match
intents = Intents.from_dict(
{
"language": self.hass.config.language,
"intents": {
"QuestionIntent": {
"data": [
{
"sentences": answer["sentences"],
"metadata": {"answer_id": answer["id"]},
}
for answer in answers
]
}
},
}
)

# Assume slot list references are wildcards
wildcard_names: set[str] = set()
for intent in intents.intents.values():
for intent_data in intent.data:
for sentence in intent_data.sentences:
_collect_list_references(sentence, wildcard_names)

for wildcard_name in wildcard_names:
intents.slot_lists[wildcard_name] = WildcardSlotList(wildcard_name)

# Match response text
result = recognize(response_text, intents)
if result is None:
# No match
return AssistSatelliteAnswer(id=None, sentence=response_text)

assert result.intent_metadata
return AssistSatelliteAnswer(
id=result.intent_metadata["answer_id"],
sentence=response_text,
slots={
entity_name: entity.value
for entity_name, entity in result.entities.items()
},
)

async def async_accept_pipeline_from_satellite(
self,
audio_stream: AsyncIterable[bytes],
Expand Down Expand Up @@ -351,6 +477,11 @@ async def async_accept_pipeline_from_satellite(
self._internal_on_pipeline_event(PipelineEvent(PipelineEventType.RUN_END))
return

if (self._ask_question_future is not None) and (
start_stage == PipelineStage.STT
):
end_stage = PipelineStage.STT

device_id = self.registry_entry.device_id if self.registry_entry else None

# Refresh context if necessary
Expand Down Expand Up @@ -433,6 +564,16 @@ def _internal_on_pipeline_event(self, event: PipelineEvent) -> None:
self._set_state(AssistSatelliteState.IDLE)
elif event.type is PipelineEventType.STT_START:
self._set_state(AssistSatelliteState.LISTENING)
elif event.type is PipelineEventType.STT_END:
# Intercepting text for ask question
if (
(self._ask_question_future is not None)
and (not self._ask_question_future.done())
and event.data
):
self._ask_question_future.set_result(
event.data.get("stt_output", {}).get("text")
)
elif event.type is PipelineEventType.INTENT_START:
self._set_state(AssistSatelliteState.PROCESSING)
elif event.type is PipelineEventType.TTS_START:
Expand All @@ -443,6 +584,12 @@ def _internal_on_pipeline_event(self, event: PipelineEvent) -> None:
if not self._run_has_tts:
self._set_state(AssistSatelliteState.IDLE)

if (self._ask_question_future is not None) and (
not self._ask_question_future.done()
):
# No text for ask question
self._ask_question_future.set_result(None)

self.on_pipeline_event(event)

@callback
Expand Down Expand Up @@ -577,3 +724,15 @@ async def _resolve_announcement_media_id(
media_id_source=media_id_source,
preannounce_media_id=preannounce_media_id,
)


def _collect_list_references(expression: Expression, list_names: set[str]) -> None:
"""Collect list reference names recursively."""
if isinstance(expression, Sequence):
seq: Sequence = expression
for item in seq.items:
_collect_list_references(item, list_names)
elif isinstance(expression, ListReference):
# {list}
list_ref: ListReference = expression
list_names.add(list_ref.slot_name)
3 changes: 3 additions & 0 deletions homeassistant/components/assist_satellite/icons.json
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,9 @@
},
"start_conversation": {
"service": "mdi:forum"
},
"ask_question": {
"service": "mdi:microphone-question"
}
}
}
3 changes: 2 additions & 1 deletion homeassistant/components/assist_satellite/manifest.json
Original file line number Diff line number Diff line change
Expand Up @@ -5,5 +5,6 @@
"dependencies": ["assist_pipeline", "http", "stt", "tts"],
"documentation": "https://www.home-assistant.io/integrations/assist_satellite",
"integration_type": "entity",
"quality_scale": "internal"
"quality_scale": "internal",
"requirements": ["hassil==2.2.3"]
}
Loading
Loading