Skip to content

Commit 124931b

Browse files
TTS to always stream when available (home-assistant#148695)
Co-authored-by: Michael Hansen <[email protected]>
1 parent c27a67d commit 124931b

File tree

6 files changed

+107
-9
lines changed

6 files changed

+107
-9
lines changed

homeassistant/components/tts/__init__.py

Lines changed: 13 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -382,7 +382,7 @@ async def write_input() -> None:
382382
assert process.stderr
383383
stderr_data = await process.stderr.read()
384384
_LOGGER.error(stderr_data.decode())
385-
raise RuntimeError(
385+
raise HomeAssistantError(
386386
f"Unexpected error while running ffmpeg with arguments: {command}. "
387387
"See log for details."
388388
)
@@ -976,7 +976,7 @@ async def _async_generate_tts_audio(
976976
if engine_instance.name is None or engine_instance.name is UNDEFINED:
977977
raise HomeAssistantError("TTS engine name is not set.")
978978

979-
if isinstance(engine_instance, Provider) or isinstance(message_or_stream, str):
979+
if isinstance(engine_instance, Provider):
980980
if isinstance(message_or_stream, str):
981981
message = message_or_stream
982982
else:
@@ -996,8 +996,18 @@ async def make_data_generator(data: bytes) -> AsyncGenerator[bytes]:
996996
data_gen = make_data_generator(data)
997997

998998
else:
999+
if isinstance(message_or_stream, str):
1000+
1001+
async def gen_stream() -> AsyncGenerator[str]:
1002+
yield message_or_stream
1003+
1004+
stream = gen_stream()
1005+
1006+
else:
1007+
stream = message_or_stream
1008+
9991009
tts_result = await engine_instance.internal_async_stream_tts_audio(
1000-
TTSAudioRequest(language, options, message_or_stream)
1010+
TTSAudioRequest(language, options, stream)
10011011
)
10021012
extension = tts_result.extension
10031013
data_gen = tts_result.data_gen

tests/components/assist_pipeline/snapshots/test_pipeline.ambr

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
# serializer version: 1
2-
# name: test_chat_log_tts_streaming[to_stream_deltas0-0-]
2+
# name: test_chat_log_tts_streaming[to_stream_deltas0-1-hello, how are you?]
33
list([
44
dict({
55
'data': dict({

tests/components/assist_pipeline/test_pipeline.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1550,9 +1550,9 @@ async def test_pipeline_language_used_instead_of_conversation_language(
15501550
"?",
15511551
],
15521552
),
1553-
# We are not streaming, so 0 chunks via streaming method
1554-
0,
1555-
"",
1553+
# We always stream when possible, so 1 chunk via streaming method
1554+
1,
1555+
"hello, how are you?",
15561556
),
15571557
# Size above STREAM_RESPONSE_CHUNKS
15581558
(

tests/components/tts/test_init.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1835,7 +1835,7 @@ async def test_async_convert_audio_error(hass: HomeAssistant) -> None:
18351835
async def bad_data_gen():
18361836
yield bytes(0)
18371837

1838-
with pytest.raises(RuntimeError):
1838+
with pytest.raises(HomeAssistantError):
18391839
# Simulate a bad WAV file
18401840
async for _chunk in tts._async_convert_audio(
18411841
hass, "wav", bad_data_gen(), "mp3"

tests/components/wyoming/snapshots/test_tts.ambr

Lines changed: 80 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,35 +1,92 @@
11
# serializer version: 1
22
# name: test_get_tts_audio
33
list([
4+
dict({
5+
'data': dict({
6+
}),
7+
'payload': None,
8+
'type': 'synthesize-start',
9+
}),
10+
dict({
11+
'data': dict({
12+
'text': 'Hello world',
13+
}),
14+
'payload': None,
15+
'type': 'synthesize-chunk',
16+
}),
417
dict({
518
'data': dict({
619
'text': 'Hello world',
720
}),
821
'payload': None,
922
'type': 'synthesize',
1023
}),
24+
dict({
25+
'data': dict({
26+
}),
27+
'payload': None,
28+
'type': 'synthesize-stop',
29+
}),
1130
])
1231
# ---
1332
# name: test_get_tts_audio_different_formats
1433
list([
34+
dict({
35+
'data': dict({
36+
}),
37+
'payload': None,
38+
'type': 'synthesize-start',
39+
}),
40+
dict({
41+
'data': dict({
42+
'text': 'Hello world',
43+
}),
44+
'payload': None,
45+
'type': 'synthesize-chunk',
46+
}),
1547
dict({
1648
'data': dict({
1749
'text': 'Hello world',
1850
}),
1951
'payload': None,
2052
'type': 'synthesize',
2153
}),
54+
dict({
55+
'data': dict({
56+
}),
57+
'payload': None,
58+
'type': 'synthesize-stop',
59+
}),
2260
])
2361
# ---
2462
# name: test_get_tts_audio_different_formats.1
2563
list([
64+
dict({
65+
'data': dict({
66+
}),
67+
'payload': None,
68+
'type': 'synthesize-start',
69+
}),
70+
dict({
71+
'data': dict({
72+
'text': 'Hello world',
73+
}),
74+
'payload': None,
75+
'type': 'synthesize-chunk',
76+
}),
2677
dict({
2778
'data': dict({
2879
'text': 'Hello world',
2980
}),
3081
'payload': None,
3182
'type': 'synthesize',
3283
}),
84+
dict({
85+
'data': dict({
86+
}),
87+
'payload': None,
88+
'type': 'synthesize-stop',
89+
}),
3390
])
3491
# ---
3592
# name: test_get_tts_audio_streaming
@@ -71,6 +128,23 @@
71128
# ---
72129
# name: test_voice_speaker
73130
list([
131+
dict({
132+
'data': dict({
133+
'voice': dict({
134+
'name': 'voice1',
135+
'speaker': 'speaker1',
136+
}),
137+
}),
138+
'payload': None,
139+
'type': 'synthesize-start',
140+
}),
141+
dict({
142+
'data': dict({
143+
'text': 'Hello world',
144+
}),
145+
'payload': None,
146+
'type': 'synthesize-chunk',
147+
}),
74148
dict({
75149
'data': dict({
76150
'text': 'Hello world',
@@ -82,5 +156,11 @@
82156
'payload': None,
83157
'type': 'synthesize',
84158
}),
159+
dict({
160+
'data': dict({
161+
}),
162+
'payload': None,
163+
'type': 'synthesize-stop',
164+
}),
85165
])
86166
# ---

tests/components/wyoming/test_tts.py

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -52,6 +52,7 @@ async def test_get_tts_audio(
5252

5353
# Verify audio
5454
audio_events = [
55+
AudioStart(rate=16000, width=2, channels=1).event(),
5556
AudioChunk(audio=audio, rate=16000, width=2, channels=1).event(),
5657
AudioStop().event(),
5758
]
@@ -77,7 +78,10 @@ async def test_get_tts_audio(
7778
assert wav_file.getframerate() == 16000
7879
assert wav_file.getsampwidth() == 2
7980
assert wav_file.getnchannels() == 1
80-
assert wav_file.readframes(wav_file.getnframes()) == audio
81+
82+
# nframes = 0 due to streaming
83+
assert len(data) == len(audio) + 44 # WAVE header is 44 bytes
84+
assert data[44:] == audio
8185

8286
assert mock_client.written == snapshot
8387

@@ -88,6 +92,7 @@ async def test_get_tts_audio_different_formats(
8892
"""Test changing preferred audio format."""
8993
audio = bytes(16000 * 2 * 1) # one second
9094
audio_events = [
95+
AudioStart(rate=16000, width=2, channels=1).event(),
9196
AudioChunk(audio=audio, rate=16000, width=2, channels=1).event(),
9297
AudioStop().event(),
9398
]
@@ -123,6 +128,7 @@ async def test_get_tts_audio_different_formats(
123128

124129
# MP3 is the default
125130
audio_events = [
131+
AudioStart(rate=16000, width=2, channels=1).event(),
126132
AudioChunk(audio=audio, rate=16000, width=2, channels=1).event(),
127133
AudioStop().event(),
128134
]
@@ -167,6 +173,7 @@ async def test_get_tts_audio_audio_oserror(
167173
"""Test get audio and error raising."""
168174
audio = bytes(100)
169175
audio_events = [
176+
AudioStart(rate=16000, width=2, channels=1).event(),
170177
AudioChunk(audio=audio, rate=16000, width=2, channels=1).event(),
171178
AudioStop().event(),
172179
]
@@ -197,6 +204,7 @@ async def test_voice_speaker(
197204
"""Test using a different voice and speaker."""
198205
audio = bytes(100)
199206
audio_events = [
207+
AudioStart(rate=16000, width=2, channels=1).event(),
200208
AudioChunk(audio=audio, rate=16000, width=2, channels=1).event(),
201209
AudioStop().event(),
202210
]

0 commit comments

Comments
 (0)