@@ -65,16 +65,16 @@ def maybe_to(x, dtype):
6565
6666
6767@contextmanager
68- def ensures_16_precision ( mixed_dtype ):
68+ def ensures_target_precision ( target_dtype ):
6969 """
7070 Workaround for precision related issues when training with bf16-true PyTorch Lightning precision setting.
7171 In bf16-true, PTL changes PyTorch's default dtype, which may break implicit assumptions for some models.
7272 This context manager restores default float32 precision and runs the computation in float32 autocast context.
7373 """
7474 default_dtype = torch .get_default_dtype ()
75- torch .set_default_dtype (mixed_dtype )
75+ torch .set_default_dtype (target_dtype )
7676 try :
77- with torch .amp .autocast (device_type = "cuda" if torch .cuda .is_available () else "cpu" , dtype = mixed_dtype ):
77+ with torch .amp .autocast (device_type = "cuda" if torch .cuda .is_available () else "cpu" , dtype = target_dtype ):
7878 yield
7979 finally :
8080 torch .set_default_dtype (default_dtype )
@@ -175,7 +175,7 @@ def new_forward(*args, **kwargs):
175175 for k , v in kwargs .items ()
176176 }
177177 # with torch.cuda.amp.autocast(enabled=True, dtype=mixed_dtype):
178- with ensures_16_precision (mixed_dtype ):
178+ with ensures_target_precision (mixed_dtype ):
179179 return module ._original_forward (* new_args , ** new_kwargs )
180180
181181 module .forward = new_forward
@@ -315,10 +315,10 @@ def setup_rvq_audio_codec(model):
315315
316316 Includes a workaround for PTL auto-downcasting the codec model to bf16 with bf16-true precision.
317317 """
318- if hasattr (model , "audio_codec" ) and next (model .audio_codec .parameters ()).dtype == torch . float :
318+ if hasattr (model , "audio_codec" ) and next (model .audio_codec .parameters ()).dtype == model . audio_codec_run_dtype :
319319 return # skip if already set up and has the right dtype
320320
321- with fp32_precision ( ):
321+ with ensures_target_precision ( model . audio_codec_run_dtype ):
322322 if model .cfg .get ("pretrained_ae_dir" , None ):
323323 model .audio_codec = (
324324 RVQVAEModel .from_pretrained (
@@ -448,6 +448,9 @@ def __init__(self, cfg: dict) -> None:
448448 # delete llm because we use it only to get the embbeding tokens
449449 del self .language_model
450450
451+ # get codec run precision
452+ self .audio_codec_run_dtype = getattr (torch , self .cfg .get ("audio_codec_run_dtype" , "bfloat16" ), torch .float32 )
453+
451454 # instanciate eartts model and codec
452455 self ._load_tts_model (self .cfg )
453456 self ._codebook_size = self .tts_model .config .codebook_size
@@ -495,7 +498,7 @@ def get_codec_silence_frame_last_one(self):
495498 audio_len = torch .tensor ([audio .size (- 1 )]).long ()
496499 audio , audio_len = self .pad_audio_to_factor (audio , audio_len , self .target_samples_per_frame )
497500
498- with fp32_precision ( ), torch .no_grad ():
501+ with ensures_target_precision ( self . audio_codec_run_dtype ), torch .no_grad ():
499502 sil_codes , sil_codes_lens = self .audio_codec .encode (audio .unsqueeze (1 ), audio_len )
500503 return sil_codes [0 , - 1 ]
501504
@@ -507,7 +510,7 @@ def get_codec_silence_frame(self):
507510 audio_len = torch .tensor ([audio .size (- 1 )]).long ()
508511 audio , audio_len = self .pad_audio_to_factor (audio , audio_len , self .target_samples_per_frame )
509512
510- with fp32_precision ( ), torch .no_grad ():
513+ with ensures_target_precision ( self . audio_codec_run_dtype ), torch .no_grad ():
511514 sil_codes , _ = self .audio_codec .encode (audio .unsqueeze (1 ), audio_len ) # [1, T, C]
512515 sil_codes = sil_codes [0 ] # [T, C]
513516
@@ -693,10 +696,10 @@ def prepare_inputs(self, batch: dict):
693696 aligned_position_ids = batch ["aligned_position_ids" ]
694697
695698 # extract target audio codes
696- with fp32_precision (), torch . no_grad ():
697- target_audio , target_audio_lens = self .pad_audio_to_factor (
698- target_audio , target_audio_lens , self . target_samples_per_frame , 1
699- )
699+ target_audio , target_audio_lens = self . pad_audio_to_factor (
700+ target_audio , target_audio_lens , self .target_samples_per_frame , 1
701+ )
702+ with ensures_target_precision ( self . audio_codec_run_dtype ), torch . no_grad ():
700703 target_codes , target_codes_lens = self .audio_codec .encode (target_audio .unsqueeze (1 ), target_audio_lens )
701704
702705 # ToDo: consider use the source audio
@@ -708,8 +711,8 @@ def prepare_inputs(self, batch: dict):
708711 source_audio_lens = (source_audio_lens * (self.target_sample_rate/self.source_sample_rate)).to(lengths.dtype)
709712 # ToDo: Add a transformer encoder to help the model to better extract contextual information, replace the code bellow with it
710713 # extract embedding for context audios
711- with fp32_precision(), torch.no_grad():
712- source_audio, source_audio_lens = self.pad_audio_to_factor(source_audio, source_audio_lens, self.target_samples_per_frame, 1)
714+ source_audio, source_audio_lens = self.pad_audio_to_factor(source_audio, source_audio_lens, self.target_samples_per_frame, 1)
715+ with ensures_target_precision( self.audio_codec_run_dtype), torch.no_grad():
713716 source_codes, source_codes_lens = self.audio_codec.encode(
714717 source_audio.unsqueeze(1), source_audio_lens
715718 )
@@ -916,7 +919,7 @@ def get_teacher_force_inference_audio(self, batch, guidance_enabled=True):
916919 tf_audio_codes_pred = replace_control_speech_codes (
917920 tf_audio_codes_pred , self ._control_codes , self .codec_silence_tokens
918921 )
919- with fp32_precision ( ), torch .no_grad ():
922+ with ensures_target_precision ( self . audio_codec_run_dtype ), torch .no_grad ():
920923 audio_pred , audio_len = self .audio_codec .decode (tf_audio_codes_pred , inputs ["output_lens" ])
921924
922925 return audio_pred .squeeze (1 ), audio_len
@@ -1429,7 +1432,7 @@ def set_init_inputs(self, speaker_audio, speaker_audio_lens, system_prompt=None,
14291432 target_audio_len = torch .tensor (
14301433 [target_audio .size (- 1 )] * target_audio .size (0 ), dtype = torch .long , device = self .device
14311434 )
1432- with fp32_precision ( ), torch .no_grad ():
1435+ with ensures_target_precision ( self . audio_codec_run_dtype ), torch .no_grad ():
14331436 code , _ = self .audio_codec .encode (target_audio .unsqueeze (1 ), target_audio_len )
14341437
14351438 # get context hidden
@@ -1783,7 +1786,7 @@ def offline_inference(
17831786 gen_audio_codes = replace_control_speech_codes (
17841787 gen_audio_codes , self ._control_codes , self .codec_silence_tokens
17851788 )
1786- with fp32_precision ( ), torch .no_grad ():
1789+ with ensures_target_precision ( self . audio_codec_run_dtype ), torch .no_grad ():
17871790 audio_pred , audio_pred_len = self .audio_codec .decode (gen_audio_codes , gen_audio_codes_lens )
17881791
17891792 return audio_pred .squeeze (1 ), audio_pred_len
0 commit comments