Skip to content

Commit f9cff53

Browse files
committed
rename codec context manager precision function
Signed-off-by: Edresson Casanova <edresson1@gmail.com>
1 parent f49f391 commit f9cff53

File tree

1 file changed

+20
-17
lines changed

1 file changed

+20
-17
lines changed

nemo/collections/speechlm2/models/duplex_ear_tts.py

Lines changed: 20 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -65,16 +65,16 @@ def maybe_to(x, dtype):
6565

6666

6767
@contextmanager
68-
def ensures_16_precision(mixed_dtype):
68+
def ensures_target_precision(target_dtype):
6969
"""
7070
Workaround for precision related issues when training with bf16-true PyTorch Lightning precision setting.
7171
In bf16-true, PTL changes PyTorch's default dtype, which may break implicit assumptions for some models.
7272
This context manager restores default float32 precision and runs the computation in float32 autocast context.
7373
"""
7474
default_dtype = torch.get_default_dtype()
75-
torch.set_default_dtype(mixed_dtype)
75+
torch.set_default_dtype(target_dtype)
7676
try:
77-
with torch.amp.autocast(device_type="cuda" if torch.cuda.is_available() else "cpu", dtype=mixed_dtype):
77+
with torch.amp.autocast(device_type="cuda" if torch.cuda.is_available() else "cpu", dtype=target_dtype):
7878
yield
7979
finally:
8080
torch.set_default_dtype(default_dtype)
@@ -175,7 +175,7 @@ def new_forward(*args, **kwargs):
175175
for k, v in kwargs.items()
176176
}
177177
# with torch.cuda.amp.autocast(enabled=True, dtype=mixed_dtype):
178-
with ensures_16_precision(mixed_dtype):
178+
with ensures_target_precision(mixed_dtype):
179179
return module._original_forward(*new_args, **new_kwargs)
180180

181181
module.forward = new_forward
@@ -315,10 +315,10 @@ def setup_rvq_audio_codec(model):
315315
316316
Includes a workaround for PTL auto-downcasting the codec model to bf16 with bf16-true precision.
317317
"""
318-
if hasattr(model, "audio_codec") and next(model.audio_codec.parameters()).dtype == torch.float:
318+
if hasattr(model, "audio_codec") and next(model.audio_codec.parameters()).dtype == model.audio_codec_run_dtype:
319319
return # skip if already set up and has the right dtype
320320

321-
with fp32_precision():
321+
with ensures_target_precision(model.audio_codec_run_dtype):
322322
if model.cfg.get("pretrained_ae_dir", None):
323323
model.audio_codec = (
324324
RVQVAEModel.from_pretrained(
@@ -448,6 +448,9 @@ def __init__(self, cfg: dict) -> None:
448448
# delete llm because we use it only to get the embbeding tokens
449449
del self.language_model
450450

451+
# get codec run precision
452+
self.audio_codec_run_dtype = getattr(torch, self.cfg.get("audio_codec_run_dtype", "bfloat16"), torch.float32)
453+
451454
# instanciate eartts model and codec
452455
self._load_tts_model(self.cfg)
453456
self._codebook_size = self.tts_model.config.codebook_size
@@ -495,7 +498,7 @@ def get_codec_silence_frame_last_one(self):
495498
audio_len = torch.tensor([audio.size(-1)]).long()
496499
audio, audio_len = self.pad_audio_to_factor(audio, audio_len, self.target_samples_per_frame)
497500

498-
with fp32_precision(), torch.no_grad():
501+
with ensures_target_precision(self.audio_codec_run_dtype), torch.no_grad():
499502
sil_codes, sil_codes_lens = self.audio_codec.encode(audio.unsqueeze(1), audio_len)
500503
return sil_codes[0, -1]
501504

@@ -507,7 +510,7 @@ def get_codec_silence_frame(self):
507510
audio_len = torch.tensor([audio.size(-1)]).long()
508511
audio, audio_len = self.pad_audio_to_factor(audio, audio_len, self.target_samples_per_frame)
509512

510-
with fp32_precision(), torch.no_grad():
513+
with ensures_target_precision(self.audio_codec_run_dtype), torch.no_grad():
511514
sil_codes, _ = self.audio_codec.encode(audio.unsqueeze(1), audio_len) # [1, T, C]
512515
sil_codes = sil_codes[0] # [T, C]
513516

@@ -693,10 +696,10 @@ def prepare_inputs(self, batch: dict):
693696
aligned_position_ids = batch["aligned_position_ids"]
694697

695698
# extract target audio codes
696-
with fp32_precision(), torch.no_grad():
697-
target_audio, target_audio_lens = self.pad_audio_to_factor(
698-
target_audio, target_audio_lens, self.target_samples_per_frame, 1
699-
)
699+
target_audio, target_audio_lens = self.pad_audio_to_factor(
700+
target_audio, target_audio_lens, self.target_samples_per_frame, 1
701+
)
702+
with ensures_target_precision(self.audio_codec_run_dtype), torch.no_grad():
700703
target_codes, target_codes_lens = self.audio_codec.encode(target_audio.unsqueeze(1), target_audio_lens)
701704

702705
# ToDo: consider use the source audio
@@ -708,8 +711,8 @@ def prepare_inputs(self, batch: dict):
708711
source_audio_lens = (source_audio_lens * (self.target_sample_rate/self.source_sample_rate)).to(lengths.dtype)
709712
# ToDo: Add a transformer encoder to help the model to better extract contextual information, replace the code bellow with it
710713
# extract embedding for context audios
711-
with fp32_precision(), torch.no_grad():
712-
source_audio, source_audio_lens = self.pad_audio_to_factor(source_audio, source_audio_lens, self.target_samples_per_frame, 1)
714+
source_audio, source_audio_lens = self.pad_audio_to_factor(source_audio, source_audio_lens, self.target_samples_per_frame, 1)
715+
with ensures_target_precision(self.audio_codec_run_dtype), torch.no_grad():
713716
source_codes, source_codes_lens = self.audio_codec.encode(
714717
source_audio.unsqueeze(1), source_audio_lens
715718
)
@@ -916,7 +919,7 @@ def get_teacher_force_inference_audio(self, batch, guidance_enabled=True):
916919
tf_audio_codes_pred = replace_control_speech_codes(
917920
tf_audio_codes_pred, self._control_codes, self.codec_silence_tokens
918921
)
919-
with fp32_precision(), torch.no_grad():
922+
with ensures_target_precision(self.audio_codec_run_dtype), torch.no_grad():
920923
audio_pred, audio_len = self.audio_codec.decode(tf_audio_codes_pred, inputs["output_lens"])
921924

922925
return audio_pred.squeeze(1), audio_len
@@ -1429,7 +1432,7 @@ def set_init_inputs(self, speaker_audio, speaker_audio_lens, system_prompt=None,
14291432
target_audio_len = torch.tensor(
14301433
[target_audio.size(-1)] * target_audio.size(0), dtype=torch.long, device=self.device
14311434
)
1432-
with fp32_precision(), torch.no_grad():
1435+
with ensures_target_precision(self.audio_codec_run_dtype), torch.no_grad():
14331436
code, _ = self.audio_codec.encode(target_audio.unsqueeze(1), target_audio_len)
14341437

14351438
# get context hidden
@@ -1783,7 +1786,7 @@ def offline_inference(
17831786
gen_audio_codes = replace_control_speech_codes(
17841787
gen_audio_codes, self._control_codes, self.codec_silence_tokens
17851788
)
1786-
with fp32_precision(), torch.no_grad():
1789+
with ensures_target_precision(self.audio_codec_run_dtype), torch.no_grad():
17871790
audio_pred, audio_pred_len = self.audio_codec.decode(gen_audio_codes, gen_audio_codes_lens)
17881791

17891792
return audio_pred.squeeze(1), audio_pred_len

0 commit comments

Comments
 (0)