Skip to content

Commit 1a9d1a9

Browse files
authored
Handle non-streaming TTS case correctly (#150218)
1 parent cb7c776 commit 1a9d1a9

File tree

5 files changed

+77
-81
lines changed

5 files changed

+77
-81
lines changed

homeassistant/components/tts/__init__.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -976,11 +976,15 @@ async def _async_generate_tts_audio(
976976
if engine_instance.name is None or engine_instance.name is UNDEFINED:
977977
raise HomeAssistantError("TTS engine name is not set.")
978978

979-
if isinstance(engine_instance, Provider):
979+
if isinstance(engine_instance, Provider) or (
980+
not engine_instance.async_supports_streaming_input()
981+
):
982+
# Non-streaming
980983
if isinstance(message_or_stream, str):
981984
message = message_or_stream
982985
else:
983986
message = "".join([chunk async for chunk in message_or_stream])
987+
984988
extension, data = await engine_instance.async_internal_get_tts_audio(
985989
message, language, options
986990
)
@@ -996,6 +1000,7 @@ async def make_data_generator(data: bytes) -> AsyncGenerator[bytes]:
9961000
data_gen = make_data_generator(data)
9971001

9981002
else:
1003+
# Streaming
9991004
if isinstance(message_or_stream, str):
10001005

10011006
async def gen_stream() -> AsyncGenerator[str]:

homeassistant/components/tts/entity.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -191,6 +191,18 @@ def get_tts_audio(
191191
"""Load tts audio file from the engine."""
192192
raise NotImplementedError
193193

194+
@final
195+
async def async_internal_get_tts_audio(
196+
self, message: str, language: str, options: dict[str, Any]
197+
) -> TtsAudioType:
198+
"""Load tts audio file from the engine and update state.
199+
200+
Return a tuple of file extension and data as bytes.
201+
"""
202+
self.__last_tts_loaded = dt_util.utcnow().isoformat()
203+
self.async_write_ha_state()
204+
return await self.async_get_tts_audio(message, language, options=options)
205+
194206
async def async_get_tts_audio(
195207
self, message: str, language: str, options: dict[str, Any]
196208
) -> TtsAudioType:

tests/components/tts/test_entity.py

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -175,3 +175,31 @@ def get_tts_audio(
175175

176176
sync_non_streaming_entity = SyncNonStreamingEntity()
177177
assert sync_non_streaming_entity.async_supports_streaming_input() is False
178+
179+
180+
async def test_internal_get_tts_audio_writes_state(
181+
hass: HomeAssistant,
182+
mock_tts_entity: MockTTSEntity,
183+
) -> None:
184+
"""Test that only async_internal_get_tts_audio updates and writes the state."""
185+
186+
entity_id = f"{tts.DOMAIN}.{TEST_DOMAIN}"
187+
188+
config_entry = await mock_config_entry_setup(hass, mock_tts_entity)
189+
assert config_entry.state is ConfigEntryState.LOADED
190+
state1 = hass.states.get(entity_id)
191+
assert state1 is not None
192+
193+
# State should *not* change with external method
194+
await mock_tts_entity.async_get_tts_audio("test message", hass.config.language, {})
195+
state2 = hass.states.get(entity_id)
196+
assert state2 is not None
197+
assert state1.state == state2.state
198+
199+
# State *should* change with internal method
200+
await mock_tts_entity.async_internal_get_tts_audio(
201+
"test message", hass.config.language, {}
202+
)
203+
state3 = hass.states.get(entity_id)
204+
assert state3 is not None
205+
assert state1.state != state3.state

tests/components/tts/test_init.py

Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2032,3 +2032,34 @@ async def consume_cache(cache: tts.TTSCache):
20322032
assert await consume_mid_data_task == b"012"
20332033
with pytest.raises(ValueError):
20342034
assert await consume_pre_data_loaded_task == b"012"
2035+
2036+
2037+
async def test_async_internal_get_tts_audio_called(
2038+
hass: HomeAssistant,
2039+
mock_tts_entity: MockTTSEntity,
2040+
hass_client: ClientSessionGenerator,
2041+
) -> None:
2042+
"""Test that non-streaming entity has its async_internal_get_tts_audio method called."""
2043+
2044+
await mock_config_entry_setup(hass, mock_tts_entity)
2045+
2046+
# Non-streaming
2047+
assert mock_tts_entity.async_supports_streaming_input() is False
2048+
2049+
with patch(
2050+
"homeassistant.components.tts.entity.TextToSpeechEntity.async_internal_get_tts_audio"
2051+
) as internal_get_tts_audio:
2052+
media_source_id = tts.generate_media_source_id(
2053+
hass,
2054+
"test message",
2055+
"tts.test",
2056+
"en_US",
2057+
cache=None,
2058+
)
2059+
2060+
url = await get_media_source_url(hass, media_source_id)
2061+
client = await hass_client()
2062+
await client.get(url)
2063+
2064+
# async_internal_get_tts_audio is called
2065+
internal_get_tts_audio.assert_called_once_with("test message", "en_US", {})
Lines changed: 0 additions & 80 deletions
Original file line numberDiff line numberDiff line change
@@ -1,92 +1,35 @@
11
# serializer version: 1
22
# name: test_get_tts_audio
33
list([
4-
dict({
5-
'data': dict({
6-
}),
7-
'payload': None,
8-
'type': 'synthesize-start',
9-
}),
10-
dict({
11-
'data': dict({
12-
'text': 'Hello world',
13-
}),
14-
'payload': None,
15-
'type': 'synthesize-chunk',
16-
}),
174
dict({
185
'data': dict({
196
'text': 'Hello world',
207
}),
218
'payload': None,
229
'type': 'synthesize',
2310
}),
24-
dict({
25-
'data': dict({
26-
}),
27-
'payload': None,
28-
'type': 'synthesize-stop',
29-
}),
3011
])
3112
# ---
3213
# name: test_get_tts_audio_different_formats
3314
list([
34-
dict({
35-
'data': dict({
36-
}),
37-
'payload': None,
38-
'type': 'synthesize-start',
39-
}),
40-
dict({
41-
'data': dict({
42-
'text': 'Hello world',
43-
}),
44-
'payload': None,
45-
'type': 'synthesize-chunk',
46-
}),
4715
dict({
4816
'data': dict({
4917
'text': 'Hello world',
5018
}),
5119
'payload': None,
5220
'type': 'synthesize',
5321
}),
54-
dict({
55-
'data': dict({
56-
}),
57-
'payload': None,
58-
'type': 'synthesize-stop',
59-
}),
6022
])
6123
# ---
6224
# name: test_get_tts_audio_different_formats.1
6325
list([
64-
dict({
65-
'data': dict({
66-
}),
67-
'payload': None,
68-
'type': 'synthesize-start',
69-
}),
70-
dict({
71-
'data': dict({
72-
'text': 'Hello world',
73-
}),
74-
'payload': None,
75-
'type': 'synthesize-chunk',
76-
}),
7726
dict({
7827
'data': dict({
7928
'text': 'Hello world',
8029
}),
8130
'payload': None,
8231
'type': 'synthesize',
8332
}),
84-
dict({
85-
'data': dict({
86-
}),
87-
'payload': None,
88-
'type': 'synthesize-stop',
89-
}),
9033
])
9134
# ---
9235
# name: test_get_tts_audio_streaming
@@ -128,23 +71,6 @@
12871
# ---
12972
# name: test_voice_speaker
13073
list([
131-
dict({
132-
'data': dict({
133-
'voice': dict({
134-
'name': 'voice1',
135-
'speaker': 'speaker1',
136-
}),
137-
}),
138-
'payload': None,
139-
'type': 'synthesize-start',
140-
}),
141-
dict({
142-
'data': dict({
143-
'text': 'Hello world',
144-
}),
145-
'payload': None,
146-
'type': 'synthesize-chunk',
147-
}),
14874
dict({
14975
'data': dict({
15076
'text': 'Hello world',
@@ -156,11 +82,5 @@
15682
'payload': None,
15783
'type': 'synthesize',
15884
}),
159-
dict({
160-
'data': dict({
161-
}),
162-
'payload': None,
163-
'type': 'synthesize-stop',
164-
}),
16585
])
16686
# ---

0 commit comments

Comments
 (0)