Skip to content

Commit 77d10df

Browse files
authored
Fix Parler-TTS streamer (#170)
1 parent 650288a commit 77d10df

File tree

1 file changed

+20
-11
lines changed

1 file changed

+20
-11
lines changed

parler_tts/streamer.py

Lines changed: 20 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -42,6 +42,11 @@ def __init__(
4242
self.audio_encoder = model.audio_encoder
4343
self.generation_config = model.generation_config
4444
self.device = device if device is not None else model.device
45+
self.use_audio_scales = model.use_audio_scales
46+
self.use_4dim_audio_codes = model.use_4dim_audio_codes
47+
self.audio_kwargs = {}
48+
if self.use_audio_scales:
49+
self.audio_kwargs["audio_scales"] = [None]
4550

4651
# variables used in the streaming process
4752
self.play_steps = play_steps
@@ -72,8 +77,10 @@ def apply_delay_pattern_mask(self, input_ids):
7277
# revert the pattern delay mask by filtering the pad token id
7378
mask = (delay_pattern_mask != self.generation_config.bos_token_id) & (delay_pattern_mask != self.generation_config.pad_token_id)
7479
input_ids = input_ids[mask].reshape(1, self.decoder.num_codebooks, -1)
75-
# append the frame dimension back to the audio codes
76-
input_ids = input_ids[None, ...]
80+
81+
if self.use_4dim_audio_codes:
82+
# append the frame dimension back to the audio codes
83+
input_ids = input_ids[None, ...]
7784

7885
# send the input_ids to the correct device
7986
input_ids = input_ids.to(self.audio_encoder.device)
@@ -84,17 +91,19 @@ def apply_delay_pattern_mask(self, input_ids):
8491
or self.generation_config.eos_token_id in input_ids
8592
)
8693
if not decode_sequentially:
87-
output_values = self.audio_encoder.decode(
88-
input_ids,
89-
audio_scales=[None],
90-
)
94+
sample = self.audio_encoder.decode(
95+
audio_codes=input_ids,
96+
**self.audio_kwargs,
97+
).audio_values
98+
output_values = sample if sample.ndim == 3 else sample.unsqueeze(0)
9199
else:
92-
sample = input_ids[:, 0]
93-
sample_mask = (sample >= self.audio_encoder.config.codebook_size).sum(dim=(0, 1)) == 0
94-
sample = sample[:, :, sample_mask]
95-
output_values = self.audio_encoder.decode(sample[None, ...], [None])
100+
sample = input_ids[:, 0] if self.use_4dim_audio_codes else input_ids[0]
101+
sample_mask = ((sample >= self.audio_encoder.config.codebook_size).sum(dim=(0, 1)) == 0) if self.use_4dim_audio_codes else ((sample >= self.audio_encoder.config.codebook_size).sum(dim=0) == 0)
102+
sample = sample[:, :, sample_mask] if self.use_4dim_audio_codes else sample[:, sample_mask]
103+
sample = self.audio_encoder.decode(audio_codes=sample[None, ...], **self.audio_kwargs).audio_values
104+
output_values = sample if sample.ndim == 3 else sample.unsqueeze(0)
96105

97-
audio_values = output_values.audio_values[0, 0]
106+
audio_values = output_values[0, 0]
98107
return audio_values.cpu().float().numpy()
99108

100109
def put(self, value):

0 commit comments

Comments
 (0)