diff --git a/parler_tts/streamer.py b/parler_tts/streamer.py index 57dab91..df8f386 100644 --- a/parler_tts/streamer.py +++ b/parler_tts/streamer.py @@ -1,7 +1,6 @@ - from .modeling_parler_tts import ParlerTTSForConditionalGeneration from transformers.generation.streamers import BaseStreamer -from typing import Optional +from typing import Optional, Generator import torch import numpy as np import math @@ -58,7 +57,7 @@ def __init__( self.stop_signal = None self.timeout = timeout - def apply_delay_pattern_mask(self, input_ids): + def apply_delay_pattern_mask(self, input_ids: torch.Tensor) -> np.ndarray: # build the delay pattern mask for offsetting each codebook prediction by 1 (this behaviour is specific to Parler) _, delay_pattern_mask = self.decoder.build_delay_pattern_mask( input_ids[:, :1], @@ -97,7 +96,7 @@ def apply_delay_pattern_mask(self, input_ids): audio_values = output_values.audio_values[0, 0] return audio_values.cpu().float().numpy() - def put(self, value): + def put(self, value: torch.Tensor) -> None: batch_size = value.shape[0] // self.decoder.num_codebooks if batch_size > 1: raise ValueError("ParlerTTSStreamer only supports batch size 1") @@ -112,7 +111,7 @@ def put(self, value): self.on_finalized_audio(audio_values[self.to_yield : -self.stride]) self.to_yield += len(audio_values) - self.to_yield - self.stride - def end(self): + def end(self) -> None: """Flushes any remaining cache and appends the stop symbol.""" if self.token_cache is not None: audio_values = self.apply_delay_pattern_mask(self.token_cache) @@ -121,18 +120,19 @@ def end(self): self.on_finalized_audio(audio_values[self.to_yield :], stream_end=True) - def on_finalized_audio(self, audio: np.ndarray, stream_end: bool = False): + def on_finalized_audio(self, audio: np.ndarray, stream_end: bool = False) -> None: """Put the new audio in the queue. If the stream is ending, also put a stop signal in the queue.""" self.audio_queue.put(audio, timeout=self.timeout) if stream_end: self.audio_queue.put(self.stop_signal, timeout=self.timeout) - def __iter__(self): + def __iter__(self) -> Generator[np.ndarray, None, None]: return self - def __next__(self): + def __next__(self) -> np.ndarray: value = self.audio_queue.get(timeout=self.timeout) if not isinstance(value, np.ndarray) and value == self.stop_signal: raise StopIteration() else: - return value \ No newline at end of file + return value +