File tree Expand file tree Collapse file tree 2 files changed +14
-7
lines changed Expand file tree Collapse file tree 2 files changed +14
-7
lines changed Original file line number Diff line number Diff line change 12
12
class DACModel (PreTrainedModel ):
13
13
config_class = DACConfig
14
14
15
+ # Set main input to 'input_values' for voice steering
16
+ main_input_name = "input_values"
17
+
15
18
def __init__ (self , config ):
16
19
super ().__init__ (config )
17
20
Original file line number Diff line number Diff line change @@ -3483,13 +3483,17 @@ def generate(
3483
3483
# Apply the pattern mask to the final ids
3484
3484
output_ids = self .decoder .apply_delay_pattern_mask (output_ids , model_kwargs ["decoder_delay_pattern_mask" ])
3485
3485
3486
- # Revert the pattern delay mask by filtering the eos and bos token ids from the delay pattern mask
3487
- _ , mask = self .decoder .build_delay_pattern_mask (
3488
- input_ids ,
3489
- bos_token_id = generation_config ._bos_token_tensor ,
3490
- pad_token_id = generation_config ._pad_token_tensor ,
3491
- max_length = output_ids .shape [1 ],
3492
- )
3486
+ if "input_values" in model_kwargs :
3487
+ # Handle input_values for voice steering
3488
+ mask = output_ids
3489
+ else :
3490
+ # Revert the pattern delay mask by filtering the eos and bos token ids from the delay pattern mask
3491
+ _ , mask = self .decoder .build_delay_pattern_mask (
3492
+ input_ids ,
3493
+ bos_token_id = generation_config .bos_token_id ,
3494
+ pad_token_id = generation_config .pad_token_id ,
3495
+ max_length = output_ids .shape [1 ],
3496
+ )
3493
3497
3494
3498
mask = (mask != generation_config .bos_token_id ) & (mask != generation_config .pad_token_id )
3495
3499
output_ids = output_ids [mask ].reshape (batch_size , self .decoder .num_codebooks , - 1 )
You can’t perform that action at this time.
0 commit comments