Skip to content

Commit 3dc8676

Browse files
authored
Add TTS streaming to Wyoming satellites (home-assistant#147438)
* Add TTS streaming using intent-progress * Handle incomplete header
1 parent 0f112bb commit 3dc8676

File tree

2 files changed

+286
-31
lines changed

2 files changed

+286
-31
lines changed

homeassistant/components/wyoming/assist_satellite.py

Lines changed: 105 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -132,6 +132,10 @@ def __init__(
132132
# Used to ensure TTS timeout is acted on correctly.
133133
self._run_loop_id: str | None = None
134134

135+
# TTS streaming
136+
self._tts_stream_token: str | None = None
137+
self._is_tts_streaming: bool = False
138+
135139
@property
136140
def pipeline_entity_id(self) -> str | None:
137141
"""Return the entity ID of the pipeline to use for the next conversation."""
@@ -179,11 +183,20 @@ def on_pipeline_event(self, event: PipelineEvent) -> None:
179183
"""Set state based on pipeline stage."""
180184
assert self._client is not None
181185

182-
if event.type == assist_pipeline.PipelineEventType.RUN_END:
186+
if event.type == assist_pipeline.PipelineEventType.RUN_START:
187+
if event.data and (tts_output := event.data["tts_output"]):
188+
# Get stream token early.
189+
# If "tts_start_streaming" is True in INTENT_PROGRESS event, we
190+
# can start streaming TTS before the TTS_END event.
191+
self._tts_stream_token = tts_output["token"]
192+
self._is_tts_streaming = False
193+
elif event.type == assist_pipeline.PipelineEventType.RUN_END:
183194
# Pipeline run is complete
184195
self._is_pipeline_running = False
185196
self._pipeline_ended_event.set()
186197
self.device.set_is_active(False)
198+
self._tts_stream_token = None
199+
self._is_tts_streaming = False
187200
elif event.type == assist_pipeline.PipelineEventType.WAKE_WORD_START:
188201
self.config_entry.async_create_background_task(
189202
self.hass,
@@ -245,6 +258,20 @@ def on_pipeline_event(self, event: PipelineEvent) -> None:
245258
self._client.write_event(Transcript(text=stt_text).event()),
246259
f"{self.entity_id} {event.type}",
247260
)
261+
elif event.type == assist_pipeline.PipelineEventType.INTENT_PROGRESS:
262+
if (
263+
event.data
264+
and event.data.get("tts_start_streaming")
265+
and self._tts_stream_token
266+
and (stream := tts.async_get_stream(self.hass, self._tts_stream_token))
267+
):
268+
# Start streaming TTS early (before TTS_END).
269+
self._is_tts_streaming = True
270+
self.config_entry.async_create_background_task(
271+
self.hass,
272+
self._stream_tts(stream),
273+
f"{self.entity_id} {event.type}",
274+
)
248275
elif event.type == assist_pipeline.PipelineEventType.TTS_START:
249276
# Text-to-speech text
250277
if event.data:
@@ -267,8 +294,10 @@ def on_pipeline_event(self, event: PipelineEvent) -> None:
267294
if (
268295
event.data
269296
and (tts_output := event.data["tts_output"])
297+
and not self._is_tts_streaming
270298
and (stream := tts.async_get_stream(self.hass, tts_output["token"]))
271299
):
300+
# Send TTS only if we haven't already started streaming it in INTENT_PROGRESS.
272301
self.config_entry.async_create_background_task(
273302
self.hass,
274303
self._stream_tts(stream),
@@ -711,39 +740,62 @@ async def _stream_tts(self, tts_result: tts.ResultStream) -> None:
711740
start_time = time.monotonic()
712741

713742
try:
714-
data = b"".join([chunk async for chunk in tts_result.async_stream_result()])
715-
716-
with io.BytesIO(data) as wav_io, wave.open(wav_io, "rb") as wav_file:
717-
sample_rate = wav_file.getframerate()
718-
sample_width = wav_file.getsampwidth()
719-
sample_channels = wav_file.getnchannels()
720-
_LOGGER.debug("Streaming %s TTS sample(s)", wav_file.getnframes())
721-
722-
timestamp = 0
723-
await self._client.write_event(
724-
AudioStart(
725-
rate=sample_rate,
726-
width=sample_width,
727-
channels=sample_channels,
728-
timestamp=timestamp,
729-
).event()
743+
header_data = b""
744+
header_complete = False
745+
sample_rate: int | None = None
746+
sample_width: int | None = None
747+
sample_channels: int | None = None
748+
timestamp = 0
749+
750+
async for data_chunk in tts_result.async_stream_result():
751+
if not header_complete:
752+
# Accumulate data until we can parse the header and get
753+
# sample rate, etc.
754+
header_data += data_chunk
755+
# Most WAVE headers are 44 bytes in length
756+
if (len(header_data) >= 44) and (
757+
audio_info := _try_parse_wav_header(header_data)
758+
):
759+
# Overwrite chunk with audio after header
760+
sample_rate, sample_width, sample_channels, data_chunk = (
761+
audio_info
762+
)
763+
await self._client.write_event(
764+
AudioStart(
765+
rate=sample_rate,
766+
width=sample_width,
767+
channels=sample_channels,
768+
timestamp=timestamp,
769+
).event()
770+
)
771+
header_complete = True
772+
773+
if not data_chunk:
774+
# No audio after header
775+
continue
776+
else:
777+
# Header is incomplete
778+
continue
779+
780+
# Streaming audio
781+
assert sample_rate is not None
782+
assert sample_width is not None
783+
assert sample_channels is not None
784+
785+
audio_chunk = AudioChunk(
786+
rate=sample_rate,
787+
width=sample_width,
788+
channels=sample_channels,
789+
audio=data_chunk,
790+
timestamp=timestamp,
730791
)
731792

732-
# Stream audio chunks
733-
while audio_bytes := wav_file.readframes(_SAMPLES_PER_CHUNK):
734-
chunk = AudioChunk(
735-
rate=sample_rate,
736-
width=sample_width,
737-
channels=sample_channels,
738-
audio=audio_bytes,
739-
timestamp=timestamp,
740-
)
741-
await self._client.write_event(chunk.event())
742-
timestamp += chunk.milliseconds
743-
total_seconds += chunk.seconds
793+
await self._client.write_event(audio_chunk.event())
794+
timestamp += audio_chunk.milliseconds
795+
total_seconds += audio_chunk.seconds
744796

745-
await self._client.write_event(AudioStop(timestamp=timestamp).event())
746-
_LOGGER.debug("TTS streaming complete")
797+
await self._client.write_event(AudioStop(timestamp=timestamp).event())
798+
_LOGGER.debug("TTS streaming complete")
747799
finally:
748800
send_duration = time.monotonic() - start_time
749801
timeout_seconds = max(0, total_seconds - send_duration + _TTS_TIMEOUT_EXTRA)
@@ -812,3 +864,25 @@ def _handle_timer(
812864
self.config_entry.async_create_background_task(
813865
self.hass, self._client.write_event(event), "wyoming timer event"
814866
)
867+
868+
869+
def _try_parse_wav_header(header_data: bytes) -> tuple[int, int, int, bytes] | None:
870+
"""Try to parse a WAV header from a buffer.
871+
872+
If successful, return (rate, width, channels, audio).
873+
"""
874+
try:
875+
with io.BytesIO(header_data) as wav_io:
876+
wav_file: wave.Wave_read = wave.open(wav_io, "rb")
877+
with wav_file:
878+
return (
879+
wav_file.getframerate(),
880+
wav_file.getsampwidth(),
881+
wav_file.getnchannels(),
882+
wav_file.readframes(wav_file.getnframes()),
883+
)
884+
except wave.Error:
885+
# Ignore errors and return None
886+
pass
887+
888+
return None

tests/components/wyoming/test_satellite.py

Lines changed: 181 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1472,3 +1472,184 @@ def tts_response_finished(self):
14721472
# Stop the satellite
14731473
await hass.config_entries.async_unload(entry.entry_id)
14741474
await hass.async_block_till_done()
1475+
1476+
1477+
async def test_satellite_tts_streaming(hass: HomeAssistant) -> None:
1478+
"""Test running a streaming TTS pipeline with a satellite."""
1479+
assert await async_setup_component(hass, assist_pipeline.DOMAIN, {})
1480+
1481+
events = [
1482+
RunPipeline(start_stage=PipelineStage.ASR, end_stage=PipelineStage.TTS).event(),
1483+
]
1484+
1485+
pipeline_kwargs: dict[str, Any] = {}
1486+
pipeline_event_callback: Callable[[assist_pipeline.PipelineEvent], None] | None = (
1487+
None
1488+
)
1489+
run_pipeline_called = asyncio.Event()
1490+
audio_chunk_received = asyncio.Event()
1491+
1492+
async def async_pipeline_from_audio_stream(
1493+
hass: HomeAssistant,
1494+
context,
1495+
event_callback,
1496+
stt_metadata,
1497+
stt_stream,
1498+
**kwargs,
1499+
) -> None:
1500+
nonlocal pipeline_kwargs, pipeline_event_callback
1501+
pipeline_kwargs = kwargs
1502+
pipeline_event_callback = event_callback
1503+
1504+
run_pipeline_called.set()
1505+
async for chunk in stt_stream:
1506+
if chunk:
1507+
audio_chunk_received.set()
1508+
break
1509+
1510+
with (
1511+
patch(
1512+
"homeassistant.components.wyoming.data.load_wyoming_info",
1513+
return_value=SATELLITE_INFO,
1514+
),
1515+
patch(
1516+
"homeassistant.components.wyoming.assist_satellite.AsyncTcpClient",
1517+
SatelliteAsyncTcpClient(events),
1518+
) as mock_client,
1519+
patch(
1520+
"homeassistant.components.assist_satellite.entity.async_pipeline_from_audio_stream",
1521+
async_pipeline_from_audio_stream,
1522+
),
1523+
patch("homeassistant.components.wyoming.assist_satellite._PING_SEND_DELAY", 0),
1524+
):
1525+
entry = await setup_config_entry(hass)
1526+
device: SatelliteDevice = hass.data[wyoming.DOMAIN][entry.entry_id].device
1527+
assert device is not None
1528+
1529+
async with asyncio.timeout(1):
1530+
await mock_client.connect_event.wait()
1531+
await mock_client.run_satellite_event.wait()
1532+
1533+
async with asyncio.timeout(1):
1534+
await run_pipeline_called.wait()
1535+
1536+
assert pipeline_event_callback is not None
1537+
assert pipeline_kwargs.get("device_id") == device.device_id
1538+
1539+
# Send TTS info early
1540+
mock_tts_result_stream = MockResultStream(hass, "wav", get_test_wav())
1541+
pipeline_event_callback(
1542+
assist_pipeline.PipelineEvent(
1543+
assist_pipeline.PipelineEventType.RUN_START,
1544+
{"tts_output": {"token": mock_tts_result_stream.token}},
1545+
)
1546+
)
1547+
1548+
# Speech-to-text started
1549+
pipeline_event_callback(
1550+
assist_pipeline.PipelineEvent(
1551+
assist_pipeline.PipelineEventType.STT_START,
1552+
{"metadata": {"language": "en"}},
1553+
)
1554+
)
1555+
async with asyncio.timeout(1):
1556+
await mock_client.transcribe_event.wait()
1557+
1558+
# Push in some audio
1559+
mock_client.inject_event(
1560+
AudioChunk(rate=16000, width=2, channels=1, audio=bytes(1024)).event()
1561+
)
1562+
1563+
# User started speaking
1564+
pipeline_event_callback(
1565+
assist_pipeline.PipelineEvent(
1566+
assist_pipeline.PipelineEventType.STT_VAD_START, {"timestamp": 1234}
1567+
)
1568+
)
1569+
async with asyncio.timeout(1):
1570+
await mock_client.voice_started_event.wait()
1571+
1572+
# User stopped speaking
1573+
pipeline_event_callback(
1574+
assist_pipeline.PipelineEvent(
1575+
assist_pipeline.PipelineEventType.STT_VAD_END, {"timestamp": 5678}
1576+
)
1577+
)
1578+
async with asyncio.timeout(1):
1579+
await mock_client.voice_stopped_event.wait()
1580+
1581+
# Speech-to-text transcription
1582+
pipeline_event_callback(
1583+
assist_pipeline.PipelineEvent(
1584+
assist_pipeline.PipelineEventType.STT_END,
1585+
{"stt_output": {"text": "test transcript"}},
1586+
)
1587+
)
1588+
async with asyncio.timeout(1):
1589+
await mock_client.transcript_event.wait()
1590+
1591+
# Intent progress starts TTS streaming early with info received in the
1592+
# run-start event.
1593+
pipeline_event_callback(
1594+
assist_pipeline.PipelineEvent(
1595+
assist_pipeline.PipelineEventType.INTENT_PROGRESS,
1596+
{"tts_start_streaming": True},
1597+
)
1598+
)
1599+
1600+
# TTS events are sent now. In practice, these would be streamed as text
1601+
# chunks are generated.
1602+
async with asyncio.timeout(1):
1603+
await mock_client.tts_audio_start_event.wait()
1604+
await mock_client.tts_audio_chunk_event.wait()
1605+
await mock_client.tts_audio_stop_event.wait()
1606+
1607+
# Verify audio chunk from test WAV
1608+
assert mock_client.tts_audio_chunk is not None
1609+
assert mock_client.tts_audio_chunk.rate == 22050
1610+
assert mock_client.tts_audio_chunk.width == 2
1611+
assert mock_client.tts_audio_chunk.channels == 1
1612+
assert mock_client.tts_audio_chunk.audio == b"1234"
1613+
1614+
# Text-to-speech text
1615+
pipeline_event_callback(
1616+
assist_pipeline.PipelineEvent(
1617+
assist_pipeline.PipelineEventType.TTS_START,
1618+
{
1619+
"tts_input": "test text to speak",
1620+
"voice": "test voice",
1621+
},
1622+
)
1623+
)
1624+
1625+
# synthesize event is sent with complete message for non-streaming clients
1626+
async with asyncio.timeout(1):
1627+
await mock_client.synthesize_event.wait()
1628+
1629+
assert mock_client.synthesize is not None
1630+
assert mock_client.synthesize.text == "test text to speak"
1631+
assert mock_client.synthesize.voice is not None
1632+
assert mock_client.synthesize.voice.name == "test voice"
1633+
1634+
# Because we started streaming TTS after intent progress, we should not
1635+
# stream it again on tts-end.
1636+
with patch(
1637+
"homeassistant.components.wyoming.assist_satellite.WyomingAssistSatellite._stream_tts"
1638+
) as mock_stream_tts:
1639+
pipeline_event_callback(
1640+
assist_pipeline.PipelineEvent(
1641+
assist_pipeline.PipelineEventType.TTS_END,
1642+
{"tts_output": {"token": mock_tts_result_stream.token}},
1643+
)
1644+
)
1645+
1646+
mock_stream_tts.assert_not_called()
1647+
1648+
# Pipeline finished
1649+
pipeline_event_callback(
1650+
assist_pipeline.PipelineEvent(assist_pipeline.PipelineEventType.RUN_END)
1651+
)
1652+
1653+
# Stop the satellite
1654+
await hass.config_entries.async_unload(entry.entry_id)
1655+
await hass.async_block_till_done()

0 commit comments

Comments
 (0)