Skip to content

Commit 5d0aca9

Browse files
authored
Fix how delayed pattern mask is applied (#147)
* fix delayed pattern mask * Delete parler_tts/modeling_parler_tts.py~Updated upstream
1 parent 7f97328 commit 5d0aca9

File tree

1 file changed

+15
-19
lines changed

1 file changed

+15
-19
lines changed

parler_tts/modeling_parler_tts.py

Lines changed: 15 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -3387,7 +3387,7 @@ def generate(
33873387
)
33883388

33893389
# build the delay pattern mask for offsetting each codebook prediction by 1 (this behaviour is specific to Parler-TTS)
3390-
input_ids, decoder_delay_pattern_mask = self.decoder.build_delay_pattern_mask(
3390+
delayed_input_ids, decoder_delay_pattern_mask = self.decoder.build_delay_pattern_mask(
33913391
input_ids,
33923392
bos_token_id=generation_config._bos_token_tensor,
33933393
pad_token_id=generation_config._pad_token_tensor,
@@ -3398,7 +3398,7 @@ def generate(
33983398

33993399
# input_ids are ready to be placed on the streamer (if used)
34003400
if streamer is not None:
3401-
streamer.put(input_ids.cpu())
3401+
streamer.put(delayed_input_ids.cpu())
34023402

34033403
# 7. determine generation mode
34043404
is_greedy_gen_mode = (
@@ -3419,7 +3419,7 @@ def generate(
34193419
encoder_input_ids=inputs_tensor,
34203420
prefix_allowed_tokens_fn=None,
34213421
logits_processor=logits_processor,
3422-
device=input_ids.device,
3422+
device=delayed_input_ids.device,
34233423
)
34243424

34253425
# 9. prepare stopping criteria
@@ -3436,7 +3436,7 @@ def generate(
34363436

34373437
# 10. run greedy search
34383438
outputs = self._sample(
3439-
input_ids,
3439+
delayed_input_ids,
34403440
logits_processor=logits_processor,
34413441
stopping_criteria=stopping_criteria,
34423442
generation_config=generation_config,
@@ -3447,19 +3447,19 @@ def generate(
34473447

34483448
elif is_sample_gen_mode:
34493449
# 10. prepare logits warper
3450-
logits_warper = self._get_logits_warper(generation_config, device=input_ids.device)
3450+
logits_warper = self._get_logits_warper(generation_config, device=delayed_input_ids.device)
34513451

34523452
# expand input_ids with `num_return_sequences` additional sequences per batch
3453-
input_ids, model_kwargs = self._expand_inputs_for_generation(
3454-
input_ids=input_ids,
3453+
delayed_input_ids, model_kwargs = self._expand_inputs_for_generation(
3454+
input_ids=delayed_input_ids,
34553455
expand_size=generation_config.num_return_sequences,
34563456
is_encoder_decoder=self.config.is_encoder_decoder,
34573457
**model_kwargs,
34583458
)
34593459

34603460
# 11. run sample
34613461
outputs = self._sample(
3462-
input_ids,
3462+
delayed_input_ids,
34633463
logits_processor=logits_processor,
34643464
logits_warper=logits_warper,
34653465
stopping_criteria=stopping_criteria,
@@ -3483,17 +3483,13 @@ def generate(
34833483
# Apply the pattern mask to the final ids
34843484
output_ids = self.decoder.apply_delay_pattern_mask(output_ids, model_kwargs["decoder_delay_pattern_mask"])
34853485

3486-
if "input_values" in model_kwargs:
3487-
# Handle input_values for voice steering
3488-
mask = output_ids
3489-
else:
3490-
# Revert the pattern delay mask by filtering the eos and bos token ids from the delay pattern mask
3491-
_, mask = self.decoder.build_delay_pattern_mask(
3492-
input_ids,
3493-
bos_token_id=generation_config.bos_token_id,
3494-
pad_token_id=generation_config.pad_token_id,
3495-
max_length=output_ids.shape[1],
3496-
)
3486+
# Revert the pattern delay mask by filtering the eos and bos token ids from the delay pattern mask
3487+
_, mask = self.decoder.build_delay_pattern_mask(
3488+
input_ids,
3489+
bos_token_id=generation_config.bos_token_id,
3490+
pad_token_id=generation_config.pad_token_id,
3491+
max_length=output_ids.shape[1],
3492+
)
34973493

34983494
mask = (mask != generation_config.bos_token_id) & (mask != generation_config.pad_token_id)
34993495
output_ids = output_ids[mask].reshape(batch_size, self.decoder.num_codebooks, -1)

0 commit comments

Comments
 (0)