Skip to content

Commit 0bab56b

Browse files
committed
fix name and add min new tokens
1 parent a0bc9e7 commit 0bab56b

File tree

1 file changed

+5
-1
lines changed

1 file changed

+5
-1
lines changed

training/run_parler_tts_training.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -126,7 +126,7 @@ def main():
126126
"adam_beta2": training_args.adam_beta2,
127127
"temperature": model_args.temperature,
128128
},
129-
init_kwargs={"wandb": {"name": data_args.wandb_run_name}} if data_args.wandb_run_name else None,
129+
init_kwargs={"wandb": {"name": data_args.wandb_run_name}} if data_args.wandb_run_name else {},
130130
)
131131

132132
# Detecting last checkpoint and eventually continue from last checkpoint
@@ -750,6 +750,10 @@ def compute_metrics(audios, descriptions, prompts, device="cpu"):
750750
"do_sample": model_args.do_sample,
751751
"temperature": model_args.temperature,
752752
"max_length": model_args.max_length,
753+
# Because of the delayed pattern mask, generation might stop earlier because of unexpected behaviour
754+
# on the first tokens of the codebooks that are delayed.
755+
# This fix the issue.
756+
"min_new_tokens": num_codebooks + 1,
753757
}
754758

755759
# Define gradient update step fn

0 commit comments

Comments
 (0)