@@ -132,6 +132,10 @@ def __init__(
132132 # Used to ensure TTS timeout is acted on correctly.
133133 self ._run_loop_id : str | None = None
134134
135+ # TTS streaming
136+ self ._tts_stream_token : str | None = None
137+ self ._is_tts_streaming : bool = False
138+
135139 @property
136140 def pipeline_entity_id (self ) -> str | None :
137141 """Return the entity ID of the pipeline to use for the next conversation."""
@@ -179,11 +183,20 @@ def on_pipeline_event(self, event: PipelineEvent) -> None:
179183 """Set state based on pipeline stage."""
180184 assert self ._client is not None
181185
182- if event .type == assist_pipeline .PipelineEventType .RUN_END :
186+ if event .type == assist_pipeline .PipelineEventType .RUN_START :
187+ if event .data and (tts_output := event .data ["tts_output" ]):
188+ # Get stream token early.
189+ # If "tts_start_streaming" is True in INTENT_PROGRESS event, we
190+ # can start streaming TTS before the TTS_END event.
191+ self ._tts_stream_token = tts_output ["token" ]
192+ self ._is_tts_streaming = False
193+ elif event .type == assist_pipeline .PipelineEventType .RUN_END :
183194 # Pipeline run is complete
184195 self ._is_pipeline_running = False
185196 self ._pipeline_ended_event .set ()
186197 self .device .set_is_active (False )
198+ self ._tts_stream_token = None
199+ self ._is_tts_streaming = False
187200 elif event .type == assist_pipeline .PipelineEventType .WAKE_WORD_START :
188201 self .config_entry .async_create_background_task (
189202 self .hass ,
@@ -245,6 +258,20 @@ def on_pipeline_event(self, event: PipelineEvent) -> None:
245258 self ._client .write_event (Transcript (text = stt_text ).event ()),
246259 f"{ self .entity_id } { event .type } " ,
247260 )
261+ elif event .type == assist_pipeline .PipelineEventType .INTENT_PROGRESS :
262+ if (
263+ event .data
264+ and event .data .get ("tts_start_streaming" )
265+ and self ._tts_stream_token
266+ and (stream := tts .async_get_stream (self .hass , self ._tts_stream_token ))
267+ ):
268+ # Start streaming TTS early (before TTS_END).
269+ self ._is_tts_streaming = True
270+ self .config_entry .async_create_background_task (
271+ self .hass ,
272+ self ._stream_tts (stream ),
273+ f"{ self .entity_id } { event .type } " ,
274+ )
248275 elif event .type == assist_pipeline .PipelineEventType .TTS_START :
249276 # Text-to-speech text
250277 if event .data :
@@ -267,8 +294,10 @@ def on_pipeline_event(self, event: PipelineEvent) -> None:
267294 if (
268295 event .data
269296 and (tts_output := event .data ["tts_output" ])
297+ and not self ._is_tts_streaming
270298 and (stream := tts .async_get_stream (self .hass , tts_output ["token" ]))
271299 ):
300+ # Send TTS only if we haven't already started streaming it in INTENT_PROGRESS.
272301 self .config_entry .async_create_background_task (
273302 self .hass ,
274303 self ._stream_tts (stream ),
@@ -711,39 +740,62 @@ async def _stream_tts(self, tts_result: tts.ResultStream) -> None:
711740 start_time = time .monotonic ()
712741
713742 try :
714- data = b"" .join ([chunk async for chunk in tts_result .async_stream_result ()])
715-
716- with io .BytesIO (data ) as wav_io , wave .open (wav_io , "rb" ) as wav_file :
717- sample_rate = wav_file .getframerate ()
718- sample_width = wav_file .getsampwidth ()
719- sample_channels = wav_file .getnchannels ()
720- _LOGGER .debug ("Streaming %s TTS sample(s)" , wav_file .getnframes ())
721-
722- timestamp = 0
723- await self ._client .write_event (
724- AudioStart (
725- rate = sample_rate ,
726- width = sample_width ,
727- channels = sample_channels ,
728- timestamp = timestamp ,
729- ).event ()
743+ header_data = b""
744+ header_complete = False
745+ sample_rate : int | None = None
746+ sample_width : int | None = None
747+ sample_channels : int | None = None
748+ timestamp = 0
749+
750+ async for data_chunk in tts_result .async_stream_result ():
751+ if not header_complete :
752+ # Accumulate data until we can parse the header and get
753+ # sample rate, etc.
754+ header_data += data_chunk
755+ # Most WAVE headers are 44 bytes in length
756+ if (len (header_data ) >= 44 ) and (
757+ audio_info := _try_parse_wav_header (header_data )
758+ ):
759+ # Overwrite chunk with audio after header
760+ sample_rate , sample_width , sample_channels , data_chunk = (
761+ audio_info
762+ )
763+ await self ._client .write_event (
764+ AudioStart (
765+ rate = sample_rate ,
766+ width = sample_width ,
767+ channels = sample_channels ,
768+ timestamp = timestamp ,
769+ ).event ()
770+ )
771+ header_complete = True
772+
773+ if not data_chunk :
774+ # No audio after header
775+ continue
776+ else :
777+ # Header is incomplete
778+ continue
779+
780+ # Streaming audio
781+ assert sample_rate is not None
782+ assert sample_width is not None
783+ assert sample_channels is not None
784+
785+ audio_chunk = AudioChunk (
786+ rate = sample_rate ,
787+ width = sample_width ,
788+ channels = sample_channels ,
789+ audio = data_chunk ,
790+ timestamp = timestamp ,
730791 )
731792
732- # Stream audio chunks
733- while audio_bytes := wav_file .readframes (_SAMPLES_PER_CHUNK ):
734- chunk = AudioChunk (
735- rate = sample_rate ,
736- width = sample_width ,
737- channels = sample_channels ,
738- audio = audio_bytes ,
739- timestamp = timestamp ,
740- )
741- await self ._client .write_event (chunk .event ())
742- timestamp += chunk .milliseconds
743- total_seconds += chunk .seconds
793+ await self ._client .write_event (audio_chunk .event ())
794+ timestamp += audio_chunk .milliseconds
795+ total_seconds += audio_chunk .seconds
744796
745- await self ._client .write_event (AudioStop (timestamp = timestamp ).event ())
746- _LOGGER .debug ("TTS streaming complete" )
797+ await self ._client .write_event (AudioStop (timestamp = timestamp ).event ())
798+ _LOGGER .debug ("TTS streaming complete" )
747799 finally :
748800 send_duration = time .monotonic () - start_time
749801 timeout_seconds = max (0 , total_seconds - send_duration + _TTS_TIMEOUT_EXTRA )
@@ -812,3 +864,25 @@ def _handle_timer(
812864 self .config_entry .async_create_background_task (
813865 self .hass , self ._client .write_event (event ), "wyoming timer event"
814866 )
867+
868+
869+ def _try_parse_wav_header (header_data : bytes ) -> tuple [int , int , int , bytes ] | None :
870+ """Try to parse a WAV header from a buffer.
871+
872+ If successful, return (rate, width, channels, audio).
873+ """
874+ try :
875+ with io .BytesIO (header_data ) as wav_io :
876+ wav_file : wave .Wave_read = wave .open (wav_io , "rb" )
877+ with wav_file :
878+ return (
879+ wav_file .getframerate (),
880+ wav_file .getsampwidth (),
881+ wav_file .getnchannels (),
882+ wav_file .readframes (wav_file .getnframes ()),
883+ )
884+ except wave .Error :
885+ # Ignore errors and return None
886+ pass
887+
888+ return None
0 commit comments