Skip to content

Commit cefc882

Browse files
authored
Support streaming TTS in wyoming (home-assistant#147392)
* Support streaming TTS in wyoming * Add test * Refactor to avoid repeated task creation * Manually manage client lifecycle
1 parent 3dc8676 commit cefc882

File tree

5 files changed

+242
-7
lines changed

5 files changed

+242
-7
lines changed

homeassistant/components/wyoming/tts.py

Lines changed: 106 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,21 @@
11
"""Support for Wyoming text-to-speech services."""
22

33
from collections import defaultdict
4+
from collections.abc import AsyncGenerator
45
import io
56
import logging
67
import wave
78

8-
from wyoming.audio import AudioChunk, AudioStop
9+
from wyoming.audio import AudioChunk, AudioStart, AudioStop
910
from wyoming.client import AsyncTcpClient
10-
from wyoming.tts import Synthesize, SynthesizeVoice
11+
from wyoming.tts import (
12+
Synthesize,
13+
SynthesizeChunk,
14+
SynthesizeStart,
15+
SynthesizeStop,
16+
SynthesizeStopped,
17+
SynthesizeVoice,
18+
)
1119

1220
from homeassistant.components import tts
1321
from homeassistant.config_entries import ConfigEntry
@@ -45,6 +53,7 @@ def __init__(
4553
service: WyomingService,
4654
) -> None:
4755
"""Set up provider."""
56+
self.config_entry = config_entry
4857
self.service = service
4958
self._tts_service = next(tts for tts in service.info.tts if tts.installed)
5059

@@ -150,3 +159,98 @@ async def async_get_tts_audio(self, message, language, options):
150159
return (None, None)
151160

152161
return ("wav", data)
162+
163+
def async_supports_streaming_input(self) -> bool:
164+
"""Return if the TTS engine supports streaming input."""
165+
return self._tts_service.supports_synthesize_streaming
166+
167+
async def async_stream_tts_audio(
168+
self, request: tts.TTSAudioRequest
169+
) -> tts.TTSAudioResponse:
170+
"""Generate speech from an incoming message."""
171+
voice_name: str | None = request.options.get(tts.ATTR_VOICE)
172+
voice_speaker: str | None = request.options.get(ATTR_SPEAKER)
173+
voice: SynthesizeVoice | None = None
174+
if voice_name is not None:
175+
voice = SynthesizeVoice(name=voice_name, speaker=voice_speaker)
176+
177+
client = AsyncTcpClient(self.service.host, self.service.port)
178+
await client.connect()
179+
180+
# Stream text chunks to client
181+
self.config_entry.async_create_background_task(
182+
self.hass,
183+
self._write_tts_message(request.message_gen, client, voice),
184+
"wyoming tts write",
185+
)
186+
187+
async def data_gen():
188+
# Stream audio bytes from client
189+
try:
190+
async for data_chunk in self._read_tts_audio(client):
191+
yield data_chunk
192+
finally:
193+
await client.disconnect()
194+
195+
return tts.TTSAudioResponse("wav", data_gen())
196+
197+
async def _write_tts_message(
198+
self,
199+
message_gen: AsyncGenerator[str],
200+
client: AsyncTcpClient,
201+
voice: SynthesizeVoice | None,
202+
) -> None:
203+
"""Write text chunks to the client."""
204+
try:
205+
# Start stream
206+
await client.write_event(SynthesizeStart(voice=voice).event())
207+
208+
# Accumulate entire message for synthesize event.
209+
message = ""
210+
async for message_chunk in message_gen:
211+
message += message_chunk
212+
213+
await client.write_event(SynthesizeChunk(text=message_chunk).event())
214+
215+
# Send entire message for backwards compatibility
216+
await client.write_event(Synthesize(text=message, voice=voice).event())
217+
218+
# End stream
219+
await client.write_event(SynthesizeStop().event())
220+
except (OSError, WyomingError):
221+
# Disconnected
222+
_LOGGER.warning("Unexpected disconnection from TTS client")
223+
224+
async def _read_tts_audio(self, client: AsyncTcpClient) -> AsyncGenerator[bytes]:
225+
"""Read audio events from the client and yield WAV audio chunks.
226+
227+
The WAV header is sent first with a frame count of 0 to indicate that
228+
we're streaming and don't know the number of frames ahead of time.
229+
"""
230+
wav_header_sent = False
231+
232+
try:
233+
while event := await client.read_event():
234+
if wav_header_sent and AudioChunk.is_type(event.type):
235+
# PCM audio
236+
yield AudioChunk.from_event(event).audio
237+
elif (not wav_header_sent) and AudioStart.is_type(event.type):
238+
# WAV header with nframes = 0 for streaming
239+
audio_start = AudioStart.from_event(event)
240+
with io.BytesIO() as wav_io:
241+
wav_file: wave.Wave_write = wave.open(wav_io, "wb")
242+
with wav_file:
243+
wav_file.setframerate(audio_start.rate)
244+
wav_file.setsampwidth(audio_start.width)
245+
wav_file.setnchannels(audio_start.channels)
246+
247+
wav_io.seek(0)
248+
yield wav_io.getvalue()
249+
250+
wav_header_sent = True
251+
elif SynthesizeStopped.is_type(event.type):
252+
# All TTS audio has been received
253+
break
254+
except (OSError, WyomingError):
255+
# Disconnected
256+
_LOGGER.warning("Unexpected disconnection from TTS client")

tests/components/wyoming/__init__.py

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -69,6 +69,29 @@
6969
)
7070
]
7171
)
72+
TTS_STREAMING_INFO = Info(
73+
tts=[
74+
TtsProgram(
75+
name="Test Streaming TTS",
76+
description="Test Streaming TTS",
77+
installed=True,
78+
attribution=TEST_ATTR,
79+
voices=[
80+
TtsVoice(
81+
name="Test Voice",
82+
description="Test Voice",
83+
installed=True,
84+
attribution=TEST_ATTR,
85+
languages=["en-US"],
86+
speakers=[TtsVoiceSpeaker(name="Test Speaker")],
87+
version=None,
88+
)
89+
],
90+
version=None,
91+
supports_synthesize_streaming=True,
92+
)
93+
]
94+
)
7295
WAKE_WORD_INFO = Info(
7396
wake=[
7497
WakeProgram(
@@ -155,9 +178,15 @@ def __init__(self, responses: list[Event]) -> None:
155178
self.port: int | None = None
156179
self.written: list[Event] = []
157180
self.responses = responses
181+
self.is_connected: bool | None = None
158182

159183
async def connect(self) -> None:
160184
"""Connect."""
185+
self.is_connected = True
186+
187+
async def disconnect(self) -> None:
188+
"""Disconnect."""
189+
self.is_connected = False
161190

162191
async def write_event(self, event: Event):
163192
"""Send."""

tests/components/wyoming/conftest.py

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919
SATELLITE_INFO,
2020
STT_INFO,
2121
TTS_INFO,
22+
TTS_STREAMING_INFO,
2223
WAKE_WORD_INFO,
2324
)
2425

@@ -148,6 +149,20 @@ async def init_wyoming_tts(
148149
return tts_config_entry
149150

150151

152+
@pytest.fixture
153+
async def init_wyoming_streaming_tts(
154+
hass: HomeAssistant, tts_config_entry: ConfigEntry
155+
) -> ConfigEntry:
156+
"""Initialize Wyoming streaming TTS."""
157+
with patch(
158+
"homeassistant.components.wyoming.data.load_wyoming_info",
159+
return_value=TTS_STREAMING_INFO,
160+
):
161+
await hass.config_entries.async_setup(tts_config_entry.entry_id)
162+
163+
return tts_config_entry
164+
165+
151166
@pytest.fixture
152167
async def init_wyoming_wake_word(
153168
hass: HomeAssistant, wake_word_config_entry: ConfigEntry

tests/components/wyoming/snapshots/test_tts.ambr

Lines changed: 37 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,43 @@
3232
}),
3333
])
3434
# ---
35+
# name: test_get_tts_audio_streaming
36+
list([
37+
dict({
38+
'data': dict({
39+
}),
40+
'payload': None,
41+
'type': 'synthesize-start',
42+
}),
43+
dict({
44+
'data': dict({
45+
'text': 'Hello ',
46+
}),
47+
'payload': None,
48+
'type': 'synthesize-chunk',
49+
}),
50+
dict({
51+
'data': dict({
52+
'text': 'Word.',
53+
}),
54+
'payload': None,
55+
'type': 'synthesize-chunk',
56+
}),
57+
dict({
58+
'data': dict({
59+
'text': 'Hello Word.',
60+
}),
61+
'payload': None,
62+
'type': 'synthesize',
63+
}),
64+
dict({
65+
'data': dict({
66+
}),
67+
'payload': None,
68+
'type': 'synthesize-stop',
69+
}),
70+
])
71+
# ---
3572
# name: test_voice_speaker
3673
list([
3774
dict({

tests/components/wyoming/test_tts.py

Lines changed: 55 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,8 @@
88

99
import pytest
1010
from syrupy.assertion import SnapshotAssertion
11-
from wyoming.audio import AudioChunk, AudioStop
11+
from wyoming.audio import AudioChunk, AudioStart, AudioStop
12+
from wyoming.tts import SynthesizeStopped
1213

1314
from homeassistant.components import tts, wyoming
1415
from homeassistant.core import HomeAssistant
@@ -43,11 +44,11 @@ async def test_get_tts_audio(
4344
hass: HomeAssistant, init_wyoming_tts, snapshot: SnapshotAssertion
4445
) -> None:
4546
"""Test get audio."""
47+
entity = hass.data[DATA_INSTANCES]["tts"].get_entity("tts.test_tts")
48+
assert entity is not None
49+
assert not entity.async_supports_streaming_input()
50+
4651
audio = bytes(100)
47-
audio_events = [
48-
AudioChunk(audio=audio, rate=16000, width=2, channels=1).event(),
49-
AudioStop().event(),
50-
]
5152

5253
# Verify audio
5354
audio_events = [
@@ -215,3 +216,52 @@ async def test_voice_speaker(
215216
),
216217
)
217218
assert mock_client.written == snapshot
219+
220+
221+
async def test_get_tts_audio_streaming(
222+
hass: HomeAssistant, init_wyoming_streaming_tts, snapshot: SnapshotAssertion
223+
) -> None:
224+
"""Test get audio with streaming."""
225+
entity = hass.data[DATA_INSTANCES]["tts"].get_entity("tts.test_streaming_tts")
226+
assert entity is not None
227+
assert entity.async_supports_streaming_input()
228+
229+
audio = bytes(100)
230+
231+
# Verify audio
232+
audio_events = [
233+
AudioStart(rate=16000, width=2, channels=1).event(),
234+
AudioChunk(audio=audio, rate=16000, width=2, channels=1).event(),
235+
AudioStop().event(),
236+
SynthesizeStopped().event(),
237+
]
238+
239+
async def message_gen():
240+
yield "Hello "
241+
yield "Word."
242+
243+
with patch(
244+
"homeassistant.components.wyoming.tts.AsyncTcpClient",
245+
MockAsyncTcpClient(audio_events),
246+
) as mock_client:
247+
stream = tts.async_create_stream(
248+
hass,
249+
"tts.test_streaming_tts",
250+
"en-US",
251+
options={tts.ATTR_PREFERRED_FORMAT: "wav"},
252+
)
253+
stream.async_set_message_stream(message_gen())
254+
data = b"".join([chunk async for chunk in stream.async_stream_result()])
255+
256+
# Ensure client was disconnected properly
257+
assert mock_client.is_connected is False
258+
259+
assert data is not None
260+
with io.BytesIO(data) as wav_io, wave.open(wav_io, "rb") as wav_file:
261+
assert wav_file.getframerate() == 16000
262+
assert wav_file.getsampwidth() == 2
263+
assert wav_file.getnchannels() == 1
264+
assert wav_file.getnframes() == 0 # streaming
265+
assert data[44:] == audio # WAV header is 44 bytes
266+
267+
assert mock_client.written == snapshot

0 commit comments

Comments
 (0)