Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 9 additions & 9 deletions parler_tts/streamer.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@

from .modeling_parler_tts import ParlerTTSForConditionalGeneration
from transformers.generation.streamers import BaseStreamer
from typing import Optional
from typing import Optional, Generator
import torch
import numpy as np
import math
Expand Down Expand Up @@ -58,7 +57,7 @@ def __init__(
self.stop_signal = None
self.timeout = timeout

def apply_delay_pattern_mask(self, input_ids):
def apply_delay_pattern_mask(self, input_ids: torch.Tensor) -> np.ndarray:
# build the delay pattern mask for offsetting each codebook prediction by 1 (this behaviour is specific to Parler)
_, delay_pattern_mask = self.decoder.build_delay_pattern_mask(
input_ids[:, :1],
Expand Down Expand Up @@ -97,7 +96,7 @@ def apply_delay_pattern_mask(self, input_ids):
audio_values = output_values.audio_values[0, 0]
return audio_values.cpu().float().numpy()

def put(self, value):
def put(self, value: torch.Tensor) -> None:
batch_size = value.shape[0] // self.decoder.num_codebooks
if batch_size > 1:
raise ValueError("ParlerTTSStreamer only supports batch size 1")
Expand All @@ -112,7 +111,7 @@ def put(self, value):
self.on_finalized_audio(audio_values[self.to_yield : -self.stride])
self.to_yield += len(audio_values) - self.to_yield - self.stride

def end(self):
def end(self) -> None:
"""Flushes any remaining cache and appends the stop symbol."""
if self.token_cache is not None:
audio_values = self.apply_delay_pattern_mask(self.token_cache)
Expand All @@ -121,18 +120,19 @@ def end(self):

self.on_finalized_audio(audio_values[self.to_yield :], stream_end=True)

def on_finalized_audio(self, audio: np.ndarray, stream_end: bool = False):
def on_finalized_audio(self, audio: np.ndarray, stream_end: bool = False) -> None:
"""Put the new audio in the queue. If the stream is ending, also put a stop signal in the queue."""
self.audio_queue.put(audio, timeout=self.timeout)
if stream_end:
self.audio_queue.put(self.stop_signal, timeout=self.timeout)

def __iter__(self):
def __iter__(self) -> Generator[np.ndarray, None, None]:
return self

def __next__(self):
def __next__(self) -> np.ndarray:
value = self.audio_queue.get(timeout=self.timeout)
if not isinstance(value, np.ndarray) and value == self.stop_signal:
raise StopIteration()
else:
return value
return value