diff --git a/training/run_parler_tts_training.py b/training/run_parler_tts_training.py index 1e368e4..bf6c444 100644 --- a/training/run_parler_tts_training.py +++ b/training/run_parler_tts_training.py @@ -433,7 +433,7 @@ def pass_through_processors(description, prompt): def apply_audio_decoder(batch): len_audio = batch.pop("len_audio") audio_decoder.to(batch["input_values"].device).eval() - if bandwidth is not None: + if bandwidth in encoder_signature: batch["bandwidth"] = bandwidth elif "num_quantizers" in encoder_signature: batch["num_quantizers"] = num_codebooks