@@ -3387,7 +3387,7 @@ def generate(
33873387 )
33883388
33893389 # build the delay pattern mask for offsetting each codebook prediction by 1 (this behaviour is specific to Parler-TTS)
3390- input_ids , decoder_delay_pattern_mask = self .decoder .build_delay_pattern_mask (
3390+ delayed_input_ids , decoder_delay_pattern_mask = self .decoder .build_delay_pattern_mask (
33913391 input_ids ,
33923392 bos_token_id = generation_config ._bos_token_tensor ,
33933393 pad_token_id = generation_config ._pad_token_tensor ,
@@ -3398,7 +3398,7 @@ def generate(
33983398
33993399 # input_ids are ready to be placed on the streamer (if used)
34003400 if streamer is not None :
3401- streamer .put (input_ids .cpu ())
3401+ streamer .put (delayed_input_ids .cpu ())
34023402
34033403 # 7. determine generation mode
34043404 is_greedy_gen_mode = (
@@ -3419,7 +3419,7 @@ def generate(
34193419 encoder_input_ids = inputs_tensor ,
34203420 prefix_allowed_tokens_fn = None ,
34213421 logits_processor = logits_processor ,
3422- device = input_ids .device ,
3422+ device = delayed_input_ids .device ,
34233423 )
34243424
34253425 # 9. prepare stopping criteria
@@ -3436,7 +3436,7 @@ def generate(
34363436
34373437 # 10. run greedy search
34383438 outputs = self ._sample (
3439- input_ids ,
3439+ delayed_input_ids ,
34403440 logits_processor = logits_processor ,
34413441 stopping_criteria = stopping_criteria ,
34423442 generation_config = generation_config ,
@@ -3447,19 +3447,19 @@ def generate(
34473447
34483448 elif is_sample_gen_mode :
34493449 # 10. prepare logits warper
3450- logits_warper = self ._get_logits_warper (generation_config , device = input_ids .device )
3450+ logits_warper = self ._get_logits_warper (generation_config , device = delayed_input_ids .device )
34513451
34523452 # expand input_ids with `num_return_sequences` additional sequences per batch
3453- input_ids , model_kwargs = self ._expand_inputs_for_generation (
3454- input_ids = input_ids ,
3453+ delayed_input_ids , model_kwargs = self ._expand_inputs_for_generation (
3454+ input_ids = delayed_input_ids ,
34553455 expand_size = generation_config .num_return_sequences ,
34563456 is_encoder_decoder = self .config .is_encoder_decoder ,
34573457 ** model_kwargs ,
34583458 )
34593459
34603460 # 11. run sample
34613461 outputs = self ._sample (
3462- input_ids ,
3462+ delayed_input_ids ,
34633463 logits_processor = logits_processor ,
34643464 logits_warper = logits_warper ,
34653465 stopping_criteria = stopping_criteria ,
@@ -3483,17 +3483,13 @@ def generate(
34833483 # Apply the pattern mask to the final ids
34843484 output_ids = self .decoder .apply_delay_pattern_mask (output_ids , model_kwargs ["decoder_delay_pattern_mask" ])
34853485
3486- if "input_values" in model_kwargs :
3487- # Handle input_values for voice steering
3488- mask = output_ids
3489- else :
3490- # Revert the pattern delay mask by filtering the eos and bos token ids from the delay pattern mask
3491- _ , mask = self .decoder .build_delay_pattern_mask (
3492- input_ids ,
3493- bos_token_id = generation_config .bos_token_id ,
3494- pad_token_id = generation_config .pad_token_id ,
3495- max_length = output_ids .shape [1 ],
3496- )
3486+ # Revert the pattern delay mask by filtering the eos and bos token ids from the delay pattern mask
3487+ _ , mask = self .decoder .build_delay_pattern_mask (
3488+ input_ids ,
3489+ bos_token_id = generation_config .bos_token_id ,
3490+ pad_token_id = generation_config .pad_token_id ,
3491+ max_length = output_ids .shape [1 ],
3492+ )
34973493
34983494 mask = (mask != generation_config .bos_token_id ) & (mask != generation_config .pad_token_id )
34993495 output_ids = output_ids [mask ].reshape (batch_size , self .decoder .num_codebooks , - 1 )
0 commit comments