Skip to content

Commit 31816bd

Browse files
apresenceapresenceylacombe
authored
Prep for Voice Steering feature (#141)
* Prep for Voice Steering feature Credits: 1. ylacombe - Add input_values to DACModel - dac_wrapper/modeling_dac.py - #110 (comment) 2. stg2015 - Delay mask adjustment for input_values - modeling_parler_tts.py - #81 (comment) * Prep for voice steering/cloning w/ fix for non-streaming generation * Applied simpler input handling per Guppy16's suggestion * Applied Guppy16's suggested optimization * Applied Guppy17's suggested optimization for voice steering * Update parler_tts/modeling_parler_tts.py --------- Co-authored-by: apresence <[email protected]> Co-authored-by: Yoach Lacombe <[email protected]>
1 parent dcaed95 commit 31816bd

File tree

2 files changed

+14
-7
lines changed

2 files changed

+14
-7
lines changed

parler_tts/dac_wrapper/modeling_dac.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,9 @@
1212
class DACModel(PreTrainedModel):
1313
config_class = DACConfig
1414

15+
# Set main input to 'input_values' for voice steering
16+
main_input_name = "input_values"
17+
1518
def __init__(self, config):
1619
super().__init__(config)
1720

parler_tts/modeling_parler_tts.py

Lines changed: 11 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -3483,13 +3483,17 @@ def generate(
34833483
# Apply the pattern mask to the final ids
34843484
output_ids = self.decoder.apply_delay_pattern_mask(output_ids, model_kwargs["decoder_delay_pattern_mask"])
34853485

3486-
# Revert the pattern delay mask by filtering the eos and bos token ids from the delay pattern mask
3487-
_, mask = self.decoder.build_delay_pattern_mask(
3488-
input_ids,
3489-
bos_token_id=generation_config._bos_token_tensor,
3490-
pad_token_id=generation_config._pad_token_tensor,
3491-
max_length=output_ids.shape[1],
3492-
)
3486+
if "input_values" in model_kwargs:
3487+
# Handle input_values for voice steering
3488+
mask = output_ids
3489+
else:
3490+
# Revert the pattern delay mask by filtering the eos and bos token ids from the delay pattern mask
3491+
_, mask = self.decoder.build_delay_pattern_mask(
3492+
input_ids,
3493+
bos_token_id=generation_config.bos_token_id,
3494+
pad_token_id=generation_config.pad_token_id,
3495+
max_length=output_ids.shape[1],
3496+
)
34933497

34943498
mask = (mask != generation_config.bos_token_id) & (mask != generation_config.pad_token_id)
34953499
output_ids = output_ids[mask].reshape(batch_size, self.decoder.num_codebooks, -1)

0 commit comments

Comments
 (0)