@@ -3387,7 +3387,7 @@ def generate(
3387
3387
)
3388
3388
3389
3389
# build the delay pattern mask for offsetting each codebook prediction by 1 (this behaviour is specific to Parler-TTS)
3390
- input_ids , decoder_delay_pattern_mask = self .decoder .build_delay_pattern_mask (
3390
+ delayed_input_ids , decoder_delay_pattern_mask = self .decoder .build_delay_pattern_mask (
3391
3391
input_ids ,
3392
3392
bos_token_id = generation_config ._bos_token_tensor ,
3393
3393
pad_token_id = generation_config ._pad_token_tensor ,
@@ -3398,7 +3398,7 @@ def generate(
3398
3398
3399
3399
# input_ids are ready to be placed on the streamer (if used)
3400
3400
if streamer is not None :
3401
- streamer .put (input_ids .cpu ())
3401
+ streamer .put (delayed_input_ids .cpu ())
3402
3402
3403
3403
# 7. determine generation mode
3404
3404
is_greedy_gen_mode = (
@@ -3419,7 +3419,7 @@ def generate(
3419
3419
encoder_input_ids = inputs_tensor ,
3420
3420
prefix_allowed_tokens_fn = None ,
3421
3421
logits_processor = logits_processor ,
3422
- device = input_ids .device ,
3422
+ device = delayed_input_ids .device ,
3423
3423
)
3424
3424
3425
3425
# 9. prepare stopping criteria
@@ -3436,7 +3436,7 @@ def generate(
3436
3436
3437
3437
# 10. run greedy search
3438
3438
outputs = self ._sample (
3439
- input_ids ,
3439
+ delayed_input_ids ,
3440
3440
logits_processor = logits_processor ,
3441
3441
stopping_criteria = stopping_criteria ,
3442
3442
generation_config = generation_config ,
@@ -3447,19 +3447,19 @@ def generate(
3447
3447
3448
3448
elif is_sample_gen_mode :
3449
3449
# 10. prepare logits warper
3450
- logits_warper = self ._get_logits_warper (generation_config , device = input_ids .device )
3450
+ logits_warper = self ._get_logits_warper (generation_config , device = delayed_input_ids .device )
3451
3451
3452
3452
# expand input_ids with `num_return_sequences` additional sequences per batch
3453
- input_ids , model_kwargs = self ._expand_inputs_for_generation (
3454
- input_ids = input_ids ,
3453
+ delayed_input_ids , model_kwargs = self ._expand_inputs_for_generation (
3454
+ input_ids = delayed_input_ids ,
3455
3455
expand_size = generation_config .num_return_sequences ,
3456
3456
is_encoder_decoder = self .config .is_encoder_decoder ,
3457
3457
** model_kwargs ,
3458
3458
)
3459
3459
3460
3460
# 11. run sample
3461
3461
outputs = self ._sample (
3462
- input_ids ,
3462
+ delayed_input_ids ,
3463
3463
logits_processor = logits_processor ,
3464
3464
logits_warper = logits_warper ,
3465
3465
stopping_criteria = stopping_criteria ,
@@ -3483,17 +3483,13 @@ def generate(
3483
3483
# Apply the pattern mask to the final ids
3484
3484
output_ids = self .decoder .apply_delay_pattern_mask (output_ids , model_kwargs ["decoder_delay_pattern_mask" ])
3485
3485
3486
- if "input_values" in model_kwargs :
3487
- # Handle input_values for voice steering
3488
- mask = output_ids
3489
- else :
3490
- # Revert the pattern delay mask by filtering the eos and bos token ids from the delay pattern mask
3491
- _ , mask = self .decoder .build_delay_pattern_mask (
3492
- input_ids ,
3493
- bos_token_id = generation_config .bos_token_id ,
3494
- pad_token_id = generation_config .pad_token_id ,
3495
- max_length = output_ids .shape [1 ],
3496
- )
3486
+ # Revert the pattern delay mask by filtering the eos and bos token ids from the delay pattern mask
3487
+ _ , mask = self .decoder .build_delay_pattern_mask (
3488
+ input_ids ,
3489
+ bos_token_id = generation_config .bos_token_id ,
3490
+ pad_token_id = generation_config .pad_token_id ,
3491
+ max_length = output_ids .shape [1 ],
3492
+ )
3497
3493
3498
3494
mask = (mask != generation_config .bos_token_id ) & (mask != generation_config .pad_token_id )
3499
3495
output_ids = output_ids [mask ].reshape (batch_size , self .decoder .num_codebooks , - 1 )
0 commit comments