From d48bd1c48a4d7b4909c4adcc9c4311d41a43fb22 Mon Sep 17 00:00:00 2001 From: Edresson Casanova Date: Tue, 11 Nov 2025 05:01:08 -0800 Subject: [PATCH 001/102] Add Duplex EARTTS modules Signed-off-by: Edresson Casanova --- examples/speechlm2/duplex_eartts_train.py | 61 + .../speechlm2/duplex_eartts_train_infer.py | 62 + nemo/collections/common/data/lhotse/cutset.py | 125 + .../speechlm2/data/duplex_ear_tts_dataset.py | 547 ++++ .../speechlm2/models/duplex_ear_tts.py | 2761 +++++++++++++++++ .../speechlm2/modules/ear_tts_commons.py | 484 +++ .../speechlm2/modules/rvq_ear_tts_model.py | 1775 +++++++++++ .../speechlm2/parts/metrics/__init__.py | 7 +- .../speechlm2/parts/metrics/asr_bleu.py | 7 +- .../speechlm2/parts/metrics/bleu.py | 5 +- .../parts/metrics/intelligibility.py | 110 + .../speechlm2/parts/metrics/results_logger.py | 170 + .../speechlm2/parts/metrics/secs.py | 77 + .../speechlm2/parts/metrics/token_accuracy.py | 88 + .../speechlm2/parts/metrics/wer.py | 65 - 15 files changed, 6271 insertions(+), 73 deletions(-) create mode 100644 examples/speechlm2/duplex_eartts_train.py create mode 100644 examples/speechlm2/duplex_eartts_train_infer.py create mode 100644 nemo/collections/speechlm2/data/duplex_ear_tts_dataset.py create mode 100644 nemo/collections/speechlm2/models/duplex_ear_tts.py create mode 100644 nemo/collections/speechlm2/modules/ear_tts_commons.py create mode 100644 nemo/collections/speechlm2/modules/rvq_ear_tts_model.py create mode 100644 nemo/collections/speechlm2/parts/metrics/intelligibility.py create mode 100644 nemo/collections/speechlm2/parts/metrics/results_logger.py create mode 100644 nemo/collections/speechlm2/parts/metrics/secs.py create mode 100644 nemo/collections/speechlm2/parts/metrics/token_accuracy.py delete mode 100644 nemo/collections/speechlm2/parts/metrics/wer.py diff --git a/examples/speechlm2/duplex_eartts_train.py b/examples/speechlm2/duplex_eartts_train.py new file mode 100644 index 000000000000..ea9b4571d34f --- /dev/null +++ b/examples/speechlm2/duplex_eartts_train.py @@ -0,0 +1,61 @@ +# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import os + +import torch +from lightning.pytorch import Trainer +from omegaconf import OmegaConf + +from nemo.collections.speechlm2 import DataModule, DuplexEARTTSDataset + +from nemo.collections.speechlm2.models.duplex_ear_tts import DuplexEARTTS +from nemo.core.config import hydra_runner +from nemo.utils.exp_manager import exp_manager +from nemo.utils.trainer_utils import resolve_trainer_cfg + +torch.cuda.set_device(int(os.environ["LOCAL_RANK"])) + + +@hydra_runner(config_path="conf", config_name="s2s_duplex_speech_decoder") +def train(cfg): + OmegaConf.resolve(cfg) + torch.distributed.init_process_group(backend="nccl") + torch.set_float32_matmul_precision("medium") + torch.backends.cudnn.allow_tf32 = True + trainer = Trainer(**resolve_trainer_cfg(cfg.trainer)) + log_dir = exp_manager(trainer, cfg.get("exp_manager", None)) + OmegaConf.save(cfg, log_dir / "exp_config.yaml") + + with trainer.init_module(): + model = DuplexEARTTS(OmegaConf.to_container(cfg, resolve=True)) + + dataset = DuplexEARTTSDataset( + tokenizer=model.tokenizer, + frame_length=cfg.data.frame_length, + source_sample_rate=cfg.data.source_sample_rate, + target_sample_rate=cfg.data.target_sample_rate, + input_roles=cfg.data.input_roles, + output_roles=cfg.data.output_roles, + add_text_bos_and_eos_in_each_turn=cfg.data..get("add_text_bos_and_eos_in_each_turn", True), + add_audio_prompt_after_description=cfg.data.add_audio_prompt_after_description, + audio_prompt_duration=cfg.data.audio_prompt_duration, + num_delay_speech_tokens=cfg.model.get("num_delay_speech_tokens", 2) + ) + datamodule = DataModule(cfg.data, tokenizer=model.tokenizer, dataset=dataset) + + trainer.fit(model, datamodule) + + +if __name__ == "__main__": + train() diff --git a/examples/speechlm2/duplex_eartts_train_infer.py b/examples/speechlm2/duplex_eartts_train_infer.py new file mode 100644 index 000000000000..15ec634a6fbf --- /dev/null +++ b/examples/speechlm2/duplex_eartts_train_infer.py @@ -0,0 +1,62 @@ +# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import os + +import torch +from lightning.pytorch import Trainer +from omegaconf import OmegaConf + +from nemo.collections.speechlm2 import DataModule, DuplexEARTTSDataset + +from nemo.collections.speechlm2.models.duplex_ear_tts import DuplexEARTTS +from nemo.core.config import hydra_runner +from nemo.utils.exp_manager import exp_manager +from nemo.utils.trainer_utils import resolve_trainer_cfg + +torch.cuda.set_device(int(os.environ["LOCAL_RANK"])) + +@hydra_runner(config_path="conf", config_name="s2s_duplex_speech_decoder") +def inference(cfg): + OmegaConf.resolve(cfg) + torch.distributed.init_process_group(backend="nccl") + torch.set_float32_matmul_precision("medium") + torch.backends.cudnn.allow_tf32 = True + trainer = Trainer(**resolve_trainer_cfg(cfg.trainer)) + log_dir = exp_manager(trainer, cfg.get("exp_manager", None)) + OmegaConf.save(cfg, log_dir / "exp_config.yaml") + + with trainer.init_module(): + model = DuplexEARTTS(OmegaConf.to_container(cfg, resolve=True)) + + model.eval() + + dataset = DuplexEARTTSDataset( + tokenizer=model.tokenizer, + frame_length=cfg.data.frame_length, + source_sample_rate=cfg.data.source_sample_rate, + target_sample_rate=cfg.data.target_sample_rate, + input_roles=cfg.data.input_roles, + output_roles=cfg.data.output_roles, + add_text_bos_and_eos_in_each_turn=cfg.data..get("add_text_bos_and_eos_in_each_turn", True), + add_audio_prompt_after_description=cfg.data.add_audio_prompt_after_description, + audio_prompt_duration=cfg.data.audio_prompt_duration, + num_delay_speech_tokens=cfg.model.get("num_delay_speech_tokens", 2) + ) + datamodule = DataModule(cfg.data, tokenizer=model.tokenizer, dataset=dataset) + + trainer.validate(model, datamodule) + + +if __name__ == "__main__": + inference() diff --git a/nemo/collections/common/data/lhotse/cutset.py b/nemo/collections/common/data/lhotse/cutset.py index 4ba644c97002..76825e612ba7 100644 --- a/nemo/collections/common/data/lhotse/cutset.py +++ b/nemo/collections/common/data/lhotse/cutset.py @@ -535,6 +535,131 @@ def cut_to_conversation( ) +@data_type_parser(["lhotse_magpietts_data_as_continuation"]) +def read_lhotse_magpietts_data_as_continuation(config) -> tuple[CutSet, bool]: + def convert_lhotse_magpietts_data_as_cont(cut): + # create a copy of agent supervision and original duration + orig_agent_sup = fastcopy(cut.supervisions[0]) + target_audio_org_dur = cut.target_audio.duration + + # Resample both to match sample_rate + cut.target_audio = cut.target_audio.resample(sample_rate) + cut.context_audio = cut.context_audio.resample(sample_rate) + + # Compute total duration + total_duration = cut.target_audio.duration + + # Convert target_audio (Recording) into MonoCut so we can pad it + cut_target = MonoCut( + id=f"{cut.id}_target", + start=0.0, + duration=cut.target_audio.duration, + channel=0, + recording=cut.target_audio, + supervisions=[], + ) + + # create silence audio + num_samples = int(total_duration * sample_rate) + zero_audio = np.zeros((1, num_samples), dtype=np.float32) + source_recording = create_recording_from_array( + zero_audio, + sampling_rate=sample_rate, + recording_id=f"{cut.id}_source", + ) + + cut_source = MonoCut( + id=f"{cut.id}_source", + start=0.0, + duration=cut.target_audio.duration, + channel=0, + recording=source_recording, + supervisions=[], + ) + + # Save both to memory + cut_source = cut_source.move_to_memory(audio_format='wav') + cut_target = cut_target.move_to_memory(audio_format='wav') + + # user starts on zeros with dummy text + user_sup = fastcopy( + orig_agent_sup, + start=0.0, + duration=0.08, # keep only on frame to the user + speaker="user", + text="dummy text", + ) + # agent starts when user turn finish and has target_audio_dur + agent_sup = fastcopy( + orig_agent_sup, + start=0.0, + duration=target_audio_org_dur - 0.08, + speaker="agent", + ) + + # Add extra sil in the end of the audio to force the model to produce silence if it receives zeros and the was all processed + if ADD_EXTRA_END_SIL: + sil_duration = random.uniform(*SILENCE_RANGE) + # pad audios + cut_target = cut_target.pad(duration=total_duration + sil_duration, direction="right") + cut_source = cut_source.pad(duration=total_duration + sil_duration, direction="right") + # Save both to memory + cut_source = cut_source.to_mono().move_to_memory(audio_format='wav') + cut_target = cut_target.to_mono().move_to_memory(audio_format='wav') + agent_sup.duration = agent_sup.duration + sil_duration + 1.0 # added here 1.0 seconds to not have text EOS for this dataset to avoid conflicts with S2S, text EOS is the interruption token on duplex + user_sup.duration = user_sup.duration + sil_duration + + # Assemble final cut + cut_source.supervisions = [user_sup, agent_sup] + cut_source.recording = cut_source.recording # remains the resampled context_audio + cut_source.target_audio = cut_target.recording + cut_source.duration = cut_target.duration + cut_source.formatter = "lhotse_magpietts_data_as_continuation" + cut_source.context_audio = cut.context_audio + return cut_source + + def filter_cer(example): + if isinstance(example, Cut) and len(example.supervisions) > 0 and example.supervisions[0].has_custom("cer"): + return example.supervisions[0].cer <= MAX_CER + else: + return True + + def filter_val_flag(example): + if isinstance(example, Cut) and example.has_custom("validation_status") and example.validation_status != KEEP_FLAG: + return False + else: + return True + + def filter_secs(example): + if isinstance(example, Cut) and len(example.supervisions) > 0 and example.supervisions[0].has_custom("context_speaker_similarity"): + return example.supervisions[0].context_speaker_similarity >= MIN_SECS + else: + return True + + # load lhotse cuts + cuts, is_tarred = read_cutset_from_config(config) + + ADD_EXTRA_END_SIL = config.get("add_extra_end_silence", False) + SILENCE_RANGE = config.get("extra_end_silence_range", [0.5, 6.0]) + + # load prompt cut + sample_rate = 22050 + + # filter dataset + MAX_CER = config.get("max_cer", 0.03) + cuts = cuts.filter(filter_cer) + # filter invalid samples + KEEP_FLAG = "pass" + cuts = cuts.filter(filter_val_flag) + # filter based on context speaker similarity + MIN_SECS = config.get("min_context_speaker_similarity", 0.6) + cuts = cuts.filter(filter_secs) + + # convert cuts + cuts = cuts.map(convert_lhotse_magpietts_data_as_cont) + return cuts, is_tarred + + @data_type_parser(["lhotse_as_conversation"]) def read_lhotse_as_conversation(config) -> tuple[CutSet, bool]: cuts, is_tarred = read_cutset_from_config(config) diff --git a/nemo/collections/speechlm2/data/duplex_ear_tts_dataset.py b/nemo/collections/speechlm2/data/duplex_ear_tts_dataset.py new file mode 100644 index 000000000000..cf511539bd9d --- /dev/null +++ b/nemo/collections/speechlm2/data/duplex_ear_tts_dataset.py @@ -0,0 +1,547 @@ +# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import re + +import torch +import torch.nn.functional as F +import torch.utils.data +import torchaudio +import random + +from lhotse import CutSet, MonoCut, Recording, Seconds, SupervisionSegment, compute_num_frames +from lhotse.cut import Cut +from lhotse.dataset.collation import collate_audio, collate_vectors +from lhotse.utils import ifnone + +from nemo.collections.common.tokenizers import TokenizerSpec +from nemo.collections.speechlm2.data.utils import get_pad_id +from nemo.utils import logging +from nemo.collections.speechlm2.modules.ear_tts_commons import SCRIPT_PLACEHOLDER + + +def sample_audio_segments_repeat( + prompt_audio: torch.Tensor, + prompt_audio_lens: torch.Tensor, + n_sample: int, + sample: bool = True, +) -> torch.Tensor: + """ + Extract audio segments of length n_sample. + If sample=True: randomly sample segments (repeating if shorter). + If sample=False: always take from the beginning (repeating if shorter). + + Args: + prompt_audio: Tensor [B, T] + prompt_audio_lens: Tensor [B] with valid lengths + n_sample: int, target length per segment + sample: bool, whether to randomly sample (True) or take first seconds (False) + + Returns: + Tensor [B, n_sample] + """ + B, T = prompt_audio.shape + device = prompt_audio.device + out = torch.zeros(B, n_sample, device=device, dtype=prompt_audio.dtype) + + for b in range(B): + length = min(prompt_audio_lens[b].item(), T) + + # Case: empty audio + if length <= 0: + continue + + if length >= n_sample: + if sample: + # Random start (safe bounds) + max_start = max(1, length - n_sample + 1) + start = torch.randint(0, max_start, (1,), device=device).item() + else: + # Deterministic: take from start + start = 0 + out[b] = prompt_audio[b, start:start + n_sample] + + else: + # Audio shorter than target → repeat + if sample: + # Random start position + start = torch.randint(0, length, (1,), device=device).item() + else: + start = 0 + segment = prompt_audio[b, start:length] + + repeat_times = (n_sample + (length - start) - 1) // (length - start) + repeated = segment.repeat(repeat_times)[:n_sample] + out[b] = repeated + + return out + + +def get_mask_from_lengths( + lengths: torch.Tensor = None, + x: torch.Tensor = None, +) -> torch.Tensor: + """Constructs binary mask from a 1D torch tensor of input lengths + Args: + lengths: torch.tensor (torch.tensor): 1D tensor with lengths + x: torch.tensor = tensor to be used on, last dimension is for mask + Returns: + mask (torch.tensor): num_sequences x max_length binary tensor + """ + if lengths is None: + assert x is not None + return torch.ones(x.shape[-1], dtype=torch.bool, device=x.device) + else: + if x is None: + max_len = torch.max(lengths) + else: + max_len = x.shape[-1] + + ids = torch.arange(0, max_len, device=lengths.device, dtype=lengths.dtype) + mask = ids < lengths.unsqueeze(1) + return mask + + +class DuplexEARTTSDataset(torch.utils.data.Dataset): + """ + A dataset for duplex speech-to-speech models that handles bidirectional conversations. + + This dataset processes Lhotse CutSet objects containing recordings with supervision segments + from different speakers (roles). It creates aligned representations of audio and text for + both source (input) and target (output) channels, preserving temporal alignment between + audio frames and text tokens. + + Args: + tokenizer (TokenizerSpec): + Tokenizer for converting text to token IDs and vice versa. Must support BOS and EOS tokens. + It's expected to support PAD token as well, otherwise we will use 0 as the pad token + and emit a warning. + + frame_length (Seconds): + Duration of a single frame in seconds. Used to calculate frame positions for token alignment. + + source_sample_rate (int): + Sample rate for source audio (e.g., 16000 Hz). + + target_sample_rate (int): + Sample rate for target audio (e.g., 22050 Hz). + + input_roles (list[str], optional): + List of speaker roles (cut.supervisions[:].speaker) to consider as inputs. Defaults to ["user"]. + + output_roles (list[str], optional): + List of speaker roles (cut.supervisions[:].speaker) to consider as outputs. Defaults to ["agent"]. + + Returns: + A dictionary with the following keys: + - source_audio: Tensor of source waveform samples [B, T] + - source_audio_lens: Tensor of source audio lengths [B] + - target_audio: Tensor of target waveform samples [B, T] + - target_audio_lens: Tensor of target audio lengths [B] + - input_text_tokens: Tensor of target text tokens [B, T], with special tokens (BOS/EOS/PAD) + at positions aligned with audio frames + - target_token_lens: Tensor of target token sequence lengths [B] + - source_tokens: Tensor of source text tokens [B, T], with special tokens (BOS/EOS/PAD) + at positions aligned with audio frames + - source_token_lens: Tensor of source token sequence lengths [B] + - target_texts: List of full target texts joined from output_roles supervisions [B] + + Notes: + - The dataset ensures frame-level alignment between audio and text by inserting tokens at + specific frame positions based on the timing of supervision segments. + - PAD tokens (typically 0) are used to fill gaps where there's no text. + - BOS tokens mark the beginning of each speech segment. + - EOS tokens mark the end of each speech segment. + - Text tokens from each speaker are placed at frame positions corresponding to their + timestamp in the original recording, preserving the temporal relationship. + This is a segment-level alignment only, not word-level alignment. + """ + + def __init__( + self, + tokenizer, + frame_length: Seconds, + source_sample_rate: int, + target_sample_rate: int, + input_roles: list[str] = None, + output_roles: list[str] = None, + add_description: bool = True, + p_drop_description: float = 0.0, + add_text_bos_and_eos_in_each_turn: bool = False, + add_audio_prompt_after_description: bool = False, + audio_prompt_duration: float = 3.0, + num_delay_speech_tokens: int = 0 + ): + self.tokenizer = tokenizer + self.frame_length = frame_length + self.source_sample_rate = source_sample_rate + self.target_sample_rate = target_sample_rate + self.input_roles = set(ifnone(input_roles, ["user"])) + self.output_roles = set(ifnone(output_roles, ["agent"])) + self.add_description = add_description + self.p_drop_description = p_drop_description + self.add_text_bos_and_eos_in_each_turn = add_text_bos_and_eos_in_each_turn + self.add_audio_prompt_after_description = add_audio_prompt_after_description + self.audio_prompt_duration = audio_prompt_duration + self.num_delay_speech_tokens = num_delay_speech_tokens + + assert tokenizer.bos is not None, "BOS support in the tokenizer is required for S2S models." + assert tokenizer.eos is not None, "EOS support in the tokenizer is required for S2S models." + + def generate_prompt_description(self, device): + messages = [] + if random.random() > self.p_drop_description: + # ToDo: add extra system prompts + system_prompt = ( + "You engage in conversation with the user. When delivering your response as speech, " + "if the user provides a description such as emotions, scene details, " + "or speaker style, you adjust your speaking style accordingly when delivering the response. " + "However, this description should influence only the delivery of your response, not its content. " + "Your response should remain independent of any stylistic instructions." + ) + messages.append({"role": "system", "content": system_prompt}) + else: + messages.append({"role": "system", "content": ""}) + + # given that descriptions are currently not supported, only added the user prompt + # ToDo: add extra user prompts or completly remove it as it is not used in NanoV2 + user_prompt = "Can you tell me something interesting?" + messages.append({"role": "user", "content": user_prompt}) + messages.append({"role": "assistant", "content": SCRIPT_PLACEHOLDER}) + non_script_list = self.tokenizer.apply_chat_template( + messages, + tokenize=False, + add_generation_prompt=False, + ).split(SCRIPT_PLACEHOLDER + self.tokenizer.eos_token)[:-1] + + input_ids = [] + for i, non_script in enumerate(non_script_list): + desc_ids = self.tokenizer.text_to_ids(non_script) + input_ids.extend(desc_ids) + + input_ids = torch.tensor(input_ids, dtype=torch.long, device=device).view(1, -1) + return input_ids + + def __getitem__(self, cuts: CutSet) -> dict: + cuts = cuts.transform_text(_strip_timestamps) + source_audio, source_audio_lens = collate_audio(cuts.resample(self.source_sample_rate)) + target_audio, target_audio_lens = collate_audio( + cuts.resample(self.target_sample_rate), recording_field="target_audio" + ) + input_text_tokens, target_token_lens = collate_token_channel( + cuts, self.tokenizer, self.frame_length, roles=self.output_roles, add_text_bos_and_eos_in_each_turn=self.add_text_bos_and_eos_in_each_turn, + ) + source_tokens, source_token_lens = collate_token_channel( + cuts, self.tokenizer, self.frame_length, roles=self.input_roles, add_text_bos_and_eos_in_each_turn=self.add_text_bos_and_eos_in_each_turn + ) + + # if context audio is available use it, otherwise use a random agent turn as speaker reference + if hasattr(cuts[0], "context_audio"): + speaker_reference_audio = [] + speaker_reference_audio_lens = [] + for cut in cuts: + ref_audio = torch.tensor(cut.context_audio.resample(self.target_sample_rate).load_audio()).float() + ref_audio_len = torch.tensor(ref_audio.shape[1]).long() + speaker_reference_audio.append(ref_audio.squeeze(0)) + speaker_reference_audio_lens.append(ref_audio_len) + + speaker_reference_audio = collate_vectors( + speaker_reference_audio, padding_value=0 + ).float() + speaker_reference_audio_lens = torch.tensor(speaker_reference_audio_lens).long() + else: + # extract target speaker reference from a random audio audio + speaker_reference_audio, speaker_reference_audio_lens = collate_random_turn_audio( + cuts.resample(self.target_sample_rate), roles=self.output_roles, recording_field="target_audio" + ) + + # ensures that input_text_tokens is not longer than its duration + input_text_tokens = input_text_tokens[:, :target_token_lens.max()] + + source_fps = self.source_sample_rate / ( + self.source_sample_rate * self.frame_length + ) + source_samples_per_frame = int(self.source_sample_rate//source_fps) + target_fps = self.target_sample_rate / ( + self.target_sample_rate * self.frame_length + ) + target_samples_per_frame = int(self.target_sample_rate//target_fps) + + # one is default and we add BOS on speech channel to ensures it, inside of the model class, so if we want bigger than that we can add padding in the audio here + if self.num_delay_speech_tokens: + # compute the padding need in target audio for the number of delay tokens + extra_frames = int(self.num_delay_speech_tokens * target_samples_per_frame) + # left pad target audio to create the delay and make the model to predict silence while consuming self.num_delay_speech_tokens text tokens + target_audio = F.pad(target_audio, (extra_frames, 0)) + target_audio_lens = target_audio_lens + extra_frames + + # right pad the source audio to avoid size mismatch + extra_frames = int(self.num_delay_speech_tokens * source_samples_per_frame) + source_audio = F.pad(source_audio, (0, extra_frames)) + source_audio_lens = source_audio_lens + extra_frames + + if self.add_description: + text_pad_id = get_pad_id(self.tokenizer) + input_text_tokens_ = [] + source_tokens_ = [] + source_audio_ = [] + target_audio_ = [] + desc_lens = [] + desc_plus_audio_prompt_lens = [] + # for each sample in the batch + for i in range(input_text_tokens.size(0)): + desc_tokens_ids = self.generate_prompt_description(device=input_text_tokens[i].device).squeeze(0) + if self.add_audio_prompt_after_description: + prompt_audio_size = int(((self.audio_prompt_duration * self.target_sample_rate) // target_samples_per_frame) * target_samples_per_frame) + prompt_audio = sample_audio_segments_repeat(speaker_reference_audio, speaker_reference_audio_lens, prompt_audio_size, sample=True) + # add a silence in the end to smooth the transition between prompt and audio tokens, keep one extra pad token due shift on subword_ids + prompt_audio[:, -int(target_samples_per_frame * 2):] = 0 + + # create tensor to pad text channels with the same amount of frames added in audio channel (audio prompt) + prompt_audio_text_pad_size = (prompt_audio_size // target_samples_per_frame) + prompt_audio_text_pad = torch.ones(prompt_audio_text_pad_size, device=input_text_tokens.device, dtype=input_text_tokens.dtype) * text_pad_id + # set last prompt frame with eos in text channel + prompt_audio_text_pad[-1] = self.tokenizer.eos + + # Add eos to simulate the end of a turn as in EAR-TTS inference + desc_tokens_ids = torch.cat([desc_tokens_ids, torch.tensor([self.tokenizer.eos], dtype=desc_tokens_ids.dtype, device=desc_tokens_ids.device)]) + # Add padding equivalent to the audio prompt size in number of tokens + new_input_text_tokens = torch.cat([desc_tokens_ids.to(input_text_tokens.dtype), prompt_audio_text_pad.to(input_text_tokens.dtype), input_text_tokens[i]]) + # append to list and update lens + input_text_tokens_.append(new_input_text_tokens) + target_token_lens[i] = target_token_lens[i] + len(desc_tokens_ids) + prompt_audio_text_pad_size + + # add description to source text tokens + source_tokens_.append(torch.cat([desc_tokens_ids, prompt_audio_text_pad, source_tokens[i]])) + source_token_lens[i] = source_token_lens[i] + len(desc_tokens_ids) + prompt_audio_text_pad_size + # add silence in the source audio while the prompt is being processed + pad_size = (len(desc_tokens_ids) * source_samples_per_frame) + prompt_audio.size(1) + pad_audio = torch.zeros(pad_size, device=source_audio.device, dtype=source_audio.dtype) + source_audio_.append(torch.cat([pad_audio, source_audio[i]])) + source_audio_lens[i] = source_audio_lens[i] + pad_size + # add silence in the target audio while the prompt is being processed + pad_size = len(desc_tokens_ids) * target_samples_per_frame + pad_audio = torch.zeros(pad_size, device=target_audio.device, dtype=target_audio.dtype) + target_audio_.append(torch.cat([pad_audio, prompt_audio[i], target_audio[i]])) + target_audio_lens[i] = target_audio_lens[i] + pad_size + prompt_audio.size(1) + # desc duration + desc_lens.append(len(desc_tokens_ids)) + desc_plus_audio_prompt_lens.append(len(desc_tokens_ids) + prompt_audio_text_pad_size - 1) # -1 due the shift done in subword_ids + else: + # add description to target text tokens + input_text_tokens_.append(torch.cat([desc_tokens_ids, input_text_tokens[i]])) + target_token_lens[i] = target_token_lens[i] + len(desc_tokens_ids) + # add description to source text tokens + source_tokens_.append(torch.cat([desc_tokens_ids, source_tokens[i]])) + source_token_lens[i] = source_token_lens[i] + len(desc_tokens_ids) + # add silence in the source audio while the prompt is being processed + pad_size = len(desc_tokens_ids) * source_samples_per_frame + pad_audio = torch.zeros(pad_size, device=source_audio.device, dtype=source_audio.dtype) + source_audio_.append(torch.cat([pad_audio, source_audio[i]])) + source_audio_lens[i] = source_audio_lens[i] + pad_size + # add silence in the target audio while the prompt is being processed + pad_size = len(desc_tokens_ids) * target_samples_per_frame + pad_audio = torch.zeros(pad_size, device=target_audio.device, dtype=target_audio.dtype) + target_audio_.append(torch.cat([pad_audio, target_audio[i]])) + target_audio_lens[i] = target_audio_lens[i] + pad_size + + # des duration + desc_lens.append(len(desc_tokens_ids)) + desc_plus_audio_prompt_lens.append(len(desc_tokens_ids)) + + # collate tensors + input_text_tokens = collate_vectors(input_text_tokens_, padding_value=text_pad_id) + source_tokens = collate_vectors(source_tokens_, padding_value=text_pad_id) + source_audio = collate_vectors(source_audio_, padding_value=0) + target_audio = collate_vectors(target_audio_, padding_value=0) + + # recreate audio mask + audio_mask = get_mask_from_lengths(target_token_lens) + # ignore desc len in audio mask + for i, frame in enumerate(desc_lens): + audio_mask[i, :frame] = 0.0 + + # desc mask is totally the oposite of audio mask + desc_mask = ~ audio_mask + + # create non_prompt_mask that should mask desc plus audio prompt if used + non_prompt_mask = get_mask_from_lengths(target_token_lens) + for i, frame in enumerate(desc_plus_audio_prompt_lens): + non_prompt_mask[i, :frame-1] = 0.0 + else: + # create a mask for audio using target tokens that suppose to have the same size of the tokenized audio + audio_mask = get_mask_from_lengths(target_token_lens) + # create a full zero desc mask + desc_mask = torch.zeros_like(audio_mask) + # keep text mask as audio_mask + non_prompt_mask = audio_mask + + batch_size = len(target_token_lens) + max_len = max(target_token_lens) + + # Segment IDs per sequence (padded) + aligned_segment_ids = torch.stack([ + torch.nn.functional.pad(torch.full((l,), i), (0, max_len - l), value=-1) # -1 for padding + for i, l in enumerate(target_token_lens) + ], dim=0) # [B, max_len] + + # Attention mask: same-segment & causal + aligned_attention_mask = ( + (aligned_segment_ids.unsqueeze(-2) == aligned_segment_ids.unsqueeze(-1)) # [B, max_len, max_len] + & (torch.arange(max_len).unsqueeze(0).unsqueeze(1) + <= torch.arange(max_len).unsqueeze(0).unsqueeze(-1)) # causal tril + ) + + aligned_attention_mask = aligned_attention_mask.unsqueeze(1) # [B, 1, max_len, max_len] + + # create pos ids from the aligned lenght + aligned_position_ids = torch.stack([ + torch.nn.functional.pad(torch.arange(l), (0, max(target_token_lens) - l), value=0) # value=0 is safe for padding + for l in target_token_lens + ], dim=0) + + return { + "sample_id": [str(cut.id) for cut in cuts], + "audio_mask": audio_mask.bool(), + "non_prompt_mask": non_prompt_mask.bool(), + "desc_mask": desc_mask.bool(), + "desc_lens": desc_lens, + "desc_plus_audio_prompt_lens": desc_plus_audio_prompt_lens, + "aligned_attention_mask": aligned_attention_mask.bool(), + "aligned_position_ids": aligned_position_ids, + "source_audio": source_audio, + "source_audio_lens": source_audio_lens, + "target_audio": target_audio, + "target_audio_lens": target_audio_lens, + "input_text_tokens": input_text_tokens, + "target_token_lens": target_token_lens, + "source_tokens": source_tokens, + "source_token_lens": source_token_lens, + "target_texts": [ + " ".join(s.text for s in cut.supervisions if s.speaker in self.output_roles) for cut in cuts + ], + "speaker_reference_audio": speaker_reference_audio, + "speaker_reference_audio_lens": speaker_reference_audio_lens, + "formatter": [getattr(cut, "formatter", "s2s_duplex") for cut in cuts], + } + + +def collate_random_turn_audio( + cuts: CutSet, + roles: set[str], + recording_field: str = "target_audio", +) -> tuple[torch.Tensor, torch.Tensor]: + selected_turn_audios = [] + selected_turn_audios_lens = [] + for cut in cuts: + # Filter supervisions matching roles + matching_supervisions = [s for s in cut.supervisions if s.speaker in roles] + + # Randomly select one supervision + selected_supervision = random.choice(matching_supervisions) + + # Truncate audio according to supervision + truncated_audio = cut.truncate( + offset=max(0, selected_supervision.start), + duration=selected_supervision.duration + ).load_custom(recording_field) + + selected_turn_audios.append(truncated_audio.squeeze(0)) + selected_turn_audios_lens.append(truncated_audio.shape[-1]) + + return collate_vectors(selected_turn_audios, padding_value=0), torch.tensor(selected_turn_audios_lens) + + +def collate_token_channel( + cuts: CutSet, + tokenizer: TokenizerSpec, + frame_length: Seconds, + roles: set[str], + add_text_bos_and_eos_in_each_turn: bool = True +) -> tuple[torch.Tensor, torch.Tensor]: + pad_id = get_pad_id(tokenizer) + tokens = [ + build_token_channel(c, tokenizer=tokenizer, frame_length=frame_length, roles=roles, pad_id=pad_id, add_text_bos_and_eos_in_each_turn=add_text_bos_and_eos_in_each_turn) + for c in cuts + ] + token_lens = torch.tensor([len(tt) for tt in tokens]) + tokens = collate_vectors(tokens, padding_value=pad_id) + return tokens, token_lens + + +def build_token_channel( + cut: Cut, + tokenizer: TokenizerSpec, + frame_length: Seconds, + roles: set[str], + pad_id: int = -1, + add_text_bos_and_eos_in_each_turn: bool = True, +) -> torch.Tensor: + diagnostic = f"Extra info: {cut.id=}" + if getattr(cut, "shard_origin", None) is not None: + diagnostic = f"{diagnostic} {cut.shard_origin=}" + + total = compute_num_frames(cut.duration, frame_length, cut.sampling_rate) + tokens = torch.ones(total, dtype=torch.long) * pad_id + for supervision in cut.supervisions: + if supervision.speaker in roles: + text = supervision.text + if add_text_bos_and_eos_in_each_turn: + text_ids = torch.as_tensor([tokenizer.bos] + tokenizer.text_to_ids(text)) + else: + text_ids = torch.as_tensor(tokenizer.text_to_ids(text)) + + # Determine the frame offset for the start of the supervision to insert the text tokens. + pos = compute_num_frames(supervision.start, frame_length, cut.sampling_rate) + if pos > len(tokens): + logging.warning( + f"Ill-constructed example: the beginning offset of a supervision {pos} is larger than the example's length {len(tokens)}. {diagnostic}" + ) + continue + + # Determine the frame offset for the last non-EOS text token to form a valid range for insertion; + # Note that EOS will be placed possibly much later, at the frame that coincides with end of speech, + # rather than end of text. The gap between last non-EOS token and EOS token will be filled with `pad_id`. + endpos = pos + len(text_ids) + if endpos > len(tokens): + trunc_len = len(tokens) - pos + logging.warning( + f"Truncating training example's text_ids of length {len(text_ids)} by {trunc_len} because {endpos=} > {len(tokens)=}. {diagnostic}" + ) + text_ids = text_ids[:trunc_len] + try: + tokens[pos:endpos] = text_ids + except Exception as e: + raise RuntimeError(f"{tokens.shape=} {pos=} {endpos=} {text_ids.shape=} {diagnostic}") from e + + # Insert EOS at the end of the supervision segment. + if add_text_bos_and_eos_in_each_turn: + eospos = compute_num_frames(supervision.end, frame_length, cut.sampling_rate) + if eospos < len(tokens): # skip otherwise - unfinished turn + tokens[eospos] = tokenizer.eos + + return tokens + + +def _strip_timestamps( + text: str, _TIMESTAMP_PATTERN=re.compile(r"<\|\d+\|>"), _SPACE_PATTERN=re.compile(r"\s+") +) -> str: + """ + Strips timestamp tokens from text, e.g. turns: + '<|0|> Hey <|3|> <|3|> how <|5|> <|7|> are <|8|> <|8|> <|10|> you? <|12|>' + into: + 'Hey how are you?' + """ + # Regexp pattern args are cached compiled patterns (micro-optimization). + text = _TIMESTAMP_PATTERN.sub("", text) # strip timestamp tokens if present + return _SPACE_PATTERN.sub(" ", text).strip() # strip multi-whitespaces diff --git a/nemo/collections/speechlm2/models/duplex_ear_tts.py b/nemo/collections/speechlm2/models/duplex_ear_tts.py new file mode 100644 index 000000000000..872068d487d9 --- /dev/null +++ b/nemo/collections/speechlm2/models/duplex_ear_tts.py @@ -0,0 +1,2761 @@ +# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import os +import random +import tempfile +import numpy as np +import time + +import glob +import torch +import torch.distributed as dist +import torch.nn.functional as F +import torchaudio +from lightning import LightningModule +from omegaconf import DictConfig, OmegaConf +from peft import PeftModel +from torch import Tensor, nn +from torch.distributed.fsdp import fully_shard +from torch.distributed.tensor import Replicate, Shard +from torch.distributed.tensor.parallel import ( + ColwiseParallel, + PrepareModuleInput, + RowwiseParallel, + SequenceParallel, + loss_parallel, + parallelize_module, +) +from transformers import DynamicCache +import math + +from nemo.collections.asr.models import EncDecSpeakerLabelModel + +from transformers import AutoModelForCausalLM + +from nemo.collections.audio.parts.utils.resampling import resample +from nemo.core.classes.module import NeuralModule +from nemo.collections.common.tokenizers import AutoTokenizer +from nemo.collections.nlp.parts.nlp_overrides import NLPSaveRestoreConnector +from nemo.collections.speechlm2.data.utils import get_pad_id +from nemo.collections.speechlm2.models.duplex_s2s_model import tokens_to_str +from nemo.collections.speechlm2.parts.hf_hub import HFHubMixin +from nemo.collections.speechlm2.parts.lora import maybe_install_lora +from nemo.collections.speechlm2.parts.metrics.asr_bleu import ASRBLEU +from nemo.collections.speechlm2.parts.metrics.bleu import BLEU +from nemo.collections.speechlm2.parts.metrics.intelligibility import Intelligibility +from nemo.collections.speechlm2.parts.metrics.results_logger import ResultsLogger +from nemo.collections.speechlm2.parts.metrics.secs import SECS +from nemo.collections.speechlm2.parts.metrics.token_accuracy import TokenAccuracy +from nemo.collections.speechlm2.parts.optim_setup import configure_optimizers, is_frozen +from nemo.collections.speechlm2.parts.precision import fp32_precision +from nemo.collections.speechlm2.parts.pretrained import ( + load_pretrained_hf, + set_model_dict_for_partial_init, + setup_speech_encoder, +) +from nemo.core.neural_types import AudioSignal, LabelsType, LengthsType, NeuralType +from nemo.utils import logging + +from nemo.collections.tts.modules import transformer_2501 +from nemo.collections.tts.modules.mimi_codec_modules import ReshapeTransformerEncoder +from nemo.collections.speechlm2.modules.ear_tts_commons import SCRIPT_PLACEHOLDER + +from nemo.collections.speechlm2.modules.cfm import MatchaTTSCFM +from types import SimpleNamespace + + +from nemo.collections.speechlm2.modules.rvq_ear_tts_model import RVQEARTTSModel, RVQEARTTSConfig, build_vocabs, SubwordFlagEmbedding, RMSNorm +from nemo.collections.speechlm2.modules.rvq_ear_tts_vae import RVQVAEModel +from nemo.collections.speechlm2.data.duplex_ear_tts_dataset import normalize_text_fn + +import torch +import torch.nn as nn +import copy + +def maybe_to(x, dtype): + if x is None: + return None + if isinstance(x, torch.Tensor) and torch.is_floating_point(x): + return x.to(dtype) + return x + +from collections import Counter +from contextlib import contextmanager +import torch + +@contextmanager +def ensures_16_precision(mixed_dtype): + """ + Workaround for precision related issues when training with bf16-true PyTorch Lightning precision setting. + In bf16-true, PTL changes PyTorch's default dtype, which may break implicit assumptions for some models. + This context manager restores default float32 precision and runs the computation in float32 autocast context. + """ + default_dtype = torch.get_default_dtype() + torch.set_default_dtype(mixed_dtype) + try: + with torch.amp.autocast(device_type="cuda" if torch.cuda.is_available() else "cpu", dtype=mixed_dtype): + yield + finally: + torch.set_default_dtype(default_dtype) + + +def make_tts_model_mixed_precision_definite(model, inputs, + mixed_dtype=torch.bfloat16, + bf16_min=1e-2, bf16_max=1e2, + safety_factor=1.0): + safe_min = bf16_min * safety_factor + safe_max = bf16_max * safety_factor + + # 1️⃣ Collect activation stats in FP32 + model_fp32 = copy.deepcopy(model).eval().to(torch.float32) + stats = {} + hooks = [] + + def _activation_hook(name): + def hook(_, __, out): + if isinstance(out, tuple): + out = out[0] + if torch.is_tensor(out): + stats[name] = {"min": float(out.detach().min()), "max": float(out.detach().max())} + return hook + + for name, module in model_fp32.named_modules(): + if isinstance(module, (nn.Linear, nn.LayerNorm, nn.Embedding)): + hooks.append(module.register_forward_hook(_activation_hook(name))) + + with torch.no_grad(): + _ = model_fp32( + code=inputs["code"], + audio_mask=maybe_to(inputs["audio_mask"], torch.float32), + attention_mask=maybe_to(inputs["attention_mask"], torch.float32), + position_ids=inputs["position_ids"], + context_hidden_state=maybe_to(inputs["context_hidden_state"], torch.float32), + subword_ids=inputs["subword_ids"], + subword_mask=maybe_to(inputs["subword_mask"], torch.float32), + non_prompt_mask=maybe_to(inputs["non_prompt_mask"], torch.float32) + ) + + for h in hooks: + h.remove() + + # 2️⃣ Patch model for mixed precision with safe propagation + model_patched = copy.deepcopy(model).eval() + bf16_layers, fp32_layers = [], [] + + all_modules = list(model_patched.named_modules()) + num_modules = len(all_modules) + + # flag to propagate FP32 to next safe layers + propagate_fp32 = False + + for idx, (name, module) in enumerate(all_modules): + if name not in stats: + continue + mn, mx = stats[name]["min"], stats[name]["max"] + safe = (abs(mn) < safe_max and abs(mx) < safe_max + and not (abs(mn) < safe_min and abs(mx) < safe_min)) + + is_sensitive = False + if isinstance(module, (nn.LayerNorm, nn.Embedding)): + is_sensitive = True + elif isinstance(module, nn.Linear): + if not safe: + is_sensitive = True + + # mark this layer + if is_sensitive: + if name not in fp32_layers: + fp32_layers.append(name) + propagate_fp32 = True # propagate FP32 to next layers if safe + else: + if propagate_fp32: + # next layer is safe but preceded by FP32-sensitive -> still FP32 + fp32_layers.append(name) + propagate_fp32 = False # stop propagation after one safe layer + else: + # layer itself is safe and no FP32 propagation -> use BF16/FP16 + if isinstance(module, nn.Linear): + bf16_layers.append(name) + + # 3️⃣ Wrap forwards to enforce precision + def wrap_forward(module, is_fp32_sensitive): + if hasattr(module, "_original_forward"): + return + module._original_forward = module.forward + + def new_forward(*args, **kwargs): + if is_fp32_sensitive: + with fp32_precision(): + return module._original_forward(*args, **kwargs) + else: + new_args = tuple(a.to(mixed_dtype) if isinstance(a, torch.Tensor) and a.is_floating_point() else a for a in args) + new_kwargs = {k: v.to(mixed_dtype) if isinstance(v, torch.Tensor) and v.is_floating_point() else v + for k, v in kwargs.items()} + # with torch.cuda.amp.autocast(enabled=True, dtype=mixed_dtype): + with ensures_16_precision(mixed_dtype): + return module._original_forward(*new_args, **new_kwargs) + + module.forward = new_forward + + for name, module in model_patched.named_modules(): + if isinstance(module, (nn.Linear, nn.LayerNorm, nn.Embedding)): + wrap_forward(module, name in fp32_layers) + + # 4️⃣ Count actual running dtype + running_dtypes = Counter() + hook_handles = [] + + def dtype_counter_hook(module, inputs, outputs): + for x in inputs: + if isinstance(x, torch.Tensor): + running_dtypes[str(x.dtype)] += 1 + outputs_list = outputs if isinstance(outputs, (tuple, list)) else [outputs] + for x in outputs_list: + if isinstance(x, torch.Tensor): + running_dtypes[str(x.dtype)] += 1 + + for name, module in model_patched.named_modules(): + if isinstance(module, (nn.Linear, nn.LayerNorm, nn.Embedding)): + hook_handles.append(module.register_forward_hook(dtype_counter_hook)) + + with torch.no_grad(): + _ = model_patched( + code=inputs["code"], + audio_mask=maybe_to(inputs["audio_mask"], torch.float32), + attention_mask=maybe_to(inputs["attention_mask"], torch.float32), + position_ids=inputs["position_ids"], + context_hidden_state=maybe_to(inputs["context_hidden_state"], torch.float32), + subword_ids=inputs["subword_ids"], + subword_mask=maybe_to(inputs["subword_mask"], torch.float32), + non_prompt_mask=maybe_to(inputs["non_prompt_mask"], torch.float32) + ) + + for h in hook_handles: + h.remove() + + num_bf16_fp16 = running_dtypes.get("torch.bfloat16", 0) + running_dtypes.get("torch.float16", 0) + num_fp32 = running_dtypes.get("torch.float32", 0) + + summary = { + "bf16_layers": bf16_layers, + "fp32_layers": fp32_layers, + "num_bf16_fp16": num_bf16_fp16, + "num_fp32": num_fp32, + "stats": stats, + "safe_min": safe_min, + "safe_max": safe_max, + "safety_factor": safety_factor, + } + + # print("Num. BF16/FP16 activations:", num_bf16_fp16) + # print("Num. FP32 activations:", num_fp32) + print("Num. BF16/FP16 candidate layers:", len(bf16_layers)) + print("Num. FP32 layers (sensitive + propagated):", len(fp32_layers)) + + return model_patched, summary + + + +def generate_multiturn_speaking_mask(input_ids: torch.Tensor, bos_token_id: int = 0, eos_token_id: int = 1): + """ + Efficient, batched speaking mask generator that marks 1 between and pairs. + If is missing after a , mask continues to end. Handles multiple turns. + + Args: + input_ids (torch.Tensor): LongTensor of shape (B, T) + bos_token_id (int): Token ID for + eos_token_id (int): Token ID for + + Returns: + torch.Tensor: FloatTensor of shape (B, T), with 1.0 for speaking, 0.0 for silence. + + Note BOS is considered as speaking (1) and EOS as non speaking 0 + """ + B, T = input_ids.shape + device = input_ids.device + bos_mask = (input_ids == bos_token_id).to(torch.int32).to(device) + eos_mask = (input_ids == eos_token_id).to(torch.int32).to(device) + bos_cumsum = torch.cumsum(bos_mask, dim=1) + eos_cumsum = torch.cumsum(eos_mask, dim=1) + speaking_mask = (bos_cumsum > eos_cumsum).to(torch.float32) + return speaking_mask.long() + + +def replace_control_speech_codes(speech_codes: torch.Tensor, control_codes: torch.Tensor, silence_tokens: torch.Tensor = None) -> torch.Tensor: + """ + Replaces control codes (speech BOS, EOS, etc) in `speech_codes` with the first frame which is + assumed to consist of 'valid' codes representing silence. + """ + if silence_tokens is not None: + # Expand to [B, 1, 74] + silence_tokens_expanded = silence_tokens.unsqueeze(0).unsqueeze(1).expand(speech_codes.shape[0], 1, -1) + return torch.where(torch.isin(speech_codes, control_codes), silence_tokens_expanded, speech_codes) + + if torch.isin(speech_codes[:, :1], control_codes).any(): + return torch.where(torch.isin(speech_codes, control_codes), torch.zeros_like(speech_codes[:, :1]), speech_codes) + else: + return torch.where(torch.isin(speech_codes, control_codes), speech_codes[:, :1], speech_codes) + + +def get_mask_from_lengths( + lengths: torch.Tensor = None, + x: torch.Tensor = None, + pad_to_factor: int = None +) -> torch.Tensor: + """Constructs binary mask from a 1D torch tensor of input lengths + Args: + lengths: torch.tensor (torch.tensor): 1D tensor with lengths + x: torch.tensor = tensor to be used on, last dimension is for mask + Returns: + mask (torch.tensor): num_sequences x max_length binary tensor + """ + if lengths is None: + assert x is not None + return torch.ones(x.shape[-1], dtype=torch.bool, device=x.device) + else: + if x is None: + max_len = torch.max(lengths) + else: + max_len = x.shape[-1] + + if pad_to_factor is not None: + with fp32_precision(): + max_len = torch.ceil(max_len / pad_to_factor) * pad_to_factor + + ids = torch.arange(0, max_len, device=lengths.device, dtype=lengths.dtype) + mask = ids < lengths.unsqueeze(1) + return mask + + +from transformers import MimiModel, AutoFeatureExtractor +class MimiCodec(NeuralModule): + def __init__(self, model_path_or_name="kyutai/mimi", num_codebooks=12): + super().__init__() + self.codec = MimiModel.from_pretrained(model_path_or_name) + self.feature_extractor = AutoFeatureExtractor.from_pretrained(model_path_or_name) + self.num_codebooks = num_codebooks + + @property + def device(self): + return next(self.codec.parameters()).device + + @property + def _codebook_size(self): + return self.codec.config.codebook_size + + @property + def _num_codebooks(self): + return self.num_codebooks + + @property + def samples_per_frame(self): + return int(self.feature_extractor.sampling_rate // self.codec.config.frame_rate) + + def encode(self, audio, audio_len): + audio = audio.squeeze(1) + with fp32_precision(): + # make the audio divisible by frame rate and also by self.frame_stacking_factor with extra frames of 1 to avoid issues because we are removing a audio frame to shift target and input for TF + audio, audio_len = self.pad_audio_to_factor(audio, audio_len, self.samples_per_frame, extra_frames=0) + # explicitly encode then decode the audio inputs + encoder_outputs = self.codec.encode(audio.unsqueeze(1).to(self.device), num_quantizers=self.num_codebooks) + codes = encoder_outputs.audio_codes + tokens_len = audio_len // self.samples_per_frame + return codes.transpose(1, 2), tokens_len + + def decode(self, tokens, tokens_len): + with fp32_precision(): + tokens = tokens.transpose(1, 2) + # tokens: B, T', C' + audio = self.codec.decode(tokens).audio_values.squeeze(1) + audio_len = tokens_len * self.samples_per_frame + return audio, audio_len + + def forward(self, audio, audio_len): + tokens, tokens_len = self.encode(audio, audio_len) + audio, audio_len = self.decode(tokens, tokens_len) + return audio, audio_len + + + def pad_audio_to_factor(self, audio, audio_len, samples_per_frame, extra_frames: int = 0): + """ + Zero pad the end of the audio so that we do not have a partial end frame. + The output will be zero-padded to have an integer number of frames of + length `samples_per_frame * frame_stacking_factor`. + + Args: + audio: input time-domain signal (B, T) + audio_len: valid length for each example in the batch (B,) + samples_per_frame: number of samples per frame + + Returns: + padded_audio: Padded time-domain signal (B, T') + padded_len: Adjusted valid lengths (B,) + """ + with fp32_precision(): + padded_len = (samples_per_frame * torch.ceil(audio_len / samples_per_frame).int()) + (extra_frames * samples_per_frame) + max_len = padded_len.max().int().item() + num_padding = (max_len - audio.shape[1]) + padded_audio = F.pad(audio, (0, num_padding)) + return padded_audio, padded_len + + + +def setup_rvq_audio_codec(model): + """ + Sets up an ``AudioCodecModel``, initializing it from pretrained weights. + The result is assigned to ``model.audio_codec`` attribute. + + Includes a workaround for PTL auto-downcasting the codec model to bf16 with bf16-true precision. + """ + if hasattr(model, "audio_codec") and next(model.audio_codec.parameters()).dtype == torch.float: + return # skip if already set up and has the right dtype + with fp32_precision(): + model.audio_codec = RVQVAEModel.from_pretrained(model.cfg.pretrained_ae_dir, strict=False).eval().to(model.device) + for p in model.audio_codec.parameters(): + p.requires_grad = False + +def setup_audio_codec(self): + setup_rvq_audio_codec(self) + assert callable(self.tts_model.set_rvq_embs) + self.tts_model.set_rvq_embs(torch.stack([x.detach() for x in self.audio_codec.prvq.mus_list], 0)) + self.tts_model.rvq_embs = self.tts_model.rvq_embs.to(next(self.tts_model.parameters()).dtype) + # compute target fps + self.target_fps = self.target_sample_rate / self.audio_codec.config.wav_to_token_ratio + self.target_samples_per_frame = self.audio_codec.config.wav_to_token_ratio + +def subwords_to_chars(subword_ids: torch.Tensor, + subword_id_to_char_ids: dict[int, tuple[int, ...]], + bos_id: int, + eos_id: int, + pad_id: int): + """ + Fully vectorized subword->char expansion across all BOS..EOS spans: + - Handles multiple spans per batch + - Preserves BOS/EOS + - Truncates expansions to fit each span + - Very fast on GPU + """ + device = subword_ids.device + B, T = subword_ids.shape + + # Build LUT + max_subword_id = int(subword_ids.max().item()) + max_chars = max(len(v) for v in subword_id_to_char_ids.values()) if subword_id_to_char_ids else 0 + if max_chars == 0: + return subword_ids.clone() + + char_expansion = torch.full((max_subword_id + 1, max_chars), + fill_value=pad_id, device=device, dtype=subword_ids.dtype) + expansion_len = torch.zeros(max_subword_id + 1, dtype=torch.long, device=device) + for k, v in subword_id_to_char_ids.items(): + if k <= max_subword_id: + v_t = torch.tensor(v, device=device, dtype=subword_ids.dtype) + char_expansion[k, :len(v_t)] = v_t + expansion_len[k] = len(v_t) + + # Output initialized with PAD + output = torch.full_like(subword_ids, fill_value=pad_id) + special_mask = (subword_ids == bos_id) | (subword_ids == eos_id) + output[special_mask] = subword_ids[special_mask] + + # Find next EOS for each position + pos = torch.arange(T, device=device) + bos_mask = (subword_ids == bos_id) + eos_mask = (subword_ids == eos_id) + eos_pos_tensor = torch.where(eos_mask, pos.unsqueeze(0).expand(B, T), + torch.full((B, T), T, device=device)) + next_eos_idx = torch.flip(torch.cummin(torch.flip(eos_pos_tensor, [1]), dim=1).values, [1]) + + # Collect all BOS coordinates + bos_coords = torch.nonzero(bos_mask, as_tuple=False) + if bos_coords.numel() == 0: + return output + + batch_ids = bos_coords[:, 0] + span_starts = bos_coords[:, 1] + 1 + span_ends = next_eos_idx[batch_ids, bos_coords[:, 1]] + span_lens = (span_ends - span_starts).clamp(min=0) + S = span_lens.numel() + if S == 0: + return output + + # Max span length + max_span_len = int(span_lens.max().item()) + + # Gather subwords for all spans [S, max_span_len] + rel = torch.arange(max_span_len, device=device).unsqueeze(0).expand(S, -1) + span_idx = span_starts.unsqueeze(1) + rel + span_idx_clamped = span_idx.clamp(0, T-1) + batch_idx_expand = batch_ids.unsqueeze(1).expand(-1, max_span_len) + sub_span = subword_ids[batch_idx_expand, span_idx_clamped] + + # Mask positions beyond actual span length + valid_pos_mask = rel < span_lens.unsqueeze(1) + sub_span = torch.where(valid_pos_mask, sub_span, torch.full_like(sub_span, pad_id)) + + # Expand subwords -> chars + expanded = char_expansion[sub_span] # [S, max_span_len, max_chars] + S_len = max_span_len * max_chars + expanded_flat = expanded.view(S, S_len) + valid_char_mask = expanded_flat != pad_id + valid_cumsum = torch.cumsum(valid_char_mask.long(), dim=1) + span_lens_exp = span_lens.unsqueeze(1).expand(-1, S_len) + keep_mask = valid_char_mask & (valid_cumsum <= span_lens_exp) + + # Compute flattened indices to scatter + rank_flat = (valid_cumsum - 1).clamp(min=0).view(-1) + values_flat = expanded_flat.view(-1) + keep_flat = keep_mask.view(-1) + kept_values = values_flat[keep_flat] + + target_positions = (span_starts.unsqueeze(1).repeat(1, S_len).view(-1))[keep_flat] + rank_flat[keep_flat] + target_batches = batch_ids.unsqueeze(1).repeat(1, S_len).view(-1)[keep_flat] + + # Safety clamp + within_T = target_positions < T + kept_values = kept_values[within_T] + target_positions = target_positions[within_T] + target_batches = target_batches[within_T] + + # Scatter in one shot + output[target_batches, target_positions] = kept_values + + return output + + +def subwords_to_chars_batched(subword_ids: torch.Tensor, + subword_id_to_char_ids: dict[int, tuple[int, ...]], + bos_id: int, + eos_id: int, + pad_id: int, + silence_id: int = 0): + """ + Batched subword->char expansion per BOS..EOS span. + - Multiple spans per batch + - Fully vectorized (no Python loop over spans) + - BOS/EOS exact + - Silences between spans + """ + B, T = subword_ids.shape + device = subword_ids.device + + # Build LUT + max_subword_id = int(subword_ids.max().item()) + max_chars = max(len(v) for v in subword_id_to_char_ids.values()) if subword_id_to_char_ids else 0 + if max_chars == 0: + return subword_ids.clone() + + char_expansion = torch.full((max_subword_id + 1, max_chars), + fill_value=pad_id, device=device, dtype=subword_ids.dtype) + expansion_len = torch.zeros(max_subword_id + 1, dtype=torch.long, device=device) + for k, v in subword_id_to_char_ids.items(): + if k <= max_subword_id: + v_t = torch.tensor(v, device=device, dtype=subword_ids.dtype) + char_expansion[k, :len(v_t)] = v_t + expansion_len[k] = len(v_t) + + # Output initialized with PAD + output = torch.full_like(subword_ids, fill_value=pad_id) + special_mask = (subword_ids == bos_id) | (subword_ids == eos_id) + output[special_mask] = subword_ids[special_mask] + + # Masks + bos_mask = (subword_ids == bos_id) + eos_mask = (subword_ids == eos_id) + + # Compute next EOS per position + pos = torch.arange(T, device=device) + eos_pos_tensor = torch.where(eos_mask, pos.unsqueeze(0).expand(B, T), + torch.full((B, T), T, device=device)) + next_eos_idx = torch.flip(torch.cummin(torch.flip(eos_pos_tensor, [1]), dim=1).values, [1]) + + # Collect all spans + bos_coords = torch.nonzero(bos_mask, as_tuple=False) + if bos_coords.numel() == 0: + return output + + batch_ids = bos_coords[:, 0] + span_starts = bos_coords[:, 1] + 1 + span_ends = next_eos_idx[batch_ids, bos_coords[:, 1]] + span_lens = (span_ends - span_starts).clamp(min=0) + S = span_lens.numel() + if S == 0: + return output + + # Gather subwords for all spans + max_span_len = int(span_lens.max().item()) + rel = torch.arange(max_span_len, device=device).unsqueeze(0).expand(S, -1) + span_idx = span_starts.unsqueeze(1) + rel + span_idx_clamped = span_idx.clamp(0, T - 1) + batch_idx_expand = batch_ids.unsqueeze(1).expand(-1, max_span_len) + sub_span = subword_ids[batch_idx_expand, span_idx_clamped] + + # Mask positions beyond actual span length + valid_pos_mask = rel < span_lens.unsqueeze(1) + sub_span = torch.where(valid_pos_mask, sub_span, torch.full_like(sub_span, pad_id)) + + # Expand subwords -> chars + expanded = char_expansion[sub_span] # [S, max_span_len, max_chars] + S_len = max_span_len * max_chars + expanded_flat = expanded.view(S, S_len) + valid_char_mask = expanded_flat != pad_id + valid_cumsum = torch.cumsum(valid_char_mask.long(), dim=1) + span_lens_exp = span_lens.unsqueeze(1).expand(-1, S_len) + keep_mask = valid_char_mask & (valid_cumsum <= span_lens_exp) + + # Compute target positions + rank_flat = (valid_cumsum - 1).clamp(min=0).view(-1) + values_flat = expanded_flat.view(-1) + keep_flat = keep_mask.view(-1) + kept_values = values_flat[keep_flat] + + target_positions = (span_starts.unsqueeze(1).repeat(1, S_len).view(-1))[keep_flat] + rank_flat[keep_flat] + target_batches = batch_ids.unsqueeze(1).repeat(1, S_len).view(-1)[keep_flat] + + # Safety clamp + within_T = target_positions < T + kept_values = kept_values[within_T] + target_positions = target_positions[within_T] + target_batches = target_batches[within_T] + + # Scatter in one shot + output[target_batches, target_positions] = kept_values + + return output + + +def build_char_expansion_lut(subword_id_to_char_ids: dict[int, tuple[int, ...]], + pad_id: int, + device: str = "cuda"): + """ + Prebuild the LUT once for training. + Returns: + char_expansion: [max_subword_id+1, max_chars] + expansion_len: number of chars per subword + """ + if not subword_id_to_char_ids: + return None, None + + max_subword_id = max(subword_id_to_char_ids.keys()) + max_chars = max(len(v) for v in subword_id_to_char_ids.values()) + char_expansion = torch.full((max_subword_id + 1, max_chars), + fill_value=pad_id, device=device, dtype=torch.long) + expansion_len = torch.zeros(max_subword_id + 1, device=device, dtype=torch.long) + + for k, v in subword_id_to_char_ids.items(): + if k <= max_subword_id: + v_t = torch.tensor(v, device=device, dtype=torch.long) + char_expansion[k, :len(v_t)] = v_t + expansion_len[k] = len(v_t) + + return char_expansion, expansion_len + + +def subwords_to_chars_batched_fast(subword_ids: torch.Tensor, + char_expansion: torch.Tensor, + expansion_len: torch.Tensor, + bos_id: int, + eos_id: int, + pad_id: int): + """ + Fast batched subword->char expansion using prebuilt LUT. + Fully vectorized, multiple spans per batch, no autograd overhead. + """ + with torch.no_grad(): + if char_expansion is None: + return subword_ids.clone() + + B, T = subword_ids.shape + device = subword_ids.device + + # Initialize output + output = torch.full_like(subword_ids, pad_id) + special_mask = (subword_ids == bos_id) | (subword_ids == eos_id) + output[special_mask] = subword_ids[special_mask] + + # Masks + bos_mask = (subword_ids == bos_id) + eos_mask = (subword_ids == eos_id) + + # Next EOS per position + pos = torch.arange(T, device=device) + eos_pos_tensor = torch.where(eos_mask, pos.unsqueeze(0).expand(B, T), + torch.full((B, T), T, device=device)) + next_eos_idx = torch.flip(torch.cummin(torch.flip(eos_pos_tensor, [1]), dim=1).values, [1]) + + # Collect all spans + bos_coords = torch.nonzero(bos_mask, as_tuple=False) + if bos_coords.numel() == 0: + return output + + batch_ids = bos_coords[:, 0] + span_starts = bos_coords[:, 1] + 1 + span_ends = next_eos_idx[batch_ids, bos_coords[:, 1]] + span_lens = (span_ends - span_starts).clamp(min=0) + S = span_lens.numel() + if S == 0: + return output + + # Gather subwords + max_span_len = int(span_lens.max().item()) + rel = torch.arange(max_span_len, device=device).unsqueeze(0).expand(S, -1) + span_idx = span_starts.unsqueeze(1) + rel + span_idx_clamped = span_idx.clamp(0, T-1) + batch_idx_expand = batch_ids.unsqueeze(1).expand(-1, max_span_len) + sub_span = subword_ids[batch_idx_expand, span_idx_clamped] + + valid_pos_mask = rel < span_lens.unsqueeze(1) + sub_span = torch.where(valid_pos_mask, sub_span, torch.full_like(sub_span, pad_id)) + + # Expand using prebuilt LUT + expanded = char_expansion[sub_span] # [S, max_span_len, max_chars] + S_len = max_span_len * char_expansion.shape[1] + expanded_flat = expanded.view(S, S_len) + + valid_char_mask = expanded_flat != pad_id + valid_cumsum = torch.cumsum(valid_char_mask.long(), dim=1) + span_lens_exp = span_lens.unsqueeze(1).expand(-1, S_len) + keep_mask = valid_char_mask & (valid_cumsum <= span_lens_exp) + + rank_flat = (valid_cumsum - 1).clamp(min=0).view(-1) + values_flat = expanded_flat.view(-1) + keep_flat = keep_mask.view(-1) + kept_values = values_flat[keep_flat] + + target_positions = (span_starts.unsqueeze(1).repeat(1, S_len).view(-1))[keep_flat] + rank_flat[keep_flat] + target_batches = batch_ids.unsqueeze(1).repeat(1, S_len).view(-1)[keep_flat] + + within_T = target_positions < T + kept_values = kept_values[within_T] + target_positions = target_positions[within_T] + target_batches = target_batches[within_T] + + output[target_batches, target_positions] = kept_values + + return output + + +class WordSepTokenizer(AutoTokenizer): + """ + Tokenizer wrapper that inserts a special word-separator token before each token + that starts a new word. This is useful for Speech-LLM and TTS pipelines + that require explicit word boundaries in the token sequence. + + Supported models: + - LLaMA-3.1-family + - NVIDIA Nemotron Nano-9B-v2 + + Attributes: + word_sep_token (str): The special token used to mark word boundaries. + word_boundary_prefix (str): The token prefix indicating a word boundary. + word_sep_id (int): The token ID corresponding to `word_sep_token`. + """ + + def __init__(self, model_name: str, *args, **kwargs): + """ + Initializes the WordSepTokenizer. + + Args: + model_name (str): Name of the model to load. Determines the special + word-separator token and word boundary prefix. + *args: Additional positional arguments passed to the base `AutoTokenizer`. + **kwargs: Additional keyword arguments passed to the base `AutoTokenizer`. + + Raises: + ValueError: If `model_name` is not supported. + """ + super().__init__(model_name, *args, **kwargs) + + model_name_lower = model_name.lower() + if "llama-3.1" in model_name_lower: + self.word_sep_token = "<|reserved_special_token_0|>" + self.word_boundary_prefix = "Ġ" + elif "qwen2.5" in model_name_lower: + self.word_sep_token = "<|box_start|>" + self.word_boundary_prefix = "Ġ" + elif "nvidia-nemotron-nano-9b-v2" in model_name_lower: + self.word_sep_token = "" + self.word_boundary_prefix = "Ġ" + else: + raise ValueError( + f"WordSepTokenizer does not support model '{model_name}'. " + "Supported: LLaMA-3.1-family, NVIDIA Nemotron Nano-9B-v2." + ) + + self.word_sep_id = self.tokenizer.convert_tokens_to_ids(self.word_sep_token) + + def text_to_ids(self, text: str): + """ + Converts input text into token IDs, inserting the word-separator ID + before tokens that start a new word. + + Args: + text (str): Input string to tokenize. + + Returns: + List[int]: Token IDs with word-separator IDs inserted. + + Notes: + - If `text` is empty or tokenization returns no tokens, returns an empty list. + - The first token separator (if any) is removed to avoid leading separators. + """ + if not text: + return [] + + # ensures that first word has a space to avoid different tokens for the first word + if text[0] != " ": + text = " " + text + + # Original token IDs + ids = super().text_to_ids(text) + if not ids: + return [] + + # Convert IDs to tokens safely (must be CPU Python list, no separator IDs yet) + tokens = self.tokenizer.convert_ids_to_tokens(list(ids)) + + # Mask for tokens starting with word boundary + mask = [t.startswith(self.word_boundary_prefix) for t in tokens] + + # Prepare result + result = [] + for tid, m in zip(ids, mask): + if m: + result.append(self.word_sep_id) + result.append(tid) + + # Remove leading separator if present + if result and result[0] == self.word_sep_id: + result = result[1:] + + return result + + def ids_to_text(self, ids): + """ + Converts token IDs back to text, replacing word-separator tokens with spaces. + + Args: + ids (List[int]): List of token IDs. + + Returns: + str: Decoded text with word separators converted to spaces. + """ + text = super().ids_to_text(ids) + return text.replace(self.word_sep_token, " ") + + +class NeMoGroupedCodec(NeuralModule): + def __init__(self, codec, frame_stacking_factor=1): + super().__init__() + self.codec = codec + self.frame_stacking_factor = frame_stacking_factor + + @property + def device(self): + return self.codec.device + + @property + def _codebook_size(self): + return self.codec.vector_quantizer.codebook_size_per_group + + @property + def _num_codebooks(self): + return self.codec.vector_quantizer.num_groups * self.frame_stacking_factor + + @property + def samples_per_frame(self): + return self.codec.samples_per_frame * self.frame_stacking_factor + + def encode(self, audio, audio_len): + audio = audio.squeeze(1) + with fp32_precision(): + # make the audio divisible by frame rate and also by self.frame_stacking_factor with extra frames of 1 to avoid issues because we are removing a audio frame to shift target and input for TF + audio, audio_len = self.pad_audio_to_factor(audio, audio_len, self.samples_per_frame, extra_frames=0) + # encodes audio using the codec + tokens, tokens_len = self.codec.encode(audio=audio, audio_len=audio_len) # B, C, T + tokens = tokens.transpose(1, 2) # → B, T, C + B, T, C = tokens.shape + assert T % self.frame_stacking_factor == 0 + grouped = tokens.reshape(B, T // self.frame_stacking_factor, C * self.frame_stacking_factor) + tokens_len = tokens_len // self.frame_stacking_factor + # grouped = grouped.transpose(1, 2) + + return grouped, tokens_len + + def decode(self, tokens, tokens_len): + with fp32_precision(): + # tokens = tokens.transpose(1, 2) + # tokens: B, T', C' + B, T, Cg = tokens.shape + assert Cg % self.frame_stacking_factor == 0 + C = Cg // self.frame_stacking_factor + ungrouped = tokens.reshape(B, T * self.frame_stacking_factor, C) # → [B, T, C] + ungrouped = ungrouped.transpose(1, 2) # → [B, C, T] for decode + tokens_len = torch.ceil(tokens_len * self.frame_stacking_factor).to(tokens_len.dtype) + audio, audio_len = self.codec.decode(tokens=ungrouped, tokens_len=tokens_len) + return audio, audio_len + + def decode_audio(self, inputs: torch.Tensor, input_len: torch.Tensor): + """Apply decoder on the input. Note that the input is a non-quantized encoder output or a dequantized representation. + + Args: + inputs: encoded signal + input_len: valid length for each example in the batch + + Returns: + Decoded output `audio` in the time domain and its length in number of samples `audio_len`. + Note that `audio_len` will be a multiple of `self.samples_per_frame`. + """ + with fp32_precision(): + if self.frame_stacking_factor > 1: + inputs = inputs.transpose(1, 2) + B, T, Cg = inputs.shape + C = Cg // self.frame_stacking_factor + inputs = inputs.reshape(B, T * self.frame_stacking_factor, C) # → [B, T, C] + input_len = torch.ceil(input_len * self.frame_stacking_factor).to(input_len.dtype) + inputs = inputs.transpose(1, 2) + + audio, audio_len = self.codec.audio_decoder(inputs=inputs, input_len=input_len) + return audio, audio_len + + def dequantize(self, tokens: torch.Tensor, tokens_len: torch.Tensor) -> torch.Tensor: + """Convert the discrete tokens into a continuous encoded representation. + + Args: + tokens: discrete tokens for each codebook for each time frame + tokens_len: valid length of each example in the batch + + Returns: + Continuous encoded representation of the discrete input representation. + """ + with fp32_precision(): + # reshape to dequantize + if self.frame_stacking_factor > 1: + tokens = tokens.transpose(1, 2) + # tokens: B, T', C' + B, T, Cg = tokens.shape + assert Cg % self.frame_stacking_factor == 0 + C = Cg // self.frame_stacking_factor + tokens = tokens.reshape(B, T * self.frame_stacking_factor, C) # → [B, T, C] + tokens = tokens.transpose(1, 2) # → [B, C, T] for decode + tokens_len = torch.ceil(tokens_len * self.frame_stacking_factor).to(tokens_len.dtype) + dequantized = self.codec.dequantize(tokens=tokens, tokens_len=tokens_len) + # reshape back to the compress form if needed + if self.frame_stacking_factor > 1: + dequantized = dequantized.transpose(1, 2) # → B, T, C + B, T, C = dequantized.shape + assert T % self.frame_stacking_factor == 0 + dequantized = dequantized.reshape(B, T // self.frame_stacking_factor, C * self.frame_stacking_factor) + dequantized = dequantized.transpose(1, 2) # → B, C, T + + return dequantized + + def forward(self, audio, audio_len): + tokens, tokens_len = self.encode(audio, audio_len) + audio, audio_len = self.decode(tokens, tokens_len) + return audio, audio_len + + def pad_audio_to_factor(self, audio, audio_len, samples_per_frame, extra_frames: int = 0): + """ + Zero pad the end of the audio so that we do not have a partial end frame. + The output will be zero-padded to have an integer number of frames of + length `samples_per_frame * frame_stacking_factor`. + + Args: + audio: input time-domain signal (B, T) + audio_len: valid length for each example in the batch (B,) + samples_per_frame: number of samples per frame + + Returns: + padded_audio: Padded time-domain signal (B, T') + padded_len: Adjusted valid lengths (B,) + """ + with fp32_precision(): + padded_len = (samples_per_frame * torch.ceil(audio_len / samples_per_frame).int()) + (extra_frames * samples_per_frame) + max_len = padded_len.max().int().item() + num_padding = (max_len - audio.shape[1]) + padded_audio = F.pad(audio, (0, num_padding)) + return padded_audio, padded_len + +import math + +def compare_dicts(dict_a, dict_b): + all_keys = set(dict_a.keys()).union(dict_b.keys()) + equal = True + differing_keys = [] + + for key in sorted(all_keys): + a_val = dict_a.get(key, None) + b_val = dict_b.get(key, None) + + # Skip if value is None in either dict + if a_val is None or b_val is None: + continue + + # Handle both being NaN (float) + if (isinstance(a_val, float) and math.isnan(a_val)) and \ + (isinstance(b_val, float) and math.isnan(b_val)): + continue + + # Handle both being tensors + if isinstance(a_val, torch.Tensor) and isinstance(b_val, torch.Tensor): + # Shape mismatch + if a_val.shape != b_val.shape: + print(f"❌ Shape mismatch at key '{key}': {a_val.shape} vs {b_val.shape}") + equal = False + differing_keys.append(key) + continue + + # Compare tensors elementwise (treating NaNs as equal) + diff_mask = ~(torch.isclose(a_val, b_val, equal_nan=True)) + if diff_mask.any(): + equal = False + differing_keys.append(key) + idx = torch.nonzero(diff_mask, as_tuple=False) + print(f"❌ Tensor mismatch at key '{key}': {idx.shape[0]} differing positions, shape: ", a_val.shape, b_val.shape) + # Print up to first 10 differences + for i, pos in enumerate(idx[:10]): + pos_tuple = tuple(pos.tolist()) + a_item = a_val[pos_tuple].item() + b_item = b_val[pos_tuple].item() + print(f" Position {pos_tuple}: {a_item} vs {b_item}") + if idx.shape[0] > 10: + print(f" ... and {idx.shape[0] - 10} more differences") + continue + + # Fallback: direct comparison + if a_val != b_val: + print(f"❌ Value mismatch at key '{key}': {a_val} vs {b_val}") + equal = False + differing_keys.append(key) + + if equal: + print("✅ All comparable keys and values match!") + else: + print("⚠️ Some keys/values differ (see above).") + + return equal, differing_keys + +import copy +def extract_first_tensor(x): + """Recursively find the first tensor in nested structures.""" + if isinstance(x, torch.Tensor): + return x + if isinstance(x, (list, tuple)): + for v in x: + t = extract_first_tensor(v) + if t is not None: + return t + if isinstance(x, dict): + for v in x.values(): + t = extract_first_tensor(v) + if t is not None: + return t + return None + +def compare_tts_model_fp32_bf16_old(tts_model, inputs, atol=1e-3, topk=15): + model_fp32 = copy.deepcopy(tts_model).eval().to(torch.float32) + model_bf16 = copy.deepcopy(tts_model).eval().to(torch.bfloat16) + + diffs = {} + + def make_hook(name, tag): + def hook_fn(module, inp, out): + tensor = extract_first_tensor(out) + if tensor is not None: + tensor = tensor.detach().float().cpu() + if name not in diffs: + diffs[name] = {} + diffs[name][tag] = tensor + return hook_fn + + # Register hooks independently + hooks_fp32 = [] + for name, module in model_fp32.named_modules(): + hooks_fp32.append(module.register_forward_hook(make_hook(name, "fp32"))) + + hooks_bf16 = [] + for name, module in model_bf16.named_modules(): + hooks_bf16.append(module.register_forward_hook(make_hook(name, "bf16"))) + + def maybe_to(x, dtype): + if x is None: + return None + if isinstance(x, torch.Tensor) and torch.is_floating_point(x): + return x.to(dtype) + return x + + with torch.no_grad(): + # BF16 forward + with torch.autocast('cuda', dtype=torch.bfloat16): + _ = model_bf16( + code=maybe_to(inputs["code"], torch.bfloat16), + audio_mask=inputs["audio_mask"], + attention_mask=inputs["attention_mask"], + position_ids=inputs["position_ids"], + context_hidden_state=maybe_to(inputs["context_hidden_state"], torch.bfloat16), + subword_ids=inputs["subword_ids"], + subword_mask=inputs["subword_mask"], + non_prompt_mask=inputs["non_prompt_mask"] + ) + + # FP32 forward + _ = model_fp32( + code=maybe_to(inputs["code"], torch.float32), + audio_mask=inputs["audio_mask"], + attention_mask=inputs["attention_mask"], + position_ids=inputs["position_ids"], + context_hidden_state=maybe_to(inputs["context_hidden_state"], torch.float32), + subword_ids=inputs["subword_ids"], + subword_mask=inputs["subword_mask"], + non_prompt_mask=inputs["non_prompt_mask"] + ) + + # Compute diffs for matching layers + diff_list = [] + for name, val in diffs.items(): + if "fp32" in val and "bf16" in val: + delta = (val["fp32"] - val["bf16"]).abs().mean().item() + diff_list.append((name, delta)) + + diff_list.sort(key=lambda x: x[1], reverse=True) + + print(f"\nTop {topk} layers with largest FP32 vs BF16 diff:") + if not diff_list: + print("⚠️ No matching tensor outputs found. Try increasing atol or check nested outputs.") + else: + for name, delta in diff_list[:topk]: + print(f"{name:<60} mean abs diff = {delta:.6f}") + + for h in hooks_fp32 + hooks_bf16: + h.remove() + + return diff_list + +def compare_tts_model_fp32_bf16_mixed(tts_model, inputs, topk=15): + """ + Compare FP32 vs BF16-safe (with fp32_precision layers) outputs. + tts_model can have patched FP32 layers; these will run in FP32. + """ + import copy + diffs = {} + + def extract_first_tensor(x): + if isinstance(x, (tuple, list)): + for y in x: + if torch.is_tensor(y): + return y + return None + if torch.is_tensor(x): + return x + return None + + def make_hook(name, tag): + def hook_fn(module, inp, out): + tensor = extract_first_tensor(out) + if tensor is not None: + tensor = tensor.detach().float().cpu() + if name not in diffs: + diffs[name] = {} + diffs[name][tag] = tensor + return hook_fn + + # FP32 reference model + model_fp32 = copy.deepcopy(tts_model).eval().to(torch.float32) + + hooks_fp32 = [m.register_forward_hook(make_hook(n, "fp32")) for n, m in model_fp32.named_modules()] + hooks_bf16 = [m.register_forward_hook(make_hook(n, "bf16")) for n, m in tts_model.named_modules()] + + def maybe_to(x, dtype): + if x is None: + return None + if isinstance(x, torch.Tensor) and torch.is_floating_point(x): + return x.to(dtype) + return x + + with torch.no_grad(): + # BF16-safe forward (patched FP32 layers run in FP32) + with torch.autocast("cuda", dtype=torch.bfloat16): + _ = tts_model( + code=maybe_to(inputs["code"], torch.bfloat16), + audio_mask=inputs["audio_mask"], + attention_mask=inputs["attention_mask"], + position_ids=inputs["position_ids"], + context_hidden_state=maybe_to(inputs["context_hidden_state"], torch.bfloat16), + subword_ids=inputs["subword_ids"], + subword_mask=inputs["subword_mask"], + non_prompt_mask=inputs["non_prompt_mask"] + ) + + # FP32 forward + _ = model_fp32( + code=maybe_to(inputs["code"], torch.float32), + audio_mask=inputs["audio_mask"], + attention_mask=inputs["attention_mask"], + position_ids=inputs["position_ids"], + context_hidden_state=maybe_to(inputs["context_hidden_state"], torch.float32), + subword_ids=inputs["subword_ids"], + subword_mask=inputs["subword_mask"], + non_prompt_mask=inputs["non_prompt_mask"] + ) + + # Compute diffs + diff_list = [] + for name, val in diffs.items(): + if "fp32" in val and "bf16" in val: + delta = (val["fp32"] - val["bf16"]).abs().mean().item() + diff_list.append((name, delta)) + + diff_list.sort(key=lambda x: x[1], reverse=True) + print(f"\nTop {topk} layers with largest FP32 vs BF16 diff:") + for name, delta in diff_list[:topk]: + print(f"{name:<60} mean abs diff = {delta:.6f}") + + for h in hooks_fp32 + hooks_bf16: + h.remove() + + return diff_list + +def rescale_state_dict( + state_dict, + target_std=0.02, + first_n_layers=None, + layer_prefix="tts_model.backbone.layers." +): + """ + Rescale trainable weights in a state_dict for BF16 stability. + + Args: + state_dict: PyTorch state_dict + target_std: desired target std for weights + first_n_layers: if not None, rescale only the first N transformer blocks + layer_prefix: prefix for layer names (default: "tts_model.backbone.layers.") + Returns: + new_state_dict + """ + weight_tensors = [] + + # Compute which prefixes to match if first_n_layers is set + prefixes_to_match = [] + if first_n_layers is not None: + prefixes_to_match = [f"{layer_prefix}{i}" for i in range(first_n_layers)] + + for name, param in state_dict.items(): + if not torch.is_tensor(param): + continue + + if "rvq_embs" in name: + continue + + # Skip biases & 1-dim params (norm weights/gates) + if param.ndim <= 1: + continue + + # Skip layers not in the first N + if first_n_layers is not None and not any(name.startswith(pfx) for pfx in prefixes_to_match): + continue + + weight_tensors.append(param.float()) + + if not weight_tensors: + if first_n_layers is not None: + print(f"⚠️ No weights found for first {first_n_layers} layers with prefix '{layer_prefix}'.") + else: + print("⚠️ No weights found to rescale in state_dict.") + return state_dict + + # Compute global std across selected weights (on CPU) + cpu_weights = [p.detach().cpu() for p in weight_tensors] + flat = torch.cat([p.flatten() for p in cpu_weights]) + current_std = float(torch.std(flat)) + scale = target_std / (current_std + 1e-8) + + print( + f"📦 Rescaling state_dict " + f"{'(first N layers)' if first_n_layers else '(all layers)'}: " + f"current std = {current_std:.6f}, target = {target_std}, scale = {scale:.6f}" + ) + + # Apply scaling + new_state_dict = {} + for name, param in state_dict.items(): + if ( + torch.is_tensor(param) + and param.ndim > 1 + and (first_n_layers is None or any(name.startswith(pfx) for pfx in prefixes_to_match)) + ): + new_state_dict[name] = param * scale + else: + new_state_dict[name] = param + + print("✅ Done: weights rescaled.") + return new_state_dict + + +class DuplexEARTTS(LightningModule, HFHubMixin): + def __init__(self, cfg: dict) -> None: + assert isinstance(cfg, dict), ( + "You must pass the config to DuplexEARTTS as a Python dict to support hyperparameter serialization " + f"in PTL checkpoints (we got: '{type(cfg)=}')." + ) + super().__init__() + self.save_hyperparameters() + # convert dict to config + cfg = DictConfig(cfg) + self.trainer_config = cfg.trainer + self.data_cfg = cfg.data + self.cfg = cfg.model + self.target_sample_rate = cfg.data.target_sample_rate + self.source_sample_rate = cfg.data.source_sample_rate + self.normalize_text = cfg.data.get("normalize_text", False) + self.model_16_precision_safe = None + + self.validation_save_path = os.path.join(cfg.exp_manager.explicit_log_dir, "validation_logs") + + # move back text channel by x, in inference it advance the text channel prediction by x frames + self.advance_text_channel_by = self.cfg.get("advance_text_channel_by", None) + + # Load ForCausalLM + if self.cfg.tts_config.context_hidden_size is not None: + self.language_model = self._load_language_model(self.cfg) + self.embed_tokens = self._load_embed_tokens(self.cfg) + # delete llm because we use it only to get the embbeding tokens + del self.language_model + if self.cfg.tts_config.get("use_subword_flag_emb", False): + self.subword_flag_emb = SubwordFlagEmbedding(self.cfg.pretrained_lm_name, self.cfg.tts_config.context_hidden_size) + + # instanciate eartts model and codec + self._load_tts_model(self.cfg) + self._codebook_size = self.tts_model.config.codebook_size + + # compute source fps + self.source_fps = self.source_sample_rate / ( + self.source_sample_rate * cfg.data.frame_length + ) # conver frame rate in fps + self.source_samples_per_frame = int(self.source_sample_rate//self.source_fps) + + # get codec silence tokens + self.codec_silence_tokens = self.get_codec_silence_frame() + + # Load tokenizer + if self.cfg.get("use_word_sep_tokenizer", False): + self.tokenizer = WordSepTokenizer(self.cfg.pretrained_lm_name, use_fast=True, trust_remote_code=True) + else: + self.tokenizer = AutoTokenizer(self.cfg.pretrained_lm_name, use_fast=True, trust_remote_code=True) # Note that we are using fast tokenizer + + if 'Qwen2.5' in self.cfg.pretrained_lm_name: + # For Qwen, '<|im_start|>' is a common choice for a BOS token. + # You can check your tokenizer's vocabulary for the best candidate. + logging.warning("Tokenizer does not have a `bos_token`. Setting it to '<|im_start|>'.") + self.tokenizer.bos_token = '<|im_start|>' + self.tokenizer.eos_token = '<|im_end|>' + + elif 'Nemotron' in self.cfg.pretrained_lm_name: + # ====== NEMOTRON-SPECIFIC HANDLING ====== + self.tokenizer.bos_token = '' + self.tokenizer.eos_token = '' + self.tokenizer.pad_token = '' + + # cached for quicker audio decoding + self.register_buffer( + "_control_codes", + torch.tensor([self.speech_bos_id, self.speech_eos_id, self.speech_pad_id], device=self.device), + ) + + self._use_fsdp = False + self._use_tp = False + if self.cfg.get("pretrained_model", None): + self.init_model_from_another_checkpoint(self.cfg.pretrained_model) + + def get_codec_silence_frame_last_one(self): + audio = torch.zeros(1, 10*self.target_sample_rate).float().to(self.device) + audio_len = torch.tensor([audio.size(-1)]).long() + audio, audio_len = self.pad_audio_to_factor(audio, audio_len, self.target_samples_per_frame) + + with fp32_precision(), torch.no_grad(): + sil_codes, sil_codes_lens = self.audio_codec.encode( + audio.unsqueeze(1), audio_len + ) + return sil_codes[0, -1] + + def get_codec_silence_frame(self): + from collections import Counter + + # Generate long zero waveform (silence) + audio = torch.zeros(1, 10 * self.target_sample_rate).float().to(self.device) + audio_len = torch.tensor([audio.size(-1)]).long() + audio, audio_len = self.pad_audio_to_factor(audio, audio_len, self.target_samples_per_frame) + + with fp32_precision(), torch.no_grad(): + sil_codes, _ = self.audio_codec.encode(audio.unsqueeze(1), audio_len) # [1, T, C] + sil_codes = sil_codes[0] # [T, C] + + # Convert each frame (C tokens) into a tuple + combos = [tuple(row.tolist()) for row in sil_codes] + + # Count frequencies + counter = Counter(combos) + + # Pick the most common combination + most_common_combo, freq = counter.most_common(1)[0] + + # Return as tensor [C] + return torch.tensor(most_common_combo, device=self.device, dtype=torch.long) + + def _load_embed_tokens(self, cfg) -> nn.Embedding: + """Load token embedding layer for RVQ-EAR-TTS.""" + if self.language_model: + assert callable(self.language_model.get_input_embeddings) + embed_tokens: nn.Embedding = self.language_model.get_input_embeddings() + else: + embed_tokens_state_dict = torch.load( + cfg.pretrained_lm_embedding_path, map_location="cpu", weights_only=True + ) + + # Create token embedding layer + vocab_size, hidden_size = embed_tokens_state_dict["weight"].size() + embed_tokens = nn.Embedding(vocab_size, hidden_size, dtype=torch.bfloat16) + embed_tokens.load_state_dict(embed_tokens_state_dict) + return embed_tokens + + def _load_tts_model(self, cfg) -> nn.Module: + """Load TTS model for RVQ-EAR-TTS.""" + if self.cfg.get("pretrained_tts_model", None): + self.tts_model = RVQEARTTSModel.from_pretrained(cfg.pretrained_tts_model, RVQEARTTSConfig(**cfg.tts_config), strict=False) + else: + # start the model from scratch + self.tts_model = RVQEARTTSModel(RVQEARTTSConfig(**cfg.tts_config)) + + setup_audio_codec(self) + + def _load_language_model(self, cfg): + """Load language model for RVQ-EAR-TTS.""" + if cfg.pretrained_lm_name: + language_model = load_pretrained_hf(self.cfg.pretrained_lm_name, pretrained_weights=True, trust_remote_code=True).eval() + else: + language_model = None + return language_model + + def setup_speaker_encoder(self): + with fp32_precision(): + self.speaker_encoder = EncDecSpeakerLabelModel.from_pretrained(model_name=self.speaker_encoder_model_name) + + # freeze the pretrained speaker encoder + self.speaker_encoder.eval() + self.speaker_encoder.freeze() + + for p in self.speaker_encoder.parameters(): + p.requires_grad = False + + def init_model_from_another_checkpoint(self, checkpoint_path): + if checkpoint_path is not None: + if '.nemo' in checkpoint_path: + with tempfile.TemporaryDirectory() as tmpdir: + NLPSaveRestoreConnector._unpack_nemo_file(checkpoint_path, tmpdir) + checkpoint_path = f"{tmpdir}/model_weights.ckpt" + checkpoint_state = torch.load(checkpoint_path, map_location='cpu') + else: + checkpoint_state = torch.load(checkpoint_path, weights_only=False, map_location='cpu')['state_dict'] + + checkpoint_state = set_model_dict_for_partial_init(checkpoint_state, self.state_dict()) + + if self.cfg.get("rescale_pretrained_weights", None): + checkpoint_state = rescale_state_dict(checkpoint_state, first_n_layers=self.cfg.get("rescale_first_n_layers", None)) + + self.load_state_dict(checkpoint_state, strict=True) + + @property + def device(self): + return next(self.parameters()).device + + @property + def speech_vocab_size(self): + """Return the size of the audio codec codebook including extra speech BOS and EOS tokens.""" + if self.use_local_transformer and self.local_transformer_type == "nar": # add extra token for mask + return self._codebook_size + 4 + return self._codebook_size + 3 + + @property + def speech_bos_id(self) -> int: + """Indicates start of utterance generation (not start of inference!).""" + if self.cfg.get("custom_speech_bos_id", None): + return self.cfg.get("custom_speech_bos_id") + return self._codebook_size + 2 + + @property + def speech_eos_id(self) -> int: + """Indicates end of utterance generation.""" + if self.cfg.get("custom_speech_eos_id", None): + return self.cfg.get("custom_speech_eos_id") + return self._codebook_size + 1 + + @property + def speech_pad_id(self) -> int: + """Indicates start of inference (the very first frame).""" + if self.cfg.get("custom_speech_pad_id", None): + return self.cfg.get("custom_speech_pad_id") + return self._codebook_size + + @property + def text_vocab_size(self): + """Return the size of the text tokenizer.""" + return self.tokenizer.vocab_size + + @property + def text_bos_id(self) -> int: + return self.tokenizer.bos_id + + @property + def text_zstts_task_id(self) -> int: + return self.tokenizer.text_to_ids("<|box_start|>") # uses <|box_start|> special token as zstts task id token + + @property + def text_cont_task_id(self) -> int: + return self.tokenizer.text_to_ids("<|object_ref_start|>") # uses <|object_ref_start|> special token as cont task id token + + @property + def text_eos_id(self) -> int: + return self.tokenizer.eos_id + + @property + def text_pad_id(self) -> int: + """ + Text pad ID is used as a 'blank' for frames when the model is not speaking + and for frames where the model is speaking but has already predicted the + entire text channel's content. + + Example: + + flow: |---user---||-------assistant--------||-user-| + text channel: 0000000000 1xxxxxxx0000000000000002 000000 + + Where 0 indicates PAD ID, 1 indicates BOS ID, 2 indacates EOS ID, + and x indicates tokens corresponding to actual text + + """ + return get_pad_id(self.tokenizer) + + def pad_audio_to_factor(self, audio, audio_len, samples_per_frame, downsampling_factor: int = 1): + """ + Zero pad the end of the audio so that we do not have a partial end frame. + The output will be zero-padded to have an integer number of frames of + length `samples_per_frame * downsampling_factor`. + + Args: + audio: input time-domain signal (B, T) + audio_len: valid length for each example in the batch (B,) + samples_per_frame: number of samples per frame + downsampling_factor: how much each frame is downsampled in later processing + + Returns: + padded_audio: Padded time-domain signal (B, T') + padded_len: Adjusted valid lengths (B,) + """ + with fp32_precision(): + total_factor = samples_per_frame * downsampling_factor + padded_len = total_factor * torch.ceil(audio_len / total_factor).int() + max_len = padded_len.max().int().item() + num_padding = max_len - audio.shape[1] + padded_audio = F.pad(audio, (0, num_padding)) + return padded_audio, padded_len + + def prepare_inputs(self, batch: dict): + """ + """ + """ + import hashlib + import torch + + def hash_texts(text_list): + hashes = [] + for t in text_list: + norm = t.strip().lower() + h = hashlib.sha256(norm.encode("utf-8")).hexdigest() + hashes.append(h) + return hashes + + # --- Safe batch filtering function --- + def filter_batch_by_indices(batch, keep_indices): + if not keep_indices: + # No common samples: empty all fields + new_batch = {} + for k, v in batch.items(): + if isinstance(v, list): + new_batch[k] = [] + elif hasattr(v, "__getitem__") and not isinstance(v, str): + try: + new_batch[k] = v[0:0] # empty tensor + except Exception: + new_batch[k] = v + else: + new_batch[k] = v + return new_batch + + new_batch = {} + for k, v in batch.items(): + try: + if isinstance(v, list): + new_batch[k] = [v[i] for i in keep_indices if i < len(v)] + elif hasattr(v, "__getitem__") and not isinstance(v, str): + slices = [i for i in keep_indices if i < v.shape[0]] + if slices: + new_batch[k] = v[slices] + else: + new_batch[k] = v[0:0] # empty tensor + else: + new_batch[k] = v # keep metadata as-is + except Exception: + new_batch[k] = v # fallback if indexing fails + return new_batch + + # --- Compute sample IDs --- + target_texts = batch["target_texts"] + batch["sample_id"] = target_texts # using text itself as unique ID + print("Sample ids:", batch["sample_id"]) + + if self.training: + # --- Track sample IDs and store full batch --- + if not hasattr(self, "train_sample_ids"): + self.train_sample_ids = set(batch["sample_id"]) + self.train_batches_by_hash = dict() + else: + self.train_sample_ids.update(batch["sample_id"]) + + # Save the full batch per sample + for i, sid in enumerate(batch["sample_id"]): + self.train_batches_by_hash[sid] = { + k: (v[i] if isinstance(v, list) else v[i:i+1]) + for k, v in batch.items() + } + + else: + # --- Validation: keep only common samples --- + if not hasattr(self, "eval_common_ids"): + self.eval_common_ids = set() + + # Only consider validation samples that exist in training + keep_indices = [i for i, sid in enumerate(batch["sample_id"]) + if sid in self.train_batches_by_hash] + + # Safe filtering + batch = filter_batch_by_indices(batch, keep_indices) + + if keep_indices: + print(f"Keeping only {len(keep_indices)} common samples from validation!") + # Update eval_common_ids + self.eval_common_ids.update(batch["sample_id"]) + print( + f"total_common={len(self.eval_common_ids)}, " + f"train_total={len(self.train_sample_ids)}" + ) + + # --- Compare the first common sample --- + first_sid = batch["sample_id"][0] + train_sample = self.train_batches_by_hash[first_sid] + val_sample = {k: (v[0] if isinstance(v, list) else v[0:1]) + for k, v in batch.items()} + + # --- Slice tensors to minimal overlapping shape --- + for k in val_sample.keys(): + t_val = val_sample[k] + t_train = train_sample.get(k, t_val) + if isinstance(t_val, torch.Tensor) and isinstance(t_train, torch.Tensor): + min_shape = tuple(min(s1, s2) for s1, s2 in zip(t_val.shape, t_train.shape)) + if all(s > 0 for s in min_shape): + slices = tuple(slice(0, s) for s in min_shape) + val_sample[k] = t_val[slices] + train_sample[k] = t_train[slices] + + print(f"Comparing first common sample (sid={first_sid})") + compare_dicts(train_sample, val_sample) + exit() + else: + print("No common samples found in this validation batch!") + + """ + # check if audios has the same batch size + assert batch["source_audio"].size(0) == batch["target_audio"].size(0) + assert batch["speaker_reference_audio"].size(0) == batch["target_audio"].size(0) + + target_audio = batch["target_audio"] + target_audio_lens = batch["target_audio_lens"] + input_text_tokens = batch["input_text_tokens"] + audio_mask = batch["audio_mask"] + desc_mask = batch["desc_mask"] + non_prompt_mask = batch["non_prompt_mask"] + aligned_attention_mask = batch["aligned_attention_mask"] + aligned_position_ids = batch["aligned_position_ids"] + + # extract target audio codes + with fp32_precision(), torch.no_grad(): + target_audio, target_audio_lens = self.pad_audio_to_factor(target_audio, target_audio_lens, self.target_samples_per_frame, 1) + target_codes, target_codes_lens = self.audio_codec.encode( + target_audio.unsqueeze(1), target_audio_lens + ) + + # ToDo: consider use the source audio + """ + # resample source audio if needed + if self.source_sample_rate != self.target_sample_rate: + source_audio = resample(source_audio, self.source_sample_rate, self.target_sample_rate) + with fp32_precision(): + source_audio_lens = (source_audio_lens * (self.target_sample_rate/self.source_sample_rate)).to(lengths.dtype) + # ToDo: Add a transformer encoder to help the model to better extract contextual information, replace the code bellow with it + # extract embedding for context audios + with fp32_precision(), torch.no_grad(): + source_audio, source_audio_lens = self.pad_audio_to_factor(source_audio, source_audio_lens, self.target_samples_per_frame, 1) + source_codes, source_codes_lens = self.audio_codec.encode( + source_audio.unsqueeze(1), source_audio_lens + ) + source_codes = source_codes.transpose(1, 2) # (B, K, T) -> (B, T, K) + """ + + with fp32_precision(): + target_len = target_codes.shape[1] + + # Pad or truncate sequence variables + def pad_or_truncate(x, pad_value=0): + if x.dim() == 2: # [B, T] + L = x.shape[1] + if L < target_len: + return F.pad(x, (0, target_len - L), value=pad_value) + else: + return x[:, :target_len] + return x # leave others for now + + input_text_tokens = pad_or_truncate(input_text_tokens, pad_value=self.text_pad_id) + audio_mask = pad_or_truncate(audio_mask, pad_value=0) + desc_mask = pad_or_truncate(desc_mask, pad_value=0) + non_prompt_mask = pad_or_truncate(non_prompt_mask, pad_value=0) + aligned_position_ids = pad_or_truncate(aligned_position_ids, pad_value=0) + + # Correct attention mask padding/truncation + B, H, L1, L2 = aligned_attention_mask.shape + new_len = target_len + if L1 < new_len or L2 < new_len: + pad_rows = new_len - L1 + pad_cols = new_len - L2 + aligned_attention_mask = F.pad(aligned_attention_mask, (0, pad_cols, 0, pad_rows)) + elif L1 > new_len or L2 > new_len: + aligned_attention_mask = aligned_attention_mask[:, :, :new_len, :new_len] + + if self.cfg.get("disable_speech_pad", False): + target_codes_aligned = target_codes + else: + # ToDo: desc_mask is one for the end of the sequence, this is what cause the artifact issue in the end, fix it. + # set the pad token when there is desc as in https://gitlab-master.nvidia.com/jaehyeonk/easy-ar-tts/-/blame/simple-bq/scripts/train_tts_with_rvqvae.py#L69 + target_codes_aligned = torch.where( + desc_mask.unsqueeze(-1), # (B, T, 1) for broadcasting + torch.full_like(target_codes, self.speech_pad_id), # fill with pad id + target_codes + ) + + if self.cfg.get("ignore_audio_prompt_on_loss", False): + # set audio_mask as non_prompt_mask to avoid the audio prompt in loss computation + audio_mask = non_prompt_mask + + if self.cfg.get("add_pad_speech_token_in_last_prompt_frame", False) and not self.cfg.get("disable_speech_pad", False): + # set special token in the last audio prompt (it will works as a BOS token) + pos = non_prompt_mask.float().argmax(dim=1) # shape: [B] + row_idx = torch.arange(B, device=self.device) + # set the extra self.speech_pad_id at first 1 position in non_prompt_mask + target_codes_aligned[row_idx, pos] = self.speech_pad_id + + B, T = input_text_tokens.shape + + # shift text tokens + subword_ids = F.pad(input_text_tokens[:, 1:], [0, 1]) + # note that we are using a text mask where we are ignoring the desc + audio prompt but we are keeping 1 until the audio ends to support duplex + subword_mask = F.pad(non_prompt_mask[:, 1:], [0, 1]) + + # ToDo: implement context from the llm + # detach embedding as in eartts + if self.cfg.tts_config.context_hidden_size is not None: + context_hidden_state = self.embed_tokens(input_text_tokens).detach() + if self.cfg.tts_config.get("use_subword_flag_emb", False): + context_hidden_state = self.subword_flag_emb(context_hidden_state, input_text_tokens) + else: + context_hidden_state = None + + if self._use_tp: + tp_world_size = self.device_mesh["tensor_parallel"].size() + if (remainder := (input_text_tokens.shape[1] - 1) % tp_world_size) != 0: + input_text_tokens = input_text_tokens[:, :-remainder] + target_codes_aligned = target_codes_aligned[:, :-remainder] + target_codes_aligned = target_codes_aligned[:, :-remainder] + audio_mask = audio_mask[:, :-remainder] + desc_mask = desc_mask[:, :-remainder] + subword_ids = subword_ids[:, :-remainder] + subword_mask = subword_mask[:, :-remainder] + + return { + "code": target_codes_aligned, + "audio_mask": audio_mask, + "attention_mask": aligned_attention_mask, + "position_ids": aligned_position_ids, + "subword_ids": subword_ids, + "subword_mask": subword_mask, + "context_hidden_state": context_hidden_state, + "output_lens": target_codes_lens, + "non_prompt_mask": non_prompt_mask, + "input_text_tokens": input_text_tokens + } + + def training_step(self, batch: dict, batch_idx: int): + for m in (self.tts_model, ): + if is_frozen(m): + m.eval() + + inputs = self.prepare_inputs(batch) + + tts_output = self.tts_model( + code=inputs["code"], + audio_mask=inputs["audio_mask"], + attention_mask=inputs["attention_mask"], + position_ids=inputs["position_ids"], + context_hidden_state=inputs["context_hidden_state"], + subword_ids=inputs["subword_ids"], + subword_mask=inputs["subword_mask"], + non_prompt_mask=inputs["non_prompt_mask"] + ) + loss_dict = {"lm_loss": tts_output.lm_loss, "c_loss": tts_output.c_loss, "k_loss": tts_output.k_loss} + backbone_out = tts_output.hidden_states + loss = sum(loss_dict.values()) + + num_frames = inputs["output_lens"].sum() + B, T = inputs["code"].shape[:2] + ans = { + "loss": loss, + "learning_rate": ( + torch.as_tensor(self.trainer.optimizers[0].param_groups[0]['lr'] if self._trainer is not None else 0) + ), + "batch_size": B, + "sequence_length": T, + "num_frames": num_frames.to(torch.float32), # avoid warning + "padding_ratio": num_frames / (B * T), + **loss_dict, + } + + self.log_dict(ans, on_step=True) + return ans + + def on_train_epoch_start(self) -> None: + setup_audio_codec(self) # potentially reloads the audio codec to make sure it's in fp32 + + def on_train_epoch_end(self) -> None: + # log model stats to debug gradient weights issues + self.log_model_stats() + + def log_model_stats(self): + total_w_sq = 0.0 + total_w_params = 0 + max_abs_w = 0.0 + sum_w = 0.0 + + total_g_sq = 0.0 + total_g_params = 0 + + for p in self.parameters(): + if not p.requires_grad: + continue + + # ----- weights ----- + w = p.detach().cpu().float() # ✅ safe offline copy + total_w_sq += (w * w).sum().item() + total_w_params += w.numel() + max_abs_w = max(max_abs_w, w.abs().max().item()) + sum_w += w.sum().item() + + # ----- grads (optional, disabled for speed) ----- + if p.grad is not None: + g = p.grad.detach().cpu().float() + total_g_sq += (g * g).sum().item() + total_g_params += g.numel() + + # L2 norms + weight_l2 = (total_w_sq ** 0.5) if total_w_sq > 0 else 0.0 + grad_l2 = (total_g_sq ** 0.5) if total_g_sq > 0 else 0.0 + + # RMS (global) + weight_rms = ((total_w_sq / total_w_params) ** 0.5) if total_w_params > 0 else 0.0 + grad_rms = ((total_g_sq / total_g_params) ** 0.5) if total_g_params > 0 else 0.0 + + # Mean + weight_mean = sum_w / total_w_params if total_w_params > 0 else 0.0 + + # direct float logging avoids device sync penalty + self.log("weights/L2", weight_l2, on_epoch=True, sync_dist=True) + self.log("weights/RMS", weight_rms, on_epoch=True, sync_dist=True) + self.log("weights/max_abs", max_abs_w, on_epoch=True, sync_dist=True) + self.log("weights/mean", weight_mean, on_epoch=True, sync_dist=True) + + # ignore the grads stats for now + # self.log("grads/L2", grad_l2, on_epoch=True, sync_dist=True) + # self.log("grads/RMS", grad_rms, on_epoch=True, sync_dist=True) + + def on_validation_epoch_start(self) -> None: + setup_audio_codec(self) + self.results_logger = ResultsLogger(self.validation_save_path).reset() + self.asr_bleu = ASRBLEU(self.cfg.scoring_asr).reset() + self.intelligibility = Intelligibility(self.cfg.scoring_asr, reuse_asr_hyps=True).reset() + self.secs = SECS(self.cfg.get("scoring_se", "titanet_large")).reset() + + def on_validation_epoch_end(self, prefix="val") -> None: + asr_bleu = self.asr_bleu.compute() + for k, m in asr_bleu.items(): + self.log(f"{prefix}_{k}", m.to(self.device), on_epoch=True, sync_dist=True) + cer_wer = self.intelligibility.compute() + for k, m in cer_wer.items(): + self.log(f"{prefix}_{k}", m.to(self.device), on_epoch=True, sync_dist=True) + secs = self.secs.compute() + for k, m in secs.items(): + self.log(f"{prefix}_{k}", m.to(self.device), on_epoch=True, sync_dist=True) + + def get_teacher_force_inference_audio(self, batch, guidance_enabled=True): + inputs = self.prepare_inputs(batch) + + tts_output = self.tts_model( + code=inputs["code"], + audio_mask=inputs["audio_mask"], + attention_mask=inputs["attention_mask"], + position_ids=inputs["position_ids"], + context_hidden_state=inputs["context_hidden_state"], + subword_ids=inputs["subword_ids"], + subword_mask=inputs["subword_mask"], + non_prompt_mask=inputs["non_prompt_mask"], + generation_config=self._get_generation_config(guidance_enabled=guidance_enabled), + teacher_forcing_inference=True, + guidance_enabled=guidance_enabled + ) + tf_audio_codes_pred = tts_output["codes"].squeeze(2) + + # decode audio + tf_audio_codes_pred = replace_control_speech_codes(tf_audio_codes_pred, self._control_codes, self.codec_silence_tokens) + with fp32_precision(), torch.no_grad(): + audio_pred, audio_len = self.audio_codec.decode( + tf_audio_codes_pred, inputs["output_lens"] + ) + + return audio_pred.squeeze(1), audio_len + + def _get_generation_config(self, guidance_enabled: bool = False): + """Get default generation config for EAR-TTS.""" + return { + "num_iter": 8, + "guidance_scale": self.cfg.get("inference_guidance_scale", 0.5) if guidance_enabled else None, + "top_p_or_k": self.cfg.get("inference_top_p_or_k", 0.8), + "noise_scale": self.cfg.get("inference_noise_scale", 0.8), + "eos_threshold": -3.0, + } + + def offline_inference_with_custom_sentences(self, test_sentences: torch.Tensor, inference_speaker_reference: torch.Tensor, speech_text_ratio: float = 3.5): + B = len(test_sentences) + # load and get speaker reference + speaker_audio, sr = torchaudio.load(inference_speaker_reference) + speaker_audio = resample(speaker_audio, sr, self.target_sample_rate) + speaker_audio = speaker_audio.repeat(B, 1).to(self.device) + # lengths -> [B] + speaker_audio_lens = torch.tensor([speaker_audio.size(1)], device=self.device).long().repeat(B) + + # Tokenize sentences + if self.normalize_text: + tokenized = [ + torch.as_tensor([self.tokenizer.bos] + self.tokenizer.text_to_ids(normalize_text_fn(text)), dtype=torch.long, device=self.device) + for text in test_sentences + ] + else: + tokenized = [ + torch.as_tensor([self.tokenizer.bos] + self.tokenizer.text_to_ids(text), dtype=torch.long, device=self.device) + for text in test_sentences + ] + + # Get max length and target length + max_len = max(len(t) for t in tokenized) + # Pad each to double length + target_len = int(speech_text_ratio * max_len) # make text longer to ensures that we have enough steps for speech gen + next_subword_ids = torch.stack([ + torch.cat([ + torch.tensor([self.text_pad_id], dtype=torch.long, device=self.device), # shift right adding one padding token + t, + torch.full((target_len - len(t) - 1,), self.text_pad_id, dtype=torch.long, device=self.device) # remaining padding + ]) + for t in tokenized + ]) + + audio, audio_len = self.offline_inference( + speaker_audio=speaker_audio, + speaker_audio_lens=speaker_audio_lens, + next_subword_ids=next_subword_ids, + guidance_enabled=self.cfg.get("inference_guidance_enabled", True) + ) + return audio, audio_len, speaker_audio, speaker_audio_lens + + def run_evaluation_one_batch(self, name, dataset_batch, use_dataloader_init=False): + results = {} + inputs = self.prepare_inputs(dataset_batch) + + # + # exit() + # first evaluation, make the model bf16 safe + if not self.model_16_precision_safe and self.cfg.get("ensures_16_safe", False) and str(self.trainer_config.precision) != str(32): + self.tts_model, summary = make_tts_model_mixed_precision_definite(self.tts_model, inputs, safety_factor=1.0, mixed_dtype=torch.float16 if str(self.trainer_config.precision) == str(16) else torch.bfloat16) + # self.tts_model, summary = make_tts_model_mixed_precision_safe(self.tts_model, inputs, safety_factor=1.0) + self.model_16_precision_safe = True + + print("Current FP32 layers:", summary["fp32_layers"]) + # compare_tts_model_fp32_bf16_mixed(self.tts_model, inputs) + # exit() + + results["audio_tf"], results["audio_tf_len"] = self.get_teacher_force_inference_audio(dataset_batch) + if use_dataloader_init: + # cut it on prompt + init_inputs = { + "code": inputs["code"], + "audio_mask": inputs["audio_mask"], + "non_prompt_mask": inputs["non_prompt_mask"], + "context_hidden_state": inputs["context_hidden_state"], + "subword_ids": inputs["subword_ids"], + "subword_mask": inputs["subword_mask"] + } + # cut init_inputs to consider only the prompt + for key in init_inputs: + if init_inputs[key] is not None: + init_inputs[key] = torch.stack([ + init_inputs[key][i, :l] + for i, l in enumerate(dataset_batch["desc_plus_audio_prompt_lens"]) + ]) + + # drop items without description to avoid issues + """ + lens = dataset_batch["desc_plus_audio_prompt_lens"] # list of lengths + + # Example condition: keep only those with the maximum length + max_len = max(lens) + keep_indices = [i for i, l in enumerate(lens) if l == max_len] + + # Convert indices to tensor for indexing torch tensors + keep_indices = torch.tensor(keep_indices, dtype=torch.long) + + # Now filter every key in dataset_batch + for k, v in dataset_batch.items(): + if isinstance(v, torch.Tensor): + dataset_batch[k] = v[keep_indices] + elif isinstance(v, list): + dataset_batch[k] = [v[i] for i in keep_indices] + + # Do the same for inputs + for k, v in inputs.items(): + if isinstance(v, torch.Tensor): + inputs[k] = v[keep_indices] + elif isinstance(v, list): + inputs[k] = [v[i] for i in keep_indices] + """ + + # remove the prompt from the input_text_tokens to emulate S2S connected inference + next_subword_ids = torch.stack([ + inputs["subword_ids"][i, l:] # slice each element + for i, l in enumerate(dataset_batch["desc_plus_audio_prompt_lens"]) + ]) + + if self.cfg.get("use_asr_speech_tokens", False) and self.cfg.get("only_semantic_to_speech", False): + inp_asr_speech_tokens = torch.stack([ + inputs["target_asr_speech_tokens"][i, l:] # slice each element + for i, l in enumerate(dataset_batch["desc_plus_audio_prompt_lens"]) + ]) + else: + inp_asr_speech_tokens = None + + results["audio"], results["audio_len"] = self.offline_inference( + speaker_audio=dataset_batch["speaker_reference_audio"], + speaker_audio_lens=dataset_batch["speaker_reference_audio_lens"], + next_subword_ids=next_subword_ids, + formatter=dataset_batch["formatter"][0], + inp_asr_speech_tokens=inp_asr_speech_tokens, + init_inputs=init_inputs if use_dataloader_init else None, + ) + + # remove prompt padding from the user audio as autoregressive inference does not return the prompt + dataset_batch["source_audio"] = dataset_batch["source_audio"][:, -int(next_subword_ids.size(-1)*self.source_samples_per_frame):] + + # clean prompt from the audio + results["audio_tf"] = results["audio_tf"][:, -int(next_subword_ids.size(-1)*self.target_samples_per_frame):] + # remove prompt from target audio + target_audio_no_prompt = dataset_batch["target_audio"][:, -int(next_subword_ids.size(-1)*self.target_samples_per_frame):] + target_audio_no_prompt_lens = dataset_batch["target_audio_lens"] - (torch.tensor(dataset_batch["desc_plus_audio_prompt_lens"], dtype=torch.long, device=dataset_batch["target_audio_lens"].device) * self.target_samples_per_frame) + # for i, l in enumerate(dataset_batch["desc_plus_audio_prompt_lens"]): + # results["audio_tf"][i, :l*self.target_samples_per_frame] = 0.0 + + with fp32_precision(): # resample is fragile to bfloat16 default dtype + metric_audio_pred = results["audio"] + metric_audio_pred_lens = results["audio_len"] + + # resample audio to the asr sampling rate + metric_audio_pred = resample(metric_audio_pred, self.target_sample_rate, 16000) + metric_audio_pred_lens = (metric_audio_pred_lens / self.target_sample_rate * 16000).to(torch.long) + # reshape target audio without prompt + target_audio_no_prompt_16khz = resample(target_audio_no_prompt, self.target_sample_rate, 16000) + target_audio_no_prompt_lens_16khz = (target_audio_no_prompt_lens / self.target_sample_rate * 16000).to(torch.long) + if self.cfg.get("use_GT_transcriptions_for_metrics", True): + # use target audio transcription for metrics + target_asr_texts = self.asr_bleu.asr.transcribe( + [audio[:alen] for audio, alen in zip(target_audio_no_prompt_16khz, target_audio_no_prompt_lens_16khz)], + batch_size=target_audio_no_prompt_16khz.shape[0], + verbose=False, + ) + metric_text = [asr_hyp.text for asr_hyp in target_asr_texts] + else: + metric_text = dataset_batch["target_texts"] + + asr_hyps = self.asr_bleu.update( + name=name, + refs=metric_text, + pred_audio=metric_audio_pred, + pred_audio_lens=metric_audio_pred_lens, + ) + + self.intelligibility.update( + name=name, + refs=metric_text, + pred_audio=metric_audio_pred, + pred_audio_lens=metric_audio_pred_lens, + asr_hyps=asr_hyps, + ) + + # add ground truth intelligibility metrics + self.intelligibility.update( + name=name+"_gt", + refs=dataset_batch["target_texts"], + pred_audio=target_audio_no_prompt_16khz, + pred_audio_lens=target_audio_no_prompt_lens_16khz, + asr_hyps=metric_text if self.cfg.get("use_GT_transcriptions_for_metrics", True) else None, # reuse GT transcription + ) + + self.secs.update( + name=name, + target_audio=resample(dataset_batch["target_audio"], self.target_sample_rate, 16000), + target_audio_lens=(dataset_batch["target_audio_lens"] / self.target_sample_rate * 16000).to(torch.long), + pred_audio=resample(results["audio"], self.target_sample_rate, 16000), + pred_audio_lens=(results["audio_len"] / self.target_sample_rate * 16000).to(torch.long), + ) + + eou_labels = generate_multiturn_speaking_mask( + next_subword_ids, bos_token_id=self.text_bos_id, eos_token_id=self.text_eos_id + ) + + self.results_logger.update( + name=name, + refs=dataset_batch["target_texts"], + hyps=metric_text, + asr_hyps=asr_hyps, + samples_id=dataset_batch['sample_id'], + pred_audio=results["audio"].float(), + pred_audio_tf=results["audio_tf"].float(), + pre_audio_trimmed=None, + reference_audio=dataset_batch["speaker_reference_audio"].float(), + target_audio=target_audio_no_prompt.float(), + pred_audio_sr=self.target_sample_rate, + user_audio=dataset_batch["source_audio"].float(), + user_audio_sr=self.source_sample_rate, + eou_pred=eou_labels, + fps=self.target_fps, + results=results if self.cfg.get("dump_tokens_text", False) else None, + tokenizer=self.tokenizer, + ) + + def validation_step(self, batch: dict, batch_idx: int): + if self.cfg.get("test_sentences", None) and self.cfg.get("inference_speaker_reference", None): + for name in self.cfg.test_sentences.keys(): + logging.info(f"Generating {name} custom sentences.") + test_sentences = self.cfg.test_sentences[name] + results = {} + results["audio"], results["audio_len"], speaker_audio, speaker_audio_lens = self.offline_inference_with_custom_sentences(test_sentences, self.cfg.inference_speaker_reference) + with fp32_precision(): # resample is fragile to bfloat16 default dtype + metric_audio_pred = results["audio"] + metric_audio_pred_lens = results["audio_len"] + + # resample audio to the asr sampling rate + metric_audio_pred = resample(metric_audio_pred, self.target_sample_rate, 16000) + metric_audio_pred_lens = (metric_audio_pred_lens / self.target_sample_rate * 16000).to(torch.long) + + asr_hyps = self.asr_bleu.update( + name=name, + refs=test_sentences, + pred_audio=metric_audio_pred, + pred_audio_lens=metric_audio_pred_lens, + ) + + self.intelligibility.update( + name=name, + refs=test_sentences, + pred_audio=metric_audio_pred, + pred_audio_lens=metric_audio_pred_lens, + asr_hyps=asr_hyps, + ) + + self.secs.update( + name=name, + target_audio=resample(speaker_audio, self.target_sample_rate, 16000), + target_audio_lens=(speaker_audio_lens / self.target_sample_rate * 16000).to(torch.long), + pred_audio=resample(results["audio"], self.target_sample_rate, 16000), + pred_audio_lens=(results["audio_len"] / self.target_sample_rate * 16000).to(torch.long), + ) + + self.results_logger.update( + name=name, + refs=test_sentences, + hyps=test_sentences, + asr_hyps=asr_hyps, + samples_id=[str(i) for i in range(len(test_sentences))], + pred_audio=results["audio"].float(), + pred_audio_tf=None, + pre_audio_trimmed=None, + reference_audio=speaker_audio.float(), + target_audio=None, + pred_audio_sr=self.target_sample_rate, + user_audio=None, + user_audio_sr=None, + eou_pred=None, + fps=self.target_fps, + results=None, + tokenizer=self.tokenizer, + ) + + else: + for name, dataset_batch in batch.items(): + if dataset_batch is None: + continue # some dataset is exhausted + # run inference for multiples references + if self.cfg.get("inference_speaker_reference_path", None): + B = len(dataset_batch['sample_id']) + for inference_speaker_reference in glob.glob(os.path.join(self.cfg.inference_speaker_reference_path, "**"), recursive=True): + if not os.path.isfile(inference_speaker_reference): + continue + print("Generating sample for speaker refernce:", inference_speaker_reference) + new_dataset_batch = copy.deepcopy(dataset_batch) + # Get only the file name + ref_name = os.path.basename(inference_speaker_reference) + # Append to each sample_id + new_dataset_batch['sample_id'] = [ + f"{sid}_{ref_name}" for sid in dataset_batch['sample_id'] + ] + speaker_audio, sr = torchaudio.load(inference_speaker_reference) + speaker_audio = resample(speaker_audio, sr, self.target_sample_rate) + speaker_audio = speaker_audio.repeat(B, 1).to(self.device) + # lengths -> [B] + speaker_audio_lens = torch.tensor([speaker_audio.size(1)], device=self.device).long().repeat(B) + new_dataset_batch["speaker_reference_audio"] = speaker_audio + new_dataset_batch["speaker_reference_audio_lens"] = speaker_audio_lens + self.run_evaluation_one_batch(name, new_dataset_batch, use_dataloader_init=False) + # run inference for a custom speaker reference + elif self.cfg.get("inference_speaker_reference", None): + new_dataset_batch = copy.deepcopy(dataset_batch) + speaker_audio, sr = torchaudio.load(inference_speaker_reference) + speaker_audio = resample(speaker_audio, sr, self.target_sample_rate) + speaker_audio = speaker_audio.repeat(B, 1).to(self.device) + # lengths -> [B] + speaker_audio_lens = torch.tensor([speaker_audio.size(1)], device=self.device).long().repeat(B) + new_dataset_batch["speaker_reference_audio"] = speaker_audio + new_dataset_batch["speaker_reference_audio_lens"] = speaker_audio_lens + self.run_evaluation_one_batch(name, new_dataset_batch, use_dataloader_init=False) + # run inference using dataloader speaker references + else: + self.run_evaluation_one_batch(name, dataset_batch, use_dataloader_init=False) + + def on_test_epoch_start(self) -> None: + return self.on_validation_epoch_start() + + def on_test_epoch_end(self) -> None: + return self.on_validation_epoch_end(prefix="test") + + def test_step(self, *args, **kwargs): + return self.validation_step(*args, **kwargs) + + def get_system_prompt(self, system_prompt=None, user_prompt=None): + messages = [] + if system_prompt is None: + system_prompt = ( + "You engage in conversation with the user. When delivering your response as speech, " + "if the user provides a description such as emotions, scene details, " + "or speaker style, you adjust your speaking style accordingly when delivering the response. " + "However, this description should influence only the delivery of your response, not its content. " + "Your response should remain independent of any stylistic instructions." + ) + messages.append({"role": "system", "content": system_prompt}) + + # ToDo: implement dataloading support for descriptions + """for desc in example["descriptions"]: + user_prompt = "" + if random.random() > self.p_drop_description and desc: + user_prompt += f"```\n{desc}\n```" + if random.random() > self.p_drop_description: + if user_prompt: + user_prompt += "\n\n" + user_prompt += self.rng.choice(self.user_prompts) + if user_prompt: + messages.append({"role": "user", "content": user_prompt}) + messages.append({"role": "assistant", "content": SCRIPT_PLACEHOLDER}) + """ + + # given that descriptions are currently not supported, only added the user prompt + if user_prompt is None: + user_prompt = "Can you tell me something interesting?" + messages.append({"role": "user", "content": user_prompt}) + messages.append({"role": "assistant", "content": SCRIPT_PLACEHOLDER}) + non_script_list = self.tokenizer.apply_chat_template( + messages, + tokenize=False, + add_generation_prompt=False, + ).split(SCRIPT_PLACEHOLDER + self.tokenizer.eos_token)[:-1] + + input_ids = [] + for i, non_script in enumerate(non_script_list): + desc_ids = self.tokenizer.text_to_ids(non_script) + input_ids.extend(desc_ids) + + input_ids = torch.tensor(input_ids, dtype=torch.long, device=self.device).view(1, -1) + return input_ids + + def get_init_inputs(self, speaker_audio, speaker_audio_lens, system_prompt=None, user_prompt=None): + # compute prompt audio size and slice it + with fp32_precision(): + """ + # old pad that can add long silences in the end + prompt_audio_size = int(((self.data_cfg.audio_prompt_duration * self.target_sample_rate) // self.target_samples_per_frame) * self.target_samples_per_frame) + B, T = speaker_audio.shape # [batch, time] + if T >= prompt_audio_size: + # Just crop if longer + prompt_audio = speaker_audio[:, :prompt_audio_size] + else: + # Repeat along time until we have enough, then crop + repeat_factor = (prompt_audio_size + T - 1) // T # ceil division + expanded = speaker_audio.repeat(1, repeat_factor) + prompt_audio = expanded[:, :prompt_audio_size] + """ + # compute the exact number of samples for the prompt duration + prompt_audio_size = int( + ((self.data_cfg.audio_prompt_duration * self.target_sample_rate) + // self.target_samples_per_frame) + * self.target_samples_per_frame + ) + + B, T = speaker_audio.shape + device = speaker_audio.device + dtype = speaker_audio.dtype + + # allocate result + prompt_audio = torch.zeros(B, prompt_audio_size, device=device, dtype=dtype) + + # process each example independently + for b in range(B): + valid_len = min(speaker_audio_lens[b].item(), T) + + # handle empty + if valid_len <= 0: + continue + + # valid (non-padded) segment + valid_segment = speaker_audio[b, :valid_len] + + if valid_len >= prompt_audio_size: + # enough valid audio → crop from start (no silence) + prompt_audio[b] = valid_segment[:prompt_audio_size] + else: + # too short → repeat and crop + repeat_factor = (prompt_audio_size + valid_len - 1) // valid_len # ceil division + expanded = valid_segment.repeat(repeat_factor) + prompt_audio[b] = expanded[:prompt_audio_size] + + # add a silence in the end to smooth the transition between prompt and audio tokens + prompt_audio[:, -int(self.target_samples_per_frame * 2):] = 0 + + # get prompt audio size + with fp32_precision(): + prompt_audio_text_pad_size = int(prompt_audio_size // self.target_samples_per_frame) + + # get description tokens + desc_tokens_ids = self.get_system_prompt(system_prompt=system_prompt, user_prompt=user_prompt) + + # create a padding tensor + prompt_audio_text_pad = torch.ones(prompt_audio_text_pad_size, device=self.device, dtype=desc_tokens_ids.dtype) * self.text_pad_id + prompt_audio_text_pad[-1] = self.tokenizer.eos + + # Add eos to simulate the end of a turn as in EAR-TTS inference + desc_tokens_ids = torch.cat([desc_tokens_ids.squeeze(), torch.tensor([self.tokenizer.eos], dtype=desc_tokens_ids.dtype, device=desc_tokens_ids.device)]) + # Add padding equivalent to the audio prompt size in number of tokens + input_text_tokens = torch.cat([desc_tokens_ids.to(desc_tokens_ids.dtype), prompt_audio_text_pad.to(desc_tokens_ids.dtype)]) + + # create pad audio for the description + pad_size = desc_tokens_ids.size(-1) * self.target_samples_per_frame + pad_audio = torch.zeros(pad_size, device=prompt_audio.device, dtype=prompt_audio.dtype).unsqueeze(0).repeat(prompt_audio.size(0), 1) + + # repeat to reaches the batch size + input_text_tokens = input_text_tokens.unsqueeze(0).repeat(prompt_audio.size(0), 1) + target_audio = torch.cat([pad_audio, prompt_audio], dim=1) + + # extract code codes + target_audio_len = torch.tensor([target_audio.size(-1)] * target_audio.size(0), dtype=torch.long, device=self.device) + with fp32_precision(), torch.no_grad(): + code, _ = self.audio_codec.encode(target_audio.unsqueeze(1), target_audio_len) + + # get context hidden + if self.cfg.tts_config.context_hidden_size is not None: + context_hidden_state = self.embed_tokens(input_text_tokens) + if self.cfg.tts_config.get("use_subword_flag_emb", False): + context_hidden_state = self.subword_flag_emb(context_hidden_state, input_text_tokens) + else: + context_hidden_state = None + + # create masks + # non_prompt_mask is all zeros, because all processed is prompt + non_prompt_mask = torch.zeros_like(input_text_tokens) + non_prompt_mask[:, -2:] = 1 # set last valid prompt frame as 1 to allow the addition of BOS in the right place + subword_mask = torch.zeros_like(input_text_tokens) # subword_mask is almost all zeros because on the warmup there is only the prompt + subword_mask[:, -3:] = 1 # -3 because of the it start right after the first valid prompt token and it is shifted by 1 + # audio mask is all ones except for description + audio_mask = torch.ones_like(input_text_tokens) + audio_mask[:, :desc_tokens_ids.size(-1)] = 0 + # desc mask is all zeros except the description + desc_mask = torch.zeros_like(input_text_tokens) + desc_mask[:, :desc_tokens_ids.size(-1)] = 1 + + + if not self.cfg.get("disable_speech_pad", False): + # add special tokens on audio codes + code = torch.where( + desc_mask.unsqueeze(-1).bool(), # (B, T, 1) for broadcasting + torch.full_like(code, self.speech_pad_id), # fill with pad id + code + ) + + # shift subword_ids + # subword_ids = F.pad(input_text_tokens[:, 1:], [0, 1], value=current_subword_id) + subword_ids = F.pad(input_text_tokens[:, 1:], [0, 1], value=0.0) + + if self.cfg.get("ignore_audio_prompt_on_loss", False): + # set audio_mask as non_prompt_mask to avoid the audio prompt in loss computation + audio_mask = non_prompt_mask + + if self.cfg.get("add_pad_speech_token_in_last_prompt_frame", False) and not self.cfg.get("disable_speech_pad", False): + # set special token in the last audio prompt (it will works as a BOS token) + pos = non_prompt_mask.float().argmax(dim=1) # shape: [B] + row_idx = torch.arange(B, device=self.device) + # set the extra self.speech_pad_id at first 1 position in non_prompt_mask + code[row_idx, pos] = self.speech_pad_id + + init_inputs = { + "code": code[:, :-1], + "audio_mask": audio_mask.bool()[:, :-1], + "context_hidden_state": context_hidden_state[:, :-1] if context_hidden_state is not None else None, + "subword_ids": subword_ids[:, :-1], + "subword_mask": subword_mask.bool()[:, :-1], + "non_prompt_mask": non_prompt_mask.bool()[:, :-1] + } + + return init_inputs + + @torch.no_grad() + def offline_inference( + self, + next_subword_ids: torch.Tensor, + speaker_audio: torch.Tensor, + speaker_audio_lens: torch.Tensor, + formatter: str = "", + system_prompt: str = None, + user_prompt: str = None, + guidance_enabled: bool = True, + generation_config: dict = None, + init_inputs: dict = None, + inp_asr_speech_tokens: torch.Tensor = None, + ) -> dict[str, torch.Tensor]: + """ + Autoregressive prediction. + + Args: + input_signal: a batch of waveforms with shape (B, T) with source sampling rate. + input_signal_lens: example lengths as number of samples of shape (B,). + decode_audio: bool, whether to decode audio codes to waveform. + + Returns: + A dict with keys: + * "text": generated text, de-tokenized to strings, properly skipping text_pad_id; list of length B. + * "tokens_text": generated text tokens of shape (B, T2). + * "tokens_audio": generated audio codes of shape (B, T2, K) where `K=num_codebooks`. + * "tokens_len" output lengths as number of tokens of shape (B,). + * "audio": generated waveform of shape (B, T3) (`decode_audio=True`). + * "audio_len" output lengths as number of waveform samples of shape (B,) (when `decode_audio=True`). + """ + B = next_subword_ids.size(0) + + # init_inputs, code, past_key_values = self.init_model_for_ar_inference(speaker_audio=speaker_audio, speaker_audio_lens=speaker_audio_lens, system_prompt=system_prompt, user_prompt=user_prompt, guidance_enabled=guidance_enabled, generation_config=generation_config) + + # ToDo: verify why codes differ from dataloader init_inputs when using nanocodec + if init_inputs is None: + init_inputs = self.get_init_inputs(speaker_audio, speaker_audio_lens, system_prompt=system_prompt, user_prompt=user_prompt) + # compare_dicts(init_inputs_fn, init_inputs) + + if self.cfg.get("use_asr_speech_tokens", False) and self.cfg.get("only_semantic_to_speech", False): + # set mask to zero and subword ids to self.text_pad_id as in training + init_inputs["subword_mask"] = torch.full_like(init_inputs["subword_mask"], 0.0) + init_inputs["subword_ids"] = torch.full_like(init_inputs["subword_ids"], self.text_pad_id) + next_subword_ids = torch.full_like(next_subword_ids, self.text_pad_id) + + if generation_config is None: + generation_config = self._get_generation_config(guidance_enabled) + logging.info(f"Doing inference using the following config: {generation_config} !") + + init_inputs.update({"use_cache": True, "past_key_values": None, "guidance_enabled": guidance_enabled}) + + # warmup the model and generate the very first audio token + outputs = self.tts_model(**init_inputs) + + if self.cfg.get("inference_skip_first_code_prediction_on_init", True): + # use the last token on init, because we are shifthing it in the model forward, so we dont really need to compute it + code = init_inputs["code"][:, -1:] + else: + code, _, _ = self.tts_model.generate_step(outputs.hidden_states[:, -1:], **generation_config) + + past_key_values = outputs["past_key_values"] + + # get current asr speech token + if self.cfg.get("use_asr_speech_tokens", False): + if self.cfg.get("only_semantic_to_speech", False): + cur_asr_speech_tokens = inp_asr_speech_tokens[:, 0].unsqueeze(-1) + else: + if guidance_enabled and self.cfg.get("asr_speech_tokens_use_guidance", True): + hidden_states, uncond_hidden_states = outputs.hidden_states.chunk(2, dim=0) + logits = self.asr_speech_tokens_head(hidden_states + (generation_config["guidance_scale"] * (hidden_states - uncond_hidden_states))) + else: + hidden_states, _ = outputs.hidden_states.chunk(2, dim=0) + logits = self.asr_speech_tokens_head(hidden_states) + + cur_asr_speech_tokens = logits.argmax(dim=-1)[:, -1].unsqueeze(-1) + + # use the text tokens to stop generation + max_steps = next_subword_ids.size(-1) + # create variable to store the audios + gen_audio_codes = torch.zeros(B, max_steps, self.tts_model.config.num_quantizers, device=self.device, dtype=torch.long) + + # init subwork as all ones + subword_mask = torch.ones(B, max_steps, device=self.device, dtype=torch.bool) + # get first context subword_id, that is the last subword_ids from the warmup + first_context_subword_id = init_inputs["subword_ids"][:, -1].unsqueeze(-1) + + # reset cache of cumulative_word_emb + if self.cfg.tts_config.get("use_cumulative_word_emb", False): + self.tts_model.embed_subword.cumulative_word_emb.reset(B) + + for i in range(max_steps-1): + step_start = time.time() + # current subword id is always seem + current_subword_id = next_subword_ids[:, i].unsqueeze(-1) + + if self.cfg.tts_config.context_hidden_size is not None: + # get context_hidden_state it is always one step behind current_subword_id + # for the first step uses the last step from warmup + if i == 0: + context_subword_id = first_context_subword_id + else: + context_subword_id = next_subword_ids[:, i-1].unsqueeze(-1) + + context_hidden_state = self.embed_tokens(context_subword_id) + if self.cfg.tts_config.get("use_subword_flag_emb", False): + context_hidden_state = self.subword_flag_emb(context_hidden_state, context_subword_id) + else: + context_hidden_state = None + + # create subword_mask + current_subword_mask = subword_mask[:, i].unsqueeze(-1) + + # get subword_ids + inputs = { + "code": code, + "context_hidden_state": context_hidden_state, + "subword_ids": current_subword_id, + "subword_mask": current_subword_mask, + "past_key_values": past_key_values, + "use_cache": True, + "guidance_enabled": guidance_enabled, + "generation_config": generation_config, + "ignore_eos_flag_stop": True + } + + outputs = self.tts_model(**inputs) + + code = outputs["codes"] + past_key_values = outputs["past_key_values"] + # ToDo: check why it is -1 + gen_audio_codes[:, i-1] = code.squeeze(1) + + if self.cfg.get("use_asr_speech_tokens", False) and not self.cfg.get("only_semantic_to_speech", False): + if guidance_enabled and self.cfg.get("asr_speech_tokens_use_guidance", True): + hidden_states, uncond_hidden_states = outputs.hidden_states.chunk(2, dim=0) + logits = self.asr_speech_tokens_head(hidden_states + (generation_config["guidance_scale"] * (hidden_states - uncond_hidden_states))) + else: + hidden_states, _ = outputs.hidden_states.chunk(2, dim=0) + logits = self.asr_speech_tokens_head(hidden_states) + + cur_asr_speech_tokens = logits.argmax(dim=-1)[:, -1].unsqueeze(-1) + + # force silence as next token + if self.cfg.get('inference_force_speech_silence_on_eos', None): + silence_codes = self.codec_silence_tokens.view(1, 1, -1).expand(code.shape) + code = torch.where( + current_subword_id.unsqueeze(-1) == self.text_eos_id, + silence_codes, # silence + code, # keep original + ) + + step_time = time.time()-step_start + logging.info(f"Autoregressive inference step: {i} of {max_steps} take around {step_time}s") + + + gen_audio_codes_lens = torch.tensor([gen_audio_codes.shape[1]] * gen_audio_codes.shape[0]).to(self.device) + # decode audio. Note that it is not necessary because the prompt is removed, so no special token should be on the output, but lets do it for safety + gen_audio_codes = replace_control_speech_codes(gen_audio_codes, self._control_codes, self.codec_silence_tokens) + with fp32_precision(), torch.no_grad(): + audio_pred, audio_len = self.audio_codec.decode( + gen_audio_codes, gen_audio_codes_lens + ) + + return audio_pred.squeeze(1), audio_len + + + def backward(self, *args, **kwargs): + with loss_parallel(): + super().backward(*args, **kwargs) + + def configure_optimizers(self): + return configure_optimizers(self) + + @property + def oomptimizer_schema(self) -> dict: + """ + Return a typing schema for optimal batch size calibration for various + sequence lengths using OOMptimizer. + """ + return { + "cls": dict, + "inputs": [ + {"name": "source_audio", "type": NeuralType(("B", "T"), AudioSignal()), "seq_length": "input"}, + {"name": "source_audio_lens", "type": NeuralType(("B",), LengthsType()), "seq_length": "input"}, + {"name": "target_audio", "type": NeuralType(("B", "T"), AudioSignal()), "seq_length": "input"}, + {"name": "target_audio_lens", "type": NeuralType(("B",), LengthsType()), "seq_length": "input"}, + { + "name": "input_text_tokens", + "type": NeuralType(("B", "T"), LabelsType()), + "seq_length": "output", + "vocab_size": self.tokenizer.vocab_size, + }, + ], + } + + def configure_model(self) -> None: + # TODO(pzelasko): refactor into separate module re-usable across models + device_mesh = self.device_mesh + if device_mesh is None: + return + + llm = self.tts_model.backbone + if isinstance(llm, PeftModel): + llm = llm.base_model.model + + if (tp_mesh := device_mesh["tensor_parallel"]).size() > 1: + self._use_tp = True + + plan = { + "layers.0": PrepareModuleInput( + input_layouts=(Replicate(),), # , None) + desired_input_layouts=(Shard(1),), # , None) + use_local_output=True, + ), + "norm": SequenceParallel(), + } + parallelize_module(llm, tp_mesh, plan) + + for transformer_block in llm.layers: + plan = { + "input_layernorm": SequenceParallel(), + "self_attn.q_proj": ColwiseParallel(), + "self_attn.k_proj": ColwiseParallel(), + "self_attn.v_proj": ColwiseParallel(), + "self_attn.o_proj": RowwiseParallel(output_layouts=Shard(1)), + "post_attention_layernorm": SequenceParallel(), + "mlp": PrepareModuleInput( + input_layouts=(Shard(1),), + desired_input_layouts=(Replicate(),), + ), + "mlp.gate_proj": ColwiseParallel(), + "mlp.up_proj": ColwiseParallel(), + "mlp.down_proj": RowwiseParallel(output_layouts=Shard(1)), + # "pre_feedforward_layernorm": SequenceParallel(), + # "post_feedforward_layernorm": SequenceParallel(), + } + + # Adjust attention module to use the local number of heads + attn_layer = transformer_block.self_attn + for attr in ("num_heads", "num_key_value_heads", "hidden_size"): + val = getattr(attn_layer, attr) + if val % tp_mesh.size() != 0: + logging.warning( + f"attn_layer.{attr}={val} is not divisible by {tp_mesh.size()=}: " + f"set a different tensor parallelism size to avoid errors." + ) + setattr(attn_layer, attr, val // tp_mesh.size()) + + parallelize_module(transformer_block, tp_mesh, plan) + + for m in (self.tts_model.mog_head, self.tts_model.embed_subword, self.tts_model.embed_context, self.tts_model.embed_code, self.tts_model.null_emb, self.tts_model.bos_emb, self.tts_model.lm_head): + parallelize_module( + m, + tp_mesh, + ColwiseParallel( + input_layouts=Shard(1), + output_layouts=Shard(-1), + use_local_output=False, + ), + ) + + if (dp_mesh := device_mesh["data_parallel"]).size() > 1: + assert dp_mesh.ndim == 1 + self._use_fsdp = True + + fsdp_config = {"mesh": dp_mesh} + + for idx, layer in enumerate(llm.layers): + llm.layers[idx] = fully_shard(layer, **fsdp_config) + + for idx in range(self.tts_model._num_codebooks): + self.tts_model.audio_embeddings[idx] = fully_shard(self.tts_model.audio_embeddings[idx], **fsdp_config) + + if self.tts_model.use_local_transformer: + self.tts_model.local_transformer = fully_shard(self.tts_model.local_transformer, **fsdp_config) + self.tts_model.local_transformer_in_projection = fully_shard(self.tts_model.local_transformer_in_projection, **fsdp_config) + else: + self.embed_text_tokens = fully_shard(self.embed_text_tokens, **fsdp_config) + # self.tts_model = fully_shard(self.tts_model, **fsdp_config) + self.tts_model.mog_head = fully_shard(self.tts_model.mog_head, **fsdp_config) + self.tts_model.embed_subword = fully_shard(self.tts_model.embed_subword, **fsdp_config) + self.tts_model.embed_context = fully_shard(self.tts_model.embed_context, **fsdp_config) + self.tts_model.embed_code = fully_shard(self.tts_model.embed_code, **fsdp_config) + self.tts_model.null_emb = fully_shard(self.tts_model.null_emb, **fsdp_config) + self.tts_model.bos_emb = fully_shard(self.tts_model.bos_emb, **fsdp_config) + self.tts_model.lm_head = fully_shard(self.tts_model.lm_head, **fsdp_config) + + def load_state_dict(self, state_dict, strict: bool = True): + try: + return super().load_state_dict(state_dict, strict=strict) + except RuntimeError as e: + logging.info(f"Error loading model state_dict !! Retrying with partial initialization!") + model_dict = set_model_dict_for_partial_init(state_dict, self.state_dict()) + return super().load_state_dict(model_dict, strict=False) diff --git a/nemo/collections/speechlm2/modules/ear_tts_commons.py b/nemo/collections/speechlm2/modules/ear_tts_commons.py new file mode 100644 index 000000000000..03cdc9da298a --- /dev/null +++ b/nemo/collections/speechlm2/modules/ear_tts_commons.py @@ -0,0 +1,484 @@ +# Standard library +import glob +import json +import os +import re +import shutil +import subprocess +import sys +import importlib.machinery +from collections.abc import Mapping, MutableMapping +from typing import Any +import argparse + +import torch +from torch import nn +from safetensors import safe_open + +from nemo.utils import logging + +# ============================================================================== +# Contants +# ============================================================================== +PYTHON_CONFIG_GETTER_NAME = "get_config" +CHECKPOINT_FORMAT = "checkpoint_{}/ema.safetensors" +CONFIG_NAME = "config.json" +GIT_HASH_NAME = "githash" +SCRIPT_PLACEHOLDER = "[[[<<>>]]]" + + +# ============================================================================== +# Configuration Class and Utilities +# ============================================================================== + + +class Config(MutableMapping): + """ + A dictionary-like configuration class that uses attributes for storage + and supports both attribute and item-style access. + + This class inherits from `collections.abc.MutableMapping` and stores all + key-value pairs as instance attributes in its internal `__dict__`. + + Nested dictionaries are recursively converted into Config objects upon being set. + """ + + def __init__(self, **kwargs): + """ + Initializes the Config object from keyword arguments. + """ + # __setattr__ will handle the recursive conversion for each item + for key, value in kwargs.items(): + setattr(self, key, value) + + def to_dict(self): + """ + Recursively converts the Config object back into a standard dictionary. + + Returns: + dict: A standard dictionary representation of the configuration. + """ + result = {} + for key, value in self.items(): + if isinstance(value, Config): + # If the value is a Config object, recursively call to_dict() + result[key] = value.to_dict() + else: + result[key] = value + return result + + def to_json(self, indent=2): + """ + Serializes the configuration object to a formatted JSON string. + + Args: + indent (int, optional): The indentation level for the JSON output. + Defaults to 2. + + Returns: + str: The configuration as a JSON-formatted string. + """ + # Leverage the to_dict() method for clean serialization + return json.dumps(self.to_dict(), indent=indent) + + # --- Core MutableMapping Methods --- + + def __setattr__(self, key, value): + """ + Sets an attribute. Recursively converts dicts to Config objects. + This is the primary method for adding/modifying data. + """ + if isinstance(value, Mapping): + value = Config(**value) + # Use object's __setattr__ to avoid infinite recursion + object.__setattr__(self, key, value) + + def __setitem__(self, key, value): + """Allows setting items using dictionary syntax (e.g., `config['key'] = value`).""" + setattr(self, key, value) + + def __getattr__(self, key): + """Allows accessing items as attributes (e.g., `config.key`).""" + # This method is only called for attributes that don't already exist. + raise AttributeError(f"'{self.__class__.__name__}' object has no attribute '{key}'") + + def __getitem__(self, key): + """Allows accessing items using dictionary syntax (e.g., `config['key']`).""" + try: + return getattr(self, key) + except AttributeError as e: + # Convert AttributeError to KeyError for dict-like behavior + raise KeyError(key) from e + + def __delitem__(self, key): + """Allows deleting items using dictionary syntax (e.g., `del config['key']`).""" + try: + delattr(self, key) + except AttributeError as e: + # Convert AttributeError to KeyError for dict-like behavior + raise KeyError(key) from e + + def __iter__(self): + """Returns an iterator over the keys (attributes) of the object.""" + return iter(self.__dict__) + + def __len__(self): + """Returns the number of items (attributes) in the object.""" + return len(self.__dict__) + + # --- Utility Methods --- + + def __repr__(self): + """Returns an informative string representation of the Config object.""" + return f"{self.__class__.__name__}({self.__dict__!r})" + + def __hash__(self): + """Makes the object hashable if its contents are hashable.""" + return hash(tuple(sorted(self.items()))) + + +def get_config_from_file(config_path: str) -> Config: + """ + Loads a configuration from a JSON or Python file. + + - For JSON files (`*.json`), it parses the file directly. + - For Python files (`*.py`), it imports the file as a module and calls a + `get_config()` function within it. + - It also supports a special syntax `path/to/config.py:config_name` to select + a specific configuration from a Python file that returns a dictionary of configs. + + Args: + config_path (str): The path to the configuration file. + + Returns: + Config: The loaded configuration object. + + Raises: + AssertionError: If the file path is invalid, does not exist, or is not in + the expected format. + """ + + match = re.search(r".+\.((json)|(py)|(py:.+))$", config_path) + assert match, f"Only Python (*.py) or JSON (*.json) files are supported, but got {config_path}." + + py_config_name: str | None = None + if not (config_path.endswith(".py") or config_path.endswith(".json")): + config_path_split = config_path.split(":") + config_path = ":".join(config_path_split[:-1]) + py_config_name = config_path_split[-1] + + assert os.path.isfile(config_path), f"Configuration file not found at: {config_path}" + + if config_path.endswith(".json"): + with open(config_path) as f: + config = json.load(f) + else: + config_module = importlib.machinery.SourceFileLoader("_config", config_path).load_module() + assert hasattr(config_module, PYTHON_CONFIG_GETTER_NAME), ( + f"Python config file must define a `{PYTHON_CONFIG_GETTER_NAME}` function." + ) + config = getattr(config_module, PYTHON_CONFIG_GETTER_NAME)(py_config_name) + assert isinstance(config, Mapping), f"`{PYTHON_CONFIG_GETTER_NAME}` must return a dictionary-like object." + cfg = Config(**config) + return cfg + + +def get_config() -> Config: + """ + Parses command-line arguments to load the main configuration for a training run. + + This function implements a hierarchical configuration loading strategy: + 1. It checks if a `config.json` exists in the specified `--workdir`. If so, it loads it. + 2. If a `--config` argument is also provided, it uses that file to update the + configuration loaded from the work directory. + 3. If no config exists in the work directory, it requires the `--config` argument + to be provided as the base configuration. + + This allows for resuming training from a work directory while also being able to + override specific parameters for a new run. + + Returns: + Config: The final, consolidated configuration object. + """ + parser = argparse.ArgumentParser(description="Load training configuration.") + parser.add_argument("-c", "--config", type=str, default=None, help="Path to a Python or JSON configuration file.") + parser.add_argument( + "-w", "--workdir", type=str, required=True, help="Work directory to save logs and checkpoints." + ) + + args = parser.parse_args() + workdir_path = args.workdir + config_save_path = os.path.join(workdir_path, CONFIG_NAME) + + if os.path.exists(config_save_path): + logging.info(f"Resuming from work directory. Loading configuration from {config_save_path}.") + cfg = get_config_from_file(config_save_path) + if args.config and args.config != config_save_path: + logging.info(f"Updating loaded configuration with parameters from {args.config}.") + override_cfg = get_config_from_file(args.config) + cfg.update(override_cfg) + else: + assert args.config is not None, "A configuration file must be specified via `-c` or `--config` for a new run." + logging.info(f"Starting a new run. Loading configuration from {args.config}.") + cfg = get_config_from_file(args.config) + cfg.workdir_path = workdir_path + + return cfg + + +def get_config_from_dir(workdir_path: str) -> Config: + """ + A simple utility to load the configuration directly from a work directory. + + Args: + workdir_path (str): The path to the work directory containing a `config.json`. + + Returns: + Config: The loaded configuration object. + """ + config_save_path = os.path.join(workdir_path, CONFIG_NAME) + cfg = get_config_from_file(config_save_path) + cfg.workdir_path = workdir_path + return cfg + + + + +# ============================================================================== +# Base Model Classes +# ============================================================================== + +class PreTrainedModel(nn.Module): + config_class = Config + + """ + A base class for models to handle loading from pretrained checkpoints. + + This class provides a common interface for initializing a model and loading + weights from a saved checkpoint, following a pattern similar to libraries + like Hugging Face's Transformers. + + Args: + config (Config | dict[str, Any]): A configuration object containing model hyperparameters. + """ + + def __init__(self, config: Config | dict[str, Any], *args, **kwargs): + super().__init__() + self.config = config if isinstance(config, self.config_class) else self.config_class(**config) + + @classmethod + def from_pretrained( + cls, + pretrained_dir: str, + cfg: Config | dict[str, Any] | None = None, + checkpoint_regex: str = "checkpoint_*/ema.safetensors", + strict: bool = False, + **model_kwargs, + ) -> "PreTrainedModel": + """ + Loads a pretrained model from a directory. + + This method first loads the configuration file from the specified directory, + initializes the model with this configuration, and then loads the weights + from the latest checkpoint file found in that directory. + + Args: + cls (type): The model class to instantiate. + pretrained_dir (str): The directory containing the pretrained model + config and checkpoint files. + cfg (Config | dict[str, Any] | None, optional): An optional config object to override + the loaded config. Defaults to None. + checkpoint_regex (str, optional): A regex pattern to find the checkpoint + file. Defaults to "checkpoint_*/ema.safetensors". + strict (bool, optional): Whether to strictly enforce that the keys in + the checkpoint match the keys of the model. + Defaults to False. + **model_kwargs: Additional keyword arguments to pass to the model's + constructor. + + Returns: + PreTrainedModel: An instance of the model with loaded weights. + """ + pretrained_cfg = get_config_from_dir(pretrained_dir).model + if cfg is not None: + pretrained_cfg.update(cfg) + logging.info(f"The loaded config of the pretrained model is updated to: {pretrained_cfg}") + model = cls( + pretrained_cfg, + **model_kwargs, + ) + model_state_dict = {} + with safe_open(latest_checkpoint_path(pretrained_dir, checkpoint_regex), framework="pt", device="cpu") as f: + for key in f.keys(): + model_state_dict[key] = f.get_tensor(key) + model.load_state_dict(model_state_dict, strict=strict) + return model + + def get_optimizer_param_groups(self, weight_decay: float = 0.0) -> list[dict]: + """ + Separates model parameters into two groups: one with weight decay and one without. + + This is a common practice in training deep learning models, where weight decay + is typically applied to the weights of linear and convolutional layers, but not + to biases or normalization layer parameters. + + Args: + weight_decay (float, optional): The weight decay value to apply to the + first group of parameters. Defaults to 0.0. + + Returns: + list[dict]: A list of two dictionaries, each suitable for an optimizer's + parameter groups. The first group has weight decay, and the + second does not. + """ + + def _get_weight_names(module): + """Recursively finds the names of all 'weight' parameters in conv/linear layers.""" + result = [] + is_weight_layer = isinstance( + module, + ( + nn.Linear + | nn.Conv1d + | nn.Conv2d + | nn.Conv3d + | nn.ConvTranspose1d + | nn.ConvTranspose2d + | nn.ConvTranspose3d + ), + ) + if is_weight_layer: + result.append("weight") + else: + for name, child in module.named_children(): + result += [f"{name}.{n}" for n in _get_weight_names(child)] + return result + + # Separate parameters + params_w_decay, params_wo_decay = [], [] + param_names_w_decay = set(_get_weight_names(self)) + + for n, p in self.named_parameters(): + if p.requires_grad: + if n in param_names_w_decay: + params_w_decay.append(p) + else: + params_wo_decay.append(p) + return [ + {"params": params_w_decay, "weight_decay": weight_decay}, + {"params": params_wo_decay, "weight_decay": 0.0}, + ] + +# ============================================================================== +# IO and Checkpointing Utilities +# ============================================================================== + + +def check_git_hash() -> str | None: + """ + Retrieves the current git commit hash of the repository containing this file. + + This is useful for reproducibility, allowing you to track the exact version + of the code used for a particular experiment. + + Returns: + str | None: The git commit hash as a string if successful, otherwise None. + """ + + try: + # Get the directory where this script is located + source_sub_dir = os.path.dirname(os.path.realpath(__file__)) + # Execute the git command to get the current HEAD commit hash + git_hash = ( + subprocess.check_output(["git", "rev-parse", "HEAD"], cwd=source_sub_dir, stderr=subprocess.DEVNULL) + .decode(sys.stdout.encoding) + .strip() + ) + except (subprocess.CalledProcessError, FileNotFoundError): + # Handle cases where git is not installed or the directory is not a git repo + logging.warning( + "Could not retrieve git hash. This may be because the code is not in a git repository " + "or git is not installed. Git hash checking will be ignored." + ) + return None + return git_hash + + +def write_git_hash(workdir_path: str) -> None: + """ + Writes the current git hash to a file in a specified directory. + + If a hash file already exists, it compares the current hash with the saved one + and logs a warning if they differ. + + Args: + workdir_path (str): The path to the directory where the git hash file will be saved. + """ + git_hash = check_git_hash() + if git_hash is None: + return + + saved_git_hash_path = os.path.join(workdir_path, GIT_HASH_NAME) + if os.path.exists(saved_git_hash_path): + # If hash file exists, compare it with the current hash + with open(saved_git_hash_path) as f: + saved_git_hash = f.read().strip() + if saved_git_hash != git_hash: + logging.warning(f"Git hash has changed. Saved: {saved_git_hash[:8]}, Current: {git_hash[:8]}") + else: + # If no hash file exists, write the current hash + with open(saved_git_hash_path, "w") as f: + f.write(git_hash) + + +def latest_checkpoint_path(dir_path: str, regex: str | None = None) -> str: + """ + Finds the path of the latest checkpoint file or directory in a directory. + + The latest checkpoint is determined by sorting the filenames alphanumerically + and picking the last one. This assumes a naming convention like `checkpoint_1000.pt`, + `checkpoint_2000.pt`, etc. + + Args: + dir_path (str): The directory to search for checkpoints. + regex (str | None, optional): A glob pattern to match checkpoint files. If None, + a default pattern is used. Defaults to None. + + Returns: + str: The full path to the latest checkpoint file. + + Raises: + AssertionError: If no files matching the regex are found in the directory. + """ + if regex is None: + regex = CHECKPOINT_FORMAT.format("*") + + f_list = glob.glob(os.path.join(dir_path, regex)) + if not f_list: + raise FileNotFoundError(f"No checkpoint files or directories found in {dir_path} matching '{regex}'") + + # Sort files based on the integer values in their names + f_list.sort(key=lambda f: int("".join(filter(str.isdigit, f)))) + + latest_path = f_list[-1] + logging.info(f"Latest checkpoint '{os.path.relpath(latest_path, start=dir_path)}' found in '{dir_path}'.") + return latest_path + + +def manage_checkpoints(dir_path: str, max_checkpoints: int, regex: str | None = None): + """Keeps the most recent checkpoints and deletes older ones.""" + if regex is None: + regex = CHECKPOINT_FORMAT.format("*") + + checkpoints = glob.glob(os.path.join(dir_path, regex)) + + if len(checkpoints) > max_checkpoints: + # Sort files based on the integer values in their names + checkpoints.sort(key=lambda f: int("".join(filter(str.isdigit, f)))) + num_to_delete = len(checkpoints) - max_checkpoints + for old_checkpoint in checkpoints[:num_to_delete]: + logging.info(f"Deleting old checkpoint: {old_checkpoint}") + if os.path.isfile(old_checkpoint): + os.remove(old_checkpoint) + else: + shutil.rmtree(old_checkpoint) diff --git a/nemo/collections/speechlm2/modules/rvq_ear_tts_model.py b/nemo/collections/speechlm2/modules/rvq_ear_tts_model.py new file mode 100644 index 000000000000..21928e66ad73 --- /dev/null +++ b/nemo/collections/speechlm2/modules/rvq_ear_tts_model.py @@ -0,0 +1,1775 @@ +# Standard library +import argparse +import glob +import json +import math +import os +import re +import shutil +import sys +from collections.abc import Mapping, MutableMapping +from dataclasses import dataclass, field, fields +from typing import Any +import unicodedata +from nemo.collections.speechlm2.parts.precision import fp32_precision +from nemo.collections.speechlm2.parts.pretrained import set_model_dict_for_partial_init + +import torch +from torch import Tensor, nn +from torch.nn import functional as F +import transformers +from transformers import ( + AutoConfig, + AutoModel, + AutoModelForTextEncoding, + AutoTokenizer, + Cache, +) +from transformers.generation.logits_process import ( + TopKLogitsWarper, + TopPLogitsWarper, +) +from safetensors import safe_open + +from nemo.utils import logging +from nemo.collections.speechlm2.modules.ear_tts_commons import ( + Config, + PreTrainedModel +) + +# ============================================================================== +# MLP module and Norm +# ============================================================================== + +class RMSNorm(nn.Module): + def __init__(self, dim: int, eps: float = 1e-6): + super().__init__() + self.eps = eps + self.weight = nn.Parameter(torch.zeros(dim)) + + def _norm(self, x): + return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps) + + def forward(self, x): + output = self._norm(x.float()) + # Llama does x.to(float16) * w whilst Gemma3 is (x * w).to(float16) + output = output * (1.0 + self.weight.float()) + return output.type_as(x) + + def extra_repr(self): + return f"{tuple(self.weight.shape)}, eps={self.eps}" + + +class MLP(nn.Module): + def __init__( + self, + hidden_size: int, + intermediate_size: int, + ): + super().__init__() + self.gate_proj = nn.Linear(hidden_size, intermediate_size, bias=False) + self.up_proj = nn.Linear(hidden_size, intermediate_size, bias=False) + self.down_proj = nn.Linear(intermediate_size, hidden_size, bias=False) + self.act_fn = nn.GELU(approximate="tanh") + + def forward(self, x: Tensor) -> Tensor: + down_proj = self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x)) + return down_proj + + +class MLPLayer(nn.Module): + def __init__( + self, + hidden_size: int, + intermediate_size: int, + eps: float = 1e-6, + ): + super().__init__() + self.pre_norm = RMSNorm(hidden_size, eps=eps) + self.mlp = MLP(hidden_size, intermediate_size) + self.post_norm = RMSNorm(hidden_size, eps=eps) + + def forward(self, x: Tensor) -> Tensor: + y = self.pre_norm(x) + y = self.mlp(y) + y = self.post_norm(y) + x = x + y + return x + + +# ============================================================================== +# Triton-accelerated and Fallback Functions +# ============================================================================== +try: + # Attempt to import Triton for optimized GPU kernels + import triton + import triton.language as tl + + @triton.jit + def batch_matmul_kernel( + x_ptr, # Pointer to input tensor x: [batch_size, d_in] + w_ptr, # Pointer to weight tensor w: [num_weights, d_out, d_in] + y_ptr, # Pointer to index tensor y: [batch_size] + result_ptr, # Pointer to output tensor result: [batch_size, d_out] + b, + d_in, + d_out, + n, # Dimensions + BLOCK_SIZE_DIN: tl.constexpr, + BLOCK_SIZE_DOUT: tl.constexpr, + ): + """ + Triton kernel for performing a batched matrix multiplication where each row + of the input `x` is multiplied by a different weight matrix selected from `w` + by an index in `y`. + """ + # Get the program IDs for the batch and output dimensions + batch_id = tl.program_id(axis=0) + dout_block_id = tl.program_id(axis=1) + + # Early exit for out-of-bounds batch IDs + if batch_id >= b: + return + + # Load the index for the current batch item + idx = tl.load(y_ptr + batch_id) + + # Compute base offsets for the current batch item + x_offset = x_ptr + batch_id * d_in + w_offset = w_ptr + idx * d_out * d_in + + # Define the block of output dimensions to compute + dout_offsets = dout_block_id * BLOCK_SIZE_DOUT + tl.arange(0, BLOCK_SIZE_DOUT) + dout_mask = dout_offsets < d_out + + # Initialize accumulator for the result block + result_block = tl.zeros([BLOCK_SIZE_DOUT], dtype=tl.float32) + + # Loop over the input dimension in blocks + for din_start in range(0, d_in, BLOCK_SIZE_DIN): + din_offsets = din_start + tl.arange(0, BLOCK_SIZE_DIN) + din_mask = din_offsets < d_in + + # Load a block of the input vector x + x_i = tl.load(x_offset + din_offsets, mask=din_mask, other=0.0) + + # Load a block of the selected weight matrix w + w_i_block = tl.load( + w_offset + dout_offsets[:, None] * d_in + din_offsets[None, :], + mask=(dout_mask[:, None] & din_mask[None, :]), + other=0.0, + ) + + # Compute the partial dot product and accumulate + partial = tl.sum(w_i_block * x_i[None, :], axis=1) + result_block += partial + + # Store the final result block + result_offset = result_ptr + batch_id * d_out + dout_offsets + tl.store(result_offset, result_block, mask=dout_mask) + + def batch_matmul_triton(x, w, y, BLOCK_SIZE_DIN: int = 16, BLOCK_SIZE_DOUT: int = 64): + """Wrapper function to launch the Triton kernel for batch_matmul.""" + assert x.is_contiguous() and w.is_contiguous() and y.is_contiguous() + assert math.log2(BLOCK_SIZE_DIN).is_integer() and math.log2(BLOCK_SIZE_DOUT).is_integer() + + b, d_in = x.shape + n, d_out, _ = w.shape + result = torch.empty(b, d_out, device=x.device, dtype=torch.float32) + + batch_matmul_kernel[lambda meta: (b, triton.cdiv(d_out, meta["BLOCK_SIZE_DOUT"]))]( + x.float(), + w.float(), + y, + result, + b, + d_in, + d_out, + n, + BLOCK_SIZE_DIN=BLOCK_SIZE_DIN, + BLOCK_SIZE_DOUT=BLOCK_SIZE_DOUT, + ) + return result.to(dtype=x.dtype) + + # Set batch_matmul to the optimized Triton version + batch_matmul = batch_matmul_triton + logging.info("Triton is available. Using optimized Triton kernel for batch_matmul.") + +except ImportError: + # Fallback to PyTorch implementation if Triton is not available + def batch_matmul_pytorch(x: Tensor, w: Tensor, y: Tensor, *args, **kwargs) -> Tensor: + """ + Performs a batched matrix multiplication using PyTorch's native functions. + + This function serves as a fallback when Triton is not available. It achieves + the same result by gathering the appropriate weight matrices and using `torch.bmm`. + + Args: + x (Tensor): The input tensor of shape `[batch_size, d_in]`. + w (Tensor): The weight tensor of shape `[num_weights, d_out, d_in]`. + y (Tensor): The index tensor of shape `[batch_size]`. + + Returns: + Tensor: The result of the multiplication, shape `[batch_size, d_out]`. + """ + # w[y] gathers the weight matrices for each item in the batch. + # x.unsqueeze(2) reshapes x to [batch_size, d_in, 1] for bmm. + # The result is squeezed to remove the trailing dimension of size 1. + return torch.bmm(w[y], x.unsqueeze(2)).squeeze(2) + + batch_matmul = batch_matmul_pytorch + logging.info("Triton is not available. Using PyTorch fallback for batch_matmul.") + + +# ============================================================================== +# Core Mathematical and Masking Functions +# ============================================================================== + + +def gumbel_like(tensor: Tensor, eps: float = 1e-8) -> Tensor: + """ + Generates a tensor of Gumbel noise with the same shape as the input tensor. + + This is used for the Gumbel-Max trick, a technique to sample from a categorical + distribution in a differentiable way (using a straight-through estimator). + + Args: + tensor (torch.Tensor): The input tensor to match the shape of. + eps (float): A small epsilon value for numerical stability. + + Returns: + torch.Tensor: A tensor containing Gumbel noise. + """ + # Sample from a uniform distribution + u = torch.rand_like(tensor) + # Apply the inverse CDF of the Gumbel distribution + return -torch.log(-torch.log(u + eps) + eps) + + +def sequence_mask(lengths: Tensor, max_length: Tensor | int | None = None) -> Tensor: + """ + Creates a boolean mask from a 1D tensor of sequence lengths. + + This function is useful for masking out padding in sequences. Given a tensor + of lengths, it produces a 2D boolean tensor where `mask[i, j]` is `True` if + `j < lengths[i]` and `False` otherwise. + + Args: + lengths (Long Tensor): A 1D tensor of integer lengths. Shape: `[batch_size]`. + max_length (Long Tensor | int | None, optional): The maximum length of the mask. If None, + it is inferred from the maximum value + in `lengths`. Defaults to None. + + Returns: + Tensor: The boolean mask. Shape: `[batch_size, max_length]`. + """ + if max_length is None: + max_length = lengths.max() + + # Create a range tensor from 0 to max_length - 1 + x = torch.arange(max_length, dtype=lengths.dtype, device=lengths.device) # type: ignore[arg-type] + + # Compare each length with the range tensor to create the mask via broadcasting + return x.unsqueeze(0) < lengths.unsqueeze(1) + + +def get_masking_rate(rate: Tensor, exponent: float = 3.0) -> Tensor: + """ + Converts a desired token keep rate to a masking rate using a power function. + + This function is part of a scheduling strategy for masking, where the effective + masking rate changes non-linearly with the desired keep rate. This function is + its own inverse. + + Args: + rate (Tensor): The desired rate of tokens to keep (0 to 1). + exponent (float, optional): The exponent for the transformation. Defaults to 3.0. + + Returns: + Tensor: The corresponding masking rate. + """ + return (1 - rate.pow(exponent)).pow(1 / exponent) + + +# Alias the function for clarity in the inverse context +get_rate = get_masking_rate + + +def get_mask( + code_mask: Tensor, + num_masking: Tensor, + unmasking: bool = False, + validate: bool = False, +) -> Tensor: + """ + Adjusts a boolean mask by masking or unmasking tokens from the end. + + This function operates on a `code_mask` where `True` values represent valid + tokens and are assumed to be contiguous at the start of the sequence. It + calculates a new mask by decreasing (masking) or increasing (unmasking) + the number of `True` values. + + Args: + code_mask (Tensor): The input boolean mask. Shape: `[..., depth]`. + num_masking (Tensor): The number of tokens to mask or unmask. + Shape matching `code_mask`'s batch dimensions. + unmasking (bool, optional): If `True`, increases the number of valid + tokens (unmasking). Defaults to `False`. + validate (bool, optional): If `True`, asserts that the input `code_mask` + is contiguous. This adds a slight overhead and + is mainly for debugging. Defaults to `False`. + + Returns: + Tensor: A new boolean mask with the adjusted length of valid tokens. + Shape is identical to `code_mask`. + """ + depth = code_mask.size(-1) + num_valid = code_mask.sum(dim=-1, dtype=torch.long) + + if validate: + # Reconstruct the expected contiguous mask and assert equality. + expected_mask = sequence_mask(num_valid.view(-1), depth).view_as(code_mask) + assert torch.equal(code_mask, expected_mask), ( + "Input `code_mask` must have contiguous `True` values at the beginning." + ) + + # Calculate the target number of valid tokens. + if not unmasking: + # Masking: reduce the number of valid tokens, ensuring it's not negative. + num_to_keep = (num_valid - num_masking).clamp_min(0) + else: + # Unmasking: increase the number of valid tokens, capped by total depth. + num_to_keep = (num_valid + num_masking).clamp_max(depth) + + # Generate the new mask using the final number of tokens to keep. + return sequence_mask(num_to_keep.view(-1), depth).view_as(code_mask) + + + +@dataclass +class CASConfig(Config): + pretrained_tokenizer_name: str = "meta-llama/Llama-3.1-8B-Instruct" + vocab_dir: str | None = None + + # transformer backbone + backbone_type: str | None = "t5gemma" + backbone_model_class: str | None = None + backbone_config_class: str | None = None + backbone_config: Config | None = None + + +@dataclass +class MoGHeadConfig(Config): + intermediate_size: int = 4608 + num_layers: int = 3 + low_rank: int | None = 64 + num_predictions: int = 1024 + min_log_std: float = -4.0 + eps: float = 1e-6 + + +@dataclass +class RVQEARTTSConfig(Config): + model_type = "rvq_ear_tts" + + # transformer backbone + backbone_type: str | None = "gemma3_text" + backbone_model_class: str | None = None + backbone_config_class: str | None = None + backbone_config: Config | None = None + + # model specific configs + latent_size: int = 512 + codebook_size: int = 1024 + num_quantizers: int = 72 + context_hidden_size: int = 4096 + cas_config: CASConfig | None = field(default_factory=lambda: CASConfig()) + mog_head_config: MoGHeadConfig = field(default_factory=lambda: MoGHeadConfig()) + + # extra parameters used for compatibility with S2S + disable_eos_prediction: bool = False + use_subword_flag_emb: bool = True + use_bos_eos_emb: bool = True + pretrained_text_name: str | None = None + use_gated_fusion_for_text_audio: bool = True + + p_uncond: float = 0.1 + label_smoothing: float = 0.01 + max_training_rate: float = 0.8 + quantizer_dropout: float = 0.5 + random_target_masking: bool = False + exponent: float = 3.0 + + def __post_init__(self): + if self.cas_config is not None: + self.cas_config = CASConfig(**self.cas_config) + self.mog_head_config = MoGHeadConfig(**self.mog_head_config) + + +# ============================================================================== +# Model and Vocabulary Utilities +# ============================================================================== + + +@dataclass +class RVQEARTTSOutput: + """ + Output type for the RVQEARTTSModel, providing a structured way to return model outputs. + This class allows accessing outputs by attribute, key, or index. + """ + + loss: Tensor | None = None + lm_loss: Tensor | None = None + c_loss: Tensor | None = None + k_loss: Tensor | None = None + + hidden_states: Tensor | None = None + past_key_values: Tensor | None = None + + codes: Tensor | None = None + lm_logits: Tensor | None = None + eos_flag: Tensor | None = None + + def __getitem__(self, item: str | int): + """Allows for accessing attributes by key or index.""" + if isinstance(item, str): + return getattr(self, item) + else: + # Access fields in the order they are defined in the dataclass + return getattr(self, fields(self)[item].name) + + +def find_and_delete_module(parent_module: nn.Module, target_module: nn.Module, parent_name: str) -> str | None: + """ + Recursively searches for a specific module instance and deletes it from its parent. + + This is useful for dynamically modifying a model's architecture, such as replacing + an existing embedding layer with a custom one. + + Args: + parent_module (nn.Module): The module to search within. + target_module (nn.Module): The exact module instance to find and delete. + parent_name (str): The initial name of the parent module for constructing the path. + + Returns: + str | None: The full dotted name of the deleted attribute if found, otherwise None. + """ + # Iterate over all direct children of the parent module + for name, module in parent_module.named_children(): + # Use the 'is' operator to check for object identity, not just value equality + if module is target_module: + # If found, delete the attribute from the parent and return its name + delattr(parent_module, name) + return f"{parent_name}.{name}" + + # If not found, recurse into the child module + found_path = find_and_delete_module(module, target_module, parent_name=f"{parent_name}.{name}") + if found_path: + return found_path + return None + + +def build_vocabs( + pretrained_tokenizer_name: str, vocab_dir: str | None = None +) -> tuple[dict[int, tuple[int, ...]], dict[str, int], int]: + """ + Builds or loads a character-level vocabulary derived from a subword tokenizer. + + This function creates a mapping from each subword in a pretrained tokenizer to a + sequence of character IDs. It follows a modern practice of using a directory + to save and load vocabulary files, making the process more robust and extensible. + + The primary source of truth is the `char_vocab.json` file. If it exists, it's + loaded. Otherwise, it's created from the pretrained tokenizer and saved. + + Args: + pretrained_tokenizer_name (str): The name or path of the pretrained Hugging Face tokenizer. + vocab_dir (str | None, optional): The directory to save or load the character + vocabulary from. Defaults to None. + + Returns: + tuple[dict[int, tuple[int, ...]], dict[str, int], int]: A tuple containing: + - A mapping from subword IDs to tuples of character IDs. + - The character-to-ID vocabulary dictionary. + - The ID for the subword padding token. + """ + tokenizer = AutoTokenizer.from_pretrained(pretrained_tokenizer_name, trust_remote_code=True) + + def _build_char_vocab() -> dict[str, int]: + # Find all single-character tokens in the original tokenizer's vocabulary + single_chars = {subword: subword_id for subword, subword_id in tokenizer.vocab.items() if len(subword) == 1} + # Create a new, dense character vocabulary sorted by the original token ID + sorted_chars = sorted(single_chars.keys(), key=lambda k: single_chars[k]) + char_vocab = {char: i for i, char in enumerate(sorted_chars)} + return char_vocab + + # 1. Load or build the character vocabulary + if vocab_dir: + from filelock import FileLock + + char_vocab_file = os.path.join(vocab_dir, "char_vocab.json") + + os.makedirs(vocab_dir, exist_ok=True) + + with FileLock(char_vocab_file + ".lock", timeout=60): + if not os.path.exists(char_vocab_file): + char_vocab = _build_char_vocab() + + logging.info(f"Saving character vocabulary to {char_vocab_file}") + with open(char_vocab_file, "w", encoding="utf-8") as f: + json.dump(char_vocab, f, ensure_ascii=False, indent=2) + + # All processes can now safely load the file. + logging.info(f"Loading character vocabulary from {char_vocab_file}") + with open(char_vocab_file, encoding="utf-8") as f: + char_vocab = json.load(f) + else: + # No cache directory provided, build in memory. + logging.info(f"Building character vocabulary from tokenizer '{pretrained_tokenizer_name}'.") + char_vocab = _build_char_vocab() + + # 2. Reconstruct the subword-to-character mapping on the fly + subword_id_to_char_ids = { + subword_id: tuple(char_vocab[char] for char in subword if char in char_vocab) + for subword, subword_id in tokenizer.vocab.items() + } + # Filter out subwords that contain characters not in our character vocabulary + subword_id_to_char_ids = {k: v for k, v in subword_id_to_char_ids.items() if v} + + # 3. Define a padding index for subwords + subword_padding_idx = len(tokenizer.vocab) + # The padding subword maps to a new character padding ID + subword_id_to_char_ids[subword_padding_idx] = (len(char_vocab),) + return subword_id_to_char_ids, char_vocab, subword_padding_idx + +def _split_ipa_symbols(text: str) -> list[str]: + """ + Split IPA text into grapheme clusters (true phoneme symbols) + without using regex. Combines base characters with diacritics. + """ + phonemes = [] + cluster = "" + for char in text: + if unicodedata.combining(char) == 0: + # Start a new cluster + if cluster: + phonemes.append(cluster) + cluster = char + else: + # Diacritic, append to current cluster + cluster += char + if cluster: + phonemes.append(cluster) + return phonemes + +def build_phoneme_vocabs( + pretrained_tokenizer_name: str, + vocab_dir: str | None = None, + language: str = "en-us", +) -> tuple[dict[int, tuple[int, ...]], dict[str, int], int]: + """ + Build or load a phoneme-level vocabulary derived from a subword tokenizer, + using phonemizer with espeak-ng backend and IPA transcription. + + Args: + pretrained_tokenizer_name (str): Hugging Face tokenizer name or path. + vocab_dir (str | None, optional): Directory for saving/loading vocab. + language (str, optional): Language code for phonemizer (default: "en-us"). + + Returns: + tuple: + - subword_id_to_phoneme_ids: dict[int, tuple[int, ...]] + - phoneme_vocab: dict[str, int] + - subword_padding_idx: int + """ + from phonemizer import phonemize + tokenizer = AutoTokenizer.from_pretrained(pretrained_tokenizer_name) + + def _phonemize_all_subwords() -> dict[str, list[str]]: + """Phonemize all subwords once and return mapping {subword → [IPA phonemes]}.""" + subwords = list(tokenizer.vocab.keys()) + try: + phoneme_strings = phonemize( + subwords, + language=language, + backend="espeak", # use espeak-ng + strip=True, + njobs=1, + preserve_punctuation=True, + with_stress=True, + ) + # split each string into grapheme clusters (IPA symbols) + phoneme_lists = [_split_ipa_symbols(s) for s in phoneme_strings] + return {sw: phs for sw, phs in zip(subwords, phoneme_lists) if phs} + except Exception as e: + logging.error(f"[PHONEME-VOCAB] Failed to phonemize subwords: {e}") + return {} + + def _build_phoneme_vocab(subword_to_phonemes: dict[str, list[str]]) -> dict[str, int]: + phoneme_set = {p for phs in subword_to_phonemes.values() for p in phs} + sorted_phonemes = sorted(phoneme_set) + return {p: i for i, p in enumerate(sorted_phonemes)} + + # --- Load or build vocab --- + vocab_file_name = "phoneme_vocab.json" + if vocab_dir: + os.makedirs(vocab_dir, exist_ok=True) + vocab_file = os.path.join(vocab_dir, vocab_file_name) + + with FileLock(vocab_file + ".lock", timeout=60): + if not os.path.exists(vocab_file): + subword_to_phonemes = _phonemize_all_subwords() + phoneme_vocab = _build_phoneme_vocab(subword_to_phonemes) + cache = {"phoneme_vocab": phoneme_vocab, "subword_to_phonemes": subword_to_phonemes} + logging.info(f"[PHONEME-VOCAB] Saving → {vocab_file}") + with open(vocab_file, "w", encoding="utf-8") as f: + json.dump(cache, f, ensure_ascii=False, indent=2) + + logging.info(f"[PHONEME-VOCAB] Loading from {vocab_file}") + with open(vocab_file, encoding="utf-8") as f: + cache = json.load(f) + phoneme_vocab = cache["phoneme_vocab"] + subword_to_phonemes = cache["subword_to_phonemes"] + else: + logging.info(f"[PHONEME-VOCAB] Building from tokenizer '{pretrained_tokenizer_name}'") + subword_to_phonemes = _phonemize_all_subwords() + phoneme_vocab = _build_phoneme_vocab(subword_to_phonemes) + + # --- Build subword → phoneme ID mapping --- + subword_id_to_phoneme_ids = {} + for subword, subword_id in tokenizer.vocab.items(): + phonemes = subword_to_phonemes.get(subword, []) + phoneme_ids = [phoneme_vocab[p] for p in phonemes if p in phoneme_vocab] + if phoneme_ids: + subword_id_to_phoneme_ids[subword_id] = tuple(phoneme_ids) + + # Define a padding index for subwords + subword_padding_idx = len(tokenizer.vocab) + # The padding subword maps to a new phoneme padding ID + subword_id_to_phoneme_ids[subword_padding_idx] = (len(phoneme_vocab),) + + return subword_id_to_phoneme_ids, phoneme_vocab, subword_padding_idx + + +"""@torch.compile +def depthsum_encoding_step( + embs: Tensor, + r: Tensor, + code: Tensor, + depth_str: int = 0, + k: int = 72, +) -> Tensor: + for i in range(depth_str, depth_str + k): + idx_sel = ( + embs[i].pow(2).sum(-1) # [g?, v] + - 2 + * (r.unsqueeze(-2) @ embs[i].transpose(-1, -2)).squeeze(-2) # [b, ?, g?, h] , [g?, h, v] -> [b, ?, g?, v] + ).argmin(-1) + emb_i = F.embedding(idx_sel, embs[i]) + r = r - emb_i + code[..., i : i + 1] = idx_sel + return code +""" + +@torch.compile +def depthsum_encoding_step( + embs: Tensor, + r: Tensor, + code: Tensor, + depth_str: int = 0, + k: int = 72, +) -> Tensor: + for i in range(depth_str, depth_str + k): + idx_sel = ( + embs[i].pow(2).sum(-1) # [g?, v] + - 2 + * (r.unsqueeze(-2) @ embs[i].transpose(-1, -2)).squeeze(-2) # [b, ?, g?, h] , [g?, h, v] -> [b, ?, g?, v] + ).argmin(-1) + + emb_i = F.embedding(idx_sel, embs[i]) + r = r - emb_i + + # FIX: assign correctly without shape mismatch + code[..., i] = idx_sel + + return code + +class MoGHead(nn.Module): + """ + A Mixture of Gaussians (MoG) prediction head. + + This module takes a hidden state and predicts the parameters for a mixture of + Gaussian distributions. It's suitable for modeling continuous, multi-modal data. + + Args: + hidden_size (int): The dimensionality of the input hidden state. + intermediate_size (int): The dimensionality of the MLP layers. + out_size (int): The dimensionality of the output vectors (the mean of each Gaussian). + num_layers (int): The number of MLP layers in the stack. + num_predictions (int): The number of Gaussian components in the mixture. + low_rank (int | None): The dimensionality used for compressing the hidden states. + min_log_std (float): The minimum value for the logarithm of the standard deviation. + eps (float): A small epsilon value for the RMSNorm layers. + """ + + def __init__( + self, + hidden_size: int, + intermediate_size: int, + out_size: int, + num_layers: int, + num_predictions: int, + low_rank: int | None = 64, + min_log_std: float = -4.0, + eps: float = 1e-6, + ): + super().__init__() + self.out_size = out_size + self.low_rank = low_rank + self.num_predictions = num_predictions + self.min_log_std = min_log_std + + self.mlp_stack = nn.Sequential( + *[MLPLayer(hidden_size, intermediate_size, eps=eps) for _ in range(num_layers)], + RMSNorm(hidden_size, eps=eps), + ) + + if low_rank is None: + self.proj_logits = nn.Linear(hidden_size, num_predictions, bias=False) # Predicts mixture weights + self.proj_mus = nn.Linear(hidden_size, num_predictions * out_size, bias=False) # Predicts means + self.proj_logs = nn.Linear(hidden_size, 1, bias=False) # Predicts log standard deviations + else: + assert low_rank < out_size + self.proj_logits = nn.Linear(hidden_size, num_predictions, bias=False) # Predicts mixture weights + self.proj_mus = nn.Linear(hidden_size, num_predictions * low_rank, bias=False) # Predicts means + self.proj_logs = nn.Linear(hidden_size, 1, bias=False) # Predicts log standard deviations + self.proj_else = nn.Linear(hidden_size, out_size, bias=False) + self.low_mat = nn.Parameter(torch.randn(num_predictions, out_size, low_rank) * (low_rank**-0.5)) + + def infer(self, x: Tensor, guidance_scale: float = 0.0, top_p_or_k: float | int = 1.0) -> tuple[Tensor, Tensor]: + """ + Performs inference by sampling from the predicted mixture distribution. + + Args: + x (Tensor): The input hidden state. + guidance_scale (float): The weight for classifier-free guidance. + top_p_or_k (float | int): The value for top-p (nucleus) or top-k sampling of the mixture components. + + Returns: + tuple[Tensor, Tensor]: A tuple containing the mean of the chosen component, + and the log standard deviations. + """ + b, t, _ = x.size() + n, d = self.num_predictions, self.low_rank or self.out_size + x = self.mlp_stack(x) + if guidance_scale > 0: + b //= 2 + x_cond, x_uncond = x.chunk(2, dim=0) + x = x_cond + guidance_scale * (x_cond - x_uncond) + + logits = self.proj_logits(x) + + # Apply top-p or top-k filtering to the mixture logits + if top_p_or_k is not None: + + logits = ( + TopPLogitsWarper(top_p_or_k)( + None, + logits.view(-1, n), + ).view_as(logits) + if isinstance(top_p_or_k, float) + else TopKLogitsWarper(top_p_or_k)( + None, + logits.view(-1, n), + ).view_as(logits) + ) + + # Sample a mixture component using the Gumbel-Max trick + mixture_indices = (F.log_softmax(logits, dim=-1) + gumbel_like(logits)).argmax(-1) + + # Select the mean corresponding to the sampled component + mu = batch_matmul( + x.view(b * t, -1), + self.proj_mus.weight.detach().view(n, d, -1), + mixture_indices.view(b * t), + ).view(b, t, d) + if self.proj_mus.bias is not None: + mu += self.proj_mus.bias.detach().view(n, d)[mixture_indices] + + if self.low_rank: + assert math.log2(d).is_integer() and math.log2(self.out_size).is_integer() + mu = batch_matmul( + mu.view(b * t, -1), + self.low_mat.detach().view(n, self.out_size, -1), + mixture_indices.view(b * t), + BLOCK_SIZE_DIN=d, + BLOCK_SIZE_DOUT=self.out_size, + ).view(b, t, self.out_size) + + mu_res = self.proj_else(x) + else: + mu_res = torch.zeros((b, t, d), device=x.device) + + logs = self.proj_logs(x).clamp_min(self.min_log_std) + return mu * torch.exp(logs) + mu_res, logs + + def forward(self, x: Tensor) -> tuple[Tensor, Tensor, Tensor, Tensor]: + """ + Performs a forward pass for training. + + Args: + x (Tensor): The input hidden state. + + Returns: + tuple[Tensor, Tensor, Tensor]: A tuple containing the mixture logits, + the means for all components, and the + log standard deviations. + """ + b, t, _ = x.size() + d = self.low_rank or self.out_size + x = self.mlp_stack(x) + logits = self.proj_logits(x) + mus = self.proj_mus(x).view(b, t, self.num_predictions, d) + logs = self.proj_logs(x).clamp_min(self.min_log_std) + + if self.low_rank: + mu_res = self.proj_else(x) + else: + mu_res = torch.zeros((b, t, d), device=x.device) + return logits, mus, mu_res, logs + + def dist(self, mus: Tensor, mu: Tensor) -> Tensor: + """ + mus: [b, t, n, d] + mu: [b, t, d] + + return: [b, t, n] + """ + if self.low_rank is None: + return (mus - mu.unsqueeze(-2)).pow(2).sum(-1) + else: + low_mat_sq = self.low_mat.transpose(-1, -2) @ self.low_mat + x, y = mus, mu + b, t, n, d_l = x.size() + wx_sq = ( + x + * torch.einsum( + "btni,nij->btnj", + x, + low_mat_sq.to(x), + ) + ).sum(-1) # [b, t, n] + y_sq = y.pow(2).sum(-1, keepdim=True) # [b, t, 1] + xwy = (x * torch.einsum("bti,nij->btnj", y, self.low_mat.to(y))).sum( + -1 + ) # [b, t, n, d_l], [n, d_i, d_l], [b, t, d_i] -> [b, t, n] + + dist = wx_sq + y_sq - 2 * xwy + return torch.abs(dist) + + +class NeMoSubwordFlagEmbedding(nn.Module): + """ + Adds a tiny embedding table for continuation tokens + (subwords that do NOT start with Ġ or the word-boundary marker). + Compatible with NeMo AutoTokenizer. + """ + def __init__(self, model_name: str, d_model: int): + super().__init__() + # Load tokenizer from NeMo + # self.tokenizer_hf = AutoTokenizer.from_pretrained(model_name) + from nemo.collections.common.tokenizers import AutoTokenizer as NeMoAutoTokenizer + self.tokenizer = NeMoAutoTokenizer(model_name, use_fast=True, trust_remote_code=True) + self.vocab_size = self.tokenizer.vocab_size + self.d_model = d_model + + # Precompute continuation flags + tokens = [self.tokenizer.ids_to_tokens(i) for i in range(self.vocab_size)] + self.register_buffer( + 'is_continuation', + torch.tensor([1 if not (tok.startswith("Ġ") or tok.startswith("▁")) else 0 for tok in tokens], + dtype=torch.long) + ) + + # Tiny embedding table: 0 = word-start, 1 = continuation + init_std = self.d_model ** -0.5 + self.cont_emb = nn.Embedding(2, self.d_model) + nn.init.normal_(self.cont_emb.weight, mean=0.0, std=init_std) + + # Force word-start embedding to zero so only continuation tokens get shifted + self.cont_emb.weight.data[0].zero_() + + def forward(self, subword_embeds: torch.Tensor, token_ids: torch.LongTensor): + # Continuation flags + cont_flags = self.is_continuation[token_ids] + + # Add continuation embedding + cont_emb = self.cont_emb(cont_flags) + return subword_embeds + cont_emb + + +class SubwordFlagEmbedding(nn.Module): + """ + Adds a small continuation embedding for subwords (tokens without word-boundary marker). + Automatically adds a custom padding token at index vocab_size. + Ignores special tokens (starting with '<') when computing continuation flags. + """ + def __init__(self, model_name: str, d_model: int): + super().__init__() + self.tokenizer = AutoTokenizer.from_pretrained(model_name) + self.vocab_size = self.tokenizer.vocab_size + self.d_model = d_model + + # Custom pad token at vocab_size + self.pad_id = self.vocab_size + # register pad_id as a tensor buffer to avoid device issues + self.register_buffer("pad_tensor", torch.tensor(self.pad_id, dtype=torch.long)) + + # Precompute continuation flags + tokens = [self.tokenizer.convert_ids_to_tokens(i) for i in range(self.vocab_size)] + cont_flags = [ + 1 if not (tok.startswith("Ġ") or tok.startswith("▁") or tok.startswith("<")) else 0 + for tok in tokens + ] + cont_flags.append(0) # for the custom pad token + self.register_buffer("is_continuation", torch.tensor(cont_flags, dtype=torch.long)) + + # Continuation embedding + init_std = self.d_model ** -0.5 + self.cont_emb = nn.Embedding(2, self.d_model) + nn.init.normal_(self.cont_emb.weight, mean=0.0, std=init_std) + self.cont_emb.weight.data[0].zero_() + + def forward(self, subword_embeds: torch.Tensor, token_ids: torch.LongTensor): + # Replace OOV token IDs with pad_id safely + token_ids_clamped = torch.where(token_ids >= self.vocab_size, + self.pad_tensor, + token_ids) + # Continuation flags + cont_flags = self.is_continuation[token_ids_clamped] + # Add continuation embedding + cont_emb = self.cont_emb(cont_flags) + return subword_embeds + cont_emb + +class BOSEOSEmbedding(nn.Module): + """ + Adds independent embeddings for BOS and EOS tokens using a single embedding table. + Index 0 = regular token (ignored), 1 = BOS, 2 = EOS. + Compatible with Hugging Face tokenizers that may or may not have BOS/EOS. + """ + def __init__(self, model_name: str, d_model: int): + super().__init__() + self.tokenizer = AutoTokenizer.from_pretrained(model_name) + # vocab size that includes special tokens + vocab_dict = self.tokenizer.get_vocab() + self.vocab_size = max(vocab_dict.values()) + self.d_model = d_model + + # Custom pad token for OOVs + self.pad_id = self.vocab_size + self.register_buffer("pad_tensor", torch.tensor(self.pad_id, dtype=torch.long)) + + # Identify BOS and EOS tokens (may be None) + tokens = [self.tokenizer.convert_ids_to_tokens(i) for i in range(self.vocab_size)] + + if 'Qwen2.5' in model_name: + # For Qwen, '<|im_start|>' is a common choice for a BOS token. + # You can check your tokenizer's vocabulary for the best candidate. + logging.warning("Tokenizer does not have a `bos_token`. Setting it to '<|im_start|>'.") + self.tokenizer.bos_token = '<|im_start|>' + self.tokenizer.eos_token = '<|im_end|>' + + special_flags = [] + for tok in tokens: + if self.tokenizer.bos_token is not None and tok == self.tokenizer.bos_token: + special_flags.append(1) + elif self.tokenizer.eos_token is not None and tok == self.tokenizer.eos_token: + special_flags.append(2) + else: + special_flags.append(0) + special_flags.append(0) # for custom pad token + self.register_buffer("special_flags", torch.tensor(special_flags, dtype=torch.long)) + # Embedding table: 0 = regular, 1 = BOS, 2 = EOS + init_std = self.d_model ** -0.5 + self.special_emb = nn.Embedding(3, d_model) + nn.init.normal_(self.special_emb.weight, mean=0.0, std=init_std) + self.special_emb.weight.data[0].zero_() # regular tokens ignored + + def forward(self, token_embeds: torch.Tensor, token_ids: torch.LongTensor): + """ + token_embeds: (B, T, d_model) + token_ids: (B, T) + """ + # Clamp OOVs to custom pad token + safe_ids = torch.where(token_ids >= self.vocab_size, self.pad_tensor, token_ids) + + # Lookup flags (0=regular, 1=BOS, 2=EOS) + flags = self.special_flags[safe_ids] + return token_embeds + self.special_emb(flags) + +class SubwordEmbedding(nn.Module): + """ + Produces subword embeddings from a Hugging Face tokenizer vocabulary. + No special handling for OOVs or padding — assumes token_ids are valid. + """ + def __init__(self, model_name: str, d_model: int): + super().__init__() + self.tokenizer = AutoTokenizer.from_pretrained(model_name) + + # Get vocab size from tokenizer + vocab_dict = self.tokenizer.get_vocab() + self.vocab_size = max(vocab_dict.values()) + 1 # +1 for safety + self.d_model = d_model + + # Subword embedding table + init_std = d_model ** -0.5 + self.subword_emb = nn.Embedding(self.vocab_size, d_model) + nn.init.normal_(self.subword_emb.weight, mean=0.0, std=init_std) + + def forward(self, token_ids: torch.LongTensor, subword_mask: torch.tensor = None): + """ + token_ids: (B, T) + subword_mask: (B, T) + Returns: + subword_embeds: (B, T, d_model) + """ + return self.subword_emb(token_ids) + + +class CharAwareSubwordEncoder(nn.Module): + """ + An encoder that creates subword embeddings from character-level embeddings. + + This module replaces a standard subword embedding layer. It breaks down each + subword into its constituent characters, embeds the characters, and then + aggregates these character embeddings (e.g., via mean pooling) to form the + final subword representation. This allows the model to handle rare or out-of-vocabulary + subwords more gracefully. + + Args: + out_size (int): The dimensionality of the output embedding vectors. + pretrained_tokenizer_name (str): The name of the base Hugging Face tokenizer. + vocab_dir (str | None): Directory to save/load the character vocabulary. + backbone_type (str | None): The type of backbone model from Hugging Face (e.g., "t5gemma"). + backbone_model_class (str | None): The class name of the backbone model if not using AutoModel. + backbone_config_class (str | None): The class name of the backbone config. + backbone_config (Config | None): A configuration for the backbone model. + """ + + def __init__( + self, + out_size: int, + pretrained_tokenizer_name: str, + vocab_dir: str | None = None, + backbone_type: str | None = "t5gemma", + backbone_model_class: str | None = None, + backbone_config_class: str | None = None, + backbone_config: Config | None = None, + use_subword_flag_emb: bool = True, + use_bos_eos_emb: bool = True, + use_cumulative_word_emb: bool = False, + ): + super().__init__() + + # 1. Build or load the character vocabulary + self.subword_id_to_char_ids, self.char_vocab, self.subword_padding_idx = build_vocabs( + pretrained_tokenizer_name, vocab_dir, + ) + + self.char_padding_idx = len(self.char_vocab) + self.use_subword_flag_emb = use_subword_flag_emb + self.use_bos_eos_emb = use_bos_eos_emb + + # 2. Initialize the backbone model + if backbone_type: + config = AutoConfig.for_model(backbone_type, **(backbone_config.to_dict() if backbone_config else {})) + self.backbone = AutoModelForTextEncoding.from_config(config) + else: + assert backbone_model_class and backbone_config_class + config_class = getattr(transformers, backbone_config_class) + model_class = getattr(transformers, backbone_model_class) + config = config_class(**(backbone_config.to_dict() if backbone_config else {})) + self.backbone = model_class(config) + + self.hidden_size = self.backbone.get_input_embeddings().weight.size(-1) + + # 3. Delete the original subword embedding layer and replace it with our character embedding layer + find_and_delete_module(self.backbone, self.backbone.get_input_embeddings(), "backbone") + self.embed_tokens = nn.Embedding(len(self.char_vocab) + 1, self.hidden_size, padding_idx=self.char_padding_idx) + self.proj_embedding = nn.Linear(self.hidden_size, out_size, bias=False) + + if self.use_subword_flag_emb: + self.subword_flag_emb = SubwordFlagEmbedding(pretrained_tokenizer_name, self.hidden_size) + + if self.use_bos_eos_emb: + self.bos_eos_emb = BOSEOSEmbedding(pretrained_tokenizer_name, self.hidden_size) + + def prepare_inputs(self, subword_ids: Tensor, padding_mask: Tensor) -> tuple[Tensor, Tensor]: + """ + Converts a batch of subword IDs into a padded batch of character IDs. + + Args: + subword_ids (Tensor): A tensor of subword IDs. Shape: `[batch, seq_len]`. + padding_mask (Tensor): A boolean mask indicating valid (non-padding) subwords. + + Returns: + tuple[Tensor, Tensor]: A tuple containing: + - Padded character IDs. Shape: `[num_valid_subwords, max_char_len]`. + - Lengths of each character sequence. Shape: `[num_valid_subwords]`. + """ + device = subword_ids.device + # Select only the valid subword IDs + subword_id_list = torch.masked_select(subword_ids, padding_mask).cpu().tolist() + # Map each subword ID to its sequence of character IDs + char_id_list = [list(self.subword_id_to_char_ids.get(x, ())) for x in subword_id_list] + + char_lengths = torch.tensor([len(x) for x in char_id_list], dtype=torch.long, device=device) + batch_size = char_lengths.size(0) + max_len = int(char_lengths.max().item()) if batch_size > 0 else 0 + + # Create a padded tensor for the character IDs + char_ids = torch.full((batch_size, max_len), self.char_padding_idx, dtype=torch.long, device=device) + for i, char_seq in enumerate(char_id_list): + char_ids[i, : len(char_seq)] = torch.tensor(char_seq, dtype=torch.long, device=device) + + return char_ids, char_lengths + + def forward(self, subword_ids: Tensor, subword_mask: Tensor | None = None) -> Tensor: + """ + Performs the forward pass to get character-aware subword embeddings. + + Args: + subword_ids (Tensor): A tensor of subword IDs. Shape: `[batch, seq_len]`. + subword_mask (Tensor | None): A boolean mask for padding. Defaults to None. + + Returns: + Tensor: The final subword embeddings. Shape: `[batch, seq_len, hidden_size]`. + """ + if subword_mask is None: + subword_mask = torch.ones_like(subword_ids, dtype=torch.bool) + + # 1. Convert subword IDs to character IDs + char_ids, char_lengths = self.prepare_inputs(subword_ids, subword_mask) + + # char_mask = sequence_mask(char_lengths).float() + char_mask = sequence_mask(char_lengths) + + # 2. Get character embeddings and pass them through the backbone + char_embeds = self.embed_tokens(char_ids) + + # The backbone model should be able to accept `inputs_embeds` + char_hidden_states = self.backbone(inputs_embeds=char_embeds, attention_mask=char_mask).last_hidden_state + + # 3. Aggregate character embeddings to form subword embeddings (mean pooling) + # We mask the padding characters before summing to get a correct mean. + masked_sum = (char_hidden_states * char_mask.unsqueeze(-1)).sum(dim=1) + # Avoid division by zero for empty sequences + mean_emb = masked_sum / (char_lengths.unsqueeze(-1).clamp(min=1)) + + # 4. Scatter the aggregated embeddings back to the original subword sequence shape + out_emb = self.proj_embedding(mean_emb) + subword_embeds = torch.zeros( + subword_ids.shape + (out_emb.size(-1),), device=subword_ids.device, dtype=out_emb.dtype + ) + subword_embeds[subword_mask] = out_emb + + if self.use_subword_flag_emb: + subword_embeds = self.subword_flag_emb(subword_embeds, subword_ids) + + if self.use_bos_eos_emb: + subword_embeds = self.bos_eos_emb(subword_embeds, subword_ids) + + return subword_embeds + + +class GatedProjectedSumRMSNorm(nn.Module): + def __init__(self, audio_dim, text_dim, hidden_dim, + final_norm=True, num_codebooks=31, init_residual_scale=0.5): + super().__init__() + self.num_codebooks = num_codebooks + + self.audio_proj = nn.Linear(audio_dim, hidden_dim) + self.text_proj = nn.Linear(text_dim, hidden_dim) + + nn.init.normal_(self.audio_proj.weight, mean=0.0, std=0.015) + nn.init.zeros_(self.audio_proj.bias) + nn.init.normal_(self.text_proj.weight, mean=0.0, std=0.015) + nn.init.zeros_(self.text_proj.bias) + + # FP32 gate params + self.gate = nn.Parameter(torch.zeros(hidden_dim, dtype=torch.float32)) + self.residual_scale = nn.Parameter(torch.tensor(init_residual_scale, dtype=torch.float32)) + + self.final_norm = RMSNorm(hidden_dim) if final_norm else nn.Identity() + + def forward(self, audio_emb, text_emb): + audio_emb = audio_emb / self.num_codebooks + + # projections run in model dtype (BF16) + audio_h = self.audio_proj(audio_emb) + text_h = self.text_proj(text_emb) + + dtype = audio_h.dtype + + with fp32_precision(): + gate = torch.sigmoid(self.gate) # FP32 + res = torch.sigmoid(self.residual_scale) # FP32 + + h = gate.to(dtype) * audio_h + (1 - gate).to(dtype) * text_h + h = res.to(dtype) * h + h = self.final_norm(h.float()).to(dtype) + + return h + +class RVQEARTTSModel(PreTrainedModel): + """ + The main RVQEARTTS model, which can be used for both training and inference. + + This model integrates a character-aware text encoder and a MoG head with a + transformer backbone. It can be trained to predict audio codes or used for + autoregressive inference. + + Args: + config (RVQEARTTSConfig | dict[str, Any]): The configuration object for the model. + """ + + config_class: type[Config] = RVQEARTTSConfig + rvq_embs: Tensor + + def __init__(self, config: RVQEARTTSConfig | dict[str, Any]): + super().__init__(config) + + # Backbone module + if self.config.get("pretrained_text_name", None): + # Load pretrained backbone from huggingface + from nemo.collections.speechlm2.parts.pretrained import load_pretrained_hf + llm = load_pretrained_hf(self.config.pretrained_text_name, pretrained_weights=True).train() + self.backbone = llm.model # fetch PretrainedBaseModel from model "ForCausalLM" + else: + if self.config.backbone_type is None: + assert self.config.backbone_model_class is not None and self.config.backbone_config_class is not None + backbone_config = getattr(transformers, self.config.backbone_config_class)( + **(self.config.backbone_config.to_dict() if self.config.backbone_config else {}), + ) + self.backbone = getattr(transformers, self.config.backbone_model_class)(backbone_config) + else: + backbone_config = AutoConfig.for_model( + self.config.backbone_type, + **(self.config.backbone_config.to_dict() if self.config.backbone_config else {}), + ) + self.backbone = AutoModel.from_config(backbone_config) + + self.hidden_size = self.backbone.get_input_embeddings().weight.size(-1) + find_and_delete_module(self.backbone, self.backbone.get_input_embeddings(), "backbone") + + # Embedding and projection layers + self.bos_emb = nn.Parameter(torch.randn(self.hidden_size)) + self.null_emb = nn.Parameter(torch.randn(self.hidden_size)) + if self.config.random_target_masking: + self.embed_target_mask = nn.Embedding(self.config.num_quantizers, self.hidden_size) + + self.embed_code = nn.Linear(self.config.latent_size, self.hidden_size, bias=False) + + self.embed_context = ( + nn.Linear(self.config.context_hidden_size, self.hidden_size, bias=False) + if self.config.context_hidden_size + else None + ) + + self.embed_subword = ( + CharAwareSubwordEncoder(out_size=self.hidden_size, use_subword_flag_emb=self.config.use_subword_flag_emb, use_bos_eos_emb=self.config.use_bos_eos_emb, **self.config.cas_config) + if self.config.cas_config + else None + ) + + if self.config.use_gated_fusion_for_text_audio: + self.gated_fusion_audio_text = GatedProjectedSumRMSNorm(self.hidden_size, self.hidden_size, self.hidden_size, self.config.num_quantizers) + + # Prediction Heads + if not self.config.disable_eos_prediction: + self.lm_head = nn.Linear(self.hidden_size, 2, bias=False) + + self.mog_head = MoGHead( + hidden_size=self.hidden_size, + out_size=self.config.latent_size, + **self.config.mog_head_config, + ) + + def set_rvq_embs(self, rvq_embs: Tensor): + self.register_buffer("rvq_embs", rvq_embs.detach().clone()) + + def depthsum_embedding(self, code: Tensor) -> Tensor: + """ + code: [b, t, d] + rvq_embs: [d, v, h] + + ret: [b, t, h] + """ + b, t, d = code.size() + _, v, h = self.rvq_embs.size() + device = code.device + + ret = torch.zeros((b, t, h), device=device) + embs = F.pad(self.rvq_embs, [0, 0, 0, 1]) + for i in range(d): + emb = embs[i] + ret = ret + F.embedding(code[..., i], emb) + return ret + + def prepare_training_inputs(self, code: Tensor) -> tuple[Tensor, Tensor, Tensor, Tensor, Tensor]: + """Prepares masked and dropped-out versions of the code for training.""" + b, t, d = code.size() + device = code.device + + src_rate = torch.rand((b, t), device=device) * self.config.max_training_rate + src_masking_rate = get_masking_rate(src_rate, self.config.exponent) + src_num_masking = torch.ceil(src_masking_rate * self.config.num_quantizers).long() + + src_code_mask = torch.ones((b, t, d), dtype=torch.bool, device=device) + src_code_mask = get_mask(src_code_mask, src_num_masking) + src_masked_code = code * src_code_mask + (torch.zeros_like(code) + self.config.codebook_size) * ( + ~src_code_mask + ) + + if self.config.random_target_masking: + tgt_rate = src_rate + (1.0 - src_rate) * torch.rand((b, t), device=device) + tgt_masking_rate = get_masking_rate(tgt_rate, self.config.exponent) + tgt_num_masking = torch.floor(tgt_masking_rate * self.config.num_quantizers).long() + + tgt_code_mask = torch.ones((b, t, d), dtype=torch.bool, device=device) + tgt_code_mask = get_mask(tgt_code_mask, tgt_num_masking) + tgt_masked_code = code * tgt_code_mask + (torch.zeros_like(code) + self.config.codebook_size) * ( + ~tgt_code_mask + ) + else: + tgt_code_mask = torch.ones((b, t, d), dtype=torch.bool, device=device) + tgt_masked_code = code + + dropout_mask = torch.where( + torch.rand((b, t, 1), device=device) < self.config.quantizer_dropout, + (torch.randint(0, self.config.num_quantizers + 1, (b, t, 1), device=device)), + self.config.num_quantizers, + ) > torch.arange(d, dtype=torch.long, device=device) + dropped_code = code * dropout_mask + (torch.zeros_like(code) + self.config.codebook_size) * (~dropout_mask) + + return src_masked_code, src_code_mask, tgt_masked_code, tgt_code_mask, dropped_code + + def _prepare_conditioning( + self, + context_hidden_state: Tensor | None, + subword_ids: Tensor | None, + subword_mask: Tensor | None, + uncond_dec_flag: Tensor, + asr_speech_tokens_emb: Tensor | None, + ) -> Tensor: + """Computes the final conditioning tensor by combining all sources.""" + cond = torch.zeros((1, 1, self.hidden_size), device=uncond_dec_flag.device) + + if self.embed_context is not None and context_hidden_state is not None: + cond = cond + self.embed_context(context_hidden_state) + + if self.embed_subword is not None and subword_ids is not None: + # Infer subword mask from context if not provided + if subword_mask is None and context_hidden_state is not None: + subword_mask = torch.any(context_hidden_state != 0, dim=-1) + # at least one value should be true, otherwise we can completly skip it to avoid errors + if subword_mask is not None and subword_mask.any(): + cond = cond + self.embed_subword(subword_ids, subword_mask) + + if asr_speech_tokens_emb is not None: + cond = cond + asr_speech_tokens_emb + + # Replace with null embedding for unconditional generation + cond = torch.where(uncond_dec_flag, self.null_emb, cond) + return cond + + def _compute_losses( + self, + code: Tensor, + lm_logits: Tensor, + mog_logits: Tensor, + mog_mus: Tensor, + mog_mu_res: Tensor, + mog_logs: Tensor, + src_code_mask: Tensor, + tgt_code_mask: Tensor, + audio_mask: Tensor, + ) -> tuple[Tensor, Tensor, Tensor]: + """Helper to compute all losses for the training step.""" + with torch.autocast(code.device.type, enabled=False): + # 1. LM Loss (predicting discrete tokens) + if not self.config.disable_eos_prediction: + eos_mask = (~audio_mask) & F.pad(audio_mask[:, :-1], [1, 0]) + lm_mask = eos_mask | audio_mask + lm_target = torch.where(eos_mask, 1, 0) + lm_loss = ( + F.cross_entropy(lm_logits.transpose(1, 2), lm_target, reduction="none") * lm_mask + ).sum() / lm_mask.sum().clamp_min(1) + else: + lm_loss = 0.0 + + # 2. Continuous & KL Losses (for the MoG head) + target_mask = (~src_code_mask & tgt_code_mask) & audio_mask.unsqueeze(-1) + reduced_target_mask = target_mask.any(dim=-1) + + cont_code_target = self.depthsum_embedding( + code * target_mask + (torch.zeros_like(code) + self.config.codebook_size) * (~target_mask) + ) + mog_logits = mog_logits.float() + mog_mus = mog_mus.float() + mog_mu_res = mog_mu_res.float() + mog_logs = mog_logs.float() + + # Log probability of the true code under each Gaussian component + logp_code = (-0.5 * math.log(2 * math.pi) - mog_logs) * self.config.latent_size - 0.5 * self.mog_head.dist( + mog_mus, (cont_code_target - mog_mu_res) * torch.exp(-mog_logs) + ) + + # Compute posterior q(k|c) + q_kc = ( + torch.softmax( + logp_code, + -1, + ) + * (1 - self.config.label_smoothing) + + self.config.label_smoothing / self.mog_head.num_predictions + ).detach() + log_q_kc = torch.log(q_kc + 1e-8).detach() + + # Continuous Loss (negative log-likelihood) + c_loss = (-(q_kc * logp_code).sum(-1) * reduced_target_mask).sum() / target_mask.sum().clamp_min(1) + + # KL Divergence Loss + k_loss = ( + (q_kc * (log_q_kc - F.log_softmax(mog_logits, -1))).sum(-1) * reduced_target_mask + ).sum() / target_mask.sum().clamp_min(1) + + return lm_loss, c_loss, k_loss + + def forward( + self, + code: Tensor, + attention_mask: Tensor | None = None, + position_ids: Tensor | None = None, + context_hidden_state: Tensor | None = None, + subword_ids: Tensor | None = None, + subword_mask: Tensor | None = None, + audio_mask: Tensor | None = None, + non_prompt_mask: Tensor | None = None, + past_key_values: Cache | None = None, + use_cache: bool = False, + training: bool | None = None, + guidance_enabled: bool = False, + generation_config: dict[str, Any] | None = None, + teacher_forcing_inference: bool = False, + ignore_eos_flag_stop: bool = False, + asr_speech_tokens_emb: Tensor | None = None, + ) -> RVQEARTTSOutput: + """ + Performs a forward pass handling training, generation, or single-step inference. + + Args: + code (Tensor): Input audio codes. For training, this is the ground truth. + For generation, this is the previously generated code token. + attention_mask (Tensor | None): Attention mask for the backbone transformer. + position_ids (Tensor | None): Position ids for the backbone transformer. + context_hidden_state (Tensor | None): Conditioning from a language model. + subword_ids (Tensor | None): Subword token IDs for conditioning. + subword_mask (Tensor | None): Mask for subword IDs. + audio_mask (Tensor | None): Mask for valid audio positions (for training and inference initialization). + past_key_values (Cache | None): Cache for past key-values for fast decoding. + use_cache (bool): If True, returns the updated `past_key_values`. + training (bool | None): Explicitly set training mode. If `None`, uses `self.training`. + guidance_enabled (bool): If True, duplicates inputs internally to run both + conditional and unconditional passes + generation_config (dict[str, Any] | None): If provided, triggers an iterative code generation. + + Returns: + RVQEARTTSOutput: A dataclass containing losses (for training) or generated outputs + and the cache (for inference). + """ + + # Determine operating mode. + if training is None: + training = self.training + + if audio_mask is not None: + if training: + (src_masked_code, src_code_mask, tgt_masked_code, tgt_code_mask, dropped_code) = ( + self.prepare_training_inputs(code) + ) + uncond_dec_flag = torch.rand(code.size(0), 1, 1, device=code.device) < self.config.p_uncond + else: + dropped_code = code + uncond_dec_flag = torch.zeros(code.size(0), 1, 1, device=code.device, dtype=torch.bool) + + code_embeds = ( + self.embed_code(self.depthsum_embedding(F.pad(dropped_code[:, :-1], [0, 0, 1, 0]))) + + (audio_mask & (~F.pad(audio_mask[:, :-1], [1, 0]))).unsqueeze(-1) * self.bos_emb + ) + + else: # Inference + code_embeds = self.embed_code(self.depthsum_embedding(code)) + uncond_dec_flag = torch.zeros(code.size(0), 1, 1, device=code.device, dtype=torch.bool) + + if guidance_enabled: + assert not training, "Classifier-free guidance can only be used when `training` is False." + code_embeds = torch.cat([code_embeds] * 2, 0) + if attention_mask is not None: + attention_mask = torch.cat([attention_mask] * 2, 0) + if position_ids is not None: + position_ids = torch.cat([position_ids] * 2, 0) + if context_hidden_state is not None: + context_hidden_state = torch.cat([context_hidden_state] * 2, 0) + if subword_ids is not None: + subword_ids = torch.cat([subword_ids] * 2, 0) + if subword_mask is not None: + subword_mask = torch.cat([subword_mask] * 2, 0) + if asr_speech_tokens_emb is not None: + asr_speech_tokens_emb = torch.cat([asr_speech_tokens_emb] * 2, 0) + + uncond_dec_flag = torch.cat([uncond_dec_flag, torch.ones_like(uncond_dec_flag)], 0) + + # Prepare conditioning + cond = self._prepare_conditioning(context_hidden_state, subword_ids, subword_mask, uncond_dec_flag, asr_speech_tokens_emb=asr_speech_tokens_emb) + + if self.config.use_gated_fusion_for_text_audio: + inputs_embeds = self.gated_fusion_audio_text(code_embeds, cond) + else: + inputs_embeds = code_embeds + cond + + # Main backbone pass + backbone_outputs = self.backbone( + inputs_embeds=inputs_embeds, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_values=past_key_values, + use_cache=use_cache, + ) + hidden_states = backbone_outputs.last_hidden_state + + if audio_mask is not None and training: + # --- Training-specific loss computation --- + if not self.config.disable_eos_prediction: + lm_logits = self.lm_head(hidden_states) + else: + lm_logits = None + + mog_input_embeds = self.embed_code(self.depthsum_embedding(src_masked_code)) + if self.config.random_target_masking: + mog_input_embeds = mog_input_embeds + self.embed_target_mask((tgt_code_mask.sum(-1) - 1).clamp_min(0)) + mog_input_embeds = mog_input_embeds + hidden_states + mog_logits, mog_mus, mog_mu_res, mog_logs = self.mog_head(mog_input_embeds) + + lm_loss, c_loss, k_loss = self._compute_losses( + code, lm_logits, mog_logits, mog_mus, mog_mu_res, mog_logs, src_code_mask, tgt_code_mask, audio_mask + ) + total_loss = lm_loss + c_loss + k_loss + + return RVQEARTTSOutput(loss=total_loss, lm_loss=lm_loss, c_loss=c_loss, k_loss=k_loss, hidden_states=hidden_states) + else: # Inference + if not generation_config: + return RVQEARTTSOutput( + hidden_states=hidden_states, + past_key_values=backbone_outputs.past_key_values, + ) + else: + if teacher_forcing_inference: + generated_codes, lm_logits, eos_flag = self.generate_teacher_forcing(hidden_states, generation_config) + else: + generated_codes, lm_logits, eos_flag = self.generate_step(hidden_states, ignore_eos_flag_stop=ignore_eos_flag_stop, **generation_config) + return RVQEARTTSOutput( + past_key_values=backbone_outputs.past_key_values, + codes=generated_codes, + lm_logits=lm_logits, + eos_flag=eos_flag, + hidden_states=hidden_states, + ) + + @torch.no_grad() + def generate_teacher_forcing(self, hidden_states: Tensor, generation_config: dict): + """ + Teacher-forcing wrapper for generate_step, processing all frames in parallel + using a per-frame loop internally. + + Args: + hidden_states: [B, T, H] hidden states + generation_config: kwargs for self.generate_step() + + Returns: + generated_codes: [B, T, ...] generated codes per frame + lm_logits: [B, T, vocab_size] language model logits + eos_flag: [B, T] boolean tensor indicating EOS + """ + B, T, H = hidden_states.shape + + # Preallocate caches + generated_codes_cache = [] + lm_logits_cache = [] + eos_flag_cache = [] + + # Iterate over time steps (frames) + for t in range(T): + # extract one frame (as the original generate_step expects) + frame_hidden = hidden_states[:, t, :] # [B, H] + + # call original generate_step + generated_codes, lm_logits, eos_flag = self.generate_step( + frame_hidden.unsqueeze(1), # keep batch dim + frame dim + **generation_config + ) + if generated_codes is not None: + # store in cache + generated_codes_cache.append(generated_codes) + lm_logits_cache.append(lm_logits) + eos_flag_cache.append(eos_flag) + + # Stack results along time dimension + generated_codes = torch.stack(generated_codes_cache, dim=1) # [B, T, ...] + if not self.config.disable_eos_prediction: + lm_logits = torch.stack(lm_logits_cache, dim=1) # [B, T, vocab_size] + eos_flag = torch.stack(eos_flag_cache, dim=1) # [B, T] + else: + lm_logits = None + eos_flag = None + + return generated_codes, lm_logits, eos_flag + + @torch.no_grad() + def generate_step( + self, + hidden_states: Tensor, + num_iter: int, + guidance_scale: list[float] | float | None = None, + top_p_or_k: list[float | int] | float | int | None = None, + noise_scale: list[float] | float | None = None, + exponent: float | None = None, + eos_threshold: float | None = None, + ignore_eos_flag_stop: bool = False, + ) -> tuple[Tensor | None, Tensor, Tensor]: + """ + Performs the iterative unmasking process for a single generation step. + + This function takes the hidden state from the backbone transformer and generates + codes through an iterative unmasking process. + + Args: + hidden_states (Tensor): The hidden states from the backbone. If using CFG, + this should be the combined [uncond, cond] tensor. + num_iter (int): The number of unmasking iterations. + guidance_scale (list[float] | float | None): The scale for Classifier-Free Guidance. + top_p_or_k (ist[float | int] | float | int | None): The value for top-p or top-k sampling. + noise_scale (list[float] | float | None): The scale of noise to add during MoG sampling. + exponent (float | None): The exponent for the masking schedule. + eos_threshold (float | None): The threshold for EOS prediction. + + Returns: + tuple[Tensor | None, Tensor, Tensor]: A tuple containing: + - the generated codes. + - The logits from `lm_head`. + - The EOS flag. + """ + # 1. Preparation + if guidance_scale is not None: + if not isinstance(guidance_scale, list): + guidance_scale = [guidance_scale] * (1 + num_iter) # includes one step for `lm_head` + assert len(guidance_scale) == 1 + num_iter + if top_p_or_k is not None: + if not isinstance(top_p_or_k, list): + top_p_or_k = [top_p_or_k] * (1 + num_iter) # includes one step for `lm_head` + assert len(top_p_or_k) == 1 + num_iter + if noise_scale is not None: + if not isinstance(noise_scale, list): + noise_scale = [noise_scale] * num_iter + assert len(noise_scale) == num_iter + if exponent is None: + exponent = self.config.exponent + + if guidance_scale is not None: + # The effective batch size is halved + hidden_states, uncond_hidden_states = hidden_states.chunk(2, dim=0) + else: + uncond_hidden_states = hidden_states[:0, :0, :0] + + b, t, _ = hidden_states.size() + d = self.config.num_quantizers + device = hidden_states.device + + # 2. Predict the discrete part of the code + if not self.config.disable_eos_prediction: + if guidance_scale is not None: + lm_logits = self.lm_head(hidden_states + guidance_scale[0] * (hidden_states - uncond_hidden_states)) + else: + lm_logits = self.lm_head(hidden_states) + if top_p_or_k is not None: + lm_logits = ( + TopPLogitsWarper(top_p_or_k[0])( + None, + lm_logits.view(-1, lm_logits.size(-1)), + ).view_as(lm_logits) + if isinstance(top_p_or_k[0], float) + else TopKLogitsWarper(top_p_or_k[0])( + None, + lm_logits.view(-1, lm_logits.size(-1)), + ).view_as(lm_logits) + ) + lm_logits = F.log_softmax(lm_logits, -1) + if eos_threshold is not None: + eos_flag = lm_logits[..., -1] > eos_threshold + else: + eos_flag = lm_logits.argmax(-1) == 1 + + if torch.all(eos_flag) and not ignore_eos_flag_stop: + return None, lm_logits, eos_flag + else: + lm_logits = None + eos_flag = None + + # Initialize the full code tensor + code = torch.zeros((b, t, d), dtype=torch.long, device=device) + self.config.codebook_size + + # 3. Set up the iterative denoising schedule for the continuous part + rates = torch.linspace(0.0, 1.0, num_iter + 1, device=device)[:-1].unsqueeze(-1) + masking_rates = get_masking_rate(rates, exponent=exponent) + num_maskings = torch.ceil(masking_rates * self.config.num_quantizers).long() + + ks = num_maskings - F.pad(num_maskings[1:], [0, 0, 0, 1]) + + # 4. Iteratively unmask the continuous part of the code + cnt = 0 + for i, k in enumerate(ks): + if torch.all(k == 0): + continue + + # Prepare input for the MoG head + guidance_scale_i = guidance_scale[i] if guidance_scale is not None else 0.0 + top_p_or_k_i = top_p_or_k[i] if top_p_or_k is not None else 1.0 + noise_scale_i = noise_scale[i] if noise_scale is not None else 1.0 + + mog_input_embeds = self.embed_code(self.depthsum_embedding(code)) + if self.config.random_target_masking: + mog_input_embeds += self.embed_target_mask(cnt + k - 1) + if guidance_scale_i > 0.0: + mog_input_embeds = torch.cat( + [mog_input_embeds + hidden_states, mog_input_embeds + uncond_hidden_states], 0 + ) + else: + mog_input_embeds += hidden_states + + mog_mu, mog_logs = self.mog_head.infer( + mog_input_embeds, + guidance_scale=guidance_scale_i, + top_p_or_k=top_p_or_k_i, + ) + z = mog_mu + torch.exp(mog_logs) * torch.randn_like(mog_mu) * noise_scale_i + code = depthsum_encoding_step(self.rvq_embs, z, code, cnt, k[0].item()) + cnt += k[0].item() + return code, lm_logits, eos_flag + + def load_state_dict(self, state_dict, strict: bool = True): + try: + super().load_state_dict(state_dict, strict=strict) + except RuntimeError as e: + logging.info(f"Error loading model state_dict !! Retrying with partial initialization!") + model_dict = set_model_dict_for_partial_init(state_dict, self.state_dict()) + super().load_state_dict(model_dict, strict=False) diff --git a/nemo/collections/speechlm2/parts/metrics/__init__.py b/nemo/collections/speechlm2/parts/metrics/__init__.py index aa628a2f76b3..5952c1fb0b30 100644 --- a/nemo/collections/speechlm2/parts/metrics/__init__.py +++ b/nemo/collections/speechlm2/parts/metrics/__init__.py @@ -13,10 +13,11 @@ # limitations under the License. from .asr_bleu import ASRBLEU from .bleu import BLEU -from .wer import WER - +from .token_accuracy import TokenAccuracy +from .results_logger import ResultsLogger __all__ = [ 'ASRBLEU', 'BLEU', - 'WER', + 'TokenAccuracy', + 'ResultsLogger', ] diff --git a/nemo/collections/speechlm2/parts/metrics/asr_bleu.py b/nemo/collections/speechlm2/parts/metrics/asr_bleu.py index d7124b49ff01..3ffd489312dc 100644 --- a/nemo/collections/speechlm2/parts/metrics/asr_bleu.py +++ b/nemo/collections/speechlm2/parts/metrics/asr_bleu.py @@ -57,7 +57,7 @@ def reset(self): def update( self, name: str, refs: list[str], pred_audio: torch.Tensor, pred_audio_lens: torch.Tensor = None - ) -> None: + ) -> list[str]: if self.asr is None: self.reset() @@ -70,7 +70,7 @@ def update( batch_size=pred_audio.shape[0], verbose=False, ) - + asr_hyps_texts = [] for ref, asr_hyp in zip(refs, asr_hyps): asr_hyp = asr_hyp.text self._refs[name].append(self.normalizer(ref)) @@ -78,6 +78,9 @@ def update( if self.verbose: asrb = sacrebleu.sentence_bleu(asr_hyp, [ref]).score logging.info(f"[REF]\t{ref}\n[ASR]\t{asr_hyp} [{asrb:.2f}]") + asr_hyps_texts.append(asr_hyp) + + return asr_hyps_texts def compute(self) -> dict[str, torch.Tensor]: """Computes the final score and deallocates ASR and partial results.""" diff --git a/nemo/collections/speechlm2/parts/metrics/bleu.py b/nemo/collections/speechlm2/parts/metrics/bleu.py index 34f6da088a4d..8479f3ec6346 100644 --- a/nemo/collections/speechlm2/parts/metrics/bleu.py +++ b/nemo/collections/speechlm2/parts/metrics/bleu.py @@ -40,8 +40,6 @@ def __init__(self, normalize: bool = True, normalizer=None, verbose: bool = True self._hyps = defaultdict(list) def reset(self): - self._refs.clear() - self._hyps.clear() return self def update(self, name: str, refs: list[str], hyps: list[str]) -> None: @@ -58,7 +56,8 @@ def compute(self) -> dict[str, torch.Tensor]: metric = torch.tensor(sacrebleu.corpus_bleu(self._hyps[name], [self._refs[name]]).score) corpus_metric[f"txt_bleu_{name}"] = metric corpus_metric["txt_bleu"] = torch.stack(list(corpus_metric.values())).mean() - self.reset() + self._refs.clear() + self._hyps.clear() return corpus_metric diff --git a/nemo/collections/speechlm2/parts/metrics/intelligibility.py b/nemo/collections/speechlm2/parts/metrics/intelligibility.py new file mode 100644 index 000000000000..94e78b6f97c3 --- /dev/null +++ b/nemo/collections/speechlm2/parts/metrics/intelligibility.py @@ -0,0 +1,110 @@ +# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from collections import defaultdict + +import torch +from whisper_normalizer.english import EnglishTextNormalizer + +from nemo.collections.asr.models import ASRModel +from nemo.collections.common.parts.optional_cuda_graphs import WithOptionalCudaGraphs +from nemo.collections.speechlm2.parts.precision import fp32_precision +from nemo.collections.speechlm2.parts.pretrained import load_pretrained_nemo +from nemo.utils import logging + +from nemo.collections.asr.metrics.wer import word_error_rate + + +class Intelligibility: + """ + Computes CER on ASR predictions on generated audio with pretrained NeMo ASR. + By default, uses Whisper's EnglishTextNormalizer on hypotheses and references. + """ + + def __init__(self, pretrained_asr: str, normalize: bool = True, normalizer=None, verbose: bool = True, reuse_asr_hyps: bool = False) -> None: + self.asr = None # load into memory on reset() + self.pretrained_asr_name = pretrained_asr + self.verbose = verbose + self.reuse_asr_hyps = reuse_asr_hyps + if normalize: + if normalizer is None: + self.normalizer = EnglishTextNormalizer() + else: + self.normalizer = normalizer + else: + self.normalizer = _identity + + self._refs = defaultdict(list) + self._hyps = defaultdict(list) + + def reset(self): + # Cleaning up GPU memory before we load ASRModel, because it may already + # be quite fragmented and close to the limit after observing many + # dynamic shapes during the training epoch. + torch.cuda.memory.empty_cache() + with fp32_precision(): # Some NeMo ASR models weren't trained with bfloat16. + if not self.reuse_asr_hyps: + self.asr = load_pretrained_nemo(ASRModel, self.pretrained_asr_name).eval() + WithOptionalCudaGraphs.disable_cuda_graphs_recursive(self.asr, attribute_path="decoding.decoding") + return self + + def update( + self, name: str, refs: list[str], pred_audio: torch.Tensor, pred_audio_lens: torch.Tensor = None, asr_hyps: list[str] = None + ) -> list[str]: + if self.asr is None and not self.reuse_asr_hyps: + self.reset() + + if pred_audio_lens is None: + pred_audio_lens = [pred_audio.shape[1]] * pred_audio.shape[0] + + with fp32_precision(): + if not self.reuse_asr_hyps: + asr_hyps = self.asr.transcribe( + [audio[:alen] for audio, alen in zip(pred_audio, pred_audio_lens)], + batch_size=pred_audio.shape[0], + verbose=False, + ) + asr_hyps = [asr_hyp.text for asr_hyp in asr_hyps] + + asr_hyps_texts = [] + for ref, asr_hyp in zip(refs, asr_hyps): + self._refs[name].append(self.normalizer(ref)) + self._hyps[name].append(self.normalizer(asr_hyp)) + asr_hyps_texts.append(asr_hyp) + + return asr_hyps_texts + + def compute(self) -> dict[str, torch.Tensor]: + """Computes the final score and deallocates ASR and partial results.""" + corpus_metric = {} + corpus_metric["cer"] = [] + corpus_metric["wer"] = [] + for name in self._refs.keys(): + cer = torch.tensor(word_error_rate(self._hyps[name], self._refs[name], use_cer=True)) + wer = torch.tensor(word_error_rate(self._hyps[name], self._refs[name], use_cer=False)) + corpus_metric[f"cer_{name}"] = cer + corpus_metric[f"wer_{name}"] = wer + corpus_metric["cer"].append(cer) + corpus_metric["wer"].append(wer) + + corpus_metric["cer"] = torch.stack(corpus_metric["cer"]).mean() + corpus_metric["wer"] = torch.stack(corpus_metric["wer"]).mean() + self._refs.clear() + self._hyps.clear() + self.asr = None # free up GPU memory + torch.cuda.memory.empty_cache() + return corpus_metric + + +def _identity(x): + return x diff --git a/nemo/collections/speechlm2/parts/metrics/results_logger.py b/nemo/collections/speechlm2/parts/metrics/results_logger.py new file mode 100644 index 000000000000..bef245555eb3 --- /dev/null +++ b/nemo/collections/speechlm2/parts/metrics/results_logger.py @@ -0,0 +1,170 @@ +# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import json +import os +import shutil +from collections import defaultdict + +import torch +import torchaudio + +from nemo.utils import logging + + +def safe_remove_path(path): + try: + shutil.rmtree(path) + except: + pass # File was already deleted by another thread + + +class ResultsLogger: + """ + Saves audios and a json file with the model outputs. + """ + + def __init__(self, save_path): + self.save_path = save_path + self.audio_save_path = os.path.join(save_path, "pred_wavs") + os.makedirs(self.audio_save_path, exist_ok=True) + self.matadata_save_path = os.path.join(save_path, "metadatas") + os.makedirs(self.matadata_save_path, exist_ok=True) + + def reset(self): + # ensures that we are cleaning the metadata files + metadata_files = os.listdir(self.matadata_save_path) + for f in metadata_files: + open(os.path.join(self.matadata_save_path, f), 'w').close() + + # clean out any existing .wav predictions safely + try: + audio_files = os.listdir(self.audio_save_path) + for f in audio_files: + if f.lower().endswith(".wav"): + try: + os.remove(os.path.join(self.audio_save_path, f)) + except FileNotFoundError: + pass # already gone + except Exception: + logging.warning(f"Failed to remove audio file {f} during reset.", stack_info=False) + except FileNotFoundError: + # directory somehow missing: recreate it + os.makedirs(self.audio_save_path, exist_ok=True) + + return self + + @staticmethod + def merge_and_save_audio( + out_audio_path: str, pred_audio: torch.Tensor, pred_audio_sr: int, user_audio: torch.Tensor, user_audio_sr: int + ) -> None: + # if user_audio is None ignore it + if user_audio is not None: + user_audio = torchaudio.functional.resample(user_audio.float(), user_audio_sr, pred_audio_sr) + T1, T2 = pred_audio.shape[0], user_audio.shape[0] + max_len = max(T1, T2) + pred_audio_padded = torch.nn.functional.pad(pred_audio, (0, max_len - T1), mode='constant', value=0) + user_audio_padded = torch.nn.functional.pad(user_audio, (0, max_len - T2), mode='constant', value=0) + + # combine audio in a multichannel audio + combined_wav = torch.cat( + [ + user_audio_padded.squeeze().unsqueeze(0).detach().cpu(), + pred_audio_padded.squeeze().unsqueeze(0).detach().cpu(), + ], + dim=0, + ).squeeze() + + else: + combined_wav = pred_audio.unsqueeze(0).detach().cpu() + + # save audio + torchaudio.save(out_audio_path, combined_wav, pred_audio_sr) + logging.info(f"Audio saved at: {out_audio_path}") + + def update( + self, + name: str, + refs: list[str], + hyps: list[str], + asr_hyps: list[str], + samples_id: list[str], + pred_audio: torch.Tensor, + pred_audio_sr: int, + user_audio: torch.Tensor, + user_audio_sr: int, + target_audio: torch.Tensor = None, + pred_audio_tf: torch.Tensor = None, + pre_audio_trimmed: torch.Tensor = None, + eou_pred: torch.Tensor = None, + fps: float = None, + results=None, + tokenizer=None, + reference_audio: torch.Tensor = None, + ) -> None: + + out_json_path = os.path.join(self.matadata_save_path, f"{name}.json") + out_dicts = [] + for i in range(len(refs)): + # save audio + sample_id = samples_id[i][:150] # make sure that sample id is not too big + out_audio_path = os.path.join(self.audio_save_path, f"{name}_{sample_id}.wav") + self.merge_and_save_audio(out_audio_path, pred_audio[i], pred_audio_sr, user_audio[i] if user_audio is not None else None, user_audio_sr) + + if pred_audio_tf is not None: + out_audio_path_tf = out_audio_path.replace(".wav", "_tf.wav") + self.merge_and_save_audio(out_audio_path_tf, pred_audio_tf[i], pred_audio_sr, user_audio[i] if user_audio is not None else None, user_audio_sr) + + if target_audio is not None: + out_audio_path_gt = out_audio_path.replace(".wav", "_GT.wav") + self.merge_and_save_audio(out_audio_path_gt, target_audio[i], pred_audio_sr, user_audio[i] if user_audio is not None else None, user_audio_sr) + + # create a wav with eou prediction for debug purposes + if eou_pred is not None: + out_audio_path_eou = os.path.join(self.audio_save_path, f"{name}_{sample_id}_eou.wav") + repeat_factor = int(pred_audio_sr / fps) + eou_pred_wav = ( + eou_pred[i].unsqueeze(0).unsqueeze(-1).repeat(1, 1, repeat_factor) + ) # (B, T, repeat_factor) + eou_pred_wav = eou_pred_wav.view(1, -1) # (B, T * repeat_factor) + eou_pred_wav = eou_pred_wav.float() * 0.8 # make 1 audible and keep 0 as total silence + torchaudio.save(out_audio_path_eou, eou_pred_wav.squeeze().unsqueeze(0).detach().cpu(), pred_audio_sr) + + if pre_audio_trimmed is not None: + out_audio_path_trimmed = os.path.join(self.audio_save_path, f"{name}_{sample_id}_pred_trimmed.wav") + torchaudio.save(out_audio_path_trimmed, pre_audio_trimmed[i].squeeze().unsqueeze(0).detach().cpu(), pred_audio_sr) + + if reference_audio is not None: + out_audio_path_ref = os.path.join(self.audio_save_path, f"{name}_{sample_id}_spk_reference.wav") + torchaudio.save(out_audio_path_ref, reference_audio[i].squeeze().unsqueeze(0).detach().cpu(), pred_audio_sr) + + # cache metadata + out_dict = { + "target_text": refs[i], + "pred_text": hyps[i], + "speech_pred_transcribed": asr_hyps[i], + "audio_path": os.path.relpath(out_audio_path, self.save_path), + } + if results is not None: + if tokenizer is not None: + out_dict['tokens_text'] = " ".join(tokenizer.ids_to_tokens(results['tokens_text'][i])) + else: + out_dict['tokens_text'] = results['tokens_text'][i].tolist() + out_dicts.append(out_dict) + # uses append here to avoid needs to cache + with open(out_json_path, 'a+', encoding='utf-8') as fout: + for out_dict in out_dicts: + fout.write(json.dumps(out_dict, ensure_ascii=False, indent=4) + '\n') + # json.dump(out_dict, fout) + + logging.info(f"Metadata file for {name} dataset updated at: {out_json_path}") diff --git a/nemo/collections/speechlm2/parts/metrics/secs.py b/nemo/collections/speechlm2/parts/metrics/secs.py new file mode 100644 index 000000000000..a1a611b16fcd --- /dev/null +++ b/nemo/collections/speechlm2/parts/metrics/secs.py @@ -0,0 +1,77 @@ +# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from collections import defaultdict + +import sacrebleu +import torch +from whisper_normalizer.english import EnglishTextNormalizer + +from nemo.collections.asr.models import ASRModel +from nemo.collections.asr.models import EncDecSpeakerLabelModel +from nemo.collections.common.parts.optional_cuda_graphs import WithOptionalCudaGraphs +from nemo.collections.speechlm2.parts.precision import fp32_precision +from nemo.collections.speechlm2.parts.pretrained import load_pretrained_nemo +from nemo.utils import logging + + +class SECS: + """ + Computes Speacker encoder cossine similarity (SECS) on generated audio with pretrained speaker encoder model. + """ + + def __init__(self, pretrained_se_name: str) -> None: + self.speaker_encoder = None # load into memory on reset() + self.pretrained_se_name = pretrained_se_name + self._secs = defaultdict(list) + + def reset(self): + # Cleaning up GPU memory before we load ASRModel, because it may already + # be quite fragmented and close to the limit after observing many + # dynamic shapes during the training epoch. + torch.cuda.memory.empty_cache() + with fp32_precision(): + self.speaker_encoder = EncDecSpeakerLabelModel.from_pretrained(model_name=self.pretrained_se_name).eval() + + return self + + def update( + self, name: str, target_audio: torch.Tensor, target_audio_lens: torch.Tensor, pred_audio: torch.Tensor, pred_audio_lens: torch.Tensor + ) -> None: + if self.speaker_encoder is None: + self.reset() + + with fp32_precision(): + with torch.no_grad(): + _, t_g = self.speaker_encoder(input_signal=target_audio, input_signal_length=target_audio_lens.long()) + _, s_g = self.speaker_encoder(input_signal=pred_audio, input_signal_length=pred_audio_lens.long()) + secs = torch.nn.functional.cosine_similarity(t_g, s_g, dim=-1).mean() + + + self._secs[name].append(secs) + + def compute(self) -> dict[str, torch.Tensor]: + """Computes the final score and deallocates ASR and partial results.""" + corpus_metric = {} + avg_secs = [] + for name in self._secs.keys(): + secs = torch.stack(self._secs[name]).mean() + corpus_metric[f"secs_{name}"] = secs + avg_secs.append(secs) + + corpus_metric["secs"] = torch.stack(avg_secs).mean() + self._secs.clear() + self.speaker_encoder = None # free up GPU memory + torch.cuda.memory.empty_cache() + return corpus_metric + diff --git a/nemo/collections/speechlm2/parts/metrics/token_accuracy.py b/nemo/collections/speechlm2/parts/metrics/token_accuracy.py new file mode 100644 index 000000000000..231015d145d8 --- /dev/null +++ b/nemo/collections/speechlm2/parts/metrics/token_accuracy.py @@ -0,0 +1,88 @@ +# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from collections import defaultdict +import torch +from nemo.utils import logging + + +def compute_token_accuracy_with_tolerance(target, pred, token, tolerance=1): + """ + Computes the accuracy of `token` in `target` vs `pred` within a ±`tolerance` window. + + Args: + target (torch.Tensor): Batch of target sequences (batch_size, seq_len) + pred (torch.Tensor): Batch of predicted sequences (batch_size, seq_len) + token (int): The token to compute accuracy for + tolerance (int): Allowed index difference (window) for correct predictions + + Returns: + float: Accuracy as correct / total occurrences of `token` in target + """ + batch_size, seq_len = target.shape + + # Mask of positions where token appears + target_mask = target == token + pred_mask = pred == token + + correct = 0 + total = 0 + + # For each sequence in the batch + for b in range(batch_size): + # Get indices of token in target and pred + target_indices = target_mask[b].nonzero(as_tuple=True)[0] + pred_indices = pred_mask[b].nonzero(as_tuple=True)[0] + + total += len(target_indices) + if len(pred_indices) == 0: + continue # No token in pred, none can be correct + + # Compute pairwise distances + distances = torch.abs(target_indices[:, None] - pred_indices[None, :]) # shape (n_target, n_pred) + min_distances, _ = torch.min(distances, dim=1) # min distance for each target occurrence + + # Count how many target tokens have a pred within tolerance + correct += torch.sum(min_distances <= tolerance).item() + + accuracy = correct / total if total > 0 else 0.0 + return accuracy + + +class TokenAccuracy: + """ + Computes Token Accuracy scores. + """ + + def __init__(self, token_name: str, token_id: int, tolerance: int = 1, verbose: bool = True): + self.token_name = token_name + self.token_id = token_id + self.tolerance = tolerance + self.verbose = verbose + self.scores = defaultdict(list) + + def reset(self): + return self + + def update(self, name: str, refs: torch.Tensor, hyps: torch.Tensor) -> None: + token_acc = compute_token_accuracy_with_tolerance(refs, hyps, token=self.token_id, tolerance=self.tolerance) + self.scores[name].append(token_acc) + + def compute(self) -> dict[str, torch.Tensor]: + corpus_metric = {} + for name in self.scores.keys(): + metric = torch.tensor(self.scores[name]).mean() + corpus_metric[f"token_acc_{self.token_name}_{name}"] = metric + self.scores.clear() + return corpus_metric + diff --git a/nemo/collections/speechlm2/parts/metrics/wer.py b/nemo/collections/speechlm2/parts/metrics/wer.py deleted file mode 100644 index 93fe1da6a34c..000000000000 --- a/nemo/collections/speechlm2/parts/metrics/wer.py +++ /dev/null @@ -1,65 +0,0 @@ -# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -from collections import defaultdict - -import torch -from whisper_normalizer.english import EnglishTextNormalizer - -from nemo.collections.asr.metrics.wer import word_error_rate -from nemo.utils import logging - - -class WER: - """ - Computes WER on text predictions. - By default, uses Whisper's EnglishTextNormalizer on hypotheses and references. - """ - - def __init__(self, normalize: bool = True, normalizer=None, verbose: bool = True): - self.verbose = verbose - if normalize: - if normalizer is None: - self.normalizer = EnglishTextNormalizer() - else: - self.normalizer = normalizer - else: - self.normalizer = _identity - - self._refs = defaultdict(list) - self._hyps = defaultdict(list) - - def reset(self): - self._refs.clear() - self._hyps.clear() - return self - - def update(self, name: str, refs: list[str], hyps: list[str]) -> None: - for ref, hyp in zip(refs, hyps): - self._refs[name].append(self.normalizer(ref)) - self._hyps[name].append(self.normalizer(hyp)) - if self.verbose and refs and hyps: - logging.info(f"[REF]\t{refs[0]}\n[HYP]\t{hyps[0]}") - - def compute(self) -> dict[str, torch.Tensor]: - corpus_metric = {} - for name in self._refs.keys(): - metric = torch.tensor(word_error_rate(self._hyps[name], self._refs[name])) - corpus_metric[f"wer_{name}"] = metric - corpus_metric["wer"] = torch.stack(list(corpus_metric.values())).mean() - self.reset() - return corpus_metric - - -def _identity(x): - return x From e3de87216e3fd278d5566516fedaa665a288e1c7 Mon Sep 17 00:00:00 2001 From: Edresson Casanova Date: Tue, 11 Nov 2025 10:42:01 -0800 Subject: [PATCH 002/102] Add EARTTS codec and extra missing modules Signed-off-by: Edresson Casanova --- examples/speechlm2/duplex_eartts_train.py | 2 +- .../speechlm2/duplex_eartts_train_infer.py | 2 +- nemo/collections/common/data/lhotse/cutset.py | 60 +- nemo/collections/speechlm2/__init__.py | 3 +- nemo/collections/speechlm2/data/__init__.py | 2 + .../speechlm2/data/duplex_ear_tts_dataset.py | 13 +- .../speechlm2/models/duplex_ear_tts.py | 1107 +---------------- .../speechlm2/modules/rvq_ear_tts_vae.py | 1085 ++++++++++++++++ .../collections/speechlm2/parts/pretrained.py | 15 +- 9 files changed, 1197 insertions(+), 1092 deletions(-) create mode 100644 nemo/collections/speechlm2/modules/rvq_ear_tts_vae.py diff --git a/examples/speechlm2/duplex_eartts_train.py b/examples/speechlm2/duplex_eartts_train.py index ea9b4571d34f..771e4f24240e 100644 --- a/examples/speechlm2/duplex_eartts_train.py +++ b/examples/speechlm2/duplex_eartts_train.py @@ -47,7 +47,7 @@ def train(cfg): target_sample_rate=cfg.data.target_sample_rate, input_roles=cfg.data.input_roles, output_roles=cfg.data.output_roles, - add_text_bos_and_eos_in_each_turn=cfg.data..get("add_text_bos_and_eos_in_each_turn", True), + add_text_bos_and_eos_in_each_turn=cfg.data.get("add_text_bos_and_eos_in_each_turn", True), add_audio_prompt_after_description=cfg.data.add_audio_prompt_after_description, audio_prompt_duration=cfg.data.audio_prompt_duration, num_delay_speech_tokens=cfg.model.get("num_delay_speech_tokens", 2) diff --git a/examples/speechlm2/duplex_eartts_train_infer.py b/examples/speechlm2/duplex_eartts_train_infer.py index 15ec634a6fbf..04fdd72ce8bd 100644 --- a/examples/speechlm2/duplex_eartts_train_infer.py +++ b/examples/speechlm2/duplex_eartts_train_infer.py @@ -48,7 +48,7 @@ def inference(cfg): target_sample_rate=cfg.data.target_sample_rate, input_roles=cfg.data.input_roles, output_roles=cfg.data.output_roles, - add_text_bos_and_eos_in_each_turn=cfg.data..get("add_text_bos_and_eos_in_each_turn", True), + add_text_bos_and_eos_in_each_turn=cfg.data.get("add_text_bos_and_eos_in_each_turn", True), add_audio_prompt_after_description=cfg.data.add_audio_prompt_after_description, audio_prompt_duration=cfg.data.audio_prompt_duration, num_delay_speech_tokens=cfg.model.get("num_delay_speech_tokens", 2) diff --git a/nemo/collections/common/data/lhotse/cutset.py b/nemo/collections/common/data/lhotse/cutset.py index 76825e612ba7..09dd1a2f05da 100644 --- a/nemo/collections/common/data/lhotse/cutset.py +++ b/nemo/collections/common/data/lhotse/cutset.py @@ -21,7 +21,7 @@ from typing import KeysView, Mapping, Sequence, Tuple, Union import omegaconf -from lhotse import CutSet, Features, Recording +from lhotse import CutSet, Features, Recording, MonoCut, SupervisionSegment from lhotse.array import Array, TemporalArray from lhotse.cut import Cut, MixedCut, PaddingCut from lhotse.serialization import load_yaml @@ -535,6 +535,64 @@ def cut_to_conversation( ) + +@data_type_parser(["s2s_duplex_overlap_as_s2s_duplex"]) +def read_s2s_duplex_overlap_as_s2s_duplex(config) -> tuple[CutSet, bool]: + def filter_cuts_starting_with_agent(cuts: CutSet, agent_roles=("agent", "assistant", "Assistant")) -> CutSet: + def filter_cut_fn(cut): + # sort supervisions by start + cut.supervisions = sorted(cut.supervisions, key=lambda s: s.start) + if len(cut.supervisions): + return cut.supervisions[0].speaker not in agent_roles + else: + return False # filter emptly supervisions + + return cuts.filter(filter_cut_fn) + + def convert_overlap_cut(cut): + agent_segments = [] + for seg in cut.agent_segments: + ss = SupervisionSegment( + id=cut.id, + recording_id=cut.id, + start=seg["start"] - move_agent_text_back_by, + duration=seg["end"]-seg["start"] + move_agent_text_back_by, + text=seg["text"], + speaker="agent", + ) + agent_segments.append(ss) + + user_segments = [] + for seg in cut.user_segments: + ss = SupervisionSegment( + id=cut.id, + recording_id=cut.id, + start=seg["start"], + duration=seg["end"]-seg["start"], + text=seg["text"], + speaker="user", + ) + user_segments.append(ss) + + cut.supervisions = sorted(agent_segments + user_segments, key=lambda s: s.start) + cut.formatter = "s2s_duplex_overlap_as_s2s_duplex" + return cut + + # load lhotse cuts + cuts, is_tarred = read_cutset_from_config(config) + move_agent_text_back_by = config.get("move_agent_text_back_by", 0) + filter_samples_starting_with_agent = config.get("filter_samples_starting_with_agent", False) + agent_roles = config.get("agent_roles", ["agent", "Assistant", "assistant"]) + + # convert cuts + cuts = cuts.map(convert_overlap_cut) + + # Filter cuts where the first supervision is agent + if filter_samples_starting_with_agent: + cuts = filter_cuts_starting_with_agent(cuts, agent_roles) + + return cuts, is_tarred + @data_type_parser(["lhotse_magpietts_data_as_continuation"]) def read_lhotse_magpietts_data_as_continuation(config) -> tuple[CutSet, bool]: def convert_lhotse_magpietts_data_as_cont(cut): diff --git a/nemo/collections/speechlm2/__init__.py b/nemo/collections/speechlm2/__init__.py index b638dd13f08d..bd9267c08fe3 100644 --- a/nemo/collections/speechlm2/__init__.py +++ b/nemo/collections/speechlm2/__init__.py @@ -11,12 +11,13 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from .data import DataModule, DuplexS2SDataset, SALMDataset +from .data import DataModule, DuplexEARTTSDataset, DuplexS2SDataset, SALMDataset from .models import SALM, DuplexS2SModel, DuplexS2SSpeechDecoderModel __all__ = [ 'DataModule', 'DuplexS2SDataset', + 'DuplexEARTTSDataset', 'SALMDataset', 'DuplexS2SModel', 'DuplexS2SSpeechDecoderModel', diff --git a/nemo/collections/speechlm2/data/__init__.py b/nemo/collections/speechlm2/data/__init__.py index f698b6bcf12a..0d51497a6d8f 100644 --- a/nemo/collections/speechlm2/data/__init__.py +++ b/nemo/collections/speechlm2/data/__init__.py @@ -13,10 +13,12 @@ # limitations under the License. from .datamodule import DataModule from .s2s_dataset import DuplexS2SDataset +from .duplex_ear_tts_dataset import DuplexEARTTSDataset from .salm_dataset import SALMDataset __all__ = [ 'DataModule', 'DuplexS2SDataset', + 'DuplexEARTTSDataset', 'SALMDataset', ] diff --git a/nemo/collections/speechlm2/data/duplex_ear_tts_dataset.py b/nemo/collections/speechlm2/data/duplex_ear_tts_dataset.py index cf511539bd9d..dacc9852babe 100644 --- a/nemo/collections/speechlm2/data/duplex_ear_tts_dataset.py +++ b/nemo/collections/speechlm2/data/duplex_ear_tts_dataset.py @@ -366,13 +366,13 @@ def __getitem__(self, cuts: CutSet) -> dict: target_audio = collate_vectors(target_audio_, padding_value=0) # recreate audio mask - audio_mask = get_mask_from_lengths(target_token_lens) + non_desc_mask = get_mask_from_lengths(target_token_lens) # ignore desc len in audio mask for i, frame in enumerate(desc_lens): - audio_mask[i, :frame] = 0.0 + non_desc_mask[i, :frame] = 0.0 # desc mask is totally the oposite of audio mask - desc_mask = ~ audio_mask + desc_mask = ~ non_desc_mask # create non_prompt_mask that should mask desc plus audio prompt if used non_prompt_mask = get_mask_from_lengths(target_token_lens) @@ -380,11 +380,9 @@ def __getitem__(self, cuts: CutSet) -> dict: non_prompt_mask[i, :frame-1] = 0.0 else: # create a mask for audio using target tokens that suppose to have the same size of the tokenized audio - audio_mask = get_mask_from_lengths(target_token_lens) + non_prompt_mask = get_mask_from_lengths(target_token_lens) # create a full zero desc mask - desc_mask = torch.zeros_like(audio_mask) - # keep text mask as audio_mask - non_prompt_mask = audio_mask + desc_mask = torch.zeros_like(non_prompt_mask) batch_size = len(target_token_lens) max_len = max(target_token_lens) @@ -412,7 +410,6 @@ def __getitem__(self, cuts: CutSet) -> dict: return { "sample_id": [str(cut.id) for cut in cuts], - "audio_mask": audio_mask.bool(), "non_prompt_mask": non_prompt_mask.bool(), "desc_mask": desc_mask.bool(), "desc_lens": desc_lens, diff --git a/nemo/collections/speechlm2/models/duplex_ear_tts.py b/nemo/collections/speechlm2/models/duplex_ear_tts.py index 872068d487d9..e8a2fc6217b6 100644 --- a/nemo/collections/speechlm2/models/duplex_ear_tts.py +++ b/nemo/collections/speechlm2/models/duplex_ear_tts.py @@ -46,7 +46,7 @@ from nemo.collections.audio.parts.utils.resampling import resample from nemo.core.classes.module import NeuralModule from nemo.collections.common.tokenizers import AutoTokenizer -from nemo.collections.nlp.parts.nlp_overrides import NLPSaveRestoreConnector +from nemo.collections.common.parts.nlp_overrides import NLPSaveRestoreConnector from nemo.collections.speechlm2.data.utils import get_pad_id from nemo.collections.speechlm2.models.duplex_s2s_model import tokens_to_str from nemo.collections.speechlm2.parts.hf_hub import HFHubMixin @@ -68,16 +68,12 @@ from nemo.utils import logging from nemo.collections.tts.modules import transformer_2501 -from nemo.collections.tts.modules.mimi_codec_modules import ReshapeTransformerEncoder from nemo.collections.speechlm2.modules.ear_tts_commons import SCRIPT_PLACEHOLDER - -from nemo.collections.speechlm2.modules.cfm import MatchaTTSCFM from types import SimpleNamespace -from nemo.collections.speechlm2.modules.rvq_ear_tts_model import RVQEARTTSModel, RVQEARTTSConfig, build_vocabs, SubwordFlagEmbedding, RMSNorm +from nemo.collections.speechlm2.modules.rvq_ear_tts_model import RVQEARTTSModel, RVQEARTTSConfig from nemo.collections.speechlm2.modules.rvq_ear_tts_vae import RVQVAEModel -from nemo.collections.speechlm2.data.duplex_ear_tts_dataset import normalize_text_fn import torch import torch.nn as nn @@ -258,13 +254,10 @@ def dtype_counter_hook(module, inputs, outputs): "safety_factor": safety_factor, } - # print("Num. BF16/FP16 activations:", num_bf16_fp16) - # print("Num. FP32 activations:", num_fp32) print("Num. BF16/FP16 candidate layers:", len(bf16_layers)) print("Num. FP32 layers (sensitive + propagated):", len(fp32_layers)) return model_patched, summary - def generate_multiturn_speaking_mask(input_ids: torch.Tensor, bos_token_id: int = 0, eos_token_id: int = 1): @@ -337,80 +330,6 @@ def get_mask_from_lengths( mask = ids < lengths.unsqueeze(1) return mask - -from transformers import MimiModel, AutoFeatureExtractor -class MimiCodec(NeuralModule): - def __init__(self, model_path_or_name="kyutai/mimi", num_codebooks=12): - super().__init__() - self.codec = MimiModel.from_pretrained(model_path_or_name) - self.feature_extractor = AutoFeatureExtractor.from_pretrained(model_path_or_name) - self.num_codebooks = num_codebooks - - @property - def device(self): - return next(self.codec.parameters()).device - - @property - def _codebook_size(self): - return self.codec.config.codebook_size - - @property - def _num_codebooks(self): - return self.num_codebooks - - @property - def samples_per_frame(self): - return int(self.feature_extractor.sampling_rate // self.codec.config.frame_rate) - - def encode(self, audio, audio_len): - audio = audio.squeeze(1) - with fp32_precision(): - # make the audio divisible by frame rate and also by self.frame_stacking_factor with extra frames of 1 to avoid issues because we are removing a audio frame to shift target and input for TF - audio, audio_len = self.pad_audio_to_factor(audio, audio_len, self.samples_per_frame, extra_frames=0) - # explicitly encode then decode the audio inputs - encoder_outputs = self.codec.encode(audio.unsqueeze(1).to(self.device), num_quantizers=self.num_codebooks) - codes = encoder_outputs.audio_codes - tokens_len = audio_len // self.samples_per_frame - return codes.transpose(1, 2), tokens_len - - def decode(self, tokens, tokens_len): - with fp32_precision(): - tokens = tokens.transpose(1, 2) - # tokens: B, T', C' - audio = self.codec.decode(tokens).audio_values.squeeze(1) - audio_len = tokens_len * self.samples_per_frame - return audio, audio_len - - def forward(self, audio, audio_len): - tokens, tokens_len = self.encode(audio, audio_len) - audio, audio_len = self.decode(tokens, tokens_len) - return audio, audio_len - - - def pad_audio_to_factor(self, audio, audio_len, samples_per_frame, extra_frames: int = 0): - """ - Zero pad the end of the audio so that we do not have a partial end frame. - The output will be zero-padded to have an integer number of frames of - length `samples_per_frame * frame_stacking_factor`. - - Args: - audio: input time-domain signal (B, T) - audio_len: valid length for each example in the batch (B,) - samples_per_frame: number of samples per frame - - Returns: - padded_audio: Padded time-domain signal (B, T') - padded_len: Adjusted valid lengths (B,) - """ - with fp32_precision(): - padded_len = (samples_per_frame * torch.ceil(audio_len / samples_per_frame).int()) + (extra_frames * samples_per_frame) - max_len = padded_len.max().int().item() - num_padding = (max_len - audio.shape[1]) - padded_audio = F.pad(audio, (0, num_padding)) - return padded_audio, padded_len - - - def setup_rvq_audio_codec(model): """ Sets up an ``AudioCodecModel``, initializing it from pretrained weights. @@ -434,799 +353,6 @@ def setup_audio_codec(self): self.target_fps = self.target_sample_rate / self.audio_codec.config.wav_to_token_ratio self.target_samples_per_frame = self.audio_codec.config.wav_to_token_ratio -def subwords_to_chars(subword_ids: torch.Tensor, - subword_id_to_char_ids: dict[int, tuple[int, ...]], - bos_id: int, - eos_id: int, - pad_id: int): - """ - Fully vectorized subword->char expansion across all BOS..EOS spans: - - Handles multiple spans per batch - - Preserves BOS/EOS - - Truncates expansions to fit each span - - Very fast on GPU - """ - device = subword_ids.device - B, T = subword_ids.shape - - # Build LUT - max_subword_id = int(subword_ids.max().item()) - max_chars = max(len(v) for v in subword_id_to_char_ids.values()) if subword_id_to_char_ids else 0 - if max_chars == 0: - return subword_ids.clone() - - char_expansion = torch.full((max_subword_id + 1, max_chars), - fill_value=pad_id, device=device, dtype=subword_ids.dtype) - expansion_len = torch.zeros(max_subword_id + 1, dtype=torch.long, device=device) - for k, v in subword_id_to_char_ids.items(): - if k <= max_subword_id: - v_t = torch.tensor(v, device=device, dtype=subword_ids.dtype) - char_expansion[k, :len(v_t)] = v_t - expansion_len[k] = len(v_t) - - # Output initialized with PAD - output = torch.full_like(subword_ids, fill_value=pad_id) - special_mask = (subword_ids == bos_id) | (subword_ids == eos_id) - output[special_mask] = subword_ids[special_mask] - - # Find next EOS for each position - pos = torch.arange(T, device=device) - bos_mask = (subword_ids == bos_id) - eos_mask = (subword_ids == eos_id) - eos_pos_tensor = torch.where(eos_mask, pos.unsqueeze(0).expand(B, T), - torch.full((B, T), T, device=device)) - next_eos_idx = torch.flip(torch.cummin(torch.flip(eos_pos_tensor, [1]), dim=1).values, [1]) - - # Collect all BOS coordinates - bos_coords = torch.nonzero(bos_mask, as_tuple=False) - if bos_coords.numel() == 0: - return output - - batch_ids = bos_coords[:, 0] - span_starts = bos_coords[:, 1] + 1 - span_ends = next_eos_idx[batch_ids, bos_coords[:, 1]] - span_lens = (span_ends - span_starts).clamp(min=0) - S = span_lens.numel() - if S == 0: - return output - - # Max span length - max_span_len = int(span_lens.max().item()) - - # Gather subwords for all spans [S, max_span_len] - rel = torch.arange(max_span_len, device=device).unsqueeze(0).expand(S, -1) - span_idx = span_starts.unsqueeze(1) + rel - span_idx_clamped = span_idx.clamp(0, T-1) - batch_idx_expand = batch_ids.unsqueeze(1).expand(-1, max_span_len) - sub_span = subword_ids[batch_idx_expand, span_idx_clamped] - - # Mask positions beyond actual span length - valid_pos_mask = rel < span_lens.unsqueeze(1) - sub_span = torch.where(valid_pos_mask, sub_span, torch.full_like(sub_span, pad_id)) - - # Expand subwords -> chars - expanded = char_expansion[sub_span] # [S, max_span_len, max_chars] - S_len = max_span_len * max_chars - expanded_flat = expanded.view(S, S_len) - valid_char_mask = expanded_flat != pad_id - valid_cumsum = torch.cumsum(valid_char_mask.long(), dim=1) - span_lens_exp = span_lens.unsqueeze(1).expand(-1, S_len) - keep_mask = valid_char_mask & (valid_cumsum <= span_lens_exp) - - # Compute flattened indices to scatter - rank_flat = (valid_cumsum - 1).clamp(min=0).view(-1) - values_flat = expanded_flat.view(-1) - keep_flat = keep_mask.view(-1) - kept_values = values_flat[keep_flat] - - target_positions = (span_starts.unsqueeze(1).repeat(1, S_len).view(-1))[keep_flat] + rank_flat[keep_flat] - target_batches = batch_ids.unsqueeze(1).repeat(1, S_len).view(-1)[keep_flat] - - # Safety clamp - within_T = target_positions < T - kept_values = kept_values[within_T] - target_positions = target_positions[within_T] - target_batches = target_batches[within_T] - - # Scatter in one shot - output[target_batches, target_positions] = kept_values - - return output - - -def subwords_to_chars_batched(subword_ids: torch.Tensor, - subword_id_to_char_ids: dict[int, tuple[int, ...]], - bos_id: int, - eos_id: int, - pad_id: int, - silence_id: int = 0): - """ - Batched subword->char expansion per BOS..EOS span. - - Multiple spans per batch - - Fully vectorized (no Python loop over spans) - - BOS/EOS exact - - Silences between spans - """ - B, T = subword_ids.shape - device = subword_ids.device - - # Build LUT - max_subword_id = int(subword_ids.max().item()) - max_chars = max(len(v) for v in subword_id_to_char_ids.values()) if subword_id_to_char_ids else 0 - if max_chars == 0: - return subword_ids.clone() - - char_expansion = torch.full((max_subword_id + 1, max_chars), - fill_value=pad_id, device=device, dtype=subword_ids.dtype) - expansion_len = torch.zeros(max_subword_id + 1, dtype=torch.long, device=device) - for k, v in subword_id_to_char_ids.items(): - if k <= max_subword_id: - v_t = torch.tensor(v, device=device, dtype=subword_ids.dtype) - char_expansion[k, :len(v_t)] = v_t - expansion_len[k] = len(v_t) - - # Output initialized with PAD - output = torch.full_like(subword_ids, fill_value=pad_id) - special_mask = (subword_ids == bos_id) | (subword_ids == eos_id) - output[special_mask] = subword_ids[special_mask] - - # Masks - bos_mask = (subword_ids == bos_id) - eos_mask = (subword_ids == eos_id) - - # Compute next EOS per position - pos = torch.arange(T, device=device) - eos_pos_tensor = torch.where(eos_mask, pos.unsqueeze(0).expand(B, T), - torch.full((B, T), T, device=device)) - next_eos_idx = torch.flip(torch.cummin(torch.flip(eos_pos_tensor, [1]), dim=1).values, [1]) - - # Collect all spans - bos_coords = torch.nonzero(bos_mask, as_tuple=False) - if bos_coords.numel() == 0: - return output - - batch_ids = bos_coords[:, 0] - span_starts = bos_coords[:, 1] + 1 - span_ends = next_eos_idx[batch_ids, bos_coords[:, 1]] - span_lens = (span_ends - span_starts).clamp(min=0) - S = span_lens.numel() - if S == 0: - return output - - # Gather subwords for all spans - max_span_len = int(span_lens.max().item()) - rel = torch.arange(max_span_len, device=device).unsqueeze(0).expand(S, -1) - span_idx = span_starts.unsqueeze(1) + rel - span_idx_clamped = span_idx.clamp(0, T - 1) - batch_idx_expand = batch_ids.unsqueeze(1).expand(-1, max_span_len) - sub_span = subword_ids[batch_idx_expand, span_idx_clamped] - - # Mask positions beyond actual span length - valid_pos_mask = rel < span_lens.unsqueeze(1) - sub_span = torch.where(valid_pos_mask, sub_span, torch.full_like(sub_span, pad_id)) - - # Expand subwords -> chars - expanded = char_expansion[sub_span] # [S, max_span_len, max_chars] - S_len = max_span_len * max_chars - expanded_flat = expanded.view(S, S_len) - valid_char_mask = expanded_flat != pad_id - valid_cumsum = torch.cumsum(valid_char_mask.long(), dim=1) - span_lens_exp = span_lens.unsqueeze(1).expand(-1, S_len) - keep_mask = valid_char_mask & (valid_cumsum <= span_lens_exp) - - # Compute target positions - rank_flat = (valid_cumsum - 1).clamp(min=0).view(-1) - values_flat = expanded_flat.view(-1) - keep_flat = keep_mask.view(-1) - kept_values = values_flat[keep_flat] - - target_positions = (span_starts.unsqueeze(1).repeat(1, S_len).view(-1))[keep_flat] + rank_flat[keep_flat] - target_batches = batch_ids.unsqueeze(1).repeat(1, S_len).view(-1)[keep_flat] - - # Safety clamp - within_T = target_positions < T - kept_values = kept_values[within_T] - target_positions = target_positions[within_T] - target_batches = target_batches[within_T] - - # Scatter in one shot - output[target_batches, target_positions] = kept_values - - return output - - -def build_char_expansion_lut(subword_id_to_char_ids: dict[int, tuple[int, ...]], - pad_id: int, - device: str = "cuda"): - """ - Prebuild the LUT once for training. - Returns: - char_expansion: [max_subword_id+1, max_chars] - expansion_len: number of chars per subword - """ - if not subword_id_to_char_ids: - return None, None - - max_subword_id = max(subword_id_to_char_ids.keys()) - max_chars = max(len(v) for v in subword_id_to_char_ids.values()) - char_expansion = torch.full((max_subword_id + 1, max_chars), - fill_value=pad_id, device=device, dtype=torch.long) - expansion_len = torch.zeros(max_subword_id + 1, device=device, dtype=torch.long) - - for k, v in subword_id_to_char_ids.items(): - if k <= max_subword_id: - v_t = torch.tensor(v, device=device, dtype=torch.long) - char_expansion[k, :len(v_t)] = v_t - expansion_len[k] = len(v_t) - - return char_expansion, expansion_len - - -def subwords_to_chars_batched_fast(subword_ids: torch.Tensor, - char_expansion: torch.Tensor, - expansion_len: torch.Tensor, - bos_id: int, - eos_id: int, - pad_id: int): - """ - Fast batched subword->char expansion using prebuilt LUT. - Fully vectorized, multiple spans per batch, no autograd overhead. - """ - with torch.no_grad(): - if char_expansion is None: - return subword_ids.clone() - - B, T = subword_ids.shape - device = subword_ids.device - - # Initialize output - output = torch.full_like(subword_ids, pad_id) - special_mask = (subword_ids == bos_id) | (subword_ids == eos_id) - output[special_mask] = subword_ids[special_mask] - - # Masks - bos_mask = (subword_ids == bos_id) - eos_mask = (subword_ids == eos_id) - - # Next EOS per position - pos = torch.arange(T, device=device) - eos_pos_tensor = torch.where(eos_mask, pos.unsqueeze(0).expand(B, T), - torch.full((B, T), T, device=device)) - next_eos_idx = torch.flip(torch.cummin(torch.flip(eos_pos_tensor, [1]), dim=1).values, [1]) - - # Collect all spans - bos_coords = torch.nonzero(bos_mask, as_tuple=False) - if bos_coords.numel() == 0: - return output - - batch_ids = bos_coords[:, 0] - span_starts = bos_coords[:, 1] + 1 - span_ends = next_eos_idx[batch_ids, bos_coords[:, 1]] - span_lens = (span_ends - span_starts).clamp(min=0) - S = span_lens.numel() - if S == 0: - return output - - # Gather subwords - max_span_len = int(span_lens.max().item()) - rel = torch.arange(max_span_len, device=device).unsqueeze(0).expand(S, -1) - span_idx = span_starts.unsqueeze(1) + rel - span_idx_clamped = span_idx.clamp(0, T-1) - batch_idx_expand = batch_ids.unsqueeze(1).expand(-1, max_span_len) - sub_span = subword_ids[batch_idx_expand, span_idx_clamped] - - valid_pos_mask = rel < span_lens.unsqueeze(1) - sub_span = torch.where(valid_pos_mask, sub_span, torch.full_like(sub_span, pad_id)) - - # Expand using prebuilt LUT - expanded = char_expansion[sub_span] # [S, max_span_len, max_chars] - S_len = max_span_len * char_expansion.shape[1] - expanded_flat = expanded.view(S, S_len) - - valid_char_mask = expanded_flat != pad_id - valid_cumsum = torch.cumsum(valid_char_mask.long(), dim=1) - span_lens_exp = span_lens.unsqueeze(1).expand(-1, S_len) - keep_mask = valid_char_mask & (valid_cumsum <= span_lens_exp) - - rank_flat = (valid_cumsum - 1).clamp(min=0).view(-1) - values_flat = expanded_flat.view(-1) - keep_flat = keep_mask.view(-1) - kept_values = values_flat[keep_flat] - - target_positions = (span_starts.unsqueeze(1).repeat(1, S_len).view(-1))[keep_flat] + rank_flat[keep_flat] - target_batches = batch_ids.unsqueeze(1).repeat(1, S_len).view(-1)[keep_flat] - - within_T = target_positions < T - kept_values = kept_values[within_T] - target_positions = target_positions[within_T] - target_batches = target_batches[within_T] - - output[target_batches, target_positions] = kept_values - - return output - - -class WordSepTokenizer(AutoTokenizer): - """ - Tokenizer wrapper that inserts a special word-separator token before each token - that starts a new word. This is useful for Speech-LLM and TTS pipelines - that require explicit word boundaries in the token sequence. - - Supported models: - - LLaMA-3.1-family - - NVIDIA Nemotron Nano-9B-v2 - - Attributes: - word_sep_token (str): The special token used to mark word boundaries. - word_boundary_prefix (str): The token prefix indicating a word boundary. - word_sep_id (int): The token ID corresponding to `word_sep_token`. - """ - - def __init__(self, model_name: str, *args, **kwargs): - """ - Initializes the WordSepTokenizer. - - Args: - model_name (str): Name of the model to load. Determines the special - word-separator token and word boundary prefix. - *args: Additional positional arguments passed to the base `AutoTokenizer`. - **kwargs: Additional keyword arguments passed to the base `AutoTokenizer`. - - Raises: - ValueError: If `model_name` is not supported. - """ - super().__init__(model_name, *args, **kwargs) - - model_name_lower = model_name.lower() - if "llama-3.1" in model_name_lower: - self.word_sep_token = "<|reserved_special_token_0|>" - self.word_boundary_prefix = "Ġ" - elif "qwen2.5" in model_name_lower: - self.word_sep_token = "<|box_start|>" - self.word_boundary_prefix = "Ġ" - elif "nvidia-nemotron-nano-9b-v2" in model_name_lower: - self.word_sep_token = "" - self.word_boundary_prefix = "Ġ" - else: - raise ValueError( - f"WordSepTokenizer does not support model '{model_name}'. " - "Supported: LLaMA-3.1-family, NVIDIA Nemotron Nano-9B-v2." - ) - - self.word_sep_id = self.tokenizer.convert_tokens_to_ids(self.word_sep_token) - - def text_to_ids(self, text: str): - """ - Converts input text into token IDs, inserting the word-separator ID - before tokens that start a new word. - - Args: - text (str): Input string to tokenize. - - Returns: - List[int]: Token IDs with word-separator IDs inserted. - - Notes: - - If `text` is empty or tokenization returns no tokens, returns an empty list. - - The first token separator (if any) is removed to avoid leading separators. - """ - if not text: - return [] - - # ensures that first word has a space to avoid different tokens for the first word - if text[0] != " ": - text = " " + text - - # Original token IDs - ids = super().text_to_ids(text) - if not ids: - return [] - - # Convert IDs to tokens safely (must be CPU Python list, no separator IDs yet) - tokens = self.tokenizer.convert_ids_to_tokens(list(ids)) - - # Mask for tokens starting with word boundary - mask = [t.startswith(self.word_boundary_prefix) for t in tokens] - - # Prepare result - result = [] - for tid, m in zip(ids, mask): - if m: - result.append(self.word_sep_id) - result.append(tid) - - # Remove leading separator if present - if result and result[0] == self.word_sep_id: - result = result[1:] - - return result - - def ids_to_text(self, ids): - """ - Converts token IDs back to text, replacing word-separator tokens with spaces. - - Args: - ids (List[int]): List of token IDs. - - Returns: - str: Decoded text with word separators converted to spaces. - """ - text = super().ids_to_text(ids) - return text.replace(self.word_sep_token, " ") - - -class NeMoGroupedCodec(NeuralModule): - def __init__(self, codec, frame_stacking_factor=1): - super().__init__() - self.codec = codec - self.frame_stacking_factor = frame_stacking_factor - - @property - def device(self): - return self.codec.device - - @property - def _codebook_size(self): - return self.codec.vector_quantizer.codebook_size_per_group - - @property - def _num_codebooks(self): - return self.codec.vector_quantizer.num_groups * self.frame_stacking_factor - - @property - def samples_per_frame(self): - return self.codec.samples_per_frame * self.frame_stacking_factor - - def encode(self, audio, audio_len): - audio = audio.squeeze(1) - with fp32_precision(): - # make the audio divisible by frame rate and also by self.frame_stacking_factor with extra frames of 1 to avoid issues because we are removing a audio frame to shift target and input for TF - audio, audio_len = self.pad_audio_to_factor(audio, audio_len, self.samples_per_frame, extra_frames=0) - # encodes audio using the codec - tokens, tokens_len = self.codec.encode(audio=audio, audio_len=audio_len) # B, C, T - tokens = tokens.transpose(1, 2) # → B, T, C - B, T, C = tokens.shape - assert T % self.frame_stacking_factor == 0 - grouped = tokens.reshape(B, T // self.frame_stacking_factor, C * self.frame_stacking_factor) - tokens_len = tokens_len // self.frame_stacking_factor - # grouped = grouped.transpose(1, 2) - - return grouped, tokens_len - - def decode(self, tokens, tokens_len): - with fp32_precision(): - # tokens = tokens.transpose(1, 2) - # tokens: B, T', C' - B, T, Cg = tokens.shape - assert Cg % self.frame_stacking_factor == 0 - C = Cg // self.frame_stacking_factor - ungrouped = tokens.reshape(B, T * self.frame_stacking_factor, C) # → [B, T, C] - ungrouped = ungrouped.transpose(1, 2) # → [B, C, T] for decode - tokens_len = torch.ceil(tokens_len * self.frame_stacking_factor).to(tokens_len.dtype) - audio, audio_len = self.codec.decode(tokens=ungrouped, tokens_len=tokens_len) - return audio, audio_len - - def decode_audio(self, inputs: torch.Tensor, input_len: torch.Tensor): - """Apply decoder on the input. Note that the input is a non-quantized encoder output or a dequantized representation. - - Args: - inputs: encoded signal - input_len: valid length for each example in the batch - - Returns: - Decoded output `audio` in the time domain and its length in number of samples `audio_len`. - Note that `audio_len` will be a multiple of `self.samples_per_frame`. - """ - with fp32_precision(): - if self.frame_stacking_factor > 1: - inputs = inputs.transpose(1, 2) - B, T, Cg = inputs.shape - C = Cg // self.frame_stacking_factor - inputs = inputs.reshape(B, T * self.frame_stacking_factor, C) # → [B, T, C] - input_len = torch.ceil(input_len * self.frame_stacking_factor).to(input_len.dtype) - inputs = inputs.transpose(1, 2) - - audio, audio_len = self.codec.audio_decoder(inputs=inputs, input_len=input_len) - return audio, audio_len - - def dequantize(self, tokens: torch.Tensor, tokens_len: torch.Tensor) -> torch.Tensor: - """Convert the discrete tokens into a continuous encoded representation. - - Args: - tokens: discrete tokens for each codebook for each time frame - tokens_len: valid length of each example in the batch - - Returns: - Continuous encoded representation of the discrete input representation. - """ - with fp32_precision(): - # reshape to dequantize - if self.frame_stacking_factor > 1: - tokens = tokens.transpose(1, 2) - # tokens: B, T', C' - B, T, Cg = tokens.shape - assert Cg % self.frame_stacking_factor == 0 - C = Cg // self.frame_stacking_factor - tokens = tokens.reshape(B, T * self.frame_stacking_factor, C) # → [B, T, C] - tokens = tokens.transpose(1, 2) # → [B, C, T] for decode - tokens_len = torch.ceil(tokens_len * self.frame_stacking_factor).to(tokens_len.dtype) - dequantized = self.codec.dequantize(tokens=tokens, tokens_len=tokens_len) - # reshape back to the compress form if needed - if self.frame_stacking_factor > 1: - dequantized = dequantized.transpose(1, 2) # → B, T, C - B, T, C = dequantized.shape - assert T % self.frame_stacking_factor == 0 - dequantized = dequantized.reshape(B, T // self.frame_stacking_factor, C * self.frame_stacking_factor) - dequantized = dequantized.transpose(1, 2) # → B, C, T - - return dequantized - - def forward(self, audio, audio_len): - tokens, tokens_len = self.encode(audio, audio_len) - audio, audio_len = self.decode(tokens, tokens_len) - return audio, audio_len - - def pad_audio_to_factor(self, audio, audio_len, samples_per_frame, extra_frames: int = 0): - """ - Zero pad the end of the audio so that we do not have a partial end frame. - The output will be zero-padded to have an integer number of frames of - length `samples_per_frame * frame_stacking_factor`. - - Args: - audio: input time-domain signal (B, T) - audio_len: valid length for each example in the batch (B,) - samples_per_frame: number of samples per frame - - Returns: - padded_audio: Padded time-domain signal (B, T') - padded_len: Adjusted valid lengths (B,) - """ - with fp32_precision(): - padded_len = (samples_per_frame * torch.ceil(audio_len / samples_per_frame).int()) + (extra_frames * samples_per_frame) - max_len = padded_len.max().int().item() - num_padding = (max_len - audio.shape[1]) - padded_audio = F.pad(audio, (0, num_padding)) - return padded_audio, padded_len - -import math - -def compare_dicts(dict_a, dict_b): - all_keys = set(dict_a.keys()).union(dict_b.keys()) - equal = True - differing_keys = [] - - for key in sorted(all_keys): - a_val = dict_a.get(key, None) - b_val = dict_b.get(key, None) - - # Skip if value is None in either dict - if a_val is None or b_val is None: - continue - - # Handle both being NaN (float) - if (isinstance(a_val, float) and math.isnan(a_val)) and \ - (isinstance(b_val, float) and math.isnan(b_val)): - continue - - # Handle both being tensors - if isinstance(a_val, torch.Tensor) and isinstance(b_val, torch.Tensor): - # Shape mismatch - if a_val.shape != b_val.shape: - print(f"❌ Shape mismatch at key '{key}': {a_val.shape} vs {b_val.shape}") - equal = False - differing_keys.append(key) - continue - - # Compare tensors elementwise (treating NaNs as equal) - diff_mask = ~(torch.isclose(a_val, b_val, equal_nan=True)) - if diff_mask.any(): - equal = False - differing_keys.append(key) - idx = torch.nonzero(diff_mask, as_tuple=False) - print(f"❌ Tensor mismatch at key '{key}': {idx.shape[0]} differing positions, shape: ", a_val.shape, b_val.shape) - # Print up to first 10 differences - for i, pos in enumerate(idx[:10]): - pos_tuple = tuple(pos.tolist()) - a_item = a_val[pos_tuple].item() - b_item = b_val[pos_tuple].item() - print(f" Position {pos_tuple}: {a_item} vs {b_item}") - if idx.shape[0] > 10: - print(f" ... and {idx.shape[0] - 10} more differences") - continue - - # Fallback: direct comparison - if a_val != b_val: - print(f"❌ Value mismatch at key '{key}': {a_val} vs {b_val}") - equal = False - differing_keys.append(key) - - if equal: - print("✅ All comparable keys and values match!") - else: - print("⚠️ Some keys/values differ (see above).") - - return equal, differing_keys - -import copy -def extract_first_tensor(x): - """Recursively find the first tensor in nested structures.""" - if isinstance(x, torch.Tensor): - return x - if isinstance(x, (list, tuple)): - for v in x: - t = extract_first_tensor(v) - if t is not None: - return t - if isinstance(x, dict): - for v in x.values(): - t = extract_first_tensor(v) - if t is not None: - return t - return None - -def compare_tts_model_fp32_bf16_old(tts_model, inputs, atol=1e-3, topk=15): - model_fp32 = copy.deepcopy(tts_model).eval().to(torch.float32) - model_bf16 = copy.deepcopy(tts_model).eval().to(torch.bfloat16) - - diffs = {} - - def make_hook(name, tag): - def hook_fn(module, inp, out): - tensor = extract_first_tensor(out) - if tensor is not None: - tensor = tensor.detach().float().cpu() - if name not in diffs: - diffs[name] = {} - diffs[name][tag] = tensor - return hook_fn - - # Register hooks independently - hooks_fp32 = [] - for name, module in model_fp32.named_modules(): - hooks_fp32.append(module.register_forward_hook(make_hook(name, "fp32"))) - - hooks_bf16 = [] - for name, module in model_bf16.named_modules(): - hooks_bf16.append(module.register_forward_hook(make_hook(name, "bf16"))) - - def maybe_to(x, dtype): - if x is None: - return None - if isinstance(x, torch.Tensor) and torch.is_floating_point(x): - return x.to(dtype) - return x - - with torch.no_grad(): - # BF16 forward - with torch.autocast('cuda', dtype=torch.bfloat16): - _ = model_bf16( - code=maybe_to(inputs["code"], torch.bfloat16), - audio_mask=inputs["audio_mask"], - attention_mask=inputs["attention_mask"], - position_ids=inputs["position_ids"], - context_hidden_state=maybe_to(inputs["context_hidden_state"], torch.bfloat16), - subword_ids=inputs["subword_ids"], - subword_mask=inputs["subword_mask"], - non_prompt_mask=inputs["non_prompt_mask"] - ) - - # FP32 forward - _ = model_fp32( - code=maybe_to(inputs["code"], torch.float32), - audio_mask=inputs["audio_mask"], - attention_mask=inputs["attention_mask"], - position_ids=inputs["position_ids"], - context_hidden_state=maybe_to(inputs["context_hidden_state"], torch.float32), - subword_ids=inputs["subword_ids"], - subword_mask=inputs["subword_mask"], - non_prompt_mask=inputs["non_prompt_mask"] - ) - - # Compute diffs for matching layers - diff_list = [] - for name, val in diffs.items(): - if "fp32" in val and "bf16" in val: - delta = (val["fp32"] - val["bf16"]).abs().mean().item() - diff_list.append((name, delta)) - - diff_list.sort(key=lambda x: x[1], reverse=True) - - print(f"\nTop {topk} layers with largest FP32 vs BF16 diff:") - if not diff_list: - print("⚠️ No matching tensor outputs found. Try increasing atol or check nested outputs.") - else: - for name, delta in diff_list[:topk]: - print(f"{name:<60} mean abs diff = {delta:.6f}") - - for h in hooks_fp32 + hooks_bf16: - h.remove() - - return diff_list - -def compare_tts_model_fp32_bf16_mixed(tts_model, inputs, topk=15): - """ - Compare FP32 vs BF16-safe (with fp32_precision layers) outputs. - tts_model can have patched FP32 layers; these will run in FP32. - """ - import copy - diffs = {} - - def extract_first_tensor(x): - if isinstance(x, (tuple, list)): - for y in x: - if torch.is_tensor(y): - return y - return None - if torch.is_tensor(x): - return x - return None - - def make_hook(name, tag): - def hook_fn(module, inp, out): - tensor = extract_first_tensor(out) - if tensor is not None: - tensor = tensor.detach().float().cpu() - if name not in diffs: - diffs[name] = {} - diffs[name][tag] = tensor - return hook_fn - - # FP32 reference model - model_fp32 = copy.deepcopy(tts_model).eval().to(torch.float32) - - hooks_fp32 = [m.register_forward_hook(make_hook(n, "fp32")) for n, m in model_fp32.named_modules()] - hooks_bf16 = [m.register_forward_hook(make_hook(n, "bf16")) for n, m in tts_model.named_modules()] - - def maybe_to(x, dtype): - if x is None: - return None - if isinstance(x, torch.Tensor) and torch.is_floating_point(x): - return x.to(dtype) - return x - - with torch.no_grad(): - # BF16-safe forward (patched FP32 layers run in FP32) - with torch.autocast("cuda", dtype=torch.bfloat16): - _ = tts_model( - code=maybe_to(inputs["code"], torch.bfloat16), - audio_mask=inputs["audio_mask"], - attention_mask=inputs["attention_mask"], - position_ids=inputs["position_ids"], - context_hidden_state=maybe_to(inputs["context_hidden_state"], torch.bfloat16), - subword_ids=inputs["subword_ids"], - subword_mask=inputs["subword_mask"], - non_prompt_mask=inputs["non_prompt_mask"] - ) - - # FP32 forward - _ = model_fp32( - code=maybe_to(inputs["code"], torch.float32), - audio_mask=inputs["audio_mask"], - attention_mask=inputs["attention_mask"], - position_ids=inputs["position_ids"], - context_hidden_state=maybe_to(inputs["context_hidden_state"], torch.float32), - subword_ids=inputs["subword_ids"], - subword_mask=inputs["subword_mask"], - non_prompt_mask=inputs["non_prompt_mask"] - ) - - # Compute diffs - diff_list = [] - for name, val in diffs.items(): - if "fp32" in val and "bf16" in val: - delta = (val["fp32"] - val["bf16"]).abs().mean().item() - diff_list.append((name, delta)) - - diff_list.sort(key=lambda x: x[1], reverse=True) - print(f"\nTop {topk} layers with largest FP32 vs BF16 diff:") - for name, delta in diff_list[:topk]: - print(f"{name:<60} mean abs diff = {delta:.6f}") - - for h in hooks_fp32 + hooks_bf16: - h.remove() - - return diff_list - def rescale_state_dict( state_dict, target_std=0.02, @@ -1332,8 +458,6 @@ def __init__(self, cfg: dict) -> None: self.embed_tokens = self._load_embed_tokens(self.cfg) # delete llm because we use it only to get the embbeding tokens del self.language_model - if self.cfg.tts_config.get("use_subword_flag_emb", False): - self.subword_flag_emb = SubwordFlagEmbedding(self.cfg.pretrained_lm_name, self.cfg.tts_config.context_hidden_size) # instanciate eartts model and codec self._load_tts_model(self.cfg) @@ -1572,117 +696,6 @@ def pad_audio_to_factor(self, audio, audio_len, samples_per_frame, downsampling_ def prepare_inputs(self, batch: dict): """ - """ - """ - import hashlib - import torch - - def hash_texts(text_list): - hashes = [] - for t in text_list: - norm = t.strip().lower() - h = hashlib.sha256(norm.encode("utf-8")).hexdigest() - hashes.append(h) - return hashes - - # --- Safe batch filtering function --- - def filter_batch_by_indices(batch, keep_indices): - if not keep_indices: - # No common samples: empty all fields - new_batch = {} - for k, v in batch.items(): - if isinstance(v, list): - new_batch[k] = [] - elif hasattr(v, "__getitem__") and not isinstance(v, str): - try: - new_batch[k] = v[0:0] # empty tensor - except Exception: - new_batch[k] = v - else: - new_batch[k] = v - return new_batch - - new_batch = {} - for k, v in batch.items(): - try: - if isinstance(v, list): - new_batch[k] = [v[i] for i in keep_indices if i < len(v)] - elif hasattr(v, "__getitem__") and not isinstance(v, str): - slices = [i for i in keep_indices if i < v.shape[0]] - if slices: - new_batch[k] = v[slices] - else: - new_batch[k] = v[0:0] # empty tensor - else: - new_batch[k] = v # keep metadata as-is - except Exception: - new_batch[k] = v # fallback if indexing fails - return new_batch - - # --- Compute sample IDs --- - target_texts = batch["target_texts"] - batch["sample_id"] = target_texts # using text itself as unique ID - print("Sample ids:", batch["sample_id"]) - - if self.training: - # --- Track sample IDs and store full batch --- - if not hasattr(self, "train_sample_ids"): - self.train_sample_ids = set(batch["sample_id"]) - self.train_batches_by_hash = dict() - else: - self.train_sample_ids.update(batch["sample_id"]) - - # Save the full batch per sample - for i, sid in enumerate(batch["sample_id"]): - self.train_batches_by_hash[sid] = { - k: (v[i] if isinstance(v, list) else v[i:i+1]) - for k, v in batch.items() - } - - else: - # --- Validation: keep only common samples --- - if not hasattr(self, "eval_common_ids"): - self.eval_common_ids = set() - - # Only consider validation samples that exist in training - keep_indices = [i for i, sid in enumerate(batch["sample_id"]) - if sid in self.train_batches_by_hash] - - # Safe filtering - batch = filter_batch_by_indices(batch, keep_indices) - - if keep_indices: - print(f"Keeping only {len(keep_indices)} common samples from validation!") - # Update eval_common_ids - self.eval_common_ids.update(batch["sample_id"]) - print( - f"total_common={len(self.eval_common_ids)}, " - f"train_total={len(self.train_sample_ids)}" - ) - - # --- Compare the first common sample --- - first_sid = batch["sample_id"][0] - train_sample = self.train_batches_by_hash[first_sid] - val_sample = {k: (v[0] if isinstance(v, list) else v[0:1]) - for k, v in batch.items()} - - # --- Slice tensors to minimal overlapping shape --- - for k in val_sample.keys(): - t_val = val_sample[k] - t_train = train_sample.get(k, t_val) - if isinstance(t_val, torch.Tensor) and isinstance(t_train, torch.Tensor): - min_shape = tuple(min(s1, s2) for s1, s2 in zip(t_val.shape, t_train.shape)) - if all(s > 0 for s in min_shape): - slices = tuple(slice(0, s) for s in min_shape) - val_sample[k] = t_val[slices] - train_sample[k] = t_train[slices] - - print(f"Comparing first common sample (sid={first_sid})") - compare_dicts(train_sample, val_sample) - exit() - else: - print("No common samples found in this validation batch!") - """ # check if audios has the same batch size assert batch["source_audio"].size(0) == batch["target_audio"].size(0) @@ -1691,7 +704,6 @@ def filter_batch_by_indices(batch, keep_indices): target_audio = batch["target_audio"] target_audio_lens = batch["target_audio_lens"] input_text_tokens = batch["input_text_tokens"] - audio_mask = batch["audio_mask"] desc_mask = batch["desc_mask"] non_prompt_mask = batch["non_prompt_mask"] aligned_attention_mask = batch["aligned_attention_mask"] @@ -1735,7 +747,6 @@ def pad_or_truncate(x, pad_value=0): return x # leave others for now input_text_tokens = pad_or_truncate(input_text_tokens, pad_value=self.text_pad_id) - audio_mask = pad_or_truncate(audio_mask, pad_value=0) desc_mask = pad_or_truncate(desc_mask, pad_value=0) non_prompt_mask = pad_or_truncate(non_prompt_mask, pad_value=0) aligned_position_ids = pad_or_truncate(aligned_position_ids, pad_value=0) @@ -1750,27 +761,19 @@ def pad_or_truncate(x, pad_value=0): elif L1 > new_len or L2 > new_len: aligned_attention_mask = aligned_attention_mask[:, :, :new_len, :new_len] - if self.cfg.get("disable_speech_pad", False): - target_codes_aligned = target_codes - else: - # ToDo: desc_mask is one for the end of the sequence, this is what cause the artifact issue in the end, fix it. - # set the pad token when there is desc as in https://gitlab-master.nvidia.com/jaehyeonk/easy-ar-tts/-/blame/simple-bq/scripts/train_tts_with_rvqvae.py#L69 - target_codes_aligned = torch.where( - desc_mask.unsqueeze(-1), # (B, T, 1) for broadcasting - torch.full_like(target_codes, self.speech_pad_id), # fill with pad id - target_codes - ) - - if self.cfg.get("ignore_audio_prompt_on_loss", False): - # set audio_mask as non_prompt_mask to avoid the audio prompt in loss computation - audio_mask = non_prompt_mask + # ToDo: desc_mask is one for the end of the sequence, this is what cause the artifact issue in the end, fix it. + # set the pad token when there is desc + target_codes_aligned = torch.where( + desc_mask.unsqueeze(-1), # (B, T, 1) for broadcasting + torch.full_like(target_codes, self.speech_pad_id), # fill with pad id + target_codes + ) - if self.cfg.get("add_pad_speech_token_in_last_prompt_frame", False) and not self.cfg.get("disable_speech_pad", False): - # set special token in the last audio prompt (it will works as a BOS token) - pos = non_prompt_mask.float().argmax(dim=1) # shape: [B] - row_idx = torch.arange(B, device=self.device) - # set the extra self.speech_pad_id at first 1 position in non_prompt_mask - target_codes_aligned[row_idx, pos] = self.speech_pad_id + # set special token in the last audio prompt (it will works as a BOS token) + pos = non_prompt_mask.float().argmax(dim=1) # shape: [B] + row_idx = torch.arange(B, device=self.device) + # set the extra self.speech_pad_id at first 1 position in non_prompt_mask + target_codes_aligned[row_idx, pos] = self.speech_pad_id B, T = input_text_tokens.shape @@ -1779,12 +782,9 @@ def pad_or_truncate(x, pad_value=0): # note that we are using a text mask where we are ignoring the desc + audio prompt but we are keeping 1 until the audio ends to support duplex subword_mask = F.pad(non_prompt_mask[:, 1:], [0, 1]) - # ToDo: implement context from the llm # detach embedding as in eartts if self.cfg.tts_config.context_hidden_size is not None: context_hidden_state = self.embed_tokens(input_text_tokens).detach() - if self.cfg.tts_config.get("use_subword_flag_emb", False): - context_hidden_state = self.subword_flag_emb(context_hidden_state, input_text_tokens) else: context_hidden_state = None @@ -1794,14 +794,13 @@ def pad_or_truncate(x, pad_value=0): input_text_tokens = input_text_tokens[:, :-remainder] target_codes_aligned = target_codes_aligned[:, :-remainder] target_codes_aligned = target_codes_aligned[:, :-remainder] - audio_mask = audio_mask[:, :-remainder] desc_mask = desc_mask[:, :-remainder] subword_ids = subword_ids[:, :-remainder] subword_mask = subword_mask[:, :-remainder] return { "code": target_codes_aligned, - "audio_mask": audio_mask, + "audio_mask": non_prompt_mask, # set audio_mask as non_prompt_mask to avoid the audio prompt in loss computation "attention_mask": aligned_attention_mask, "position_ids": aligned_position_ids, "subword_ids": subword_ids, @@ -1960,6 +959,7 @@ def _get_generation_config(self, guidance_enabled: bool = False): } def offline_inference_with_custom_sentences(self, test_sentences: torch.Tensor, inference_speaker_reference: torch.Tensor, speech_text_ratio: float = 3.5): + # ToDo: split it in multiples batches to support long list of sentences B = len(test_sentences) # load and get speaker reference speaker_audio, sr = torchaudio.load(inference_speaker_reference) @@ -1969,16 +969,10 @@ def offline_inference_with_custom_sentences(self, test_sentences: torch.Tensor, speaker_audio_lens = torch.tensor([speaker_audio.size(1)], device=self.device).long().repeat(B) # Tokenize sentences - if self.normalize_text: - tokenized = [ - torch.as_tensor([self.tokenizer.bos] + self.tokenizer.text_to_ids(normalize_text_fn(text)), dtype=torch.long, device=self.device) - for text in test_sentences - ] - else: - tokenized = [ - torch.as_tensor([self.tokenizer.bos] + self.tokenizer.text_to_ids(text), dtype=torch.long, device=self.device) - for text in test_sentences - ] + tokenized = [ + torch.as_tensor([self.tokenizer.bos] + self.tokenizer.text_to_ids(text), dtype=torch.long, device=self.device) + for text in test_sentences + ] # Get max length and target length max_len = max(len(t) for t in tokenized) @@ -2009,13 +1003,11 @@ def run_evaluation_one_batch(self, name, dataset_batch, use_dataloader_init=Fals # exit() # first evaluation, make the model bf16 safe if not self.model_16_precision_safe and self.cfg.get("ensures_16_safe", False) and str(self.trainer_config.precision) != str(32): + # ToDo: move it to a method self.tts_model, summary = make_tts_model_mixed_precision_definite(self.tts_model, inputs, safety_factor=1.0, mixed_dtype=torch.float16 if str(self.trainer_config.precision) == str(16) else torch.bfloat16) # self.tts_model, summary = make_tts_model_mixed_precision_safe(self.tts_model, inputs, safety_factor=1.0) self.model_16_precision_safe = True - print("Current FP32 layers:", summary["fp32_layers"]) - # compare_tts_model_fp32_bf16_mixed(self.tts_model, inputs) - # exit() results["audio_tf"], results["audio_tf_len"] = self.get_teacher_force_inference_audio(dataset_batch) if use_dataloader_init: @@ -2036,32 +1028,6 @@ def run_evaluation_one_batch(self, name, dataset_batch, use_dataloader_init=Fals for i, l in enumerate(dataset_batch["desc_plus_audio_prompt_lens"]) ]) - # drop items without description to avoid issues - """ - lens = dataset_batch["desc_plus_audio_prompt_lens"] # list of lengths - - # Example condition: keep only those with the maximum length - max_len = max(lens) - keep_indices = [i for i, l in enumerate(lens) if l == max_len] - - # Convert indices to tensor for indexing torch tensors - keep_indices = torch.tensor(keep_indices, dtype=torch.long) - - # Now filter every key in dataset_batch - for k, v in dataset_batch.items(): - if isinstance(v, torch.Tensor): - dataset_batch[k] = v[keep_indices] - elif isinstance(v, list): - dataset_batch[k] = [v[i] for i in keep_indices] - - # Do the same for inputs - for k, v in inputs.items(): - if isinstance(v, torch.Tensor): - inputs[k] = v[keep_indices] - elif isinstance(v, list): - inputs[k] = [v[i] for i in keep_indices] - """ - # remove the prompt from the input_text_tokens to emulate S2S connected inference next_subword_ids = torch.stack([ inputs["subword_ids"][i, l:] # slice each element @@ -2411,8 +1377,6 @@ def get_init_inputs(self, speaker_audio, speaker_audio_lens, system_prompt=None, # get context hidden if self.cfg.tts_config.context_hidden_size is not None: context_hidden_state = self.embed_tokens(input_text_tokens) - if self.cfg.tts_config.get("use_subword_flag_emb", False): - context_hidden_state = self.subword_flag_emb(context_hidden_state, input_text_tokens) else: context_hidden_state = None @@ -2422,9 +1386,6 @@ def get_init_inputs(self, speaker_audio, speaker_audio_lens, system_prompt=None, non_prompt_mask[:, -2:] = 1 # set last valid prompt frame as 1 to allow the addition of BOS in the right place subword_mask = torch.zeros_like(input_text_tokens) # subword_mask is almost all zeros because on the warmup there is only the prompt subword_mask[:, -3:] = 1 # -3 because of the it start right after the first valid prompt token and it is shifted by 1 - # audio mask is all ones except for description - audio_mask = torch.ones_like(input_text_tokens) - audio_mask[:, :desc_tokens_ids.size(-1)] = 0 # desc mask is all zeros except the description desc_mask = torch.zeros_like(input_text_tokens) desc_mask[:, :desc_tokens_ids.size(-1)] = 1 @@ -2439,23 +1400,17 @@ def get_init_inputs(self, speaker_audio, speaker_audio_lens, system_prompt=None, ) # shift subword_ids - # subword_ids = F.pad(input_text_tokens[:, 1:], [0, 1], value=current_subword_id) subword_ids = F.pad(input_text_tokens[:, 1:], [0, 1], value=0.0) - if self.cfg.get("ignore_audio_prompt_on_loss", False): - # set audio_mask as non_prompt_mask to avoid the audio prompt in loss computation - audio_mask = non_prompt_mask - - if self.cfg.get("add_pad_speech_token_in_last_prompt_frame", False) and not self.cfg.get("disable_speech_pad", False): - # set special token in the last audio prompt (it will works as a BOS token) - pos = non_prompt_mask.float().argmax(dim=1) # shape: [B] - row_idx = torch.arange(B, device=self.device) - # set the extra self.speech_pad_id at first 1 position in non_prompt_mask - code[row_idx, pos] = self.speech_pad_id + # set special token in the last audio prompt (it will works as a BOS token) + pos = non_prompt_mask.float().argmax(dim=1) # shape: [B] + row_idx = torch.arange(B, device=self.device) + # set the extra self.speech_pad_id at first 1 position in non_prompt_mask + code[row_idx, pos] = self.speech_pad_id init_inputs = { "code": code[:, :-1], - "audio_mask": audio_mask.bool()[:, :-1], + "audio_mask": non_prompt_mask.bool()[:, :-1], # set audio_mask as non_prompt_mask to avoid the audio prompt in loss computation "context_hidden_state": context_hidden_state[:, :-1] if context_hidden_state is not None else None, "subword_ids": subword_ids[:, :-1], "subword_mask": subword_mask.bool()[:, :-1], @@ -2551,10 +1506,6 @@ def offline_inference( # get first context subword_id, that is the last subword_ids from the warmup first_context_subword_id = init_inputs["subword_ids"][:, -1].unsqueeze(-1) - # reset cache of cumulative_word_emb - if self.cfg.tts_config.get("use_cumulative_word_emb", False): - self.tts_model.embed_subword.cumulative_word_emb.reset(B) - for i in range(max_steps-1): step_start = time.time() # current subword id is always seem @@ -2569,8 +1520,6 @@ def offline_inference( context_subword_id = next_subword_ids[:, i-1].unsqueeze(-1) context_hidden_state = self.embed_tokens(context_subword_id) - if self.cfg.tts_config.get("use_subword_flag_emb", False): - context_hidden_state = self.subword_flag_emb(context_hidden_state, context_subword_id) else: context_hidden_state = None @@ -2608,7 +1557,7 @@ def offline_inference( cur_asr_speech_tokens = logits.argmax(dim=-1)[:, -1].unsqueeze(-1) # force silence as next token - if self.cfg.get('inference_force_speech_silence_on_eos', None): + if self.cfg.get('inference_force_speech_silence_on_eos', True): silence_codes = self.codec_silence_tokens.view(1, 1, -1).expand(code.shape) code = torch.where( current_subword_id.unsqueeze(-1) == self.text_eos_id, diff --git a/nemo/collections/speechlm2/modules/rvq_ear_tts_vae.py b/nemo/collections/speechlm2/modules/rvq_ear_tts_vae.py new file mode 100644 index 000000000000..a9d67ec5ba8e --- /dev/null +++ b/nemo/collections/speechlm2/modules/rvq_ear_tts_vae.py @@ -0,0 +1,1085 @@ +# Standard library +import functools +import math +from collections.abc import Callable +from dataclasses import dataclass, field +from typing import Any, Concatenate, Literal, overload + +# Third-party +import torch +from torch import Tensor, nn +from torch.nn import functional as F +from torchaudio import functional as ta_F + +# Project +from nemo.collections.speechlm2.modules.ear_tts_commons import ( + Config, + PreTrainedModel +) + +from contextlib import contextmanager + +@contextmanager +def disable_tf32(): + prev = torch.backends.cudnn.allow_tf32 + torch.backends.cudnn.allow_tf32 = False + try: + yield + finally: + torch.backends.cudnn.allow_tf32 = prev + +@dataclass +class RVQVAEConfig(Config): + model_type: str = "rvqvae" + + # model specific configs + latent_size: int = 512 + wav_to_token_ratio: int = field(init=False) + + n_fft: int = 32 + hop_length: int = 8 + base_hidden_size: int = 384 + channel_mult: tuple[int, ...] = (1, 2, 4) + rates: tuple[int, ...] = (8, 8, 8) + num_blocks: int = 3 + kernel_size: int = 7 + groups: int = 1 + + # quantization + codebook_size: int = 1024 + num_quantizers: int = 72 + quantizer_dropout: float = 0.5 + + def __post_init__(self): + self.wav_to_token_ratio = self.hop_length * math.prod(self.rates) + + +# ============================================================================== +# Utility Functions +# ============================================================================== + + +def zero_module(module: nn.Module) -> nn.Module: + """ + Zeros out the parameters of a PyTorch module in-place. + + This is a utility function that iterates through all parameters of a given + `nn.Module` and sets their values to zero. This is often used for specific + initialization strategies, for example in diffusion models where some layers + are initialized to zero. + + From: https://github.com/openai/guided-diffusion/blob/22e0df8183507e13a7813f8d38d51b072ca1e67c/guided_diffusion/nn.py#L68 + + Args: + module (nn.Module): The PyTorch module to be zeroed. + + Returns: + nn.Module: The same module with its parameters zeroed. + """ + for p in module.parameters(): + # p.detach().zero_() performs the operation in-place without tracking it in autograd + p.detach().zero_() + return module + + +def sequence_mask(lengths: Tensor, max_length: int | None = None) -> Tensor: + """ + Creates a boolean mask from a 1D tensor of sequence lengths. + + This function is useful for masking out padding in sequences. Given a tensor + of lengths, it produces a 2D boolean tensor where `mask[i, j]` is `True` if + `j < lengths[i]` and `False` otherwise. + + Example: + >>> lengths = torch.tensor([1, 3, 2]) + >>> sequence_mask(lengths) + tensor([[ True, False, False], + [ True, True, True], + [ True, True, False]]) + + Args: + lengths (Tensor): A 1D tensor of integer lengths. Shape: `[batch_size]`. + max_length (int | None, optional): The maximum length of the mask. If None, + it is inferred from the maximum value + in `lengths`. Defaults to None. + + Returns: + Tensor: The boolean mask. Shape: `[batch_size, max_length]`. + """ + if max_length is None: + # If max_length is not provided, use the longest sequence length in the batch. + max_length = int(lengths.max().item()) + + # Create a range tensor from 0 to max_length - 1 + x = torch.arange(max_length, dtype=lengths.dtype, device=lengths.device) + + # Compare each length with the range tensor to create the mask. + # `x.unsqueeze(0)` is `[1, max_length]` + # `lengths.unsqueeze(1)` is `[batch_size, 1]` + # Broadcasting takes care of the comparison. + return x.unsqueeze(0) < lengths.unsqueeze(1) + + +# ============================================================================== +# Signal Processing Functions +# ============================================================================== + + +def spectrogram( + wav: Tensor, + n_fft: int, + hop_length: int, + win_length: int, + window_fn: Callable[Concatenate[int, ...], Tensor] = torch.hann_window, +) -> Tensor: + """ + Computes the Short-Time Fourier Transform (STFT) of a waveform with manual padding. + + This implementation manually applies zero padding before computing the STFT. + This is done to center the analysis window at the beginning of the signal + without using the `center=True` argument in `torch.stft`, giving more control. + + Args: + wav (Tensor): The input audio waveform. + Shape: [batch_size?, time_steps], where batch_size? is an + optional batch dimension. + n_fft (int): The size of the FFT. + hop_length (int): The number of samples between adjacent STFT columns. + win_length (int): The size of the window function. + window_fn (function, optional): The window function to apply. + Defaults to torch.hann_window. + + Returns: + Tensor: The complex-valued spectrogram. + Shape: [batch_size?, n_fft // 2 + 1, num_frames] + """ + # Calculate the padding required on the left and right sides to center the frames. + pad_size_l = (n_fft - hop_length) // 2 + pad_size_r = (n_fft - hop_length) - pad_size_l + + # Use a torch.autocast context to perform STFT in float32 for precision. + with torch.autocast(device_type=wav.device.type, enabled=False): + # Apply reflection padding to the waveform. + wav = F.pad(wav.float(), (pad_size_l, pad_size_r)) + + # Create the window tensor on the same device as the waveform. + window = window_fn(win_length, dtype=torch.float, device=wav.device) + + # Compute the STFT. + # `center=False` because we have already manually padded the signal. + spec = torch.stft( + wav, + n_fft, + hop_length=hop_length, + win_length=win_length, + window=window, + center=False, + normalized=False, + onesided=True, + return_complex=True, + ) + return spec + + +def spec_to_wav( + spec: Tensor, + n_fft: int, + hop_length: int, + win_length: int, + window_fn: Callable[Concatenate[int, ...], Tensor] = torch.hann_window, + constrain_value_range: bool = False, +) -> Tensor: + """ + Converts a spectrogram back into a waveform using the overlap-add method. + This function is an approximate inverse of the `spectrogram` function. + + Args: + spec (Tensor): The input complex-valued spectrogram. + Shape: [batch_size?, dim, time_steps], where batch_size? + is an optional batch dimension. + n_fft (int): The size of the FFT used to create the spectrogram. + hop_length (int): The number of samples between frames in the original signal. + win_length (int): The size of the window function used in the original signal. + window_fn (function, optional): The window function used. Currently only + `torch.hann_window` is supported. + constrain_value_range (bool, optional): If True, constrains the IFFT values + to be within the range of the window. + This ensures that the output values + remain within the range of -1.0 to 1.0. + Defaults to False. + + Returns: + Tensor: The reconstructed waveform. + Shape: [batch_size?, time_steps] + + Raises: + ValueError: If a window function other than `torch.hann_window` is provided. + """ + with torch.autocast(device_type=spec.device.type, enabled=False): + if window_fn != torch.hann_window: + raise ValueError(f"`window_fn` should be 'torch.hann_window', but got '{window_fn}'.") + + # Calculate padding and number of frames + pad = (win_length - hop_length) // 2 + T = spec.size(-1) + window = window_fn(win_length, device=spec.device) + + # 1. Inverse FFT + # Convert from frequency domain back to time domain for each frame. + ifft = torch.fft.irfft(spec, n=n_fft, dim=-2, norm="backward") + window_unsqz = window.unsqueeze(-1) + + # 2. Optionally constrain values + if constrain_value_range: + ifft = torch.where( + ifft >= 0, + torch.minimum(ifft, window_unsqz), + torch.maximum(ifft, -window_unsqz), + ) + + # 3. Apply window to the IFFT result + ifft = ifft * window_unsqz + + # 4. Overlap and Add + # Use `torch.nn.functional.fold` to perform the overlap-add operation efficiently. + # This reconstructs the continuous signal from the windowed frames. + output_size = (T - 1) * hop_length + win_length + wav = F.fold( + ifft, + output_size=(1, output_size), + kernel_size=(1, win_length), + stride=(1, hop_length), + )[..., 0, 0, pad:-pad] + + # 5. Calculate the window envelope for normalization + # This is necessary to correct for the energy added by overlapping windows. + window_sq = window.square().expand(T, -1).transpose(0, 1) + window_envelope = F.fold( + window_sq, + output_size=(1, output_size), + kernel_size=(1, win_length), + stride=(1, hop_length), + ).squeeze()[pad:-pad] + + # 6. Normalize the waveform + # Divide by the window envelope to get the final reconstructed signal. + assert (window_envelope > 1e-11).all(), "Window envelope has zero values, cannot normalize." + wav = wav / window_envelope + + return wav + + +def spectrogram_mag( + wav: Tensor, + n_fft: int, + hop_length: int, + win_length: int, + window_fn: Callable[Concatenate[int, ...], Tensor] = torch.hann_window, + power: float = 1.0, +) -> Tensor: + """ + Computes the magnitude spectrogram from an audio waveform. + + This function first calculates the complex-valued spectrogram using the + Short-Time Fourier Transform (STFT), then computes the magnitude of the + resulting complex numbers. An optional power can be applied to the + magnitude spectrogram. + + Args: + wav (Tensor): The input audio waveform. + Shape: [batch_size?, time_steps], where batch_size? is + an optional batch dimension. + n_fft (int): The size of the Fast Fourier Transform (FFT) to use. + hop_length (int): The number of audio samples between adjacent STFT columns. + win_length (int): The size of the window function for each frame. + window_fn (function, optional): The windowing function to apply to each + frame. Defaults to torch.hann_window. + power (float, optional): The exponent to apply to the magnitude spectrogram. + A value of 2.0 yields a power spectrogram. + Defaults to 1.0 (magnitude). + + Returns: + Tensor: The resulting magnitude spectrogram. + Shape: [batch_size?, n_fft // 2 + 1, num_frames] + """ + # Calculate the complex spectrogram + spec = spectrogram(wav, n_fft, hop_length, win_length, window_fn) + + # Compute the magnitude by taking the absolute value + spec = spec.abs() + + # Apply power if it's not the default value of 1.0 + if power != 1.0: + spec = spec.pow(power) + + return spec + + +@functools.cache +def get_fbanks( + sample_rate: int, + n_fft: int, + n_mels: int, + f_min: float, + f_max: float, + norm: str = "slaney", + mel_scale: str = "slaney", +) -> Tensor: + """ + Creates and caches Mel filterbanks. + + This function generates a set of triangular filters on the Mel scale. + The `@functools.cache` decorator memoizes the result, so the filterbanks + are only computed once for a given set of parameters, improving efficiency + when the function is called multiple times with the same arguments. + + Note: This implementation only supports Mel filterbanks via torchaudio. + + Args: + sample_rate (int): The sample rate of the audio. + n_fft (int): The size of the FFT used to compute the spectrogram. + n_mels (int): The number of Mel bands to generate. + f_min (float): The lowest frequency (in Hz) for the filterbanks. + f_max (float): The highest frequency (in Hz) for the filterbanks. + norm (str, optional): The normalization method to use for the triangles. + 'slaney' normalizes to unit area. None applies no norm. + Defaults to "slaney". + mel_scale (str, optional): The Mel scale to use, "htk" or "slaney". + Defaults to "slaney". + + Returns: + Tensor: The Mel filterbank matrix. + Shape: [n_mels, n_fft // 2 + 1] + """ + # Generate Mel filterbanks using torchaudio's functional API + fb = ta_F.melscale_fbanks( + n_freqs=n_fft // 2 + 1, + f_min=f_min, + f_max=f_max, + n_mels=n_mels, + sample_rate=sample_rate, + norm=norm, + mel_scale=mel_scale, + ).T # Transpose to get the shape [n_mels, n_freqs] + return fb + + +def mel_spectrogram( + wav: Tensor, + n_fft: int, + hop_length: int, + win_length: int, + sample_rate: int, + n_mels: int, + f_min: float, + f_max: float | None = None, + window_fn: Callable[Concatenate[int, ...], Tensor] = torch.hann_window, + power: float = 1.0, + log_scale: str | None = "natural", +) -> Tensor: + """ + Computes a Mel-scaled spectrogram from an audio waveform. + + This function transforms a standard spectrogram into a Mel spectrogram by + applying Mel-scaled filterbanks. It can optionally return the result on a + logarithmic scale. + + Args: + wav (Tensor): The input audio waveform. + Shape: [batch_size?, time_steps], where batch_size? is an + optional batch dimension. + n_fft (int): The size of the FFT. + hop_length (int): The number of samples between adjacent frames. + win_length (int): The size of the window function. + sample_rate (int): The sample rate of the audio. + n_mels (int): The number of Mel bands to generate. + f_min (float): The lowest frequency (in Hz) for the Mel scale. + f_max (float | None, optional): The highest frequency (in Hz). If None, + it defaults to sample_rate / 2 (Nyquist). + window_fn (function, optional): The windowing function. Defaults to torch.hann_window. + power (float, optional): The exponent for the magnitude spectrogram before + Mel conversion. Defaults to 1.0. + log_scale (str | None, optional): The type of logarithmic scaling to apply. + Can be "natural" (for `log`), "log10", or `None` + to return the linear-amplitude Mel spectrogram. + Defaults to "natural". + + Returns: + Tensor: The resulting Mel spectrogram. + Shape: [batch_size?, n_mels, num_frames] + + Raises: + ValueError: If an unsupported string is provided for `log_scale`. + """ + # If f_max is not provided, use the Nyquist frequency. + f_max = f_max or sample_rate / 2 + + # 1. Compute the magnitude spectrogram. + spec = spectrogram_mag(wav, n_fft, hop_length, win_length, window_fn=window_fn, power=power) + + # Use a torch.autocast context to ensure the following operations + # are performed in float32 precision for numerical stability, especially + # when the input `spec` might be in a lower precision format like float16. + with torch.autocast(device_type=spec.device.type, enabled=False): + # 2. Get the Mel filterbanks (cached for efficiency). + fb = ( + get_fbanks( + sample_rate, + n_fft, + n_mels, + f_min, + f_max, + ) + .float() + .to(device=spec.device) + ) # Ensure filterbank is float32 and on the correct device. + + # 3. Apply the filterbanks to the spectrogram via matrix multiplication. + # This maps the linear frequency scale to the Mel scale. + # (n_mels, n_freqs) @ (..., n_freqs, time) -> (..., n_mels, time) + mel = torch.matmul(fb, spec.float()) + + # 4. Optionally, apply a logarithmic function. + # A small value (epsilon) is added to prevent taking the log of zero. + if log_scale == "natural": + mel = torch.log(torch.clamp(mel, min=1e-6)) + elif log_scale == "log10": + mel = torch.log10(torch.clamp(mel, min=1e-6)) + elif log_scale is not None: + raise ValueError(f"Unsupported log_scale: '{log_scale}'. Choose from 'natural', 'log10', or None.") + + return mel + + +# ============================================================================== +# Basic Modules +# ============================================================================== + + +class CausalConv1dCache: + """ + A cache for managing states in causal 1D convolutions. + + This class is used during autoregressive inference to store and update the + tail of the input to a causal convolution, which is used as padding for the + next time step. This avoids re-computing the entire sequence at each step. + """ + + def __init__(self) -> None: + self.cache: dict[int | str, Tensor] = {} + + def __getitem__(self, layer_id: int | str) -> Tensor: + """Retrieves the cached tensor for a given layer.""" + return self.cache[layer_id] + + def update( + self, + states: Tensor, + layer_id: int | str, + padding: int, + padding_value: int = 0, + flush: bool = False, + ) -> Tensor: + """ + Updates the cache for a specific layer and returns the padded input. + + Args: + states (Tensor): The new input tensor for the current time step. + layer_id (int | str): An identifier for the convolutional layer. + padding (int): The amount of left padding required by the convolution. + padding_value (int, optional): The value to use for initial padding. Defaults to 0. + flush (bool, optional): If True, the cache for this layer is deleted + after use. Defaults to False. + + Returns: + Tensor: The input states concatenated with the cached padding. + """ + device = states.device + dtype = states.dtype + b, c, t = states.size() + + if layer_id not in self.cache: + # Initialize cache with zero padding if it's the first time step + padding_tensor = torch.zeros((b, c, padding), dtype=dtype, device=device) + padding_value + else: + padding_tensor = self.cache[layer_id] + assert padding_tensor.size(2) == padding + + # Concatenate the cached padding with the new states + padded_states = torch.cat([padding_tensor, states], dim=2) + # Update the cache with the tail of the new padded states + self.cache[layer_id] = padded_states[:, :, -padding:] + + if flush: + del self.cache[layer_id] + + return padded_states + + +class LayerNormNd(nn.Module): + """ + A LayerNorm module that works for N-dimensional inputs. + + This implementation normalizes over the channel dimension (dim=1), which is + a common setup for convolutional networks. + + Args: + channels (int): The number of channels of the input tensor. + eps (float, optional): A value added to the denominator for numerical + stability. Defaults to 1e-6. + elementwise_affine (bool, optional): If True, this module has learnable + affine parameters (weight and bias). + Defaults to True. + bias (bool, optional): If True, this module has a learnable bias. + Defaults to True. + """ + + def __init__(self, channels: int, eps=1e-6, elementwise_affine: bool = True, bias: bool = True): + super().__init__() + self.channels = channels + self.eps = eps + + self.weight = nn.Parameter(torch.ones((channels,)), requires_grad=elementwise_affine) + self.bias = nn.Parameter(torch.zeros((channels,)), requires_grad=elementwise_affine and bias) + + def forward(self, x: Tensor) -> Tensor: + # Calculate mean and reciprocal standard deviation over the channel dimension + mean = x.mean(1, keepdim=True) + x_shift = x - mean + # Using rsqrt for potentially better performance + x_rstd = torch.rsqrt(x_shift.pow(2).mean(1, keepdim=True) + self.eps) + + # Reshape weight and bias to be broadcastable with the input tensor + shape = [-1 if i == 1 else 1 for i in range(x.ndim)] + + # Apply normalization and affine transformation + return (x_shift * x_rstd) * self.weight.view(shape) + self.bias.view(shape) + + +class ConvNeXt1d(nn.Module): + """ + A 1D ConvNeXt block adapted for causal convolutions on audio signals. + + This block is a core component of modern convolutional architectures, featuring + a depthwise convolution, layer normalization, and pointwise convolutions to + expand and contract the channel dimension, similar to an inverted bottleneck. + + Implementation adapted from: https://github.com/charactr-platform/vocos + + Args: + dim (int): Number of input and output channels. + intermediate_dim (int): Dimensionality of the intermediate (expanded) layer. + kernel_size (int): The kernel size for the causal depthwise convolution. + identity_init (bool, optional): If True, the final pointwise convolution + is initialized to zero, making the block + an identity function at the start of training. + Defaults to False. + layer_idx (int, optional): An index for this layer, used for caching. Defaults to 0. + """ + + def __init__( + self, + dim: int, + intermediate_dim: int, + kernel_size: int, + identity_init: bool = False, + layer_idx: int = 0, + ): + super().__init__() + self.layer_idx = layer_idx + self.kernel_size = kernel_size + + # Depthwise convolution + self.dwconv = nn.Conv1d(dim, dim, kernel_size=kernel_size, groups=dim) + + self.norm = LayerNormNd(dim) + self.pwconv1 = nn.Conv1d(dim, intermediate_dim, 1) # Pointwise/1x1 conv for expansion + self.act = nn.GELU() + + # Pointwise/1x1 conv for projection + if identity_init: + self.pwconv2 = zero_module(nn.Conv1d(intermediate_dim, dim, 1)) + else: + self.pwconv2 = nn.Conv1d(intermediate_dim, dim, 1) + + def forward(self, x: Tensor, cache: CausalConv1dCache | None = None, flush: bool = False) -> Tensor: + residual = x + + # Apply causal padding, either through a cache or manually + if cache is not None: + x = cache.update(x, self.layer_idx, self.kernel_size - 1, flush=flush) + else: + x = F.pad(x, [self.kernel_size - 1, 0]) # Left padding for causality + + # Main ConvNeXt path + x = self.dwconv(x) + x = self.norm(x) + x = self.pwconv1(x) + x = self.act(x) + x = self.pwconv2(x) + + # Add residual connection + x = residual + x + return x + + +class PreTrainedEMAVariance(nn.Module): + """ + Exponential Moving Average of Variance + """ + + def __init__(self, initial_value: float = 1.0): + super().__init__() + self.variance = nn.Parameter( + torch.tensor(initial_value), + requires_grad=False, + ) + + def forward(self) -> Tensor: + return self.variance + + +class PreTrainedProbabilisticVQ(nn.Module): + def __init__( + self, + channels: int, + num_mixtures: int, + depth: int = 1, + ): + super().__init__() + self.channels = channels + self.num_mixtures = num_mixtures + self.depth = depth + + self.mus_list = nn.ParameterList( + [ + nn.Parameter( + F.normalize(torch.randn(num_mixtures, channels), p=2.0, dim=1) * ((depth - i) / depth), + requires_grad=False, + ) + for i in range(depth) + ] + ) + self._variance_list = nn.ModuleList([PreTrainedEMAVariance() for _ in range(depth)]) + + @property + def log_std(self) -> Tensor: + return torch.log(self._variance_list[-1]()) * 0.5 + + @overload + def encode(self, z: Tensor, return_z_q: Literal[False]) -> list[Tensor]: ... + + @overload + def encode(self, z: Tensor, return_z_q: Literal[True]) -> tuple[list[Tensor], Tensor]: ... + + def encode(self, z: Tensor, return_z_q: bool = False) -> list[Tensor] | tuple[list[Tensor], Tensor]: + r = z + ids_sel = [] + for i in range(self.depth): + mus = self.mus_list[i] + idx_sel = self._dist_sq(r, mus).argmin(-1) # [b, ?, h], [v, h] -> [b, ?] + r = r - F.embedding(idx_sel, mus) + ids_sel.append(idx_sel) + if return_z_q: + return ids_sel, z - r + return ids_sel + + def decode(self, ids_sel: list[Tensor]) -> Tensor: + z = torch.zeros((*ids_sel[0].size(), self.channels), device=ids_sel[0].device) + for i in range(len(ids_sel)): + mus = self.mus_list[i] + z = z + F.embedding(ids_sel[i], mus) + return z # [b, ?, h] + + def _dist_sq(self, z: Tensor, mus: Tensor) -> Tensor: + """ + z: [b, ?, d?, h] + mus: [d?, v, h] + """ + return ( + z.pow(2).sum(-1, keepdim=True) # [b, ?, d?, 1] + + mus.pow(2).sum(-1) # [d?, v] + - 2 * (z.unsqueeze(-2) @ mus.transpose(-1, -2)).squeeze(-2) # [b, ?, d?, h] , [d?, h, v] -> [b, ?, d?, v] + ) + + +class Wav2Latent(nn.Module): + """ + An encoder model that transforms a raw waveform into a latent representation. + + This model first converts the waveform to a spectrogram, then processes it + through a series of ConvNeXt blocks and downsampling convolutional layers + to produce a compressed latent tensor. + + Args: + latent_size (int): The number of channels in the final latent representation. + n_fft (int): The FFT size for the initial spectrogram transformation. + hop_length (int): The hop length for the STFT. + base_hidden_size (int): The base number of channels for the hidden layers. + channel_mult (tuple[int, ...]): A tuple of multipliers for the hidden + size at each stage of downsampling. + rates (tuple[int, ...]): A tuple of downsampling factors (strides) for + the convolutional layers. + num_blocks (int): The number of ConvNeXt blocks per stage. + kernel_size (int): The kernel size for the ConvNeXt blocks. + groups (int): The number of groups for the downsampling convolutions. + """ + + def __init__( + self, + latent_size: int = 1024, + n_fft: int = 32, + hop_length: int = 8, + base_hidden_size: int = 384, + channel_mult: tuple[int, ...] = (1, 2, 4), + rates: tuple[int, ...] = (8, 8, 8), + num_blocks: int = 3, + kernel_size: int = 7, + groups: int = 1, + ): + super().__init__() + self.n_fft = n_fft + self.hop_length = hop_length + + # Initial projection from spectrogram to hidden size + layers: list[nn.Module] = [nn.Conv1d(n_fft + 2, base_hidden_size * channel_mult[0], 1, bias=False)] + + # Downsampling stages + for i in range(len(channel_mult)): + ch_mult, rate = channel_mult[i], rates[i] + hidden_size = base_hidden_size * ch_mult + # Add ConvNeXt blocks for this stage + for j in range(num_blocks): + layers.append( + ConvNeXt1d(hidden_size, hidden_size * 4, kernel_size, True, layer_idx=i * num_blocks + j) + ) + # Add downsampling convolution + next_hidden_size = base_hidden_size * channel_mult[i + 1] if i < len(channel_mult) - 1 else latent_size + layers.append( + nn.Conv1d(hidden_size, next_hidden_size, kernel_size=rate, stride=rate, bias=False, groups=groups) + ) + + self.layers = nn.ModuleList(layers) + + def forward(self, x: Tensor, cache=None, flush: bool = False) -> Tensor: + if cache is not None: + raise NotImplementedError("Caching is not implemented for the encoder.") + + # Convert waveform to spectrogram (magnitude and phase) + with torch.autocast(device_type=x.device.type, enabled=False): + spec = spectrogram(x.squeeze(1), n_fft=self.n_fft, hop_length=self.hop_length, win_length=self.n_fft) + # Split complex spectrogram into real and imaginary, then treat as magnitude and phase + mag, ph = torch.view_as_real(spec).chunk(2, dim=-1) + x = torch.cat([mag, ph], 1).squeeze(-1) + + # Pass through the network + for layer in self.layers: + if isinstance(layer, ConvNeXt1d): + x = layer(x, cache=cache, flush=flush) + else: + x = layer(x) + + # Transpose to [batch, time, channels] for compatibility with transformers + x = x.transpose(-1, -2) + return x + + +class Latent2Wav(nn.Module): + """ + A decoder (vocoder) model that transforms a latent representation back into a raw waveform. + + This model processes a latent tensor through a series of ConvNeXt blocks and + upsampling transposed convolutional layers to produce a spectrogram, which is + then converted back to a waveform using an inverse STFT. + + Args: + latent_size (int): The number of channels in the input latent representation. + n_fft (int): The FFT size for the final spectrogram reconstruction. + hop_length (int): The hop length for the ISTFT. + base_hidden_size (int): The base number of channels for the hidden layers. + channel_mult (tuple[int, ...]): A tuple of multipliers for the hidden + size at each stage of upsampling. + rates (tuple[int, ...]): A tuple of upsampling factors (strides) for + the transposed convolutional layers. + num_blocks (int): The number of ConvNeXt blocks per stage. + kernel_size (int): The kernel size for the ConvNeXt blocks. + groups (int): The number of groups for the upsampling convolutions. + """ + + def __init__( + self, + latent_size: int = 1024, + n_fft: int = 32, + hop_length: int = 8, + base_hidden_size: int = 384, + channel_mult: tuple[int, ...] = (4, 2, 1), + rates: tuple[int, ...] = (8, 8, 8), + num_blocks: int = 3, + kernel_size: int = 7, + groups=1, + ): + super().__init__() + self.n_fft = n_fft + self.hop_length = hop_length + self.spec_cache_idx = (len(channel_mult)) * num_blocks + + layers: list[nn.Module] = [] + + # Upsampling stages + for i in range(len(channel_mult)): + ch_mult, rate = channel_mult[i], rates[i] + hidden_size = base_hidden_size * ch_mult + # Add upsampling transposed convolution + in_size = base_hidden_size * channel_mult[i - 1] if i != 0 else latent_size + layers.append( + nn.ConvTranspose1d(in_size, hidden_size, kernel_size=rate, stride=rate, bias=False, groups=groups) + ) + # Add ConvNeXt blocks for this stage + for j in range(num_blocks): + layers.append( + ConvNeXt1d(hidden_size, hidden_size * 4, kernel_size, True, layer_idx=i * num_blocks + j) + ) + + # Final projection to spectrogram dimensions (magnitude + phase) + layers.append(nn.Conv1d(hidden_size, n_fft + 2, 1, bias=False)) + self.layers = nn.ModuleList(layers) + + def forward(self, x: Tensor, cache=None, flush: bool = False, constrain_value_range: bool = True) -> Tensor: + # Transpose input from [batch, time, channels] to [batch, channels, time] + x = x.transpose(-1, -2) + + # Pass through the network + for layer in self.layers: + if isinstance(layer, ConvNeXt1d): + x = layer(x, cache=cache, flush=flush) + else: + x = layer(x) + + # Convert network output to a complex spectrogram and then to a waveform + with torch.autocast(device_type=x.device.type, enabled=False): + max_mag = 100.0 + # Split output into magnitude and phase components + mag, ph = x.float().chunk(2, dim=1) + # Safeguard to prevent excessively large magnitudes + mag = max_mag * torch.exp(-F.softplus(-mag + math.log(max_mag))) + + # Reconstruct the complex spectrogram from magnitude and phase + # The DC and Nyquist components are real, so their phase is applied via cosine. + mag_dc, mag_mid, mag_nyquist = mag.split([1, mag.size(1) - 2, 1], dim=1) + ph_dc, ph_mid, ph_nyquist = torch.cos(ph).split([1, ph.size(1) - 2, 1], dim=1) + ph_imag = torch.sin(ph[:, 1:-1, :]) + + spec_real = mag_mid * ph_mid + spec_imag = mag_mid * ph_imag + + spec = torch.cat([mag_dc * ph_dc, spec_real + 1j * spec_imag, mag_nyquist * ph_nyquist], 1) + + # Handle caching for autoregressive generation of the spectrogram + if cache is not None: + half_spec_padding = math.ceil(((self.n_fft - self.hop_length) // 2) / self.hop_length) + spec = cache.update(spec, self.spec_cache_idx, padding=half_spec_padding * 2, flush=flush) + if flush: + spec = F.pad(spec, [0, half_spec_padding]) + + # Convert spectrogram to waveform + x = spec_to_wav( + spec, self.n_fft, self.hop_length, self.n_fft, constrain_value_range=constrain_value_range + ).unsqueeze(1) + + if cache is not None: + # Trim the output to remove the padded region from the start + half_wav_padding = half_spec_padding * self.hop_length + x = x[:, :, half_wav_padding:-half_wav_padding] + + return x + +class RVQVAEModel(PreTrainedModel): + """ + Residual Vector-Quantized Variational Autoencoder (RVQ-VAE) model. + + This model learns a discrete representation of audio by encoding a waveform + into a latent space and then quantizing the latents into discrete codes. + It consists of an encoder, a quantizer, and a decoder. + + Args: + config (RVQVAEConfig | dict[str, Any]): A configuration object with model hyperparameters. + """ + + config_class: type[Config] = RVQVAEConfig + + def __init__(self, config: RVQVAEConfig | dict[str, Any]): + super().__init__(config) + + self.encoder = Wav2Latent( + latent_size=self.config.latent_size, + n_fft=self.config.n_fft, + hop_length=self.config.hop_length, + base_hidden_size=self.config.base_hidden_size, + channel_mult=self.config.channel_mult, + rates=self.config.rates, + num_blocks=self.config.num_blocks, + kernel_size=self.config.kernel_size, + groups=self.config.groups, + ) + + # Layers for quantization + self.prvq = PreTrainedProbabilisticVQ( + channels=self.config.latent_size, + num_mixtures=self.config.codebook_size, + depth=self.config.num_quantizers, + ) + + self.decoder = Latent2Wav( + latent_size=self.config.latent_size, + n_fft=self.config.n_fft, + hop_length=self.config.hop_length, + base_hidden_size=self.config.base_hidden_size, + channel_mult=tuple(reversed(self.config.channel_mult)), + rates=tuple(reversed(self.config.rates)), + num_blocks=self.config.num_blocks, + kernel_size=self.config.kernel_size, + groups=self.config.groups, + ) + + for p in self.parameters(): + p.requires_grad = False + + def ae_encode(self, x: Tensor, cache: CausalConv1dCache | None = None, flush: bool = False) -> Tensor: + """ + Runs the encoder part of the autoencoder. + + Args: + x (Tensor): Input waveform. Shape: `[batch, 1, time]`. + cache (CausalConv1dCache | None): Not implemented for the encoder. + flush (bool): Not implemented for the encoder. + + Returns: + Tensor: The continuous latent representation. Shape: `[batch, time', channels]`. + """ + assert x.size(1) == 1 and x.dim() == 3, "Input must be a batch of mono audio." + assert x.size(2) % self.config.wav_to_token_ratio == 0, ( + f"Input audio length ({x.size(2)}) must be divisible by the model's " + f"wav_to_token_ratio ({self.config.wav_to_token_ratio}). " + f"Please pad the input to a compatible length." + ) + + if cache is not None: + raise NotImplementedError("Caching is not supported for the encoder.") + + return self.encoder(x, cache=cache, flush=flush) + + def ae_decode( + self, + x: Tensor, + constrain_value_range: bool = True, + cache: CausalConv1dCache | None = None, + flush: bool = False, + ) -> Tensor: + """ + Runs the decoder part of the autoencoder. + + Args: + x (Tensor): The (de-quantized) latent representation. Shape: `[batch, time', channels]`. + constrain_value_range (bool): If True, constrains the output of the ISTFT. + cache (CausalConv1dCache | None): Cache for autoregressive generation. + flush (bool): If True, flushes the cache. + + Returns: + Tensor: The reconstructed waveform. Shape: `[batch, 1, time]`. + """ + return self.decoder(x, constrain_value_range=constrain_value_range, cache=cache, flush=flush) + + def encode(self, x: Tensor, x_len: Tensor) -> tuple[Tensor, Tensor]: + """ + Encodes a waveform into discrete codes. + + Args: + x (Tensor): Input waveform. Shape: `[batch, 1, time]`. + x_len (Tensor): The original lengths of the waveforms in the batch. + + Returns: + tuple[Tensor, Tensor]: A tuple containing: + - The discrete codes. Shape: `[batch, time', n_quantizers]`. + - The lengths of the code sequences. + """ + with disable_tf32(): + z_e = self.ae_encode(x) + code_len = x_len // self.config.wav_to_token_ratio + return self.quantize(z_e), code_len + + def decode( + self, + code: Tensor, + code_len: Tensor | None = None, + constrain_value_range: bool = True, + cache: CausalConv1dCache | None = None, + flush: bool = False, + ) -> tuple[Tensor, Tensor | None]: + """ + Decodes discrete codes back into a waveform. + + Args: + code (Tensor): The discrete codes. Shape: `[batch, time', n_quantizers]`. + code_len (Tensor | None): The lengths of the code sequences. + constrain_value_range (bool): If True, constrains the output of the ISTFT. + cache (CausalConv1dCache | None): Cache for autoregressive generation. + flush (bool): If True, flushes the cache. + + Returns: + tuple[Tensor, Tensor | None]: A tuple containing: + - The reconstructed waveform. Shape: `[batch, 1, time]`. + - The lengths of the reconstructed waveforms. + """ + with disable_tf32(): + z_q = self.dequantize(code) + x_hat = self.ae_decode(z_q, constrain_value_range=constrain_value_range, cache=cache, flush=flush) + wav_len = code_len * self.config.wav_to_token_ratio if code_len is not None else None + return x_hat, wav_len + + def quantize(self, z: Tensor) -> Tensor: + """ + Quantizes a continuous latent tensor into discrete codes. + + Args: + z (Tensor): The continuous latent tensor from the encoder. + Shape: `[batch, time, channels]`. + + Returns: + Tensor: The quantized codes. Shape: `[batch, time, n_quantizers]`. + """ + with disable_tf32(): + ids_sel = self.prvq.encode(z, return_z_q=False) + return torch.stack(ids_sel, -1) + + def dequantize(self, code: Tensor) -> Tensor: + """ + De-quantizes discrete codes back into a continuous latent tensor. + + Args: + code (Tensor): The quantized codes. Shape: `[batch, time, n_quantizers]`. + + Returns: + Tensor: The de-quantized continuous latent tensor. + Shape: `[batch, time, latent_size]`. + """ + ids_sel = [x.squeeze(-1) for x in torch.split(code, 1, -1)] + return self.prvq.decode(ids_sel) + + def forward(self, x: Tensor, constrain_value_range: bool = False) -> Tensor: + """ + Performs a full autoencoding pass: encode, quantize, dequantize, and decode. + + Args: + x (Tensor): The input waveform. Shape: `[batch, 1, time]`. + constrain_value_range (bool): If True, constrains the output of the ISTFT. + + Returns: + Tensor: The reconstructed waveform. Shape: `[batch, 1, time]`. + """ + + with torch.no_grad(): + z_e = self.ae_encode(x) + code = self.quantize(z_e) + z_d = self.dequantize(code) + x_hat = self.ae_decode(z_d, constrain_value_range=constrain_value_range) + return x_hat diff --git a/nemo/collections/speechlm2/parts/pretrained.py b/nemo/collections/speechlm2/parts/pretrained.py index 7c5696457c65..b56a2b455754 100644 --- a/nemo/collections/speechlm2/parts/pretrained.py +++ b/nemo/collections/speechlm2/parts/pretrained.py @@ -24,7 +24,7 @@ from nemo.collections.speechlm2.parts.precision import fp32_precision from nemo.collections.tts.models import AudioCodecModel - +from nemo.utils import logging def load_pretrained_nemo(cls, model_path_or_name: str): """ @@ -99,3 +99,16 @@ def setup_speech_encoder(model: torch.nn.Module, pretrained_weights: bool = True model.perception.load_state_dict(asr.state_dict(), strict=False) else: model.perception = AudioPerceptionModule(model.cfg.perception).train() + +def set_model_dict_for_partial_init(pretrained_dict, model_dict): + # 1. filter out different size layers + for k, v in list(pretrained_dict.items()): + if k in model_dict and hasattr(model_dict[k], "numel") and v.numel() != model_dict[k].numel(): + del pretrained_dict[k] + logging.info(" | > Layer with shape mismatach in the model definition: {}".format(k)) + # 2. filter out unnecessary keys + pretrained_dict = {k: v for k, v in pretrained_dict.items() if k in model_dict} + # 3. overwrite entries in the existing state dict + model_dict.update(pretrained_dict) + logging.info(" | > {} / {} layers are restored.".format(len(pretrained_dict), len(model_dict))) + return model_dict \ No newline at end of file From 3742e35e1ac2d68bbdef69aec15f5d0ec72a417b Mon Sep 17 00:00:00 2001 From: Edresson Date: Tue, 11 Nov 2025 20:01:41 +0000 Subject: [PATCH 003/102] Apply isort and black reformatting Signed-off-by: Edresson Signed-off-by: Edresson Casanova --- examples/speechlm2/duplex_eartts_train.py | 2 +- .../speechlm2/duplex_eartts_train_infer.py | 3 +- nemo/collections/common/data/lhotse/cutset.py | 34 +- nemo/collections/speechlm2/data/__init__.py | 2 +- .../speechlm2/data/duplex_ear_tts_dataset.py | 134 ++++-- .../speechlm2/models/duplex_ear_tts.py | 421 +++++++++++------- .../speechlm2/modules/ear_tts_commons.py | 16 +- .../speechlm2/modules/rvq_ear_tts_model.py | 140 +++--- .../speechlm2/modules/rvq_ear_tts_vae.py | 11 +- .../speechlm2/parts/metrics/__init__.py | 3 +- .../parts/metrics/intelligibility.py | 21 +- .../speechlm2/parts/metrics/results_logger.py | 34 +- .../speechlm2/parts/metrics/secs.py | 12 +- .../speechlm2/parts/metrics/token_accuracy.py | 5 +- .../collections/speechlm2/parts/pretrained.py | 6 +- 15 files changed, 520 insertions(+), 324 deletions(-) diff --git a/examples/speechlm2/duplex_eartts_train.py b/examples/speechlm2/duplex_eartts_train.py index 771e4f24240e..c3ddcc072589 100644 --- a/examples/speechlm2/duplex_eartts_train.py +++ b/examples/speechlm2/duplex_eartts_train.py @@ -50,7 +50,7 @@ def train(cfg): add_text_bos_and_eos_in_each_turn=cfg.data.get("add_text_bos_and_eos_in_each_turn", True), add_audio_prompt_after_description=cfg.data.add_audio_prompt_after_description, audio_prompt_duration=cfg.data.audio_prompt_duration, - num_delay_speech_tokens=cfg.model.get("num_delay_speech_tokens", 2) + num_delay_speech_tokens=cfg.model.get("num_delay_speech_tokens", 2), ) datamodule = DataModule(cfg.data, tokenizer=model.tokenizer, dataset=dataset) diff --git a/examples/speechlm2/duplex_eartts_train_infer.py b/examples/speechlm2/duplex_eartts_train_infer.py index 04fdd72ce8bd..8118e086950c 100644 --- a/examples/speechlm2/duplex_eartts_train_infer.py +++ b/examples/speechlm2/duplex_eartts_train_infer.py @@ -26,6 +26,7 @@ torch.cuda.set_device(int(os.environ["LOCAL_RANK"])) + @hydra_runner(config_path="conf", config_name="s2s_duplex_speech_decoder") def inference(cfg): OmegaConf.resolve(cfg) @@ -51,7 +52,7 @@ def inference(cfg): add_text_bos_and_eos_in_each_turn=cfg.data.get("add_text_bos_and_eos_in_each_turn", True), add_audio_prompt_after_description=cfg.data.add_audio_prompt_after_description, audio_prompt_duration=cfg.data.audio_prompt_duration, - num_delay_speech_tokens=cfg.model.get("num_delay_speech_tokens", 2) + num_delay_speech_tokens=cfg.model.get("num_delay_speech_tokens", 2), ) datamodule = DataModule(cfg.data, tokenizer=model.tokenizer, dataset=dataset) diff --git a/nemo/collections/common/data/lhotse/cutset.py b/nemo/collections/common/data/lhotse/cutset.py index 09dd1a2f05da..e1a207568a5f 100644 --- a/nemo/collections/common/data/lhotse/cutset.py +++ b/nemo/collections/common/data/lhotse/cutset.py @@ -21,7 +21,7 @@ from typing import KeysView, Mapping, Sequence, Tuple, Union import omegaconf -from lhotse import CutSet, Features, Recording, MonoCut, SupervisionSegment +from lhotse import CutSet, Features, MonoCut, Recording, SupervisionSegment from lhotse.array import Array, TemporalArray from lhotse.cut import Cut, MixedCut, PaddingCut from lhotse.serialization import load_yaml @@ -535,7 +535,6 @@ def cut_to_conversation( ) - @data_type_parser(["s2s_duplex_overlap_as_s2s_duplex"]) def read_s2s_duplex_overlap_as_s2s_duplex(config) -> tuple[CutSet, bool]: def filter_cuts_starting_with_agent(cuts: CutSet, agent_roles=("agent", "assistant", "Assistant")) -> CutSet: @@ -545,7 +544,7 @@ def filter_cut_fn(cut): if len(cut.supervisions): return cut.supervisions[0].speaker not in agent_roles else: - return False # filter emptly supervisions + return False # filter emptly supervisions return cuts.filter(filter_cut_fn) @@ -556,7 +555,7 @@ def convert_overlap_cut(cut): id=cut.id, recording_id=cut.id, start=seg["start"] - move_agent_text_back_by, - duration=seg["end"]-seg["start"] + move_agent_text_back_by, + duration=seg["end"] - seg["start"] + move_agent_text_back_by, text=seg["text"], speaker="agent", ) @@ -568,7 +567,7 @@ def convert_overlap_cut(cut): id=cut.id, recording_id=cut.id, start=seg["start"], - duration=seg["end"]-seg["start"], + duration=seg["end"] - seg["start"], text=seg["text"], speaker="user", ) @@ -593,6 +592,7 @@ def convert_overlap_cut(cut): return cuts, is_tarred + @data_type_parser(["lhotse_magpietts_data_as_continuation"]) def read_lhotse_magpietts_data_as_continuation(config) -> tuple[CutSet, bool]: def convert_lhotse_magpietts_data_as_cont(cut): @@ -617,7 +617,7 @@ def convert_lhotse_magpietts_data_as_cont(cut): supervisions=[], ) - # create silence audio + # create silence audio num_samples = int(total_duration * sample_rate) zero_audio = np.zeros((1, num_samples), dtype=np.float32) source_recording = create_recording_from_array( @@ -643,7 +643,7 @@ def convert_lhotse_magpietts_data_as_cont(cut): user_sup = fastcopy( orig_agent_sup, start=0.0, - duration=0.08, # keep only on frame to the user + duration=0.08, # keep only on frame to the user speaker="user", text="dummy text", ) @@ -655,7 +655,7 @@ def convert_lhotse_magpietts_data_as_cont(cut): speaker="agent", ) - # Add extra sil in the end of the audio to force the model to produce silence if it receives zeros and the was all processed + # Add extra sil in the end of the audio to force the model to produce silence if it receives zeros and the was all processed if ADD_EXTRA_END_SIL: sil_duration = random.uniform(*SILENCE_RANGE) # pad audios @@ -664,8 +664,10 @@ def convert_lhotse_magpietts_data_as_cont(cut): # Save both to memory cut_source = cut_source.to_mono().move_to_memory(audio_format='wav') cut_target = cut_target.to_mono().move_to_memory(audio_format='wav') - agent_sup.duration = agent_sup.duration + sil_duration + 1.0 # added here 1.0 seconds to not have text EOS for this dataset to avoid conflicts with S2S, text EOS is the interruption token on duplex - user_sup.duration = user_sup.duration + sil_duration + agent_sup.duration = ( + agent_sup.duration + sil_duration + 1.0 + ) # added here 1.0 seconds to not have text EOS for this dataset to avoid conflicts with S2S, text EOS is the interruption token on duplex + user_sup.duration = user_sup.duration + sil_duration # Assemble final cut cut_source.supervisions = [user_sup, agent_sup] @@ -683,13 +685,21 @@ def filter_cer(example): return True def filter_val_flag(example): - if isinstance(example, Cut) and example.has_custom("validation_status") and example.validation_status != KEEP_FLAG: + if ( + isinstance(example, Cut) + and example.has_custom("validation_status") + and example.validation_status != KEEP_FLAG + ): return False else: return True def filter_secs(example): - if isinstance(example, Cut) and len(example.supervisions) > 0 and example.supervisions[0].has_custom("context_speaker_similarity"): + if ( + isinstance(example, Cut) + and len(example.supervisions) > 0 + and example.supervisions[0].has_custom("context_speaker_similarity") + ): return example.supervisions[0].context_speaker_similarity >= MIN_SECS else: return True diff --git a/nemo/collections/speechlm2/data/__init__.py b/nemo/collections/speechlm2/data/__init__.py index 0d51497a6d8f..802c199462d7 100644 --- a/nemo/collections/speechlm2/data/__init__.py +++ b/nemo/collections/speechlm2/data/__init__.py @@ -12,8 +12,8 @@ # See the License for the specific language governing permissions and # limitations under the License. from .datamodule import DataModule -from .s2s_dataset import DuplexS2SDataset from .duplex_ear_tts_dataset import DuplexEARTTSDataset +from .s2s_dataset import DuplexS2SDataset from .salm_dataset import SALMDataset __all__ = [ diff --git a/nemo/collections/speechlm2/data/duplex_ear_tts_dataset.py b/nemo/collections/speechlm2/data/duplex_ear_tts_dataset.py index dacc9852babe..0ba55afc7a88 100644 --- a/nemo/collections/speechlm2/data/duplex_ear_tts_dataset.py +++ b/nemo/collections/speechlm2/data/duplex_ear_tts_dataset.py @@ -11,14 +11,13 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +import random import re import torch import torch.nn.functional as F import torch.utils.data import torchaudio -import random - from lhotse import CutSet, MonoCut, Recording, Seconds, SupervisionSegment, compute_num_frames from lhotse.cut import Cut from lhotse.dataset.collation import collate_audio, collate_vectors @@ -26,8 +25,8 @@ from nemo.collections.common.tokenizers import TokenizerSpec from nemo.collections.speechlm2.data.utils import get_pad_id -from nemo.utils import logging from nemo.collections.speechlm2.modules.ear_tts_commons import SCRIPT_PLACEHOLDER +from nemo.utils import logging def sample_audio_segments_repeat( @@ -69,7 +68,7 @@ def sample_audio_segments_repeat( else: # Deterministic: take from start start = 0 - out[b] = prompt_audio[b, start:start + n_sample] + out[b] = prompt_audio[b, start : start + n_sample] else: # Audio shorter than target → repeat @@ -180,7 +179,7 @@ def __init__( add_text_bos_and_eos_in_each_turn: bool = False, add_audio_prompt_after_description: bool = False, audio_prompt_duration: float = 3.0, - num_delay_speech_tokens: int = 0 + num_delay_speech_tokens: int = 0, ): self.tokenizer = tokenizer self.frame_length = frame_length @@ -239,10 +238,18 @@ def __getitem__(self, cuts: CutSet) -> dict: cuts.resample(self.target_sample_rate), recording_field="target_audio" ) input_text_tokens, target_token_lens = collate_token_channel( - cuts, self.tokenizer, self.frame_length, roles=self.output_roles, add_text_bos_and_eos_in_each_turn=self.add_text_bos_and_eos_in_each_turn, + cuts, + self.tokenizer, + self.frame_length, + roles=self.output_roles, + add_text_bos_and_eos_in_each_turn=self.add_text_bos_and_eos_in_each_turn, ) source_tokens, source_token_lens = collate_token_channel( - cuts, self.tokenizer, self.frame_length, roles=self.input_roles, add_text_bos_and_eos_in_each_turn=self.add_text_bos_and_eos_in_each_turn + cuts, + self.tokenizer, + self.frame_length, + roles=self.input_roles, + add_text_bos_and_eos_in_each_turn=self.add_text_bos_and_eos_in_each_turn, ) # if context audio is available use it, otherwise use a random agent turn as speaker reference @@ -255,9 +262,7 @@ def __getitem__(self, cuts: CutSet) -> dict: speaker_reference_audio.append(ref_audio.squeeze(0)) speaker_reference_audio_lens.append(ref_audio_len) - speaker_reference_audio = collate_vectors( - speaker_reference_audio, padding_value=0 - ).float() + speaker_reference_audio = collate_vectors(speaker_reference_audio, padding_value=0).float() speaker_reference_audio_lens = torch.tensor(speaker_reference_audio_lens).long() else: # extract target speaker reference from a random audio audio @@ -266,16 +271,12 @@ def __getitem__(self, cuts: CutSet) -> dict: ) # ensures that input_text_tokens is not longer than its duration - input_text_tokens = input_text_tokens[:, :target_token_lens.max()] + input_text_tokens = input_text_tokens[:, : target_token_lens.max()] - source_fps = self.source_sample_rate / ( - self.source_sample_rate * self.frame_length - ) - source_samples_per_frame = int(self.source_sample_rate//source_fps) - target_fps = self.target_sample_rate / ( - self.target_sample_rate * self.frame_length - ) - target_samples_per_frame = int(self.target_sample_rate//target_fps) + source_fps = self.source_sample_rate / (self.source_sample_rate * self.frame_length) + source_samples_per_frame = int(self.source_sample_rate // source_fps) + target_fps = self.target_sample_rate / (self.target_sample_rate * self.frame_length) + target_samples_per_frame = int(self.target_sample_rate // target_fps) # one is default and we add BOS on speech channel to ensures it, inside of the model class, so if we want bigger than that we can add padding in the audio here if self.num_delay_speech_tokens: @@ -302,27 +303,50 @@ def __getitem__(self, cuts: CutSet) -> dict: for i in range(input_text_tokens.size(0)): desc_tokens_ids = self.generate_prompt_description(device=input_text_tokens[i].device).squeeze(0) if self.add_audio_prompt_after_description: - prompt_audio_size = int(((self.audio_prompt_duration * self.target_sample_rate) // target_samples_per_frame) * target_samples_per_frame) - prompt_audio = sample_audio_segments_repeat(speaker_reference_audio, speaker_reference_audio_lens, prompt_audio_size, sample=True) + prompt_audio_size = int( + ((self.audio_prompt_duration * self.target_sample_rate) // target_samples_per_frame) + * target_samples_per_frame + ) + prompt_audio = sample_audio_segments_repeat( + speaker_reference_audio, speaker_reference_audio_lens, prompt_audio_size, sample=True + ) # add a silence in the end to smooth the transition between prompt and audio tokens, keep one extra pad token due shift on subword_ids - prompt_audio[:, -int(target_samples_per_frame * 2):] = 0 + prompt_audio[:, -int(target_samples_per_frame * 2) :] = 0 # create tensor to pad text channels with the same amount of frames added in audio channel (audio prompt) - prompt_audio_text_pad_size = (prompt_audio_size // target_samples_per_frame) - prompt_audio_text_pad = torch.ones(prompt_audio_text_pad_size, device=input_text_tokens.device, dtype=input_text_tokens.dtype) * text_pad_id + prompt_audio_text_pad_size = prompt_audio_size // target_samples_per_frame + prompt_audio_text_pad = ( + torch.ones( + prompt_audio_text_pad_size, device=input_text_tokens.device, dtype=input_text_tokens.dtype + ) + * text_pad_id + ) # set last prompt frame with eos in text channel prompt_audio_text_pad[-1] = self.tokenizer.eos # Add eos to simulate the end of a turn as in EAR-TTS inference - desc_tokens_ids = torch.cat([desc_tokens_ids, torch.tensor([self.tokenizer.eos], dtype=desc_tokens_ids.dtype, device=desc_tokens_ids.device)]) + desc_tokens_ids = torch.cat( + [ + desc_tokens_ids, + torch.tensor( + [self.tokenizer.eos], dtype=desc_tokens_ids.dtype, device=desc_tokens_ids.device + ), + ] + ) # Add padding equivalent to the audio prompt size in number of tokens - new_input_text_tokens = torch.cat([desc_tokens_ids.to(input_text_tokens.dtype), prompt_audio_text_pad.to(input_text_tokens.dtype), input_text_tokens[i]]) + new_input_text_tokens = torch.cat( + [ + desc_tokens_ids.to(input_text_tokens.dtype), + prompt_audio_text_pad.to(input_text_tokens.dtype), + input_text_tokens[i], + ] + ) # append to list and update lens input_text_tokens_.append(new_input_text_tokens) target_token_lens[i] = target_token_lens[i] + len(desc_tokens_ids) + prompt_audio_text_pad_size # add description to source text tokens - source_tokens_.append(torch.cat([desc_tokens_ids, prompt_audio_text_pad, source_tokens[i]])) + source_tokens_.append(torch.cat([desc_tokens_ids, prompt_audio_text_pad, source_tokens[i]])) source_token_lens[i] = source_token_lens[i] + len(desc_tokens_ids) + prompt_audio_text_pad_size # add silence in the source audio while the prompt is being processed pad_size = (len(desc_tokens_ids) * source_samples_per_frame) + prompt_audio.size(1) @@ -336,7 +360,9 @@ def __getitem__(self, cuts: CutSet) -> dict: target_audio_lens[i] = target_audio_lens[i] + pad_size + prompt_audio.size(1) # desc duration desc_lens.append(len(desc_tokens_ids)) - desc_plus_audio_prompt_lens.append(len(desc_tokens_ids) + prompt_audio_text_pad_size - 1) # -1 due the shift done in subword_ids + desc_plus_audio_prompt_lens.append( + len(desc_tokens_ids) + prompt_audio_text_pad_size - 1 + ) # -1 due the shift done in subword_ids else: # add description to target text tokens input_text_tokens_.append(torch.cat([desc_tokens_ids, input_text_tokens[i]])) @@ -355,7 +381,7 @@ def __getitem__(self, cuts: CutSet) -> dict: target_audio_.append(torch.cat([pad_audio, target_audio[i]])) target_audio_lens[i] = target_audio_lens[i] + pad_size - # des duration + # des duration desc_lens.append(len(desc_tokens_ids)) desc_plus_audio_prompt_lens.append(len(desc_tokens_ids)) @@ -372,12 +398,12 @@ def __getitem__(self, cuts: CutSet) -> dict: non_desc_mask[i, :frame] = 0.0 # desc mask is totally the oposite of audio mask - desc_mask = ~ non_desc_mask + desc_mask = ~non_desc_mask # create non_prompt_mask that should mask desc plus audio prompt if used non_prompt_mask = get_mask_from_lengths(target_token_lens) for i, frame in enumerate(desc_plus_audio_prompt_lens): - non_prompt_mask[i, :frame-1] = 0.0 + non_prompt_mask[i, : frame - 1] = 0.0 else: # create a mask for audio using target tokens that suppose to have the same size of the tokenized audio non_prompt_mask = get_mask_from_lengths(target_token_lens) @@ -388,25 +414,33 @@ def __getitem__(self, cuts: CutSet) -> dict: max_len = max(target_token_lens) # Segment IDs per sequence (padded) - aligned_segment_ids = torch.stack([ - torch.nn.functional.pad(torch.full((l,), i), (0, max_len - l), value=-1) # -1 for padding - for i, l in enumerate(target_token_lens) - ], dim=0) # [B, max_len] + aligned_segment_ids = torch.stack( + [ + torch.nn.functional.pad(torch.full((l,), i), (0, max_len - l), value=-1) # -1 for padding + for i, l in enumerate(target_token_lens) + ], + dim=0, + ) # [B, max_len] # Attention mask: same-segment & causal aligned_attention_mask = ( - (aligned_segment_ids.unsqueeze(-2) == aligned_segment_ids.unsqueeze(-1)) # [B, max_len, max_len] - & (torch.arange(max_len).unsqueeze(0).unsqueeze(1) - <= torch.arange(max_len).unsqueeze(0).unsqueeze(-1)) # causal tril - ) + aligned_segment_ids.unsqueeze(-2) == aligned_segment_ids.unsqueeze(-1) + ) & ( # [B, max_len, max_len] + torch.arange(max_len).unsqueeze(0).unsqueeze(1) <= torch.arange(max_len).unsqueeze(0).unsqueeze(-1) + ) # causal tril aligned_attention_mask = aligned_attention_mask.unsqueeze(1) # [B, 1, max_len, max_len] # create pos ids from the aligned lenght - aligned_position_ids = torch.stack([ - torch.nn.functional.pad(torch.arange(l), (0, max(target_token_lens) - l), value=0) # value=0 is safe for padding - for l in target_token_lens - ], dim=0) + aligned_position_ids = torch.stack( + [ + torch.nn.functional.pad( + torch.arange(l), (0, max(target_token_lens) - l), value=0 + ) # value=0 is safe for padding + for l in target_token_lens + ], + dim=0, + ) return { "sample_id": [str(cut.id) for cut in cuts], @@ -449,8 +483,7 @@ def collate_random_turn_audio( # Truncate audio according to supervision truncated_audio = cut.truncate( - offset=max(0, selected_supervision.start), - duration=selected_supervision.duration + offset=max(0, selected_supervision.start), duration=selected_supervision.duration ).load_custom(recording_field) selected_turn_audios.append(truncated_audio.squeeze(0)) @@ -464,11 +497,18 @@ def collate_token_channel( tokenizer: TokenizerSpec, frame_length: Seconds, roles: set[str], - add_text_bos_and_eos_in_each_turn: bool = True + add_text_bos_and_eos_in_each_turn: bool = True, ) -> tuple[torch.Tensor, torch.Tensor]: pad_id = get_pad_id(tokenizer) tokens = [ - build_token_channel(c, tokenizer=tokenizer, frame_length=frame_length, roles=roles, pad_id=pad_id, add_text_bos_and_eos_in_each_turn=add_text_bos_and_eos_in_each_turn) + build_token_channel( + c, + tokenizer=tokenizer, + frame_length=frame_length, + roles=roles, + pad_id=pad_id, + add_text_bos_and_eos_in_each_turn=add_text_bos_and_eos_in_each_turn, + ) for c in cuts ] token_lens = torch.tensor([len(tt) for tt in tokens]) diff --git a/nemo/collections/speechlm2/models/duplex_ear_tts.py b/nemo/collections/speechlm2/models/duplex_ear_tts.py index e8a2fc6217b6..d39f73fb3f45 100644 --- a/nemo/collections/speechlm2/models/duplex_ear_tts.py +++ b/nemo/collections/speechlm2/models/duplex_ear_tts.py @@ -11,15 +11,19 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +import copy +import glob +import math import os import random import tempfile -import numpy as np import time +from types import SimpleNamespace -import glob +import numpy as np import torch import torch.distributed as dist +import torch.nn as nn import torch.nn.functional as F import torchaudio from lightning import LightningModule @@ -36,19 +40,17 @@ loss_parallel, parallelize_module, ) -from transformers import DynamicCache -import math +from transformers import AutoModelForCausalLM, DynamicCache from nemo.collections.asr.models import EncDecSpeakerLabelModel - -from transformers import AutoModelForCausalLM - from nemo.collections.audio.parts.utils.resampling import resample -from nemo.core.classes.module import NeuralModule -from nemo.collections.common.tokenizers import AutoTokenizer from nemo.collections.common.parts.nlp_overrides import NLPSaveRestoreConnector +from nemo.collections.common.tokenizers import AutoTokenizer from nemo.collections.speechlm2.data.utils import get_pad_id from nemo.collections.speechlm2.models.duplex_s2s_model import tokens_to_str +from nemo.collections.speechlm2.modules.ear_tts_commons import SCRIPT_PLACEHOLDER +from nemo.collections.speechlm2.modules.rvq_ear_tts_model import RVQEARTTSConfig, RVQEARTTSModel +from nemo.collections.speechlm2.modules.rvq_ear_tts_vae import RVQVAEModel from nemo.collections.speechlm2.parts.hf_hub import HFHubMixin from nemo.collections.speechlm2.parts.lora import maybe_install_lora from nemo.collections.speechlm2.parts.metrics.asr_bleu import ASRBLEU @@ -64,20 +66,11 @@ set_model_dict_for_partial_init, setup_speech_encoder, ) +from nemo.collections.tts.modules import transformer_2501 +from nemo.core.classes.module import NeuralModule from nemo.core.neural_types import AudioSignal, LabelsType, LengthsType, NeuralType from nemo.utils import logging -from nemo.collections.tts.modules import transformer_2501 -from nemo.collections.speechlm2.modules.ear_tts_commons import SCRIPT_PLACEHOLDER -from types import SimpleNamespace - - -from nemo.collections.speechlm2.modules.rvq_ear_tts_model import RVQEARTTSModel, RVQEARTTSConfig -from nemo.collections.speechlm2.modules.rvq_ear_tts_vae import RVQVAEModel - -import torch -import torch.nn as nn -import copy def maybe_to(x, dtype): if x is None: @@ -86,10 +79,13 @@ def maybe_to(x, dtype): return x.to(dtype) return x + from collections import Counter from contextlib import contextmanager + import torch + @contextmanager def ensures_16_precision(mixed_dtype): """ @@ -106,10 +102,9 @@ def ensures_16_precision(mixed_dtype): torch.set_default_dtype(default_dtype) -def make_tts_model_mixed_precision_definite(model, inputs, - mixed_dtype=torch.bfloat16, - bf16_min=1e-2, bf16_max=1e2, - safety_factor=1.0): +def make_tts_model_mixed_precision_definite( + model, inputs, mixed_dtype=torch.bfloat16, bf16_min=1e-2, bf16_max=1e2, safety_factor=1.0 +): safe_min = bf16_min * safety_factor safe_max = bf16_max * safety_factor @@ -124,6 +119,7 @@ def hook(_, __, out): out = out[0] if torch.is_tensor(out): stats[name] = {"min": float(out.detach().min()), "max": float(out.detach().max())} + return hook for name, module in model_fp32.named_modules(): @@ -139,7 +135,7 @@ def hook(_, __, out): context_hidden_state=maybe_to(inputs["context_hidden_state"], torch.float32), subword_ids=inputs["subword_ids"], subword_mask=maybe_to(inputs["subword_mask"], torch.float32), - non_prompt_mask=maybe_to(inputs["non_prompt_mask"], torch.float32) + non_prompt_mask=maybe_to(inputs["non_prompt_mask"], torch.float32), ) for h in hooks: @@ -159,8 +155,7 @@ def hook(_, __, out): if name not in stats: continue mn, mx = stats[name]["min"], stats[name]["max"] - safe = (abs(mn) < safe_max and abs(mx) < safe_max - and not (abs(mn) < safe_min and abs(mx) < safe_min)) + safe = abs(mn) < safe_max and abs(mx) < safe_max and not (abs(mn) < safe_min and abs(mx) < safe_min) is_sensitive = False if isinstance(module, (nn.LayerNorm, nn.Embedding)): @@ -195,9 +190,13 @@ def new_forward(*args, **kwargs): with fp32_precision(): return module._original_forward(*args, **kwargs) else: - new_args = tuple(a.to(mixed_dtype) if isinstance(a, torch.Tensor) and a.is_floating_point() else a for a in args) - new_kwargs = {k: v.to(mixed_dtype) if isinstance(v, torch.Tensor) and v.is_floating_point() else v - for k, v in kwargs.items()} + new_args = tuple( + a.to(mixed_dtype) if isinstance(a, torch.Tensor) and a.is_floating_point() else a for a in args + ) + new_kwargs = { + k: v.to(mixed_dtype) if isinstance(v, torch.Tensor) and v.is_floating_point() else v + for k, v in kwargs.items() + } # with torch.cuda.amp.autocast(enabled=True, dtype=mixed_dtype): with ensures_16_precision(mixed_dtype): return module._original_forward(*new_args, **new_kwargs) @@ -234,7 +233,7 @@ def dtype_counter_hook(module, inputs, outputs): context_hidden_state=maybe_to(inputs["context_hidden_state"], torch.float32), subword_ids=inputs["subword_ids"], subword_mask=maybe_to(inputs["subword_mask"], torch.float32), - non_prompt_mask=maybe_to(inputs["non_prompt_mask"], torch.float32) + non_prompt_mask=maybe_to(inputs["non_prompt_mask"], torch.float32), ) for h in hook_handles: @@ -285,7 +284,9 @@ def generate_multiturn_speaking_mask(input_ids: torch.Tensor, bos_token_id: int return speaking_mask.long() -def replace_control_speech_codes(speech_codes: torch.Tensor, control_codes: torch.Tensor, silence_tokens: torch.Tensor = None) -> torch.Tensor: +def replace_control_speech_codes( + speech_codes: torch.Tensor, control_codes: torch.Tensor, silence_tokens: torch.Tensor = None +) -> torch.Tensor: """ Replaces control codes (speech BOS, EOS, etc) in `speech_codes` with the first frame which is assumed to consist of 'valid' codes representing silence. @@ -296,15 +297,15 @@ def replace_control_speech_codes(speech_codes: torch.Tensor, control_codes: torc return torch.where(torch.isin(speech_codes, control_codes), silence_tokens_expanded, speech_codes) if torch.isin(speech_codes[:, :1], control_codes).any(): - return torch.where(torch.isin(speech_codes, control_codes), torch.zeros_like(speech_codes[:, :1]), speech_codes) + return torch.where( + torch.isin(speech_codes, control_codes), torch.zeros_like(speech_codes[:, :1]), speech_codes + ) else: return torch.where(torch.isin(speech_codes, control_codes), speech_codes[:, :1], speech_codes) def get_mask_from_lengths( - lengths: torch.Tensor = None, - x: torch.Tensor = None, - pad_to_factor: int = None + lengths: torch.Tensor = None, x: torch.Tensor = None, pad_to_factor: int = None ) -> torch.Tensor: """Constructs binary mask from a 1D torch tensor of input lengths Args: @@ -330,6 +331,7 @@ def get_mask_from_lengths( mask = ids < lengths.unsqueeze(1) return mask + def setup_rvq_audio_codec(model): """ Sets up an ``AudioCodecModel``, initializing it from pretrained weights. @@ -340,10 +342,13 @@ def setup_rvq_audio_codec(model): if hasattr(model, "audio_codec") and next(model.audio_codec.parameters()).dtype == torch.float: return # skip if already set up and has the right dtype with fp32_precision(): - model.audio_codec = RVQVAEModel.from_pretrained(model.cfg.pretrained_ae_dir, strict=False).eval().to(model.device) + model.audio_codec = ( + RVQVAEModel.from_pretrained(model.cfg.pretrained_ae_dir, strict=False).eval().to(model.device) + ) for p in model.audio_codec.parameters(): p.requires_grad = False + def setup_audio_codec(self): setup_rvq_audio_codec(self) assert callable(self.tts_model.set_rvq_embs) @@ -353,12 +358,8 @@ def setup_audio_codec(self): self.target_fps = self.target_sample_rate / self.audio_codec.config.wav_to_token_ratio self.target_samples_per_frame = self.audio_codec.config.wav_to_token_ratio -def rescale_state_dict( - state_dict, - target_std=0.02, - first_n_layers=None, - layer_prefix="tts_model.backbone.layers." -): + +def rescale_state_dict(state_dict, target_std=0.02, first_n_layers=None, layer_prefix="tts_model.backbone.layers."): """ Rescale trainable weights in a state_dict for BF16 stability. @@ -467,7 +468,7 @@ def __init__(self, cfg: dict) -> None: self.source_fps = self.source_sample_rate / ( self.source_sample_rate * cfg.data.frame_length ) # conver frame rate in fps - self.source_samples_per_frame = int(self.source_sample_rate//self.source_fps) + self.source_samples_per_frame = int(self.source_sample_rate // self.source_fps) # get codec silence tokens self.codec_silence_tokens = self.get_codec_silence_frame() @@ -476,7 +477,9 @@ def __init__(self, cfg: dict) -> None: if self.cfg.get("use_word_sep_tokenizer", False): self.tokenizer = WordSepTokenizer(self.cfg.pretrained_lm_name, use_fast=True, trust_remote_code=True) else: - self.tokenizer = AutoTokenizer(self.cfg.pretrained_lm_name, use_fast=True, trust_remote_code=True) # Note that we are using fast tokenizer + self.tokenizer = AutoTokenizer( + self.cfg.pretrained_lm_name, use_fast=True, trust_remote_code=True + ) # Note that we are using fast tokenizer if 'Qwen2.5' in self.cfg.pretrained_lm_name: # For Qwen, '<|im_start|>' is a common choice for a BOS token. @@ -503,14 +506,12 @@ def __init__(self, cfg: dict) -> None: self.init_model_from_another_checkpoint(self.cfg.pretrained_model) def get_codec_silence_frame_last_one(self): - audio = torch.zeros(1, 10*self.target_sample_rate).float().to(self.device) + audio = torch.zeros(1, 10 * self.target_sample_rate).float().to(self.device) audio_len = torch.tensor([audio.size(-1)]).long() audio, audio_len = self.pad_audio_to_factor(audio, audio_len, self.target_samples_per_frame) with fp32_precision(), torch.no_grad(): - sil_codes, sil_codes_lens = self.audio_codec.encode( - audio.unsqueeze(1), audio_len - ) + sil_codes, sil_codes_lens = self.audio_codec.encode(audio.unsqueeze(1), audio_len) return sil_codes[0, -1] def get_codec_silence_frame(self): @@ -556,7 +557,9 @@ def _load_embed_tokens(self, cfg) -> nn.Embedding: def _load_tts_model(self, cfg) -> nn.Module: """Load TTS model for RVQ-EAR-TTS.""" if self.cfg.get("pretrained_tts_model", None): - self.tts_model = RVQEARTTSModel.from_pretrained(cfg.pretrained_tts_model, RVQEARTTSConfig(**cfg.tts_config), strict=False) + self.tts_model = RVQEARTTSModel.from_pretrained( + cfg.pretrained_tts_model, RVQEARTTSConfig(**cfg.tts_config), strict=False + ) else: # start the model from scratch self.tts_model = RVQEARTTSModel(RVQEARTTSConfig(**cfg.tts_config)) @@ -566,7 +569,9 @@ def _load_tts_model(self, cfg) -> nn.Module: def _load_language_model(self, cfg): """Load language model for RVQ-EAR-TTS.""" if cfg.pretrained_lm_name: - language_model = load_pretrained_hf(self.cfg.pretrained_lm_name, pretrained_weights=True, trust_remote_code=True).eval() + language_model = load_pretrained_hf( + self.cfg.pretrained_lm_name, pretrained_weights=True, trust_remote_code=True + ).eval() else: language_model = None return language_model @@ -595,7 +600,9 @@ def init_model_from_another_checkpoint(self, checkpoint_path): checkpoint_state = set_model_dict_for_partial_init(checkpoint_state, self.state_dict()) if self.cfg.get("rescale_pretrained_weights", None): - checkpoint_state = rescale_state_dict(checkpoint_state, first_n_layers=self.cfg.get("rescale_first_n_layers", None)) + checkpoint_state = rescale_state_dict( + checkpoint_state, first_n_layers=self.cfg.get("rescale_first_n_layers", None) + ) self.load_state_dict(checkpoint_state, strict=True) @@ -606,7 +613,7 @@ def device(self): @property def speech_vocab_size(self): """Return the size of the audio codec codebook including extra speech BOS and EOS tokens.""" - if self.use_local_transformer and self.local_transformer_type == "nar": # add extra token for mask + if self.use_local_transformer and self.local_transformer_type == "nar": # add extra token for mask return self._codebook_size + 4 return self._codebook_size + 3 @@ -642,11 +649,13 @@ def text_bos_id(self) -> int: @property def text_zstts_task_id(self) -> int: - return self.tokenizer.text_to_ids("<|box_start|>") # uses <|box_start|> special token as zstts task id token + return self.tokenizer.text_to_ids("<|box_start|>") # uses <|box_start|> special token as zstts task id token @property def text_cont_task_id(self) -> int: - return self.tokenizer.text_to_ids("<|object_ref_start|>") # uses <|object_ref_start|> special token as cont task id token + return self.tokenizer.text_to_ids( + "<|object_ref_start|>" + ) # uses <|object_ref_start|> special token as cont task id token @property def text_eos_id(self) -> int: @@ -693,10 +702,9 @@ def pad_audio_to_factor(self, audio, audio_len, samples_per_frame, downsampling_ num_padding = max_len - audio.shape[1] padded_audio = F.pad(audio, (0, num_padding)) return padded_audio, padded_len - + def prepare_inputs(self, batch: dict): - """ - """ + """ """ # check if audios has the same batch size assert batch["source_audio"].size(0) == batch["target_audio"].size(0) assert batch["speaker_reference_audio"].size(0) == batch["target_audio"].size(0) @@ -711,10 +719,10 @@ def prepare_inputs(self, batch: dict): # extract target audio codes with fp32_precision(), torch.no_grad(): - target_audio, target_audio_lens = self.pad_audio_to_factor(target_audio, target_audio_lens, self.target_samples_per_frame, 1) - target_codes, target_codes_lens = self.audio_codec.encode( - target_audio.unsqueeze(1), target_audio_lens + target_audio, target_audio_lens = self.pad_audio_to_factor( + target_audio, target_audio_lens, self.target_samples_per_frame, 1 ) + target_codes, target_codes_lens = self.audio_codec.encode(target_audio.unsqueeze(1), target_audio_lens) # ToDo: consider use the source audio """ @@ -764,9 +772,9 @@ def pad_or_truncate(x, pad_value=0): # ToDo: desc_mask is one for the end of the sequence, this is what cause the artifact issue in the end, fix it. # set the pad token when there is desc target_codes_aligned = torch.where( - desc_mask.unsqueeze(-1), # (B, T, 1) for broadcasting + desc_mask.unsqueeze(-1), # (B, T, 1) for broadcasting torch.full_like(target_codes, self.speech_pad_id), # fill with pad id - target_codes + target_codes, ) # set special token in the last audio prompt (it will works as a BOS token) @@ -800,7 +808,7 @@ def pad_or_truncate(x, pad_value=0): return { "code": target_codes_aligned, - "audio_mask": non_prompt_mask, # set audio_mask as non_prompt_mask to avoid the audio prompt in loss computation + "audio_mask": non_prompt_mask, # set audio_mask as non_prompt_mask to avoid the audio prompt in loss computation "attention_mask": aligned_attention_mask, "position_ids": aligned_position_ids, "subword_ids": subword_ids, @@ -808,11 +816,11 @@ def pad_or_truncate(x, pad_value=0): "context_hidden_state": context_hidden_state, "output_lens": target_codes_lens, "non_prompt_mask": non_prompt_mask, - "input_text_tokens": input_text_tokens + "input_text_tokens": input_text_tokens, } def training_step(self, batch: dict, batch_idx: int): - for m in (self.tts_model, ): + for m in (self.tts_model,): if is_frozen(m): m.eval() @@ -826,7 +834,7 @@ def training_step(self, batch: dict, batch_idx: int): context_hidden_state=inputs["context_hidden_state"], subword_ids=inputs["subword_ids"], subword_mask=inputs["subword_mask"], - non_prompt_mask=inputs["non_prompt_mask"] + non_prompt_mask=inputs["non_prompt_mask"], ) loss_dict = {"lm_loss": tts_output.lm_loss, "c_loss": tts_output.c_loss, "k_loss": tts_output.k_loss} backbone_out = tts_output.hidden_states @@ -883,21 +891,21 @@ def log_model_stats(self): total_g_params += g.numel() # L2 norms - weight_l2 = (total_w_sq ** 0.5) if total_w_sq > 0 else 0.0 - grad_l2 = (total_g_sq ** 0.5) if total_g_sq > 0 else 0.0 + weight_l2 = (total_w_sq**0.5) if total_w_sq > 0 else 0.0 + grad_l2 = (total_g_sq**0.5) if total_g_sq > 0 else 0.0 # RMS (global) weight_rms = ((total_w_sq / total_w_params) ** 0.5) if total_w_params > 0 else 0.0 - grad_rms = ((total_g_sq / total_g_params) ** 0.5) if total_g_params > 0 else 0.0 + grad_rms = ((total_g_sq / total_g_params) ** 0.5) if total_g_params > 0 else 0.0 # Mean weight_mean = sum_w / total_w_params if total_w_params > 0 else 0.0 # direct float logging avoids device sync penalty - self.log("weights/L2", weight_l2, on_epoch=True, sync_dist=True) - self.log("weights/RMS", weight_rms, on_epoch=True, sync_dist=True) - self.log("weights/max_abs", max_abs_w, on_epoch=True, sync_dist=True) - self.log("weights/mean", weight_mean, on_epoch=True, sync_dist=True) + self.log("weights/L2", weight_l2, on_epoch=True, sync_dist=True) + self.log("weights/RMS", weight_rms, on_epoch=True, sync_dist=True) + self.log("weights/max_abs", max_abs_w, on_epoch=True, sync_dist=True) + self.log("weights/mean", weight_mean, on_epoch=True, sync_dist=True) # ignore the grads stats for now # self.log("grads/L2", grad_l2, on_epoch=True, sync_dist=True) @@ -909,7 +917,7 @@ def on_validation_epoch_start(self) -> None: self.asr_bleu = ASRBLEU(self.cfg.scoring_asr).reset() self.intelligibility = Intelligibility(self.cfg.scoring_asr, reuse_asr_hyps=True).reset() self.secs = SECS(self.cfg.get("scoring_se", "titanet_large")).reset() - + def on_validation_epoch_end(self, prefix="val") -> None: asr_bleu = self.asr_bleu.compute() for k, m in asr_bleu.items(): @@ -935,16 +943,16 @@ def get_teacher_force_inference_audio(self, batch, guidance_enabled=True): non_prompt_mask=inputs["non_prompt_mask"], generation_config=self._get_generation_config(guidance_enabled=guidance_enabled), teacher_forcing_inference=True, - guidance_enabled=guidance_enabled + guidance_enabled=guidance_enabled, ) tf_audio_codes_pred = tts_output["codes"].squeeze(2) # decode audio - tf_audio_codes_pred = replace_control_speech_codes(tf_audio_codes_pred, self._control_codes, self.codec_silence_tokens) + tf_audio_codes_pred = replace_control_speech_codes( + tf_audio_codes_pred, self._control_codes, self.codec_silence_tokens + ) with fp32_precision(), torch.no_grad(): - audio_pred, audio_len = self.audio_codec.decode( - tf_audio_codes_pred, inputs["output_lens"] - ) + audio_pred, audio_len = self.audio_codec.decode(tf_audio_codes_pred, inputs["output_lens"]) return audio_pred.squeeze(1), audio_len @@ -958,40 +966,54 @@ def _get_generation_config(self, guidance_enabled: bool = False): "eos_threshold": -3.0, } - def offline_inference_with_custom_sentences(self, test_sentences: torch.Tensor, inference_speaker_reference: torch.Tensor, speech_text_ratio: float = 3.5): + def offline_inference_with_custom_sentences( + self, test_sentences: torch.Tensor, inference_speaker_reference: torch.Tensor, speech_text_ratio: float = 3.5 + ): # ToDo: split it in multiples batches to support long list of sentences B = len(test_sentences) # load and get speaker reference speaker_audio, sr = torchaudio.load(inference_speaker_reference) speaker_audio = resample(speaker_audio, sr, self.target_sample_rate) - speaker_audio = speaker_audio.repeat(B, 1).to(self.device) + speaker_audio = speaker_audio.repeat(B, 1).to(self.device) # lengths -> [B] speaker_audio_lens = torch.tensor([speaker_audio.size(1)], device=self.device).long().repeat(B) # Tokenize sentences tokenized = [ - torch.as_tensor([self.tokenizer.bos] + self.tokenizer.text_to_ids(text), dtype=torch.long, device=self.device) + torch.as_tensor( + [self.tokenizer.bos] + self.tokenizer.text_to_ids(text), dtype=torch.long, device=self.device + ) for text in test_sentences ] # Get max length and target length max_len = max(len(t) for t in tokenized) # Pad each to double length - target_len = int(speech_text_ratio * max_len) # make text longer to ensures that we have enough steps for speech gen - next_subword_ids = torch.stack([ - torch.cat([ - torch.tensor([self.text_pad_id], dtype=torch.long, device=self.device), # shift right adding one padding token - t, - torch.full((target_len - len(t) - 1,), self.text_pad_id, dtype=torch.long, device=self.device) # remaining padding - ]) - for t in tokenized - ]) + target_len = int( + speech_text_ratio * max_len + ) # make text longer to ensures that we have enough steps for speech gen + next_subword_ids = torch.stack( + [ + torch.cat( + [ + torch.tensor( + [self.text_pad_id], dtype=torch.long, device=self.device + ), # shift right adding one padding token + t, + torch.full( + (target_len - len(t) - 1,), self.text_pad_id, dtype=torch.long, device=self.device + ), # remaining padding + ] + ) + for t in tokenized + ] + ) audio, audio_len = self.offline_inference( speaker_audio=speaker_audio, speaker_audio_lens=speaker_audio_lens, next_subword_ids=next_subword_ids, - guidance_enabled=self.cfg.get("inference_guidance_enabled", True) + guidance_enabled=self.cfg.get("inference_guidance_enabled", True), ) return audio, audio_len, speaker_audio, speaker_audio_lens @@ -999,12 +1021,21 @@ def run_evaluation_one_batch(self, name, dataset_batch, use_dataloader_init=Fals results = {} inputs = self.prepare_inputs(dataset_batch) - # + # # exit() # first evaluation, make the model bf16 safe - if not self.model_16_precision_safe and self.cfg.get("ensures_16_safe", False) and str(self.trainer_config.precision) != str(32): + if ( + not self.model_16_precision_safe + and self.cfg.get("ensures_16_safe", False) + and str(self.trainer_config.precision) != str(32) + ): # ToDo: move it to a method - self.tts_model, summary = make_tts_model_mixed_precision_definite(self.tts_model, inputs, safety_factor=1.0, mixed_dtype=torch.float16 if str(self.trainer_config.precision) == str(16) else torch.bfloat16) + self.tts_model, summary = make_tts_model_mixed_precision_definite( + self.tts_model, + inputs, + safety_factor=1.0, + mixed_dtype=torch.float16 if str(self.trainer_config.precision) == str(16) else torch.bfloat16, + ) # self.tts_model, summary = make_tts_model_mixed_precision_safe(self.tts_model, inputs, safety_factor=1.0) self.model_16_precision_safe = True print("Current FP32 layers:", summary["fp32_layers"]) @@ -1018,27 +1049,30 @@ def run_evaluation_one_batch(self, name, dataset_batch, use_dataloader_init=Fals "non_prompt_mask": inputs["non_prompt_mask"], "context_hidden_state": inputs["context_hidden_state"], "subword_ids": inputs["subword_ids"], - "subword_mask": inputs["subword_mask"] + "subword_mask": inputs["subword_mask"], } # cut init_inputs to consider only the prompt for key in init_inputs: if init_inputs[key] is not None: - init_inputs[key] = torch.stack([ - init_inputs[key][i, :l] - for i, l in enumerate(dataset_batch["desc_plus_audio_prompt_lens"]) - ]) + init_inputs[key] = torch.stack( + [init_inputs[key][i, :l] for i, l in enumerate(dataset_batch["desc_plus_audio_prompt_lens"])] + ) # remove the prompt from the input_text_tokens to emulate S2S connected inference - next_subword_ids = torch.stack([ - inputs["subword_ids"][i, l:] # slice each element - for i, l in enumerate(dataset_batch["desc_plus_audio_prompt_lens"]) - ]) + next_subword_ids = torch.stack( + [ + inputs["subword_ids"][i, l:] # slice each element + for i, l in enumerate(dataset_batch["desc_plus_audio_prompt_lens"]) + ] + ) if self.cfg.get("use_asr_speech_tokens", False) and self.cfg.get("only_semantic_to_speech", False): - inp_asr_speech_tokens = torch.stack([ - inputs["target_asr_speech_tokens"][i, l:] # slice each element - for i, l in enumerate(dataset_batch["desc_plus_audio_prompt_lens"]) - ]) + inp_asr_speech_tokens = torch.stack( + [ + inputs["target_asr_speech_tokens"][i, l:] # slice each element + for i, l in enumerate(dataset_batch["desc_plus_audio_prompt_lens"]) + ] + ) else: inp_asr_speech_tokens = None @@ -1052,13 +1086,24 @@ def run_evaluation_one_batch(self, name, dataset_batch, use_dataloader_init=Fals ) # remove prompt padding from the user audio as autoregressive inference does not return the prompt - dataset_batch["source_audio"] = dataset_batch["source_audio"][:, -int(next_subword_ids.size(-1)*self.source_samples_per_frame):] + dataset_batch["source_audio"] = dataset_batch["source_audio"][ + :, -int(next_subword_ids.size(-1) * self.source_samples_per_frame) : + ] # clean prompt from the audio - results["audio_tf"] = results["audio_tf"][:, -int(next_subword_ids.size(-1)*self.target_samples_per_frame):] + results["audio_tf"] = results["audio_tf"][:, -int(next_subword_ids.size(-1) * self.target_samples_per_frame) :] # remove prompt from target audio - target_audio_no_prompt = dataset_batch["target_audio"][:, -int(next_subword_ids.size(-1)*self.target_samples_per_frame):] - target_audio_no_prompt_lens = dataset_batch["target_audio_lens"] - (torch.tensor(dataset_batch["desc_plus_audio_prompt_lens"], dtype=torch.long, device=dataset_batch["target_audio_lens"].device) * self.target_samples_per_frame) + target_audio_no_prompt = dataset_batch["target_audio"][ + :, -int(next_subword_ids.size(-1) * self.target_samples_per_frame) : + ] + target_audio_no_prompt_lens = dataset_batch["target_audio_lens"] - ( + torch.tensor( + dataset_batch["desc_plus_audio_prompt_lens"], + dtype=torch.long, + device=dataset_batch["target_audio_lens"].device, + ) + * self.target_samples_per_frame + ) # for i, l in enumerate(dataset_batch["desc_plus_audio_prompt_lens"]): # results["audio_tf"][i, :l*self.target_samples_per_frame] = 0.0 @@ -1071,11 +1116,16 @@ def run_evaluation_one_batch(self, name, dataset_batch, use_dataloader_init=Fals metric_audio_pred_lens = (metric_audio_pred_lens / self.target_sample_rate * 16000).to(torch.long) # reshape target audio without prompt target_audio_no_prompt_16khz = resample(target_audio_no_prompt, self.target_sample_rate, 16000) - target_audio_no_prompt_lens_16khz = (target_audio_no_prompt_lens / self.target_sample_rate * 16000).to(torch.long) + target_audio_no_prompt_lens_16khz = (target_audio_no_prompt_lens / self.target_sample_rate * 16000).to( + torch.long + ) if self.cfg.get("use_GT_transcriptions_for_metrics", True): # use target audio transcription for metrics target_asr_texts = self.asr_bleu.asr.transcribe( - [audio[:alen] for audio, alen in zip(target_audio_no_prompt_16khz, target_audio_no_prompt_lens_16khz)], + [ + audio[:alen] + for audio, alen in zip(target_audio_no_prompt_16khz, target_audio_no_prompt_lens_16khz) + ], batch_size=target_audio_no_prompt_16khz.shape[0], verbose=False, ) @@ -1097,20 +1147,24 @@ def run_evaluation_one_batch(self, name, dataset_batch, use_dataloader_init=Fals pred_audio_lens=metric_audio_pred_lens, asr_hyps=asr_hyps, ) - + # add ground truth intelligibility metrics self.intelligibility.update( - name=name+"_gt", + name=name + "_gt", refs=dataset_batch["target_texts"], pred_audio=target_audio_no_prompt_16khz, pred_audio_lens=target_audio_no_prompt_lens_16khz, - asr_hyps=metric_text if self.cfg.get("use_GT_transcriptions_for_metrics", True) else None, # reuse GT transcription + asr_hyps=( + metric_text if self.cfg.get("use_GT_transcriptions_for_metrics", True) else None + ), # reuse GT transcription ) self.secs.update( name=name, target_audio=resample(dataset_batch["target_audio"], self.target_sample_rate, 16000), - target_audio_lens=(dataset_batch["target_audio_lens"] / self.target_sample_rate * 16000).to(torch.long), + target_audio_lens=(dataset_batch["target_audio_lens"] / self.target_sample_rate * 16000).to( + torch.long + ), pred_audio=resample(results["audio"], self.target_sample_rate, 16000), pred_audio_lens=(results["audio_len"] / self.target_sample_rate * 16000).to(torch.long), ) @@ -1145,7 +1199,9 @@ def validation_step(self, batch: dict, batch_idx: int): logging.info(f"Generating {name} custom sentences.") test_sentences = self.cfg.test_sentences[name] results = {} - results["audio"], results["audio_len"], speaker_audio, speaker_audio_lens = self.offline_inference_with_custom_sentences(test_sentences, self.cfg.inference_speaker_reference) + results["audio"], results["audio_len"], speaker_audio, speaker_audio_lens = ( + self.offline_inference_with_custom_sentences(test_sentences, self.cfg.inference_speaker_reference) + ) with fp32_precision(): # resample is fragile to bfloat16 default dtype metric_audio_pred = results["audio"] metric_audio_pred_lens = results["audio_len"] @@ -1204,7 +1260,9 @@ def validation_step(self, batch: dict, batch_idx: int): # run inference for multiples references if self.cfg.get("inference_speaker_reference_path", None): B = len(dataset_batch['sample_id']) - for inference_speaker_reference in glob.glob(os.path.join(self.cfg.inference_speaker_reference_path, "**"), recursive=True): + for inference_speaker_reference in glob.glob( + os.path.join(self.cfg.inference_speaker_reference_path, "**"), recursive=True + ): if not os.path.isfile(inference_speaker_reference): continue print("Generating sample for speaker refernce:", inference_speaker_reference) @@ -1212,12 +1270,10 @@ def validation_step(self, batch: dict, batch_idx: int): # Get only the file name ref_name = os.path.basename(inference_speaker_reference) # Append to each sample_id - new_dataset_batch['sample_id'] = [ - f"{sid}_{ref_name}" for sid in dataset_batch['sample_id'] - ] + new_dataset_batch['sample_id'] = [f"{sid}_{ref_name}" for sid in dataset_batch['sample_id']] speaker_audio, sr = torchaudio.load(inference_speaker_reference) speaker_audio = resample(speaker_audio, sr, self.target_sample_rate) - speaker_audio = speaker_audio.repeat(B, 1).to(self.device) + speaker_audio = speaker_audio.repeat(B, 1).to(self.device) # lengths -> [B] speaker_audio_lens = torch.tensor([speaker_audio.size(1)], device=self.device).long().repeat(B) new_dataset_batch["speaker_reference_audio"] = speaker_audio @@ -1228,7 +1284,7 @@ def validation_step(self, batch: dict, batch_idx: int): new_dataset_batch = copy.deepcopy(dataset_batch) speaker_audio, sr = torchaudio.load(inference_speaker_reference) speaker_audio = resample(speaker_audio, sr, self.target_sample_rate) - speaker_audio = speaker_audio.repeat(B, 1).to(self.device) + speaker_audio = speaker_audio.repeat(B, 1).to(self.device) # lengths -> [B] speaker_audio_lens = torch.tensor([speaker_audio.size(1)], device=self.device).long().repeat(B) new_dataset_batch["speaker_reference_audio"] = speaker_audio @@ -1258,7 +1314,7 @@ def get_system_prompt(self, system_prompt=None, user_prompt=None): "Your response should remain independent of any stylistic instructions." ) messages.append({"role": "system", "content": system_prompt}) - + # ToDo: implement dataloading support for descriptions """for desc in example["descriptions"]: user_prompt = "" @@ -1295,8 +1351,8 @@ def get_system_prompt(self, system_prompt=None, user_prompt=None): def get_init_inputs(self, speaker_audio, speaker_audio_lens, system_prompt=None, user_prompt=None): # compute prompt audio size and slice it with fp32_precision(): - """ - # old pad that can add long silences in the end + """ + # old pad that can add long silences in the end prompt_audio_size = int(((self.data_cfg.audio_prompt_duration * self.target_sample_rate) // self.target_samples_per_frame) * self.target_samples_per_frame) B, T = speaker_audio.shape # [batch, time] if T >= prompt_audio_size: @@ -1310,8 +1366,7 @@ def get_init_inputs(self, speaker_audio, speaker_audio_lens, system_prompt=None, """ # compute the exact number of samples for the prompt duration prompt_audio_size = int( - ((self.data_cfg.audio_prompt_duration * self.target_sample_rate) - // self.target_samples_per_frame) + ((self.data_cfg.audio_prompt_duration * self.target_sample_rate) // self.target_samples_per_frame) * self.target_samples_per_frame ) @@ -1343,34 +1398,49 @@ def get_init_inputs(self, speaker_audio, speaker_audio_lens, system_prompt=None, prompt_audio[b] = expanded[:prompt_audio_size] # add a silence in the end to smooth the transition between prompt and audio tokens - prompt_audio[:, -int(self.target_samples_per_frame * 2):] = 0 + prompt_audio[:, -int(self.target_samples_per_frame * 2) :] = 0 # get prompt audio size with fp32_precision(): prompt_audio_text_pad_size = int(prompt_audio_size // self.target_samples_per_frame) - + # get description tokens desc_tokens_ids = self.get_system_prompt(system_prompt=system_prompt, user_prompt=user_prompt) # create a padding tensor - prompt_audio_text_pad = torch.ones(prompt_audio_text_pad_size, device=self.device, dtype=desc_tokens_ids.dtype) * self.text_pad_id + prompt_audio_text_pad = ( + torch.ones(prompt_audio_text_pad_size, device=self.device, dtype=desc_tokens_ids.dtype) * self.text_pad_id + ) prompt_audio_text_pad[-1] = self.tokenizer.eos # Add eos to simulate the end of a turn as in EAR-TTS inference - desc_tokens_ids = torch.cat([desc_tokens_ids.squeeze(), torch.tensor([self.tokenizer.eos], dtype=desc_tokens_ids.dtype, device=desc_tokens_ids.device)]) + desc_tokens_ids = torch.cat( + [ + desc_tokens_ids.squeeze(), + torch.tensor([self.tokenizer.eos], dtype=desc_tokens_ids.dtype, device=desc_tokens_ids.device), + ] + ) # Add padding equivalent to the audio prompt size in number of tokens - input_text_tokens = torch.cat([desc_tokens_ids.to(desc_tokens_ids.dtype), prompt_audio_text_pad.to(desc_tokens_ids.dtype)]) + input_text_tokens = torch.cat( + [desc_tokens_ids.to(desc_tokens_ids.dtype), prompt_audio_text_pad.to(desc_tokens_ids.dtype)] + ) # create pad audio for the description pad_size = desc_tokens_ids.size(-1) * self.target_samples_per_frame - pad_audio = torch.zeros(pad_size, device=prompt_audio.device, dtype=prompt_audio.dtype).unsqueeze(0).repeat(prompt_audio.size(0), 1) + pad_audio = ( + torch.zeros(pad_size, device=prompt_audio.device, dtype=prompt_audio.dtype) + .unsqueeze(0) + .repeat(prompt_audio.size(0), 1) + ) # repeat to reaches the batch size input_text_tokens = input_text_tokens.unsqueeze(0).repeat(prompt_audio.size(0), 1) target_audio = torch.cat([pad_audio, prompt_audio], dim=1) # extract code codes - target_audio_len = torch.tensor([target_audio.size(-1)] * target_audio.size(0), dtype=torch.long, device=self.device) + target_audio_len = torch.tensor( + [target_audio.size(-1)] * target_audio.size(0), dtype=torch.long, device=self.device + ) with fp32_precision(), torch.no_grad(): code, _ = self.audio_codec.encode(target_audio.unsqueeze(1), target_audio_len) @@ -1383,20 +1453,23 @@ def get_init_inputs(self, speaker_audio, speaker_audio_lens, system_prompt=None, # create masks # non_prompt_mask is all zeros, because all processed is prompt non_prompt_mask = torch.zeros_like(input_text_tokens) - non_prompt_mask[:, -2:] = 1 # set last valid prompt frame as 1 to allow the addition of BOS in the right place - subword_mask = torch.zeros_like(input_text_tokens) # subword_mask is almost all zeros because on the warmup there is only the prompt - subword_mask[:, -3:] = 1 # -3 because of the it start right after the first valid prompt token and it is shifted by 1 + non_prompt_mask[:, -2:] = 1 # set last valid prompt frame as 1 to allow the addition of BOS in the right place + subword_mask = torch.zeros_like( + input_text_tokens + ) # subword_mask is almost all zeros because on the warmup there is only the prompt + subword_mask[:, -3:] = ( + 1 # -3 because of the it start right after the first valid prompt token and it is shifted by 1 + ) # desc mask is all zeros except the description desc_mask = torch.zeros_like(input_text_tokens) - desc_mask[:, :desc_tokens_ids.size(-1)] = 1 - + desc_mask[:, : desc_tokens_ids.size(-1)] = 1 if not self.cfg.get("disable_speech_pad", False): # add special tokens on audio codes code = torch.where( - desc_mask.unsqueeze(-1).bool(), # (B, T, 1) for broadcasting + desc_mask.unsqueeze(-1).bool(), # (B, T, 1) for broadcasting torch.full_like(code, self.speech_pad_id), # fill with pad id - code + code, ) # shift subword_ids @@ -1410,11 +1483,13 @@ def get_init_inputs(self, speaker_audio, speaker_audio_lens, system_prompt=None, init_inputs = { "code": code[:, :-1], - "audio_mask": non_prompt_mask.bool()[:, :-1], # set audio_mask as non_prompt_mask to avoid the audio prompt in loss computation + "audio_mask": non_prompt_mask.bool()[ + :, :-1 + ], # set audio_mask as non_prompt_mask to avoid the audio prompt in loss computation "context_hidden_state": context_hidden_state[:, :-1] if context_hidden_state is not None else None, "subword_ids": subword_ids[:, :-1], "subword_mask": subword_mask.bool()[:, :-1], - "non_prompt_mask": non_prompt_mask.bool()[:, :-1] + "non_prompt_mask": non_prompt_mask.bool()[:, :-1], } return init_inputs @@ -1456,7 +1531,9 @@ def offline_inference( # ToDo: verify why codes differ from dataloader init_inputs when using nanocodec if init_inputs is None: - init_inputs = self.get_init_inputs(speaker_audio, speaker_audio_lens, system_prompt=system_prompt, user_prompt=user_prompt) + init_inputs = self.get_init_inputs( + speaker_audio, speaker_audio_lens, system_prompt=system_prompt, user_prompt=user_prompt + ) # compare_dicts(init_inputs_fn, init_inputs) if self.cfg.get("use_asr_speech_tokens", False) and self.cfg.get("only_semantic_to_speech", False): @@ -1486,10 +1563,12 @@ def offline_inference( if self.cfg.get("use_asr_speech_tokens", False): if self.cfg.get("only_semantic_to_speech", False): cur_asr_speech_tokens = inp_asr_speech_tokens[:, 0].unsqueeze(-1) - else: + else: if guidance_enabled and self.cfg.get("asr_speech_tokens_use_guidance", True): hidden_states, uncond_hidden_states = outputs.hidden_states.chunk(2, dim=0) - logits = self.asr_speech_tokens_head(hidden_states + (generation_config["guidance_scale"] * (hidden_states - uncond_hidden_states))) + logits = self.asr_speech_tokens_head( + hidden_states + (generation_config["guidance_scale"] * (hidden_states - uncond_hidden_states)) + ) else: hidden_states, _ = outputs.hidden_states.chunk(2, dim=0) logits = self.asr_speech_tokens_head(hidden_states) @@ -1499,14 +1578,16 @@ def offline_inference( # use the text tokens to stop generation max_steps = next_subword_ids.size(-1) # create variable to store the audios - gen_audio_codes = torch.zeros(B, max_steps, self.tts_model.config.num_quantizers, device=self.device, dtype=torch.long) + gen_audio_codes = torch.zeros( + B, max_steps, self.tts_model.config.num_quantizers, device=self.device, dtype=torch.long + ) # init subwork as all ones subword_mask = torch.ones(B, max_steps, device=self.device, dtype=torch.bool) # get first context subword_id, that is the last subword_ids from the warmup first_context_subword_id = init_inputs["subword_ids"][:, -1].unsqueeze(-1) - for i in range(max_steps-1): + for i in range(max_steps - 1): step_start = time.time() # current subword id is always seem current_subword_id = next_subword_ids[:, i].unsqueeze(-1) @@ -1517,7 +1598,7 @@ def offline_inference( if i == 0: context_subword_id = first_context_subword_id else: - context_subword_id = next_subword_ids[:, i-1].unsqueeze(-1) + context_subword_id = next_subword_ids[:, i - 1].unsqueeze(-1) context_hidden_state = self.embed_tokens(context_subword_id) else: @@ -1536,7 +1617,7 @@ def offline_inference( "use_cache": True, "guidance_enabled": guidance_enabled, "generation_config": generation_config, - "ignore_eos_flag_stop": True + "ignore_eos_flag_stop": True, } outputs = self.tts_model(**inputs) @@ -1544,19 +1625,21 @@ def offline_inference( code = outputs["codes"] past_key_values = outputs["past_key_values"] # ToDo: check why it is -1 - gen_audio_codes[:, i-1] = code.squeeze(1) + gen_audio_codes[:, i - 1] = code.squeeze(1) if self.cfg.get("use_asr_speech_tokens", False) and not self.cfg.get("only_semantic_to_speech", False): if guidance_enabled and self.cfg.get("asr_speech_tokens_use_guidance", True): hidden_states, uncond_hidden_states = outputs.hidden_states.chunk(2, dim=0) - logits = self.asr_speech_tokens_head(hidden_states + (generation_config["guidance_scale"] * (hidden_states - uncond_hidden_states))) + logits = self.asr_speech_tokens_head( + hidden_states + (generation_config["guidance_scale"] * (hidden_states - uncond_hidden_states)) + ) else: hidden_states, _ = outputs.hidden_states.chunk(2, dim=0) logits = self.asr_speech_tokens_head(hidden_states) cur_asr_speech_tokens = logits.argmax(dim=-1)[:, -1].unsqueeze(-1) - # force silence as next token + # force silence as next token if self.cfg.get('inference_force_speech_silence_on_eos', True): silence_codes = self.codec_silence_tokens.view(1, 1, -1).expand(code.shape) code = torch.where( @@ -1565,21 +1648,17 @@ def offline_inference( code, # keep original ) - step_time = time.time()-step_start + step_time = time.time() - step_start logging.info(f"Autoregressive inference step: {i} of {max_steps} take around {step_time}s") - gen_audio_codes_lens = torch.tensor([gen_audio_codes.shape[1]] * gen_audio_codes.shape[0]).to(self.device) # decode audio. Note that it is not necessary because the prompt is removed, so no special token should be on the output, but lets do it for safety gen_audio_codes = replace_control_speech_codes(gen_audio_codes, self._control_codes, self.codec_silence_tokens) with fp32_precision(), torch.no_grad(): - audio_pred, audio_len = self.audio_codec.decode( - gen_audio_codes, gen_audio_codes_lens - ) + audio_pred, audio_len = self.audio_codec.decode(gen_audio_codes, gen_audio_codes_lens) return audio_pred.squeeze(1), audio_len - def backward(self, *args, **kwargs): with loss_parallel(): super().backward(*args, **kwargs) @@ -1664,7 +1743,15 @@ def configure_model(self) -> None: parallelize_module(transformer_block, tp_mesh, plan) - for m in (self.tts_model.mog_head, self.tts_model.embed_subword, self.tts_model.embed_context, self.tts_model.embed_code, self.tts_model.null_emb, self.tts_model.bos_emb, self.tts_model.lm_head): + for m in ( + self.tts_model.mog_head, + self.tts_model.embed_subword, + self.tts_model.embed_context, + self.tts_model.embed_code, + self.tts_model.null_emb, + self.tts_model.bos_emb, + self.tts_model.lm_head, + ): parallelize_module( m, tp_mesh, @@ -1686,10 +1773,12 @@ def configure_model(self) -> None: for idx in range(self.tts_model._num_codebooks): self.tts_model.audio_embeddings[idx] = fully_shard(self.tts_model.audio_embeddings[idx], **fsdp_config) - + if self.tts_model.use_local_transformer: self.tts_model.local_transformer = fully_shard(self.tts_model.local_transformer, **fsdp_config) - self.tts_model.local_transformer_in_projection = fully_shard(self.tts_model.local_transformer_in_projection, **fsdp_config) + self.tts_model.local_transformer_in_projection = fully_shard( + self.tts_model.local_transformer_in_projection, **fsdp_config + ) else: self.embed_text_tokens = fully_shard(self.embed_text_tokens, **fsdp_config) # self.tts_model = fully_shard(self.tts_model, **fsdp_config) diff --git a/nemo/collections/speechlm2/modules/ear_tts_commons.py b/nemo/collections/speechlm2/modules/ear_tts_commons.py index 03cdc9da298a..cddebcc35344 100644 --- a/nemo/collections/speechlm2/modules/ear_tts_commons.py +++ b/nemo/collections/speechlm2/modules/ear_tts_commons.py @@ -1,19 +1,19 @@ # Standard library +import argparse import glob +import importlib.machinery import json import os import re import shutil import subprocess import sys -import importlib.machinery from collections.abc import Mapping, MutableMapping from typing import Any -import argparse import torch -from torch import nn from safetensors import safe_open +from torch import nn from nemo.utils import logging @@ -174,9 +174,9 @@ def get_config_from_file(config_path: str) -> Config: config = json.load(f) else: config_module = importlib.machinery.SourceFileLoader("_config", config_path).load_module() - assert hasattr(config_module, PYTHON_CONFIG_GETTER_NAME), ( - f"Python config file must define a `{PYTHON_CONFIG_GETTER_NAME}` function." - ) + assert hasattr( + config_module, PYTHON_CONFIG_GETTER_NAME + ), f"Python config file must define a `{PYTHON_CONFIG_GETTER_NAME}` function." config = getattr(config_module, PYTHON_CONFIG_GETTER_NAME)(py_config_name) assert isinstance(config, Mapping), f"`{PYTHON_CONFIG_GETTER_NAME}` must return a dictionary-like object." cfg = Config(**config) @@ -242,12 +242,11 @@ def get_config_from_dir(workdir_path: str) -> Config: return cfg - - # ============================================================================== # Base Model Classes # ============================================================================== + class PreTrainedModel(nn.Module): config_class = Config @@ -369,6 +368,7 @@ def _get_weight_names(module): {"params": params_wo_decay, "weight_decay": 0.0}, ] + # ============================================================================== # IO and Checkpointing Utilities # ============================================================================== diff --git a/nemo/collections/speechlm2/modules/rvq_ear_tts_model.py b/nemo/collections/speechlm2/modules/rvq_ear_tts_model.py index 21928e66ad73..28ceebbc1341 100644 --- a/nemo/collections/speechlm2/modules/rvq_ear_tts_model.py +++ b/nemo/collections/speechlm2/modules/rvq_ear_tts_model.py @@ -7,40 +7,29 @@ import re import shutil import sys +import unicodedata from collections.abc import Mapping, MutableMapping from dataclasses import dataclass, field, fields from typing import Any -import unicodedata -from nemo.collections.speechlm2.parts.precision import fp32_precision -from nemo.collections.speechlm2.parts.pretrained import set_model_dict_for_partial_init import torch -from torch import Tensor, nn -from torch.nn import functional as F import transformers -from transformers import ( - AutoConfig, - AutoModel, - AutoModelForTextEncoding, - AutoTokenizer, - Cache, -) -from transformers.generation.logits_process import ( - TopKLogitsWarper, - TopPLogitsWarper, -) from safetensors import safe_open +from torch import Tensor, nn +from torch.nn import functional as F +from transformers import AutoConfig, AutoModel, AutoModelForTextEncoding, AutoTokenizer, Cache +from transformers.generation.logits_process import TopKLogitsWarper, TopPLogitsWarper +from nemo.collections.speechlm2.modules.ear_tts_commons import Config, PreTrainedModel +from nemo.collections.speechlm2.parts.precision import fp32_precision +from nemo.collections.speechlm2.parts.pretrained import set_model_dict_for_partial_init from nemo.utils import logging -from nemo.collections.speechlm2.modules.ear_tts_commons import ( - Config, - PreTrainedModel -) # ============================================================================== # MLP module and Norm # ============================================================================== + class RMSNorm(nn.Module): def __init__(self, dim: int, eps: float = 1e-6): super().__init__() @@ -329,9 +318,9 @@ def get_mask( if validate: # Reconstruct the expected contiguous mask and assert equality. expected_mask = sequence_mask(num_valid.view(-1), depth).view_as(code_mask) - assert torch.equal(code_mask, expected_mask), ( - "Input `code_mask` must have contiguous `True` values at the beginning." - ) + assert torch.equal( + code_mask, expected_mask + ), "Input `code_mask` must have contiguous `True` values at the beginning." # Calculate the target number of valid tokens. if not unmasking: @@ -345,7 +334,6 @@ def get_mask( return sequence_mask(num_to_keep.view(-1), depth).view_as(code_mask) - @dataclass class CASConfig(Config): pretrained_tokenizer_name: str = "meta-llama/Llama-3.1-8B-Instruct" @@ -542,6 +530,7 @@ def _build_char_vocab() -> dict[str, int]: subword_id_to_char_ids[subword_padding_idx] = (len(char_vocab),) return subword_id_to_char_ids, char_vocab, subword_padding_idx + def _split_ipa_symbols(text: str) -> list[str]: """ Split IPA text into grapheme clusters (true phoneme symbols) @@ -562,6 +551,7 @@ def _split_ipa_symbols(text: str) -> list[str]: phonemes.append(cluster) return phonemes + def build_phoneme_vocabs( pretrained_tokenizer_name: str, vocab_dir: str | None = None, @@ -583,6 +573,7 @@ def build_phoneme_vocabs( - subword_padding_idx: int """ from phonemizer import phonemize + tokenizer = AutoTokenizer.from_pretrained(pretrained_tokenizer_name) def _phonemize_all_subwords() -> dict[str, list[str]]: @@ -592,7 +583,7 @@ def _phonemize_all_subwords() -> dict[str, list[str]]: phoneme_strings = phonemize( subwords, language=language, - backend="espeak", # use espeak-ng + backend="espeak", # use espeak-ng strip=True, njobs=1, preserve_punctuation=True, @@ -671,6 +662,7 @@ def depthsum_encoding_step( return code """ + @torch.compile def depthsum_encoding_step( embs: Tensor, @@ -690,10 +682,11 @@ def depthsum_encoding_step( r = r - emb_i # FIX: assign correctly without shape mismatch - code[..., i] = idx_sel + code[..., i] = idx_sel return code + class MoGHead(nn.Module): """ A Mixture of Gaussians (MoG) prediction head. @@ -771,7 +764,7 @@ def infer(self, x: Tensor, guidance_scale: float = 0.0, top_p_or_k: float | int # Apply top-p or top-k filtering to the mixture logits if top_p_or_k is not None: - + logits = ( TopPLogitsWarper(top_p_or_k)( None, @@ -858,7 +851,9 @@ def dist(self, mus: Tensor, mu: Tensor) -> Tensor: x, low_mat_sq.to(x), ) - ).sum(-1) # [b, t, n] + ).sum( + -1 + ) # [b, t, n] y_sq = y.pow(2).sum(-1, keepdim=True) # [b, t, 1] xwy = (x * torch.einsum("bti,nij->btnj", y, self.low_mat.to(y))).sum( -1 @@ -874,11 +869,13 @@ class NeMoSubwordFlagEmbedding(nn.Module): (subwords that do NOT start with Ġ or the word-boundary marker). Compatible with NeMo AutoTokenizer. """ + def __init__(self, model_name: str, d_model: int): super().__init__() # Load tokenizer from NeMo # self.tokenizer_hf = AutoTokenizer.from_pretrained(model_name) from nemo.collections.common.tokenizers import AutoTokenizer as NeMoAutoTokenizer + self.tokenizer = NeMoAutoTokenizer(model_name, use_fast=True, trust_remote_code=True) self.vocab_size = self.tokenizer.vocab_size self.d_model = d_model @@ -887,12 +884,13 @@ def __init__(self, model_name: str, d_model: int): tokens = [self.tokenizer.ids_to_tokens(i) for i in range(self.vocab_size)] self.register_buffer( 'is_continuation', - torch.tensor([1 if not (tok.startswith("Ġ") or tok.startswith("▁")) else 0 for tok in tokens], - dtype=torch.long) + torch.tensor( + [1 if not (tok.startswith("Ġ") or tok.startswith("▁")) else 0 for tok in tokens], dtype=torch.long + ), ) # Tiny embedding table: 0 = word-start, 1 = continuation - init_std = self.d_model ** -0.5 + init_std = self.d_model**-0.5 self.cont_emb = nn.Embedding(2, self.d_model) nn.init.normal_(self.cont_emb.weight, mean=0.0, std=init_std) @@ -914,6 +912,7 @@ class SubwordFlagEmbedding(nn.Module): Automatically adds a custom padding token at index vocab_size. Ignores special tokens (starting with '<') when computing continuation flags. """ + def __init__(self, model_name: str, d_model: int): super().__init__() self.tokenizer = AutoTokenizer.from_pretrained(model_name) @@ -928,35 +927,34 @@ def __init__(self, model_name: str, d_model: int): # Precompute continuation flags tokens = [self.tokenizer.convert_ids_to_tokens(i) for i in range(self.vocab_size)] cont_flags = [ - 1 if not (tok.startswith("Ġ") or tok.startswith("▁") or tok.startswith("<")) else 0 - for tok in tokens + 1 if not (tok.startswith("Ġ") or tok.startswith("▁") or tok.startswith("<")) else 0 for tok in tokens ] cont_flags.append(0) # for the custom pad token self.register_buffer("is_continuation", torch.tensor(cont_flags, dtype=torch.long)) # Continuation embedding - init_std = self.d_model ** -0.5 + init_std = self.d_model**-0.5 self.cont_emb = nn.Embedding(2, self.d_model) nn.init.normal_(self.cont_emb.weight, mean=0.0, std=init_std) self.cont_emb.weight.data[0].zero_() def forward(self, subword_embeds: torch.Tensor, token_ids: torch.LongTensor): # Replace OOV token IDs with pad_id safely - token_ids_clamped = torch.where(token_ids >= self.vocab_size, - self.pad_tensor, - token_ids) + token_ids_clamped = torch.where(token_ids >= self.vocab_size, self.pad_tensor, token_ids) # Continuation flags cont_flags = self.is_continuation[token_ids_clamped] # Add continuation embedding cont_emb = self.cont_emb(cont_flags) return subword_embeds + cont_emb + class BOSEOSEmbedding(nn.Module): """ Adds independent embeddings for BOS and EOS tokens using a single embedding table. Index 0 = regular token (ignored), 1 = BOS, 2 = EOS. Compatible with Hugging Face tokenizers that may or may not have BOS/EOS. """ + def __init__(self, model_name: str, d_model: int): super().__init__() self.tokenizer = AutoTokenizer.from_pretrained(model_name) @@ -990,7 +988,7 @@ def __init__(self, model_name: str, d_model: int): special_flags.append(0) # for custom pad token self.register_buffer("special_flags", torch.tensor(special_flags, dtype=torch.long)) # Embedding table: 0 = regular, 1 = BOS, 2 = EOS - init_std = self.d_model ** -0.5 + init_std = self.d_model**-0.5 self.special_emb = nn.Embedding(3, d_model) nn.init.normal_(self.special_emb.weight, mean=0.0, std=init_std) self.special_emb.weight.data[0].zero_() # regular tokens ignored @@ -1007,11 +1005,13 @@ def forward(self, token_embeds: torch.Tensor, token_ids: torch.LongTensor): flags = self.special_flags[safe_ids] return token_embeds + self.special_emb(flags) + class SubwordEmbedding(nn.Module): """ Produces subword embeddings from a Hugging Face tokenizer vocabulary. No special handling for OOVs or padding — assumes token_ids are valid. """ + def __init__(self, model_name: str, d_model: int): super().__init__() self.tokenizer = AutoTokenizer.from_pretrained(model_name) @@ -1022,7 +1022,7 @@ def __init__(self, model_name: str, d_model: int): self.d_model = d_model # Subword embedding table - init_std = d_model ** -0.5 + init_std = d_model**-0.5 self.subword_emb = nn.Embedding(self.vocab_size, d_model) nn.init.normal_(self.subword_emb.weight, mean=0.0, std=init_std) @@ -1073,7 +1073,8 @@ def __init__( # 1. Build or load the character vocabulary self.subword_id_to_char_ids, self.char_vocab, self.subword_padding_idx = build_vocabs( - pretrained_tokenizer_name, vocab_dir, + pretrained_tokenizer_name, + vocab_dir, ) self.char_padding_idx = len(self.char_vocab) @@ -1174,7 +1175,7 @@ def forward(self, subword_ids: Tensor, subword_mask: Tensor | None = None) -> Te subword_embeds[subword_mask] = out_emb if self.use_subword_flag_emb: - subword_embeds = self.subword_flag_emb(subword_embeds, subword_ids) + subword_embeds = self.subword_flag_emb(subword_embeds, subword_ids) if self.use_bos_eos_emb: subword_embeds = self.bos_eos_emb(subword_embeds, subword_ids) @@ -1183,13 +1184,12 @@ def forward(self, subword_ids: Tensor, subword_mask: Tensor | None = None) -> Te class GatedProjectedSumRMSNorm(nn.Module): - def __init__(self, audio_dim, text_dim, hidden_dim, - final_norm=True, num_codebooks=31, init_residual_scale=0.5): + def __init__(self, audio_dim, text_dim, hidden_dim, final_norm=True, num_codebooks=31, init_residual_scale=0.5): super().__init__() self.num_codebooks = num_codebooks self.audio_proj = nn.Linear(audio_dim, hidden_dim) - self.text_proj = nn.Linear(text_dim, hidden_dim) + self.text_proj = nn.Linear(text_dim, hidden_dim) nn.init.normal_(self.audio_proj.weight, mean=0.0, std=0.015) nn.init.zeros_(self.audio_proj.bias) @@ -1207,13 +1207,13 @@ def forward(self, audio_emb, text_emb): # projections run in model dtype (BF16) audio_h = self.audio_proj(audio_emb) - text_h = self.text_proj(text_emb) + text_h = self.text_proj(text_emb) dtype = audio_h.dtype with fp32_precision(): - gate = torch.sigmoid(self.gate) # FP32 - res = torch.sigmoid(self.residual_scale) # FP32 + gate = torch.sigmoid(self.gate) # FP32 + res = torch.sigmoid(self.residual_scale) # FP32 h = gate.to(dtype) * audio_h + (1 - gate).to(dtype) * text_h h = res.to(dtype) * h @@ -1221,6 +1221,7 @@ def forward(self, audio_emb, text_emb): return h + class RVQEARTTSModel(PreTrainedModel): """ The main RVQEARTTS model, which can be used for both training and inference. @@ -1243,6 +1244,7 @@ def __init__(self, config: RVQEARTTSConfig | dict[str, Any]): if self.config.get("pretrained_text_name", None): # Load pretrained backbone from huggingface from nemo.collections.speechlm2.parts.pretrained import load_pretrained_hf + llm = load_pretrained_hf(self.config.pretrained_text_name, pretrained_weights=True).train() self.backbone = llm.model # fetch PretrainedBaseModel from model "ForCausalLM" else: @@ -1277,13 +1279,20 @@ def __init__(self, config: RVQEARTTSConfig | dict[str, Any]): ) self.embed_subword = ( - CharAwareSubwordEncoder(out_size=self.hidden_size, use_subword_flag_emb=self.config.use_subword_flag_emb, use_bos_eos_emb=self.config.use_bos_eos_emb, **self.config.cas_config) + CharAwareSubwordEncoder( + out_size=self.hidden_size, + use_subword_flag_emb=self.config.use_subword_flag_emb, + use_bos_eos_emb=self.config.use_bos_eos_emb, + **self.config.cas_config, + ) if self.config.cas_config else None ) if self.config.use_gated_fusion_for_text_audio: - self.gated_fusion_audio_text = GatedProjectedSumRMSNorm(self.hidden_size, self.hidden_size, self.hidden_size, self.config.num_quantizers) + self.gated_fusion_audio_text = GatedProjectedSumRMSNorm( + self.hidden_size, self.hidden_size, self.hidden_size, self.config.num_quantizers + ) # Prediction Heads if not self.config.disable_eos_prediction: @@ -1531,7 +1540,13 @@ def forward( uncond_dec_flag = torch.cat([uncond_dec_flag, torch.ones_like(uncond_dec_flag)], 0) # Prepare conditioning - cond = self._prepare_conditioning(context_hidden_state, subword_ids, subword_mask, uncond_dec_flag, asr_speech_tokens_emb=asr_speech_tokens_emb) + cond = self._prepare_conditioning( + context_hidden_state, + subword_ids, + subword_mask, + uncond_dec_flag, + asr_speech_tokens_emb=asr_speech_tokens_emb, + ) if self.config.use_gated_fusion_for_text_audio: inputs_embeds = self.gated_fusion_audio_text(code_embeds, cond) @@ -1566,7 +1581,9 @@ def forward( ) total_loss = lm_loss + c_loss + k_loss - return RVQEARTTSOutput(loss=total_loss, lm_loss=lm_loss, c_loss=c_loss, k_loss=k_loss, hidden_states=hidden_states) + return RVQEARTTSOutput( + loss=total_loss, lm_loss=lm_loss, c_loss=c_loss, k_loss=k_loss, hidden_states=hidden_states + ) else: # Inference if not generation_config: return RVQEARTTSOutput( @@ -1575,9 +1592,13 @@ def forward( ) else: if teacher_forcing_inference: - generated_codes, lm_logits, eos_flag = self.generate_teacher_forcing(hidden_states, generation_config) + generated_codes, lm_logits, eos_flag = self.generate_teacher_forcing( + hidden_states, generation_config + ) else: - generated_codes, lm_logits, eos_flag = self.generate_step(hidden_states, ignore_eos_flag_stop=ignore_eos_flag_stop, **generation_config) + generated_codes, lm_logits, eos_flag = self.generate_step( + hidden_states, ignore_eos_flag_stop=ignore_eos_flag_stop, **generation_config + ) return RVQEARTTSOutput( past_key_values=backbone_outputs.past_key_values, codes=generated_codes, @@ -1591,11 +1612,11 @@ def generate_teacher_forcing(self, hidden_states: Tensor, generation_config: dic """ Teacher-forcing wrapper for generate_step, processing all frames in parallel using a per-frame loop internally. - + Args: hidden_states: [B, T, H] hidden states generation_config: kwargs for self.generate_step() - + Returns: generated_codes: [B, T, ...] generated codes per frame lm_logits: [B, T, vocab_size] language model logits @@ -1615,8 +1636,7 @@ def generate_teacher_forcing(self, hidden_states: Tensor, generation_config: dic # call original generate_step generated_codes, lm_logits, eos_flag = self.generate_step( - frame_hidden.unsqueeze(1), # keep batch dim + frame dim - **generation_config + frame_hidden.unsqueeze(1), **generation_config # keep batch dim + frame dim ) if generated_codes is not None: # store in cache @@ -1627,8 +1647,8 @@ def generate_teacher_forcing(self, hidden_states: Tensor, generation_config: dic # Stack results along time dimension generated_codes = torch.stack(generated_codes_cache, dim=1) # [B, T, ...] if not self.config.disable_eos_prediction: - lm_logits = torch.stack(lm_logits_cache, dim=1) # [B, T, vocab_size] - eos_flag = torch.stack(eos_flag_cache, dim=1) # [B, T] + lm_logits = torch.stack(lm_logits_cache, dim=1) # [B, T, vocab_size] + eos_flag = torch.stack(eos_flag_cache, dim=1) # [B, T] else: lm_logits = None eos_flag = None @@ -1765,7 +1785,7 @@ def generate_step( code = depthsum_encoding_step(self.rvq_embs, z, code, cnt, k[0].item()) cnt += k[0].item() return code, lm_logits, eos_flag - + def load_state_dict(self, state_dict, strict: bool = True): try: super().load_state_dict(state_dict, strict=strict) diff --git a/nemo/collections/speechlm2/modules/rvq_ear_tts_vae.py b/nemo/collections/speechlm2/modules/rvq_ear_tts_vae.py index a9d67ec5ba8e..b3b2fe0cced7 100644 --- a/nemo/collections/speechlm2/modules/rvq_ear_tts_vae.py +++ b/nemo/collections/speechlm2/modules/rvq_ear_tts_vae.py @@ -2,6 +2,7 @@ import functools import math from collections.abc import Callable +from contextlib import contextmanager from dataclasses import dataclass, field from typing import Any, Concatenate, Literal, overload @@ -12,12 +13,8 @@ from torchaudio import functional as ta_F # Project -from nemo.collections.speechlm2.modules.ear_tts_commons import ( - Config, - PreTrainedModel -) +from nemo.collections.speechlm2.modules.ear_tts_commons import Config, PreTrainedModel -from contextlib import contextmanager @contextmanager def disable_tf32(): @@ -28,6 +25,7 @@ def disable_tf32(): finally: torch.backends.cudnn.allow_tf32 = prev + @dataclass class RVQVAEConfig(Config): model_type: str = "rvqvae" @@ -893,6 +891,7 @@ def forward(self, x: Tensor, cache=None, flush: bool = False, constrain_value_ra return x + class RVQVAEModel(PreTrainedModel): """ Residual Vector-Quantized Variational Autoencoder (RVQ-VAE) model. @@ -1076,7 +1075,7 @@ def forward(self, x: Tensor, constrain_value_range: bool = False) -> Tensor: Returns: Tensor: The reconstructed waveform. Shape: `[batch, 1, time]`. """ - + with torch.no_grad(): z_e = self.ae_encode(x) code = self.quantize(z_e) diff --git a/nemo/collections/speechlm2/parts/metrics/__init__.py b/nemo/collections/speechlm2/parts/metrics/__init__.py index 5952c1fb0b30..ca40c4cff5cd 100644 --- a/nemo/collections/speechlm2/parts/metrics/__init__.py +++ b/nemo/collections/speechlm2/parts/metrics/__init__.py @@ -13,8 +13,9 @@ # limitations under the License. from .asr_bleu import ASRBLEU from .bleu import BLEU -from .token_accuracy import TokenAccuracy from .results_logger import ResultsLogger +from .token_accuracy import TokenAccuracy + __all__ = [ 'ASRBLEU', 'BLEU', diff --git a/nemo/collections/speechlm2/parts/metrics/intelligibility.py b/nemo/collections/speechlm2/parts/metrics/intelligibility.py index 94e78b6f97c3..2598bb40a7be 100644 --- a/nemo/collections/speechlm2/parts/metrics/intelligibility.py +++ b/nemo/collections/speechlm2/parts/metrics/intelligibility.py @@ -16,14 +16,13 @@ import torch from whisper_normalizer.english import EnglishTextNormalizer +from nemo.collections.asr.metrics.wer import word_error_rate from nemo.collections.asr.models import ASRModel from nemo.collections.common.parts.optional_cuda_graphs import WithOptionalCudaGraphs from nemo.collections.speechlm2.parts.precision import fp32_precision from nemo.collections.speechlm2.parts.pretrained import load_pretrained_nemo from nemo.utils import logging -from nemo.collections.asr.metrics.wer import word_error_rate - class Intelligibility: """ @@ -31,7 +30,14 @@ class Intelligibility: By default, uses Whisper's EnglishTextNormalizer on hypotheses and references. """ - def __init__(self, pretrained_asr: str, normalize: bool = True, normalizer=None, verbose: bool = True, reuse_asr_hyps: bool = False) -> None: + def __init__( + self, + pretrained_asr: str, + normalize: bool = True, + normalizer=None, + verbose: bool = True, + reuse_asr_hyps: bool = False, + ) -> None: self.asr = None # load into memory on reset() self.pretrained_asr_name = pretrained_asr self.verbose = verbose @@ -59,7 +65,12 @@ def reset(self): return self def update( - self, name: str, refs: list[str], pred_audio: torch.Tensor, pred_audio_lens: torch.Tensor = None, asr_hyps: list[str] = None + self, + name: str, + refs: list[str], + pred_audio: torch.Tensor, + pred_audio_lens: torch.Tensor = None, + asr_hyps: list[str] = None, ) -> list[str]: if self.asr is None and not self.reuse_asr_hyps: self.reset() @@ -96,7 +107,7 @@ def compute(self) -> dict[str, torch.Tensor]: corpus_metric[f"wer_{name}"] = wer corpus_metric["cer"].append(cer) corpus_metric["wer"].append(wer) - + corpus_metric["cer"] = torch.stack(corpus_metric["cer"]).mean() corpus_metric["wer"] = torch.stack(corpus_metric["wer"]).mean() self._refs.clear() diff --git a/nemo/collections/speechlm2/parts/metrics/results_logger.py b/nemo/collections/speechlm2/parts/metrics/results_logger.py index bef245555eb3..68b3a9bb780d 100644 --- a/nemo/collections/speechlm2/parts/metrics/results_logger.py +++ b/nemo/collections/speechlm2/parts/metrics/results_logger.py @@ -84,7 +84,7 @@ def merge_and_save_audio( ], dim=0, ).squeeze() - + else: combined_wav = pred_audio.unsqueeze(0).detach().cpu() @@ -119,15 +119,33 @@ def update( # save audio sample_id = samples_id[i][:150] # make sure that sample id is not too big out_audio_path = os.path.join(self.audio_save_path, f"{name}_{sample_id}.wav") - self.merge_and_save_audio(out_audio_path, pred_audio[i], pred_audio_sr, user_audio[i] if user_audio is not None else None, user_audio_sr) + self.merge_and_save_audio( + out_audio_path, + pred_audio[i], + pred_audio_sr, + user_audio[i] if user_audio is not None else None, + user_audio_sr, + ) if pred_audio_tf is not None: out_audio_path_tf = out_audio_path.replace(".wav", "_tf.wav") - self.merge_and_save_audio(out_audio_path_tf, pred_audio_tf[i], pred_audio_sr, user_audio[i] if user_audio is not None else None, user_audio_sr) + self.merge_and_save_audio( + out_audio_path_tf, + pred_audio_tf[i], + pred_audio_sr, + user_audio[i] if user_audio is not None else None, + user_audio_sr, + ) if target_audio is not None: out_audio_path_gt = out_audio_path.replace(".wav", "_GT.wav") - self.merge_and_save_audio(out_audio_path_gt, target_audio[i], pred_audio_sr, user_audio[i] if user_audio is not None else None, user_audio_sr) + self.merge_and_save_audio( + out_audio_path_gt, + target_audio[i], + pred_audio_sr, + user_audio[i] if user_audio is not None else None, + user_audio_sr, + ) # create a wav with eou prediction for debug purposes if eou_pred is not None: @@ -142,11 +160,15 @@ def update( if pre_audio_trimmed is not None: out_audio_path_trimmed = os.path.join(self.audio_save_path, f"{name}_{sample_id}_pred_trimmed.wav") - torchaudio.save(out_audio_path_trimmed, pre_audio_trimmed[i].squeeze().unsqueeze(0).detach().cpu(), pred_audio_sr) + torchaudio.save( + out_audio_path_trimmed, pre_audio_trimmed[i].squeeze().unsqueeze(0).detach().cpu(), pred_audio_sr + ) if reference_audio is not None: out_audio_path_ref = os.path.join(self.audio_save_path, f"{name}_{sample_id}_spk_reference.wav") - torchaudio.save(out_audio_path_ref, reference_audio[i].squeeze().unsqueeze(0).detach().cpu(), pred_audio_sr) + torchaudio.save( + out_audio_path_ref, reference_audio[i].squeeze().unsqueeze(0).detach().cpu(), pred_audio_sr + ) # cache metadata out_dict = { diff --git a/nemo/collections/speechlm2/parts/metrics/secs.py b/nemo/collections/speechlm2/parts/metrics/secs.py index a1a611b16fcd..fe300092d245 100644 --- a/nemo/collections/speechlm2/parts/metrics/secs.py +++ b/nemo/collections/speechlm2/parts/metrics/secs.py @@ -17,8 +17,7 @@ import torch from whisper_normalizer.english import EnglishTextNormalizer -from nemo.collections.asr.models import ASRModel -from nemo.collections.asr.models import EncDecSpeakerLabelModel +from nemo.collections.asr.models import ASRModel, EncDecSpeakerLabelModel from nemo.collections.common.parts.optional_cuda_graphs import WithOptionalCudaGraphs from nemo.collections.speechlm2.parts.precision import fp32_precision from nemo.collections.speechlm2.parts.pretrained import load_pretrained_nemo @@ -46,7 +45,12 @@ def reset(self): return self def update( - self, name: str, target_audio: torch.Tensor, target_audio_lens: torch.Tensor, pred_audio: torch.Tensor, pred_audio_lens: torch.Tensor + self, + name: str, + target_audio: torch.Tensor, + target_audio_lens: torch.Tensor, + pred_audio: torch.Tensor, + pred_audio_lens: torch.Tensor, ) -> None: if self.speaker_encoder is None: self.reset() @@ -57,7 +61,6 @@ def update( _, s_g = self.speaker_encoder(input_signal=pred_audio, input_signal_length=pred_audio_lens.long()) secs = torch.nn.functional.cosine_similarity(t_g, s_g, dim=-1).mean() - self._secs[name].append(secs) def compute(self) -> dict[str, torch.Tensor]: @@ -74,4 +77,3 @@ def compute(self) -> dict[str, torch.Tensor]: self.speaker_encoder = None # free up GPU memory torch.cuda.memory.empty_cache() return corpus_metric - diff --git a/nemo/collections/speechlm2/parts/metrics/token_accuracy.py b/nemo/collections/speechlm2/parts/metrics/token_accuracy.py index 231015d145d8..05b9ad192611 100644 --- a/nemo/collections/speechlm2/parts/metrics/token_accuracy.py +++ b/nemo/collections/speechlm2/parts/metrics/token_accuracy.py @@ -19,13 +19,13 @@ def compute_token_accuracy_with_tolerance(target, pred, token, tolerance=1): """ Computes the accuracy of `token` in `target` vs `pred` within a ±`tolerance` window. - + Args: target (torch.Tensor): Batch of target sequences (batch_size, seq_len) pred (torch.Tensor): Batch of predicted sequences (batch_size, seq_len) token (int): The token to compute accuracy for tolerance (int): Allowed index difference (window) for correct predictions - + Returns: float: Accuracy as correct / total occurrences of `token` in target """ @@ -85,4 +85,3 @@ def compute(self) -> dict[str, torch.Tensor]: corpus_metric[f"token_acc_{self.token_name}_{name}"] = metric self.scores.clear() return corpus_metric - diff --git a/nemo/collections/speechlm2/parts/pretrained.py b/nemo/collections/speechlm2/parts/pretrained.py index b56a2b455754..b8042718bd50 100644 --- a/nemo/collections/speechlm2/parts/pretrained.py +++ b/nemo/collections/speechlm2/parts/pretrained.py @@ -26,6 +26,7 @@ from nemo.collections.tts.models import AudioCodecModel from nemo.utils import logging + def load_pretrained_nemo(cls, model_path_or_name: str): """ Load pretrained NeMo 1.0 model (inheriting from ModelPT). Works with ASR, TTS, codec models. @@ -100,15 +101,16 @@ def setup_speech_encoder(model: torch.nn.Module, pretrained_weights: bool = True else: model.perception = AudioPerceptionModule(model.cfg.perception).train() + def set_model_dict_for_partial_init(pretrained_dict, model_dict): # 1. filter out different size layers for k, v in list(pretrained_dict.items()): if k in model_dict and hasattr(model_dict[k], "numel") and v.numel() != model_dict[k].numel(): del pretrained_dict[k] - logging.info(" | > Layer with shape mismatach in the model definition: {}".format(k)) + logging.info(" | > Layer with shape mismatach in the model definition: {}".format(k)) # 2. filter out unnecessary keys pretrained_dict = {k: v for k, v in pretrained_dict.items() if k in model_dict} # 3. overwrite entries in the existing state dict model_dict.update(pretrained_dict) logging.info(" | > {} / {} layers are restored.".format(len(pretrained_dict), len(model_dict))) - return model_dict \ No newline at end of file + return model_dict From 326afc3cf190a70a8ea919d1978e2a25db4b1adc Mon Sep 17 00:00:00 2001 From: Edresson Casanova Date: Wed, 12 Nov 2025 07:04:20 -0800 Subject: [PATCH 004/102] Add set_init_inputs and get_init_input methods Signed-off-by: Edresson Casanova --- nemo/collections/common/data/lhotse/cutset.py | 11 +- .../speechlm2/models/duplex_ear_tts.py | 145 +++++++----------- 2 files changed, 64 insertions(+), 92 deletions(-) diff --git a/nemo/collections/common/data/lhotse/cutset.py b/nemo/collections/common/data/lhotse/cutset.py index e1a207568a5f..18e37e0d1592 100644 --- a/nemo/collections/common/data/lhotse/cutset.py +++ b/nemo/collections/common/data/lhotse/cutset.py @@ -14,6 +14,10 @@ import logging import re +import io +import random +import numpy as np +import soundfile as sf import warnings from functools import partial from itertools import repeat @@ -25,6 +29,7 @@ from lhotse.array import Array, TemporalArray from lhotse.cut import Cut, MixedCut, PaddingCut from lhotse.serialization import load_yaml +from lhotse.utils import fastcopy from omegaconf import DictConfig, ListConfig, OmegaConf from nemo.collections.common.data.lhotse.nemo_adapters import ( @@ -44,7 +49,6 @@ ) from nemo.collections.common.parts.preprocessing.manifest import get_full_path - def read_cutset_from_config(config: Union[DictConfig, dict]) -> Tuple[CutSet, bool]: """ Reads NeMo configuration and creates a CutSet either from Lhotse or NeMo manifests. @@ -595,6 +599,11 @@ def convert_overlap_cut(cut): @data_type_parser(["lhotse_magpietts_data_as_continuation"]) def read_lhotse_magpietts_data_as_continuation(config) -> tuple[CutSet, bool]: + def create_recording_from_array(samples: np.ndarray, sampling_rate: int, recording_id: str) -> Recording: + with io.BytesIO() as buffer: + sf.write(buffer, samples.T, samplerate=sampling_rate, format='WAV') + buffer.seek(0) + return Recording.from_bytes(buffer.read(), recording_id=recording_id) def convert_lhotse_magpietts_data_as_cont(cut): # create a copy of agent supervision and original duration orig_agent_sup = fastcopy(cut.supervisions[0]) diff --git a/nemo/collections/speechlm2/models/duplex_ear_tts.py b/nemo/collections/speechlm2/models/duplex_ear_tts.py index d39f73fb3f45..352f8fd2198d 100644 --- a/nemo/collections/speechlm2/models/duplex_ear_tts.py +++ b/nemo/collections/speechlm2/models/duplex_ear_tts.py @@ -42,7 +42,6 @@ ) from transformers import AutoModelForCausalLM, DynamicCache -from nemo.collections.asr.models import EncDecSpeakerLabelModel from nemo.collections.audio.parts.utils.resampling import resample from nemo.collections.common.parts.nlp_overrides import NLPSaveRestoreConnector from nemo.collections.common.tokenizers import AutoTokenizer @@ -576,17 +575,6 @@ def _load_language_model(self, cfg): language_model = None return language_model - def setup_speaker_encoder(self): - with fp32_precision(): - self.speaker_encoder = EncDecSpeakerLabelModel.from_pretrained(model_name=self.speaker_encoder_model_name) - - # freeze the pretrained speaker encoder - self.speaker_encoder.eval() - self.speaker_encoder.freeze() - - for p in self.speaker_encoder.parameters(): - p.requires_grad = False - def init_model_from_another_checkpoint(self, checkpoint_path): if checkpoint_path is not None: if '.nemo' in checkpoint_path: @@ -1009,11 +997,17 @@ def offline_inference_with_custom_sentences( ] ) - audio, audio_len = self.offline_inference( + # set init inputs and get it + self.set_init_inputs( speaker_audio=speaker_audio, speaker_audio_lens=speaker_audio_lens, + ) + init_inputs = self.get_init_inputs(B=inputs["subword_ids"].size(0)) + + audio, audio_len = self.offline_inference( next_subword_ids=next_subword_ids, guidance_enabled=self.cfg.get("inference_guidance_enabled", True), + init_inputs=init_inputs, ) return audio, audio_len, speaker_audio, speaker_audio_lens @@ -1021,8 +1015,6 @@ def run_evaluation_one_batch(self, name, dataset_batch, use_dataloader_init=Fals results = {} inputs = self.prepare_inputs(dataset_batch) - # - # exit() # first evaluation, make the model bf16 safe if ( not self.model_16_precision_safe @@ -1057,6 +1049,13 @@ def run_evaluation_one_batch(self, name, dataset_batch, use_dataloader_init=Fals init_inputs[key] = torch.stack( [init_inputs[key][i, :l] for i, l in enumerate(dataset_batch["desc_plus_audio_prompt_lens"])] ) + else: + # set init inputs and get it + self.set_init_inputs( + speaker_audio=dataset_batch["speaker_reference_audio"], + speaker_audio_lens=dataset_batch["speaker_reference_audio_lens"], + ) + init_inputs = self.get_init_inputs(B=inputs["subword_ids"].size(0)) # remove the prompt from the input_text_tokens to emulate S2S connected inference next_subword_ids = torch.stack( @@ -1066,23 +1065,10 @@ def run_evaluation_one_batch(self, name, dataset_batch, use_dataloader_init=Fals ] ) - if self.cfg.get("use_asr_speech_tokens", False) and self.cfg.get("only_semantic_to_speech", False): - inp_asr_speech_tokens = torch.stack( - [ - inputs["target_asr_speech_tokens"][i, l:] # slice each element - for i, l in enumerate(dataset_batch["desc_plus_audio_prompt_lens"]) - ] - ) - else: - inp_asr_speech_tokens = None - results["audio"], results["audio_len"] = self.offline_inference( - speaker_audio=dataset_batch["speaker_reference_audio"], - speaker_audio_lens=dataset_batch["speaker_reference_audio_lens"], next_subword_ids=next_subword_ids, formatter=dataset_batch["formatter"][0], - inp_asr_speech_tokens=inp_asr_speech_tokens, - init_inputs=init_inputs if use_dataloader_init else None, + init_inputs=init_inputs, ) # remove prompt padding from the user audio as autoregressive inference does not return the prompt @@ -1348,22 +1334,9 @@ def get_system_prompt(self, system_prompt=None, user_prompt=None): input_ids = torch.tensor(input_ids, dtype=torch.long, device=self.device).view(1, -1) return input_ids - def get_init_inputs(self, speaker_audio, speaker_audio_lens, system_prompt=None, user_prompt=None): + def set_init_inputs(self, speaker_audio, speaker_audio_lens, system_prompt=None, user_prompt=None): # compute prompt audio size and slice it with fp32_precision(): - """ - # old pad that can add long silences in the end - prompt_audio_size = int(((self.data_cfg.audio_prompt_duration * self.target_sample_rate) // self.target_samples_per_frame) * self.target_samples_per_frame) - B, T = speaker_audio.shape # [batch, time] - if T >= prompt_audio_size: - # Just crop if longer - prompt_audio = speaker_audio[:, :prompt_audio_size] - else: - # Repeat along time until we have enough, then crop - repeat_factor = (prompt_audio_size + T - 1) // T # ceil division - expanded = speaker_audio.repeat(1, repeat_factor) - prompt_audio = expanded[:, :prompt_audio_size] - """ # compute the exact number of samples for the prompt duration prompt_audio_size = int( ((self.data_cfg.audio_prompt_duration * self.target_sample_rate) // self.target_samples_per_frame) @@ -1491,22 +1464,55 @@ def get_init_inputs(self, speaker_audio, speaker_audio_lens, system_prompt=None, "subword_mask": subword_mask.bool()[:, :-1], "non_prompt_mask": non_prompt_mask.bool()[:, :-1], } + # register to acess later + for k, v in init_inputs.items(): + name = f"init_input_{k}" + if v is not None: + self.register_buffer(name, v) return init_inputs + def get_init_inputs( + self, + B: int, + init_inputs_names=["code", "audio_mask", "context_hidden_state", "subword_ids", "subword_mask", "non_prompt_mask"], + ): + if init_inputs_names is None: + init_inputs_names = [ + "code", + "audio_mask", + "context_hidden_state", + "subword_ids", + "subword_mask", + "non_prompt_mask", + ] + + init_inputs = {} + for name in init_inputs_names: + buf_name = f"init_input_{name}" + buf = getattr(self, buf_name, None) + + if buf is None: + init_inputs[name] = None + continue + + # Use as-is if batch matches + if buf.shape[0] == B: + init_inputs[name] = buf + else: + # Otherwise, assume batch=1 and expand to target B + init_inputs[name] = buf[:1].expand(B, *buf.shape[1:]) + + return init_inputs + @torch.no_grad() def offline_inference( self, next_subword_ids: torch.Tensor, - speaker_audio: torch.Tensor, - speaker_audio_lens: torch.Tensor, formatter: str = "", - system_prompt: str = None, - user_prompt: str = None, guidance_enabled: bool = True, generation_config: dict = None, init_inputs: dict = None, - inp_asr_speech_tokens: torch.Tensor = None, ) -> dict[str, torch.Tensor]: """ Autoregressive prediction. @@ -1527,21 +1533,6 @@ def offline_inference( """ B = next_subword_ids.size(0) - # init_inputs, code, past_key_values = self.init_model_for_ar_inference(speaker_audio=speaker_audio, speaker_audio_lens=speaker_audio_lens, system_prompt=system_prompt, user_prompt=user_prompt, guidance_enabled=guidance_enabled, generation_config=generation_config) - - # ToDo: verify why codes differ from dataloader init_inputs when using nanocodec - if init_inputs is None: - init_inputs = self.get_init_inputs( - speaker_audio, speaker_audio_lens, system_prompt=system_prompt, user_prompt=user_prompt - ) - # compare_dicts(init_inputs_fn, init_inputs) - - if self.cfg.get("use_asr_speech_tokens", False) and self.cfg.get("only_semantic_to_speech", False): - # set mask to zero and subword ids to self.text_pad_id as in training - init_inputs["subword_mask"] = torch.full_like(init_inputs["subword_mask"], 0.0) - init_inputs["subword_ids"] = torch.full_like(init_inputs["subword_ids"], self.text_pad_id) - next_subword_ids = torch.full_like(next_subword_ids, self.text_pad_id) - if generation_config is None: generation_config = self._get_generation_config(guidance_enabled) logging.info(f"Doing inference using the following config: {generation_config} !") @@ -1559,22 +1550,6 @@ def offline_inference( past_key_values = outputs["past_key_values"] - # get current asr speech token - if self.cfg.get("use_asr_speech_tokens", False): - if self.cfg.get("only_semantic_to_speech", False): - cur_asr_speech_tokens = inp_asr_speech_tokens[:, 0].unsqueeze(-1) - else: - if guidance_enabled and self.cfg.get("asr_speech_tokens_use_guidance", True): - hidden_states, uncond_hidden_states = outputs.hidden_states.chunk(2, dim=0) - logits = self.asr_speech_tokens_head( - hidden_states + (generation_config["guidance_scale"] * (hidden_states - uncond_hidden_states)) - ) - else: - hidden_states, _ = outputs.hidden_states.chunk(2, dim=0) - logits = self.asr_speech_tokens_head(hidden_states) - - cur_asr_speech_tokens = logits.argmax(dim=-1)[:, -1].unsqueeze(-1) - # use the text tokens to stop generation max_steps = next_subword_ids.size(-1) # create variable to store the audios @@ -1627,18 +1602,6 @@ def offline_inference( # ToDo: check why it is -1 gen_audio_codes[:, i - 1] = code.squeeze(1) - if self.cfg.get("use_asr_speech_tokens", False) and not self.cfg.get("only_semantic_to_speech", False): - if guidance_enabled and self.cfg.get("asr_speech_tokens_use_guidance", True): - hidden_states, uncond_hidden_states = outputs.hidden_states.chunk(2, dim=0) - logits = self.asr_speech_tokens_head( - hidden_states + (generation_config["guidance_scale"] * (hidden_states - uncond_hidden_states)) - ) - else: - hidden_states, _ = outputs.hidden_states.chunk(2, dim=0) - logits = self.asr_speech_tokens_head(hidden_states) - - cur_asr_speech_tokens = logits.argmax(dim=-1)[:, -1].unsqueeze(-1) - # force silence as next token if self.cfg.get('inference_force_speech_silence_on_eos', True): silence_codes = self.codec_silence_tokens.view(1, 1, -1).expand(code.shape) From 34e0413c14bdce9b38cc921fd11e2b7fd3a99f99 Mon Sep 17 00:00:00 2001 From: Edresson Date: Wed, 12 Nov 2025 15:05:05 +0000 Subject: [PATCH 005/102] Apply isort and black reformatting Signed-off-by: Edresson Signed-off-by: Edresson Casanova --- nemo/collections/common/data/lhotse/cutset.py | 10 ++++++---- nemo/collections/speechlm2/models/duplex_ear_tts.py | 11 +++++++++-- 2 files changed, 15 insertions(+), 6 deletions(-) diff --git a/nemo/collections/common/data/lhotse/cutset.py b/nemo/collections/common/data/lhotse/cutset.py index 18e37e0d1592..b79e86e604b8 100644 --- a/nemo/collections/common/data/lhotse/cutset.py +++ b/nemo/collections/common/data/lhotse/cutset.py @@ -12,19 +12,19 @@ # See the License for the specific language governing permissions and # limitations under the License. -import logging -import re import io +import logging import random -import numpy as np -import soundfile as sf +import re import warnings from functools import partial from itertools import repeat from pathlib import Path from typing import KeysView, Mapping, Sequence, Tuple, Union +import numpy as np import omegaconf +import soundfile as sf from lhotse import CutSet, Features, MonoCut, Recording, SupervisionSegment from lhotse.array import Array, TemporalArray from lhotse.cut import Cut, MixedCut, PaddingCut @@ -49,6 +49,7 @@ ) from nemo.collections.common.parts.preprocessing.manifest import get_full_path + def read_cutset_from_config(config: Union[DictConfig, dict]) -> Tuple[CutSet, bool]: """ Reads NeMo configuration and creates a CutSet either from Lhotse or NeMo manifests. @@ -604,6 +605,7 @@ def create_recording_from_array(samples: np.ndarray, sampling_rate: int, recordi sf.write(buffer, samples.T, samplerate=sampling_rate, format='WAV') buffer.seek(0) return Recording.from_bytes(buffer.read(), recording_id=recording_id) + def convert_lhotse_magpietts_data_as_cont(cut): # create a copy of agent supervision and original duration orig_agent_sup = fastcopy(cut.supervisions[0]) diff --git a/nemo/collections/speechlm2/models/duplex_ear_tts.py b/nemo/collections/speechlm2/models/duplex_ear_tts.py index 352f8fd2198d..f66c1a6ee8ee 100644 --- a/nemo/collections/speechlm2/models/duplex_ear_tts.py +++ b/nemo/collections/speechlm2/models/duplex_ear_tts.py @@ -1475,7 +1475,14 @@ def set_init_inputs(self, speaker_audio, speaker_audio_lens, system_prompt=None, def get_init_inputs( self, B: int, - init_inputs_names=["code", "audio_mask", "context_hidden_state", "subword_ids", "subword_mask", "non_prompt_mask"], + init_inputs_names=[ + "code", + "audio_mask", + "context_hidden_state", + "subword_ids", + "subword_mask", + "non_prompt_mask", + ], ): if init_inputs_names is None: init_inputs_names = [ @@ -1504,7 +1511,7 @@ def get_init_inputs( init_inputs[name] = buf[:1].expand(B, *buf.shape[1:]) return init_inputs - + @torch.no_grad() def offline_inference( self, From ec72055095c68dbd8c828ccb8ed44e8902394b06 Mon Sep 17 00:00:00 2001 From: Edresson Casanova Date: Thu, 13 Nov 2025 05:43:58 -0800 Subject: [PATCH 006/102] Add from config codec instanciation Signed-off-by: Edresson Casanova --- nemo/collections/common/data/lhotse/cutset.py | 10 ++++++++++ .../speechlm2/models/duplex_ear_tts.py | 17 ++++++++++++----- 2 files changed, 22 insertions(+), 5 deletions(-) diff --git a/nemo/collections/common/data/lhotse/cutset.py b/nemo/collections/common/data/lhotse/cutset.py index b79e86e604b8..b465baf64b01 100644 --- a/nemo/collections/common/data/lhotse/cutset.py +++ b/nemo/collections/common/data/lhotse/cutset.py @@ -715,6 +715,12 @@ def filter_secs(example): else: return True + def filter_target_speaker(example): + if isinstance(example, Cut) and len(example.supervisions) > 0 and TARGET_SPEAKER is not None: + return TARGET_SPEAKER in example.supervisions[0].speaker + else: + return True + # load lhotse cuts cuts, is_tarred = read_cutset_from_config(config) @@ -733,6 +739,10 @@ def filter_secs(example): # filter based on context speaker similarity MIN_SECS = config.get("min_context_speaker_similarity", 0.6) cuts = cuts.filter(filter_secs) + + # filter speaker + TARGET_SPEAKER = config.get("target_speaker", None) + cuts = cuts.filter(filter_target_speaker) # convert cuts cuts = cuts.map(convert_lhotse_magpietts_data_as_cont) diff --git a/nemo/collections/speechlm2/models/duplex_ear_tts.py b/nemo/collections/speechlm2/models/duplex_ear_tts.py index f66c1a6ee8ee..78fb04ff4e64 100644 --- a/nemo/collections/speechlm2/models/duplex_ear_tts.py +++ b/nemo/collections/speechlm2/models/duplex_ear_tts.py @@ -49,7 +49,7 @@ from nemo.collections.speechlm2.models.duplex_s2s_model import tokens_to_str from nemo.collections.speechlm2.modules.ear_tts_commons import SCRIPT_PLACEHOLDER from nemo.collections.speechlm2.modules.rvq_ear_tts_model import RVQEARTTSConfig, RVQEARTTSModel -from nemo.collections.speechlm2.modules.rvq_ear_tts_vae import RVQVAEModel +from nemo.collections.speechlm2.modules.rvq_ear_tts_vae import RVQVAEConfig, RVQVAEModel from nemo.collections.speechlm2.parts.hf_hub import HFHubMixin from nemo.collections.speechlm2.parts.lora import maybe_install_lora from nemo.collections.speechlm2.parts.metrics.asr_bleu import ASRBLEU @@ -84,7 +84,6 @@ def maybe_to(x, dtype): import torch - @contextmanager def ensures_16_precision(mixed_dtype): """ @@ -340,10 +339,18 @@ def setup_rvq_audio_codec(model): """ if hasattr(model, "audio_codec") and next(model.audio_codec.parameters()).dtype == torch.float: return # skip if already set up and has the right dtype + with fp32_precision(): - model.audio_codec = ( - RVQVAEModel.from_pretrained(model.cfg.pretrained_ae_dir, strict=False).eval().to(model.device) - ) + if model.cfg.get("pretrained_ae_dir", None): + model.audio_codec = ( + RVQVAEModel.from_pretrained(model.cfg.pretrained_ae_dir, strict=False).eval().to(model.device) + ) + else: + # init codec from config + model.audio_codec = ( + RVQVAEModel(RVQVAEConfig(**model.cfg.codec_config)) + ) + for p in model.audio_codec.parameters(): p.requires_grad = False From 3df84cc63a69d5d7a65ef3867bd7d164f6125619 Mon Sep 17 00:00:00 2001 From: Edresson Date: Thu, 13 Nov 2025 13:45:12 +0000 Subject: [PATCH 007/102] Apply isort and black reformatting Signed-off-by: Edresson Signed-off-by: Edresson Casanova --- nemo/collections/common/data/lhotse/cutset.py | 2 +- nemo/collections/speechlm2/models/duplex_ear_tts.py | 5 ++--- 2 files changed, 3 insertions(+), 4 deletions(-) diff --git a/nemo/collections/common/data/lhotse/cutset.py b/nemo/collections/common/data/lhotse/cutset.py index b465baf64b01..9b40974106b9 100644 --- a/nemo/collections/common/data/lhotse/cutset.py +++ b/nemo/collections/common/data/lhotse/cutset.py @@ -739,7 +739,7 @@ def filter_target_speaker(example): # filter based on context speaker similarity MIN_SECS = config.get("min_context_speaker_similarity", 0.6) cuts = cuts.filter(filter_secs) - + # filter speaker TARGET_SPEAKER = config.get("target_speaker", None) cuts = cuts.filter(filter_target_speaker) diff --git a/nemo/collections/speechlm2/models/duplex_ear_tts.py b/nemo/collections/speechlm2/models/duplex_ear_tts.py index 78fb04ff4e64..cb374f2fbed6 100644 --- a/nemo/collections/speechlm2/models/duplex_ear_tts.py +++ b/nemo/collections/speechlm2/models/duplex_ear_tts.py @@ -84,6 +84,7 @@ def maybe_to(x, dtype): import torch + @contextmanager def ensures_16_precision(mixed_dtype): """ @@ -347,9 +348,7 @@ def setup_rvq_audio_codec(model): ) else: # init codec from config - model.audio_codec = ( - RVQVAEModel(RVQVAEConfig(**model.cfg.codec_config)) - ) + model.audio_codec = RVQVAEModel(RVQVAEConfig(**model.cfg.codec_config)) for p in model.audio_codec.parameters(): p.requires_grad = False From 81513bf3f89f5c1353604be53abe1b191890e62d Mon Sep 17 00:00:00 2001 From: Edresson Casanova Date: Thu, 13 Nov 2025 05:50:24 -0800 Subject: [PATCH 008/102] Remove unused imports Signed-off-by: Edresson Casanova --- .../speechlm2/data/duplex_ear_tts_dataset.py | 4 +--- .../speechlm2/models/duplex_ear_tts.py | 23 ++++--------------- .../speechlm2/modules/ear_tts_commons.py | 1 - .../speechlm2/modules/rvq_ear_tts_model.py | 11 ++------- .../parts/metrics/intelligibility.py | 1 - .../speechlm2/parts/metrics/results_logger.py | 1 - .../speechlm2/parts/metrics/secs.py | 7 +----- .../speechlm2/parts/metrics/token_accuracy.py | 1 - 8 files changed, 8 insertions(+), 41 deletions(-) diff --git a/nemo/collections/speechlm2/data/duplex_ear_tts_dataset.py b/nemo/collections/speechlm2/data/duplex_ear_tts_dataset.py index 0ba55afc7a88..d2a95aad3dd9 100644 --- a/nemo/collections/speechlm2/data/duplex_ear_tts_dataset.py +++ b/nemo/collections/speechlm2/data/duplex_ear_tts_dataset.py @@ -17,8 +17,7 @@ import torch import torch.nn.functional as F import torch.utils.data -import torchaudio -from lhotse import CutSet, MonoCut, Recording, Seconds, SupervisionSegment, compute_num_frames +from lhotse import CutSet, Seconds, compute_num_frames from lhotse.cut import Cut from lhotse.dataset.collation import collate_audio, collate_vectors from lhotse.utils import ifnone @@ -410,7 +409,6 @@ def __getitem__(self, cuts: CutSet) -> dict: # create a full zero desc mask desc_mask = torch.zeros_like(non_prompt_mask) - batch_size = len(target_token_lens) max_len = max(target_token_lens) # Segment IDs per sequence (padded) diff --git a/nemo/collections/speechlm2/models/duplex_ear_tts.py b/nemo/collections/speechlm2/models/duplex_ear_tts.py index cb374f2fbed6..8b1d8641e7f7 100644 --- a/nemo/collections/speechlm2/models/duplex_ear_tts.py +++ b/nemo/collections/speechlm2/models/duplex_ear_tts.py @@ -13,23 +13,17 @@ # limitations under the License. import copy import glob -import math import os -import random import tempfile import time -from types import SimpleNamespace -import numpy as np import torch -import torch.distributed as dist import torch.nn as nn import torch.nn.functional as F import torchaudio from lightning import LightningModule -from omegaconf import DictConfig, OmegaConf +from omegaconf import DictConfig from peft import PeftModel -from torch import Tensor, nn from torch.distributed.fsdp import fully_shard from torch.distributed.tensor import Replicate, Shard from torch.distributed.tensor.parallel import ( @@ -40,33 +34,25 @@ loss_parallel, parallelize_module, ) -from transformers import AutoModelForCausalLM, DynamicCache from nemo.collections.audio.parts.utils.resampling import resample from nemo.collections.common.parts.nlp_overrides import NLPSaveRestoreConnector from nemo.collections.common.tokenizers import AutoTokenizer from nemo.collections.speechlm2.data.utils import get_pad_id -from nemo.collections.speechlm2.models.duplex_s2s_model import tokens_to_str from nemo.collections.speechlm2.modules.ear_tts_commons import SCRIPT_PLACEHOLDER from nemo.collections.speechlm2.modules.rvq_ear_tts_model import RVQEARTTSConfig, RVQEARTTSModel from nemo.collections.speechlm2.modules.rvq_ear_tts_vae import RVQVAEConfig, RVQVAEModel from nemo.collections.speechlm2.parts.hf_hub import HFHubMixin -from nemo.collections.speechlm2.parts.lora import maybe_install_lora from nemo.collections.speechlm2.parts.metrics.asr_bleu import ASRBLEU -from nemo.collections.speechlm2.parts.metrics.bleu import BLEU from nemo.collections.speechlm2.parts.metrics.intelligibility import Intelligibility from nemo.collections.speechlm2.parts.metrics.results_logger import ResultsLogger from nemo.collections.speechlm2.parts.metrics.secs import SECS -from nemo.collections.speechlm2.parts.metrics.token_accuracy import TokenAccuracy from nemo.collections.speechlm2.parts.optim_setup import configure_optimizers, is_frozen from nemo.collections.speechlm2.parts.precision import fp32_precision from nemo.collections.speechlm2.parts.pretrained import ( load_pretrained_hf, set_model_dict_for_partial_init, - setup_speech_encoder, ) -from nemo.collections.tts.modules import transformer_2501 -from nemo.core.classes.module import NeuralModule from nemo.core.neural_types import AudioSignal, LabelsType, LengthsType, NeuralType from nemo.utils import logging @@ -82,7 +68,6 @@ def maybe_to(x, dtype): from collections import Counter from contextlib import contextmanager -import torch @contextmanager @@ -1008,7 +993,7 @@ def offline_inference_with_custom_sentences( speaker_audio=speaker_audio, speaker_audio_lens=speaker_audio_lens, ) - init_inputs = self.get_init_inputs(B=inputs["subword_ids"].size(0)) + init_inputs = self.get_init_inputs(B=next_subword_ids.size(0)) audio, audio_len = self.offline_inference( next_subword_ids=next_subword_ids, @@ -1769,7 +1754,7 @@ def configure_model(self) -> None: def load_state_dict(self, state_dict, strict: bool = True): try: return super().load_state_dict(state_dict, strict=strict) - except RuntimeError as e: - logging.info(f"Error loading model state_dict !! Retrying with partial initialization!") + except RuntimeError: + logging.info("Error loading model state_dict !! Retrying with partial initialization!") model_dict = set_model_dict_for_partial_init(state_dict, self.state_dict()) return super().load_state_dict(model_dict, strict=False) diff --git a/nemo/collections/speechlm2/modules/ear_tts_commons.py b/nemo/collections/speechlm2/modules/ear_tts_commons.py index cddebcc35344..95883159edda 100644 --- a/nemo/collections/speechlm2/modules/ear_tts_commons.py +++ b/nemo/collections/speechlm2/modules/ear_tts_commons.py @@ -11,7 +11,6 @@ from collections.abc import Mapping, MutableMapping from typing import Any -import torch from safetensors import safe_open from torch import nn diff --git a/nemo/collections/speechlm2/modules/rvq_ear_tts_model.py b/nemo/collections/speechlm2/modules/rvq_ear_tts_model.py index 28ceebbc1341..473536888ffc 100644 --- a/nemo/collections/speechlm2/modules/rvq_ear_tts_model.py +++ b/nemo/collections/speechlm2/modules/rvq_ear_tts_model.py @@ -1,20 +1,13 @@ # Standard library -import argparse -import glob import json import math import os -import re -import shutil -import sys import unicodedata -from collections.abc import Mapping, MutableMapping from dataclasses import dataclass, field, fields from typing import Any import torch import transformers -from safetensors import safe_open from torch import Tensor, nn from torch.nn import functional as F from transformers import AutoConfig, AutoModel, AutoModelForTextEncoding, AutoTokenizer, Cache @@ -1789,7 +1782,7 @@ def generate_step( def load_state_dict(self, state_dict, strict: bool = True): try: super().load_state_dict(state_dict, strict=strict) - except RuntimeError as e: - logging.info(f"Error loading model state_dict !! Retrying with partial initialization!") + except RuntimeError: + logging.info("Error loading model state_dict !! Retrying with partial initialization!") model_dict = set_model_dict_for_partial_init(state_dict, self.state_dict()) super().load_state_dict(model_dict, strict=False) diff --git a/nemo/collections/speechlm2/parts/metrics/intelligibility.py b/nemo/collections/speechlm2/parts/metrics/intelligibility.py index 2598bb40a7be..8e00f71b6625 100644 --- a/nemo/collections/speechlm2/parts/metrics/intelligibility.py +++ b/nemo/collections/speechlm2/parts/metrics/intelligibility.py @@ -21,7 +21,6 @@ from nemo.collections.common.parts.optional_cuda_graphs import WithOptionalCudaGraphs from nemo.collections.speechlm2.parts.precision import fp32_precision from nemo.collections.speechlm2.parts.pretrained import load_pretrained_nemo -from nemo.utils import logging class Intelligibility: diff --git a/nemo/collections/speechlm2/parts/metrics/results_logger.py b/nemo/collections/speechlm2/parts/metrics/results_logger.py index 68b3a9bb780d..65b073742845 100644 --- a/nemo/collections/speechlm2/parts/metrics/results_logger.py +++ b/nemo/collections/speechlm2/parts/metrics/results_logger.py @@ -14,7 +14,6 @@ import json import os import shutil -from collections import defaultdict import torch import torchaudio diff --git a/nemo/collections/speechlm2/parts/metrics/secs.py b/nemo/collections/speechlm2/parts/metrics/secs.py index fe300092d245..dfb9379ae2a2 100644 --- a/nemo/collections/speechlm2/parts/metrics/secs.py +++ b/nemo/collections/speechlm2/parts/metrics/secs.py @@ -13,15 +13,10 @@ # limitations under the License. from collections import defaultdict -import sacrebleu import torch -from whisper_normalizer.english import EnglishTextNormalizer -from nemo.collections.asr.models import ASRModel, EncDecSpeakerLabelModel -from nemo.collections.common.parts.optional_cuda_graphs import WithOptionalCudaGraphs +from nemo.collections.asr.models import EncDecSpeakerLabelModel from nemo.collections.speechlm2.parts.precision import fp32_precision -from nemo.collections.speechlm2.parts.pretrained import load_pretrained_nemo -from nemo.utils import logging class SECS: diff --git a/nemo/collections/speechlm2/parts/metrics/token_accuracy.py b/nemo/collections/speechlm2/parts/metrics/token_accuracy.py index 05b9ad192611..a10300dd5d14 100644 --- a/nemo/collections/speechlm2/parts/metrics/token_accuracy.py +++ b/nemo/collections/speechlm2/parts/metrics/token_accuracy.py @@ -13,7 +13,6 @@ # limitations under the License. from collections import defaultdict import torch -from nemo.utils import logging def compute_token_accuracy_with_tolerance(target, pred, token, tolerance=1): From fc47181259ad260c9e9d66045c54a54d81530a74 Mon Sep 17 00:00:00 2001 From: Edresson Date: Thu, 13 Nov 2025 13:51:20 +0000 Subject: [PATCH 009/102] Apply isort and black reformatting Signed-off-by: Edresson Signed-off-by: Edresson Casanova --- nemo/collections/speechlm2/models/duplex_ear_tts.py | 6 +----- 1 file changed, 1 insertion(+), 5 deletions(-) diff --git a/nemo/collections/speechlm2/models/duplex_ear_tts.py b/nemo/collections/speechlm2/models/duplex_ear_tts.py index 8b1d8641e7f7..a2ba0cb63102 100644 --- a/nemo/collections/speechlm2/models/duplex_ear_tts.py +++ b/nemo/collections/speechlm2/models/duplex_ear_tts.py @@ -49,10 +49,7 @@ from nemo.collections.speechlm2.parts.metrics.secs import SECS from nemo.collections.speechlm2.parts.optim_setup import configure_optimizers, is_frozen from nemo.collections.speechlm2.parts.precision import fp32_precision -from nemo.collections.speechlm2.parts.pretrained import ( - load_pretrained_hf, - set_model_dict_for_partial_init, -) +from nemo.collections.speechlm2.parts.pretrained import load_pretrained_hf, set_model_dict_for_partial_init from nemo.core.neural_types import AudioSignal, LabelsType, LengthsType, NeuralType from nemo.utils import logging @@ -69,7 +66,6 @@ def maybe_to(x, dtype): from contextlib import contextmanager - @contextmanager def ensures_16_precision(mixed_dtype): """ From fd9893fe938ea0a5e4822897c575e58dcf271937 Mon Sep 17 00:00:00 2001 From: Edresson Casanova Date: Thu, 13 Nov 2025 06:06:49 -0800 Subject: [PATCH 010/102] Fix pylint issues Signed-off-by: Edresson Casanova --- .../speechlm2/data/duplex_ear_tts_dataset.py | 12 +- .../speechlm2/models/duplex_ear_tts.py | 29 ++-- .../speechlm2/modules/rvq_ear_tts_model.py | 137 ------------------ .../speechlm2/parts/metrics/results_logger.py | 5 +- 4 files changed, 17 insertions(+), 166 deletions(-) diff --git a/nemo/collections/speechlm2/data/duplex_ear_tts_dataset.py b/nemo/collections/speechlm2/data/duplex_ear_tts_dataset.py index d2a95aad3dd9..793af3f294e7 100644 --- a/nemo/collections/speechlm2/data/duplex_ear_tts_dataset.py +++ b/nemo/collections/speechlm2/data/duplex_ear_tts_dataset.py @@ -414,8 +414,10 @@ def __getitem__(self, cuts: CutSet) -> dict: # Segment IDs per sequence (padded) aligned_segment_ids = torch.stack( [ - torch.nn.functional.pad(torch.full((l,), i), (0, max_len - l), value=-1) # -1 for padding - for i, l in enumerate(target_token_lens) + torch.nn.functional.pad( + torch.full((seq_len,), i), (0, max_len - seq_len), value=-1 + ) # -1 for padding + for i, seq_len in enumerate(target_token_lens) ], dim=0, ) # [B, max_len] @@ -429,13 +431,13 @@ def __getitem__(self, cuts: CutSet) -> dict: aligned_attention_mask = aligned_attention_mask.unsqueeze(1) # [B, 1, max_len, max_len] - # create pos ids from the aligned lenght + # create position IDs from the aligned length aligned_position_ids = torch.stack( [ torch.nn.functional.pad( - torch.arange(l), (0, max(target_token_lens) - l), value=0 + torch.arange(seq_len), (0, max(target_token_lens) - seq_len), value=0 ) # value=0 is safe for padding - for l in target_token_lens + for seq_len in target_token_lens ], dim=0, ) diff --git a/nemo/collections/speechlm2/models/duplex_ear_tts.py b/nemo/collections/speechlm2/models/duplex_ear_tts.py index a2ba0cb63102..dde78cb3db06 100644 --- a/nemo/collections/speechlm2/models/duplex_ear_tts.py +++ b/nemo/collections/speechlm2/models/duplex_ear_tts.py @@ -53,6 +53,9 @@ from nemo.core.neural_types import AudioSignal, LabelsType, LengthsType, NeuralType from nemo.utils import logging +from collections import Counter +from contextlib import contextmanager + def maybe_to(x, dtype): if x is None: @@ -62,10 +65,6 @@ def maybe_to(x, dtype): return x -from collections import Counter -from contextlib import contextmanager - - @contextmanager def ensures_16_precision(mixed_dtype): """ @@ -126,7 +125,6 @@ def hook(_, __, out): bf16_layers, fp32_layers = [], [] all_modules = list(model_patched.named_modules()) - num_modules = len(all_modules) # flag to propagate FP32 to next safe layers propagate_fp32 = False @@ -460,12 +458,9 @@ def __init__(self, cfg: dict) -> None: self.codec_silence_tokens = self.get_codec_silence_frame() # Load tokenizer - if self.cfg.get("use_word_sep_tokenizer", False): - self.tokenizer = WordSepTokenizer(self.cfg.pretrained_lm_name, use_fast=True, trust_remote_code=True) - else: - self.tokenizer = AutoTokenizer( - self.cfg.pretrained_lm_name, use_fast=True, trust_remote_code=True - ) # Note that we are using fast tokenizer + self.tokenizer = AutoTokenizer( + self.cfg.pretrained_lm_name, use_fast=True, trust_remote_code=True + ) # Note that we are using fast tokenizer if 'Qwen2.5' in self.cfg.pretrained_lm_name: # For Qwen, '<|im_start|>' is a common choice for a BOS token. @@ -812,7 +807,6 @@ def training_step(self, batch: dict, batch_idx: int): non_prompt_mask=inputs["non_prompt_mask"], ) loss_dict = {"lm_loss": tts_output.lm_loss, "c_loss": tts_output.c_loss, "k_loss": tts_output.k_loss} - backbone_out = tts_output.hidden_states loss = sum(loss_dict.values()) num_frames = inputs["output_lens"].sum() @@ -867,11 +861,9 @@ def log_model_stats(self): # L2 norms weight_l2 = (total_w_sq**0.5) if total_w_sq > 0 else 0.0 - grad_l2 = (total_g_sq**0.5) if total_g_sq > 0 else 0.0 # RMS (global) weight_rms = ((total_w_sq / total_w_params) ** 0.5) if total_w_params > 0 else 0.0 - grad_rms = ((total_g_sq / total_g_params) ** 0.5) if total_g_params > 0 else 0.0 # Mean weight_mean = sum_w / total_w_params if total_w_params > 0 else 0.0 @@ -882,9 +874,6 @@ def log_model_stats(self): self.log("weights/max_abs", max_abs_w, on_epoch=True, sync_dist=True) self.log("weights/mean", weight_mean, on_epoch=True, sync_dist=True) - # ignore the grads stats for now - # self.log("grads/L2", grad_l2, on_epoch=True, sync_dist=True) - # self.log("grads/RMS", grad_rms, on_epoch=True, sync_dist=True) def on_validation_epoch_start(self) -> None: setup_audio_codec(self) @@ -1034,7 +1023,7 @@ def run_evaluation_one_batch(self, name, dataset_batch, use_dataloader_init=Fals for key in init_inputs: if init_inputs[key] is not None: init_inputs[key] = torch.stack( - [init_inputs[key][i, :l] for i, l in enumerate(dataset_batch["desc_plus_audio_prompt_lens"])] + [init_inputs[key][i, :plen] for i, plen in enumerate(dataset_batch["desc_plus_audio_prompt_lens"])] ) else: # set init inputs and get it @@ -1047,8 +1036,8 @@ def run_evaluation_one_batch(self, name, dataset_batch, use_dataloader_init=Fals # remove the prompt from the input_text_tokens to emulate S2S connected inference next_subword_ids = torch.stack( [ - inputs["subword_ids"][i, l:] # slice each element - for i, l in enumerate(dataset_batch["desc_plus_audio_prompt_lens"]) + inputs["subword_ids"][i, plen:] # slice each element + for i, plen in enumerate(dataset_batch["desc_plus_audio_prompt_lens"]) ] ) diff --git a/nemo/collections/speechlm2/modules/rvq_ear_tts_model.py b/nemo/collections/speechlm2/modules/rvq_ear_tts_model.py index 473536888ffc..35e160a2cf57 100644 --- a/nemo/collections/speechlm2/modules/rvq_ear_tts_model.py +++ b/nemo/collections/speechlm2/modules/rvq_ear_tts_model.py @@ -2,7 +2,6 @@ import json import math import os -import unicodedata from dataclasses import dataclass, field, fields from typing import Any @@ -17,7 +16,6 @@ from nemo.collections.speechlm2.parts.precision import fp32_precision from nemo.collections.speechlm2.parts.pretrained import set_model_dict_for_partial_init from nemo.utils import logging - # ============================================================================== # MLP module and Norm # ============================================================================== @@ -487,11 +485,8 @@ def _build_char_vocab() -> dict[str, int]: # 1. Load or build the character vocabulary if vocab_dir: from filelock import FileLock - char_vocab_file = os.path.join(vocab_dir, "char_vocab.json") - os.makedirs(vocab_dir, exist_ok=True) - with FileLock(char_vocab_file + ".lock", timeout=60): if not os.path.exists(char_vocab_file): char_vocab = _build_char_vocab() @@ -524,138 +519,6 @@ def _build_char_vocab() -> dict[str, int]: return subword_id_to_char_ids, char_vocab, subword_padding_idx -def _split_ipa_symbols(text: str) -> list[str]: - """ - Split IPA text into grapheme clusters (true phoneme symbols) - without using regex. Combines base characters with diacritics. - """ - phonemes = [] - cluster = "" - for char in text: - if unicodedata.combining(char) == 0: - # Start a new cluster - if cluster: - phonemes.append(cluster) - cluster = char - else: - # Diacritic, append to current cluster - cluster += char - if cluster: - phonemes.append(cluster) - return phonemes - - -def build_phoneme_vocabs( - pretrained_tokenizer_name: str, - vocab_dir: str | None = None, - language: str = "en-us", -) -> tuple[dict[int, tuple[int, ...]], dict[str, int], int]: - """ - Build or load a phoneme-level vocabulary derived from a subword tokenizer, - using phonemizer with espeak-ng backend and IPA transcription. - - Args: - pretrained_tokenizer_name (str): Hugging Face tokenizer name or path. - vocab_dir (str | None, optional): Directory for saving/loading vocab. - language (str, optional): Language code for phonemizer (default: "en-us"). - - Returns: - tuple: - - subword_id_to_phoneme_ids: dict[int, tuple[int, ...]] - - phoneme_vocab: dict[str, int] - - subword_padding_idx: int - """ - from phonemizer import phonemize - - tokenizer = AutoTokenizer.from_pretrained(pretrained_tokenizer_name) - - def _phonemize_all_subwords() -> dict[str, list[str]]: - """Phonemize all subwords once and return mapping {subword → [IPA phonemes]}.""" - subwords = list(tokenizer.vocab.keys()) - try: - phoneme_strings = phonemize( - subwords, - language=language, - backend="espeak", # use espeak-ng - strip=True, - njobs=1, - preserve_punctuation=True, - with_stress=True, - ) - # split each string into grapheme clusters (IPA symbols) - phoneme_lists = [_split_ipa_symbols(s) for s in phoneme_strings] - return {sw: phs for sw, phs in zip(subwords, phoneme_lists) if phs} - except Exception as e: - logging.error(f"[PHONEME-VOCAB] Failed to phonemize subwords: {e}") - return {} - - def _build_phoneme_vocab(subword_to_phonemes: dict[str, list[str]]) -> dict[str, int]: - phoneme_set = {p for phs in subword_to_phonemes.values() for p in phs} - sorted_phonemes = sorted(phoneme_set) - return {p: i for i, p in enumerate(sorted_phonemes)} - - # --- Load or build vocab --- - vocab_file_name = "phoneme_vocab.json" - if vocab_dir: - os.makedirs(vocab_dir, exist_ok=True) - vocab_file = os.path.join(vocab_dir, vocab_file_name) - - with FileLock(vocab_file + ".lock", timeout=60): - if not os.path.exists(vocab_file): - subword_to_phonemes = _phonemize_all_subwords() - phoneme_vocab = _build_phoneme_vocab(subword_to_phonemes) - cache = {"phoneme_vocab": phoneme_vocab, "subword_to_phonemes": subword_to_phonemes} - logging.info(f"[PHONEME-VOCAB] Saving → {vocab_file}") - with open(vocab_file, "w", encoding="utf-8") as f: - json.dump(cache, f, ensure_ascii=False, indent=2) - - logging.info(f"[PHONEME-VOCAB] Loading from {vocab_file}") - with open(vocab_file, encoding="utf-8") as f: - cache = json.load(f) - phoneme_vocab = cache["phoneme_vocab"] - subword_to_phonemes = cache["subword_to_phonemes"] - else: - logging.info(f"[PHONEME-VOCAB] Building from tokenizer '{pretrained_tokenizer_name}'") - subword_to_phonemes = _phonemize_all_subwords() - phoneme_vocab = _build_phoneme_vocab(subword_to_phonemes) - - # --- Build subword → phoneme ID mapping --- - subword_id_to_phoneme_ids = {} - for subword, subword_id in tokenizer.vocab.items(): - phonemes = subword_to_phonemes.get(subword, []) - phoneme_ids = [phoneme_vocab[p] for p in phonemes if p in phoneme_vocab] - if phoneme_ids: - subword_id_to_phoneme_ids[subword_id] = tuple(phoneme_ids) - - # Define a padding index for subwords - subword_padding_idx = len(tokenizer.vocab) - # The padding subword maps to a new phoneme padding ID - subword_id_to_phoneme_ids[subword_padding_idx] = (len(phoneme_vocab),) - - return subword_id_to_phoneme_ids, phoneme_vocab, subword_padding_idx - - -"""@torch.compile -def depthsum_encoding_step( - embs: Tensor, - r: Tensor, - code: Tensor, - depth_str: int = 0, - k: int = 72, -) -> Tensor: - for i in range(depth_str, depth_str + k): - idx_sel = ( - embs[i].pow(2).sum(-1) # [g?, v] - - 2 - * (r.unsqueeze(-2) @ embs[i].transpose(-1, -2)).squeeze(-2) # [b, ?, g?, h] , [g?, h, v] -> [b, ?, g?, v] - ).argmin(-1) - emb_i = F.embedding(idx_sel, embs[i]) - r = r - emb_i - code[..., i : i + 1] = idx_sel - return code -""" - - @torch.compile def depthsum_encoding_step( embs: Tensor, diff --git a/nemo/collections/speechlm2/parts/metrics/results_logger.py b/nemo/collections/speechlm2/parts/metrics/results_logger.py index 65b073742845..4d083ac22345 100644 --- a/nemo/collections/speechlm2/parts/metrics/results_logger.py +++ b/nemo/collections/speechlm2/parts/metrics/results_logger.py @@ -22,10 +22,7 @@ def safe_remove_path(path): - try: - shutil.rmtree(path) - except: - pass # File was already deleted by another thread + shutil.rmtree(path, ignore_errors=True) class ResultsLogger: From d4b61f0ba08f620a6f8168fc50332930a8ee9477 Mon Sep 17 00:00:00 2001 From: Edresson Casanova Date: Thu, 13 Nov 2025 06:27:42 -0800 Subject: [PATCH 011/102] Fix code scanning issues Signed-off-by: Edresson Casanova --- .../speechlm2/data/duplex_ear_tts_dataset.py | 186 +++++++++--------- .../speechlm2/models/duplex_ear_tts.py | 4 - 2 files changed, 92 insertions(+), 98 deletions(-) diff --git a/nemo/collections/speechlm2/data/duplex_ear_tts_dataset.py b/nemo/collections/speechlm2/data/duplex_ear_tts_dataset.py index 793af3f294e7..3f7a4469235f 100644 --- a/nemo/collections/speechlm2/data/duplex_ear_tts_dataset.py +++ b/nemo/collections/speechlm2/data/duplex_ear_tts_dataset.py @@ -173,7 +173,6 @@ def __init__( target_sample_rate: int, input_roles: list[str] = None, output_roles: list[str] = None, - add_description: bool = True, p_drop_description: float = 0.0, add_text_bos_and_eos_in_each_turn: bool = False, add_audio_prompt_after_description: bool = False, @@ -186,7 +185,6 @@ def __init__( self.target_sample_rate = target_sample_rate self.input_roles = set(ifnone(input_roles, ["user"])) self.output_roles = set(ifnone(output_roles, ["agent"])) - self.add_description = add_description self.p_drop_description = p_drop_description self.add_text_bos_and_eos_in_each_turn = add_text_bos_and_eos_in_each_turn self.add_audio_prompt_after_description = add_audio_prompt_after_description @@ -290,99 +288,99 @@ def __getitem__(self, cuts: CutSet) -> dict: source_audio = F.pad(source_audio, (0, extra_frames)) source_audio_lens = source_audio_lens + extra_frames - if self.add_description: - text_pad_id = get_pad_id(self.tokenizer) - input_text_tokens_ = [] - source_tokens_ = [] - source_audio_ = [] - target_audio_ = [] - desc_lens = [] - desc_plus_audio_prompt_lens = [] - # for each sample in the batch - for i in range(input_text_tokens.size(0)): - desc_tokens_ids = self.generate_prompt_description(device=input_text_tokens[i].device).squeeze(0) - if self.add_audio_prompt_after_description: - prompt_audio_size = int( - ((self.audio_prompt_duration * self.target_sample_rate) // target_samples_per_frame) - * target_samples_per_frame - ) - prompt_audio = sample_audio_segments_repeat( - speaker_reference_audio, speaker_reference_audio_lens, prompt_audio_size, sample=True - ) - # add a silence in the end to smooth the transition between prompt and audio tokens, keep one extra pad token due shift on subword_ids - prompt_audio[:, -int(target_samples_per_frame * 2) :] = 0 - - # create tensor to pad text channels with the same amount of frames added in audio channel (audio prompt) - prompt_audio_text_pad_size = prompt_audio_size // target_samples_per_frame - prompt_audio_text_pad = ( - torch.ones( - prompt_audio_text_pad_size, device=input_text_tokens.device, dtype=input_text_tokens.dtype - ) - * text_pad_id - ) - # set last prompt frame with eos in text channel - prompt_audio_text_pad[-1] = self.tokenizer.eos - - # Add eos to simulate the end of a turn as in EAR-TTS inference - desc_tokens_ids = torch.cat( - [ - desc_tokens_ids, - torch.tensor( - [self.tokenizer.eos], dtype=desc_tokens_ids.dtype, device=desc_tokens_ids.device - ), - ] - ) - # Add padding equivalent to the audio prompt size in number of tokens - new_input_text_tokens = torch.cat( - [ - desc_tokens_ids.to(input_text_tokens.dtype), - prompt_audio_text_pad.to(input_text_tokens.dtype), - input_text_tokens[i], - ] + + text_pad_id = get_pad_id(self.tokenizer) + input_text_tokens_ = [] + source_tokens_ = [] + source_audio_ = [] + target_audio_ = [] + desc_lens = [] + desc_plus_audio_prompt_lens = [] + # for each sample in the batch + for i in range(input_text_tokens.size(0)): + desc_tokens_ids = self.generate_prompt_description(device=input_text_tokens[i].device).squeeze(0) + if self.add_audio_prompt_after_description: + prompt_audio_size = int( + ((self.audio_prompt_duration * self.target_sample_rate) // target_samples_per_frame) + * target_samples_per_frame + ) + prompt_audio = sample_audio_segments_repeat( + speaker_reference_audio, speaker_reference_audio_lens, prompt_audio_size, sample=True + ) + # add a silence in the end to smooth the transition between prompt and audio tokens, keep one extra pad token due shift on subword_ids + prompt_audio[:, -int(target_samples_per_frame * 2) :] = 0 + + # create tensor to pad text channels with the same amount of frames added in audio channel (audio prompt) + prompt_audio_text_pad_size = prompt_audio_size // target_samples_per_frame + prompt_audio_text_pad = ( + torch.ones( + prompt_audio_text_pad_size, device=input_text_tokens.device, dtype=input_text_tokens.dtype ) - # append to list and update lens - input_text_tokens_.append(new_input_text_tokens) - target_token_lens[i] = target_token_lens[i] + len(desc_tokens_ids) + prompt_audio_text_pad_size - - # add description to source text tokens - source_tokens_.append(torch.cat([desc_tokens_ids, prompt_audio_text_pad, source_tokens[i]])) - source_token_lens[i] = source_token_lens[i] + len(desc_tokens_ids) + prompt_audio_text_pad_size - # add silence in the source audio while the prompt is being processed - pad_size = (len(desc_tokens_ids) * source_samples_per_frame) + prompt_audio.size(1) - pad_audio = torch.zeros(pad_size, device=source_audio.device, dtype=source_audio.dtype) - source_audio_.append(torch.cat([pad_audio, source_audio[i]])) - source_audio_lens[i] = source_audio_lens[i] + pad_size - # add silence in the target audio while the prompt is being processed - pad_size = len(desc_tokens_ids) * target_samples_per_frame - pad_audio = torch.zeros(pad_size, device=target_audio.device, dtype=target_audio.dtype) - target_audio_.append(torch.cat([pad_audio, prompt_audio[i], target_audio[i]])) - target_audio_lens[i] = target_audio_lens[i] + pad_size + prompt_audio.size(1) - # desc duration - desc_lens.append(len(desc_tokens_ids)) - desc_plus_audio_prompt_lens.append( - len(desc_tokens_ids) + prompt_audio_text_pad_size - 1 - ) # -1 due the shift done in subword_ids - else: - # add description to target text tokens - input_text_tokens_.append(torch.cat([desc_tokens_ids, input_text_tokens[i]])) - target_token_lens[i] = target_token_lens[i] + len(desc_tokens_ids) - # add description to source text tokens - source_tokens_.append(torch.cat([desc_tokens_ids, source_tokens[i]])) - source_token_lens[i] = source_token_lens[i] + len(desc_tokens_ids) - # add silence in the source audio while the prompt is being processed - pad_size = len(desc_tokens_ids) * source_samples_per_frame - pad_audio = torch.zeros(pad_size, device=source_audio.device, dtype=source_audio.dtype) - source_audio_.append(torch.cat([pad_audio, source_audio[i]])) - source_audio_lens[i] = source_audio_lens[i] + pad_size - # add silence in the target audio while the prompt is being processed - pad_size = len(desc_tokens_ids) * target_samples_per_frame - pad_audio = torch.zeros(pad_size, device=target_audio.device, dtype=target_audio.dtype) - target_audio_.append(torch.cat([pad_audio, target_audio[i]])) - target_audio_lens[i] = target_audio_lens[i] + pad_size - - # des duration - desc_lens.append(len(desc_tokens_ids)) - desc_plus_audio_prompt_lens.append(len(desc_tokens_ids)) + * text_pad_id + ) + # set last prompt frame with eos in text channel + prompt_audio_text_pad[-1] = self.tokenizer.eos + + # Add eos to simulate the end of a turn as in EAR-TTS inference + desc_tokens_ids = torch.cat( + [ + desc_tokens_ids, + torch.tensor( + [self.tokenizer.eos], dtype=desc_tokens_ids.dtype, device=desc_tokens_ids.device + ), + ] + ) + # Add padding equivalent to the audio prompt size in number of tokens + new_input_text_tokens = torch.cat( + [ + desc_tokens_ids.to(input_text_tokens.dtype), + prompt_audio_text_pad.to(input_text_tokens.dtype), + input_text_tokens[i], + ] + ) + # append to list and update lens + input_text_tokens_.append(new_input_text_tokens) + target_token_lens[i] = target_token_lens[i] + len(desc_tokens_ids) + prompt_audio_text_pad_size + + # add description to source text tokens + source_tokens_.append(torch.cat([desc_tokens_ids, prompt_audio_text_pad, source_tokens[i]])) + source_token_lens[i] = source_token_lens[i] + len(desc_tokens_ids) + prompt_audio_text_pad_size + # add silence in the source audio while the prompt is being processed + pad_size = (len(desc_tokens_ids) * source_samples_per_frame) + prompt_audio.size(1) + pad_audio = torch.zeros(pad_size, device=source_audio.device, dtype=source_audio.dtype) + source_audio_.append(torch.cat([pad_audio, source_audio[i]])) + source_audio_lens[i] = source_audio_lens[i] + pad_size + # add silence in the target audio while the prompt is being processed + pad_size = len(desc_tokens_ids) * target_samples_per_frame + pad_audio = torch.zeros(pad_size, device=target_audio.device, dtype=target_audio.dtype) + target_audio_.append(torch.cat([pad_audio, prompt_audio[i], target_audio[i]])) + target_audio_lens[i] = target_audio_lens[i] + pad_size + prompt_audio.size(1) + # desc duration + desc_lens.append(len(desc_tokens_ids)) + desc_plus_audio_prompt_lens.append( + len(desc_tokens_ids) + prompt_audio_text_pad_size - 1 + ) # -1 due the shift done in subword_ids + else: + # add description to target text tokens + input_text_tokens_.append(torch.cat([desc_tokens_ids, input_text_tokens[i]])) + target_token_lens[i] = target_token_lens[i] + len(desc_tokens_ids) + # add description to source text tokens + source_tokens_.append(torch.cat([desc_tokens_ids, source_tokens[i]])) + source_token_lens[i] = source_token_lens[i] + len(desc_tokens_ids) + # add silence in the source audio while the prompt is being processed + pad_size = len(desc_tokens_ids) * source_samples_per_frame + pad_audio = torch.zeros(pad_size, device=source_audio.device, dtype=source_audio.dtype) + source_audio_.append(torch.cat([pad_audio, source_audio[i]])) + source_audio_lens[i] = source_audio_lens[i] + pad_size + # add silence in the target audio while the prompt is being processed + pad_size = len(desc_tokens_ids) * target_samples_per_frame + pad_audio = torch.zeros(pad_size, device=target_audio.device, dtype=target_audio.dtype) + target_audio_.append(torch.cat([pad_audio, target_audio[i]])) + target_audio_lens[i] = target_audio_lens[i] + pad_size + + # des duration + desc_lens.append(len(desc_tokens_ids)) + desc_plus_audio_prompt_lens.append(len(desc_tokens_ids)) # collate tensors input_text_tokens = collate_vectors(input_text_tokens_, padding_value=text_pad_id) diff --git a/nemo/collections/speechlm2/models/duplex_ear_tts.py b/nemo/collections/speechlm2/models/duplex_ear_tts.py index dde78cb3db06..bf9bf4587215 100644 --- a/nemo/collections/speechlm2/models/duplex_ear_tts.py +++ b/nemo/collections/speechlm2/models/duplex_ear_tts.py @@ -252,7 +252,6 @@ def generate_multiturn_speaking_mask(input_ids: torch.Tensor, bos_token_id: int Note BOS is considered as speaking (1) and EOS as non speaking 0 """ - B, T = input_ids.shape device = input_ids.device bos_mask = (input_ids == bos_token_id).to(torch.int32).to(device) eos_mask = (input_ids == eos_token_id).to(torch.int32).to(device) @@ -753,8 +752,6 @@ def pad_or_truncate(x, pad_value=0): # set the extra self.speech_pad_id at first 1 position in non_prompt_mask target_codes_aligned[row_idx, pos] = self.speech_pad_id - B, T = input_text_tokens.shape - # shift text tokens subword_ids = F.pad(input_text_tokens[:, 1:], [0, 1]) # note that we are using a text mask where we are ignoring the desc + audio prompt but we are keeping 1 until the audio ends to support duplex @@ -772,7 +769,6 @@ def pad_or_truncate(x, pad_value=0): input_text_tokens = input_text_tokens[:, :-remainder] target_codes_aligned = target_codes_aligned[:, :-remainder] target_codes_aligned = target_codes_aligned[:, :-remainder] - desc_mask = desc_mask[:, :-remainder] subword_ids = subword_ids[:, :-remainder] subword_mask = subword_mask[:, :-remainder] From 2d10b4e0cb1d14e0a82335d6bca99d5d18dbe0fd Mon Sep 17 00:00:00 2001 From: Edresson Casanova Date: Thu, 13 Nov 2025 06:33:04 -0800 Subject: [PATCH 012/102] Add missing copyright Signed-off-by: Edresson Casanova --- .../speechlm2/modules/ear_tts_commons.py | 15 ++++++++++++++- .../speechlm2/modules/rvq_ear_tts_model.py | 15 ++++++++++++++- .../speechlm2/modules/rvq_ear_tts_vae.py | 15 ++++++++++++++- 3 files changed, 42 insertions(+), 3 deletions(-) diff --git a/nemo/collections/speechlm2/modules/ear_tts_commons.py b/nemo/collections/speechlm2/modules/ear_tts_commons.py index 95883159edda..6f878ff1360d 100644 --- a/nemo/collections/speechlm2/modules/ear_tts_commons.py +++ b/nemo/collections/speechlm2/modules/ear_tts_commons.py @@ -1,4 +1,17 @@ -# Standard library +# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + import argparse import glob import importlib.machinery diff --git a/nemo/collections/speechlm2/modules/rvq_ear_tts_model.py b/nemo/collections/speechlm2/modules/rvq_ear_tts_model.py index 35e160a2cf57..734a67b95683 100644 --- a/nemo/collections/speechlm2/modules/rvq_ear_tts_model.py +++ b/nemo/collections/speechlm2/modules/rvq_ear_tts_model.py @@ -1,4 +1,17 @@ -# Standard library +# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + import json import math import os diff --git a/nemo/collections/speechlm2/modules/rvq_ear_tts_vae.py b/nemo/collections/speechlm2/modules/rvq_ear_tts_vae.py index b3b2fe0cced7..78906cdc27f2 100644 --- a/nemo/collections/speechlm2/modules/rvq_ear_tts_vae.py +++ b/nemo/collections/speechlm2/modules/rvq_ear_tts_vae.py @@ -1,4 +1,17 @@ -# Standard library +# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + import functools import math from collections.abc import Callable From 8dc232bb0488355af301da07df1a1d2578971204 Mon Sep 17 00:00:00 2001 From: Edresson Date: Thu, 13 Nov 2025 16:55:53 +0000 Subject: [PATCH 013/102] Apply isort and black reformatting Signed-off-by: Edresson Signed-off-by: Edresson Casanova --- .../speechlm2/data/duplex_ear_tts_dataset.py | 9 ++------- nemo/collections/speechlm2/models/duplex_ear_tts.py | 11 ++++++----- .../speechlm2/modules/rvq_ear_tts_model.py | 2 ++ 3 files changed, 10 insertions(+), 12 deletions(-) diff --git a/nemo/collections/speechlm2/data/duplex_ear_tts_dataset.py b/nemo/collections/speechlm2/data/duplex_ear_tts_dataset.py index 3f7a4469235f..a09c4a61bf6a 100644 --- a/nemo/collections/speechlm2/data/duplex_ear_tts_dataset.py +++ b/nemo/collections/speechlm2/data/duplex_ear_tts_dataset.py @@ -288,7 +288,6 @@ def __getitem__(self, cuts: CutSet) -> dict: source_audio = F.pad(source_audio, (0, extra_frames)) source_audio_lens = source_audio_lens + extra_frames - text_pad_id = get_pad_id(self.tokenizer) input_text_tokens_ = [] source_tokens_ = [] @@ -325,9 +324,7 @@ def __getitem__(self, cuts: CutSet) -> dict: desc_tokens_ids = torch.cat( [ desc_tokens_ids, - torch.tensor( - [self.tokenizer.eos], dtype=desc_tokens_ids.dtype, device=desc_tokens_ids.device - ), + torch.tensor([self.tokenizer.eos], dtype=desc_tokens_ids.dtype, device=desc_tokens_ids.device), ] ) # Add padding equivalent to the audio prompt size in number of tokens @@ -412,9 +409,7 @@ def __getitem__(self, cuts: CutSet) -> dict: # Segment IDs per sequence (padded) aligned_segment_ids = torch.stack( [ - torch.nn.functional.pad( - torch.full((seq_len,), i), (0, max_len - seq_len), value=-1 - ) # -1 for padding + torch.nn.functional.pad(torch.full((seq_len,), i), (0, max_len - seq_len), value=-1) # -1 for padding for i, seq_len in enumerate(target_token_lens) ], dim=0, diff --git a/nemo/collections/speechlm2/models/duplex_ear_tts.py b/nemo/collections/speechlm2/models/duplex_ear_tts.py index bf9bf4587215..844e79b6f303 100644 --- a/nemo/collections/speechlm2/models/duplex_ear_tts.py +++ b/nemo/collections/speechlm2/models/duplex_ear_tts.py @@ -16,6 +16,8 @@ import os import tempfile import time +from collections import Counter +from contextlib import contextmanager import torch import torch.nn as nn @@ -53,9 +55,6 @@ from nemo.core.neural_types import AudioSignal, LabelsType, LengthsType, NeuralType from nemo.utils import logging -from collections import Counter -from contextlib import contextmanager - def maybe_to(x, dtype): if x is None: @@ -870,7 +869,6 @@ def log_model_stats(self): self.log("weights/max_abs", max_abs_w, on_epoch=True, sync_dist=True) self.log("weights/mean", weight_mean, on_epoch=True, sync_dist=True) - def on_validation_epoch_start(self) -> None: setup_audio_codec(self) self.results_logger = ResultsLogger(self.validation_save_path).reset() @@ -1019,7 +1017,10 @@ def run_evaluation_one_batch(self, name, dataset_batch, use_dataloader_init=Fals for key in init_inputs: if init_inputs[key] is not None: init_inputs[key] = torch.stack( - [init_inputs[key][i, :plen] for i, plen in enumerate(dataset_batch["desc_plus_audio_prompt_lens"])] + [ + init_inputs[key][i, :plen] + for i, plen in enumerate(dataset_batch["desc_plus_audio_prompt_lens"]) + ] ) else: # set init inputs and get it diff --git a/nemo/collections/speechlm2/modules/rvq_ear_tts_model.py b/nemo/collections/speechlm2/modules/rvq_ear_tts_model.py index 734a67b95683..606ce7a86856 100644 --- a/nemo/collections/speechlm2/modules/rvq_ear_tts_model.py +++ b/nemo/collections/speechlm2/modules/rvq_ear_tts_model.py @@ -29,6 +29,7 @@ from nemo.collections.speechlm2.parts.precision import fp32_precision from nemo.collections.speechlm2.parts.pretrained import set_model_dict_for_partial_init from nemo.utils import logging + # ============================================================================== # MLP module and Norm # ============================================================================== @@ -498,6 +499,7 @@ def _build_char_vocab() -> dict[str, int]: # 1. Load or build the character vocabulary if vocab_dir: from filelock import FileLock + char_vocab_file = os.path.join(vocab_dir, "char_vocab.json") os.makedirs(vocab_dir, exist_ok=True) with FileLock(char_vocab_file + ".lock", timeout=60): From 626b3283dbded62f88866a5bf6c010c3924c32cf Mon Sep 17 00:00:00 2001 From: Edresson Casanova Date: Thu, 13 Nov 2025 09:28:24 -0800 Subject: [PATCH 014/102] Fix code scanning issues Signed-off-by: Edresson Casanova --- .../speechlm2/data/duplex_ear_tts_dataset.py | 47 +++++++++---------- .../speechlm2/models/duplex_ear_tts.py | 2 - .../speechlm2/modules/ear_tts_commons.py | 7 ++- .../speechlm2/modules/rvq_ear_tts_model.py | 5 +- .../speechlm2/modules/rvq_ear_tts_vae.py | 6 --- 5 files changed, 31 insertions(+), 36 deletions(-) diff --git a/nemo/collections/speechlm2/data/duplex_ear_tts_dataset.py b/nemo/collections/speechlm2/data/duplex_ear_tts_dataset.py index a09c4a61bf6a..8a2a92360cb7 100644 --- a/nemo/collections/speechlm2/data/duplex_ear_tts_dataset.py +++ b/nemo/collections/speechlm2/data/duplex_ear_tts_dataset.py @@ -288,6 +288,10 @@ def __getitem__(self, cuts: CutSet) -> dict: source_audio = F.pad(source_audio, (0, extra_frames)) source_audio_lens = source_audio_lens + extra_frames +<<<<<<< HEAD +======= + # Add audio and text prompts +>>>>>>> e5cc7fb09 (Fix code scanning issues) text_pad_id = get_pad_id(self.tokenizer) input_text_tokens_ = [] source_tokens_ = [] @@ -379,30 +383,25 @@ def __getitem__(self, cuts: CutSet) -> dict: desc_lens.append(len(desc_tokens_ids)) desc_plus_audio_prompt_lens.append(len(desc_tokens_ids)) - # collate tensors - input_text_tokens = collate_vectors(input_text_tokens_, padding_value=text_pad_id) - source_tokens = collate_vectors(source_tokens_, padding_value=text_pad_id) - source_audio = collate_vectors(source_audio_, padding_value=0) - target_audio = collate_vectors(target_audio_, padding_value=0) - - # recreate audio mask - non_desc_mask = get_mask_from_lengths(target_token_lens) - # ignore desc len in audio mask - for i, frame in enumerate(desc_lens): - non_desc_mask[i, :frame] = 0.0 - - # desc mask is totally the oposite of audio mask - desc_mask = ~non_desc_mask - - # create non_prompt_mask that should mask desc plus audio prompt if used - non_prompt_mask = get_mask_from_lengths(target_token_lens) - for i, frame in enumerate(desc_plus_audio_prompt_lens): - non_prompt_mask[i, : frame - 1] = 0.0 - else: - # create a mask for audio using target tokens that suppose to have the same size of the tokenized audio - non_prompt_mask = get_mask_from_lengths(target_token_lens) - # create a full zero desc mask - desc_mask = torch.zeros_like(non_prompt_mask) + # collate tensors + input_text_tokens = collate_vectors(input_text_tokens_, padding_value=text_pad_id) + source_tokens = collate_vectors(source_tokens_, padding_value=text_pad_id) + source_audio = collate_vectors(source_audio_, padding_value=0) + target_audio = collate_vectors(target_audio_, padding_value=0) + + # recreate audio mask + non_desc_mask = get_mask_from_lengths(target_token_lens) + # ignore desc len in audio mask + for i, frame in enumerate(desc_lens): + non_desc_mask[i, :frame] = 0.0 + + # desc mask is totally the oposite of audio mask + desc_mask = ~non_desc_mask + + # create non_prompt_mask that should mask desc plus audio prompt if used + non_prompt_mask = get_mask_from_lengths(target_token_lens) + for i, frame in enumerate(desc_plus_audio_prompt_lens): + non_prompt_mask[i, : frame - 1] = 0.0 max_len = max(target_token_lens) diff --git a/nemo/collections/speechlm2/models/duplex_ear_tts.py b/nemo/collections/speechlm2/models/duplex_ear_tts.py index 844e79b6f303..b9287e63f39f 100644 --- a/nemo/collections/speechlm2/models/duplex_ear_tts.py +++ b/nemo/collections/speechlm2/models/duplex_ear_tts.py @@ -1063,8 +1063,6 @@ def run_evaluation_one_batch(self, name, dataset_batch, use_dataloader_init=Fals ) * self.target_samples_per_frame ) - # for i, l in enumerate(dataset_batch["desc_plus_audio_prompt_lens"]): - # results["audio_tf"][i, :l*self.target_samples_per_frame] = 0.0 with fp32_precision(): # resample is fragile to bfloat16 default dtype metric_audio_pred = results["audio"] diff --git a/nemo/collections/speechlm2/modules/ear_tts_commons.py b/nemo/collections/speechlm2/modules/ear_tts_commons.py index 6f878ff1360d..b1720782c5c9 100644 --- a/nemo/collections/speechlm2/modules/ear_tts_commons.py +++ b/nemo/collections/speechlm2/modules/ear_tts_commons.py @@ -43,7 +43,6 @@ # Configuration Class and Utilities # ============================================================================== - class Config(MutableMapping): """ A dictionary-like configuration class that uses attributes for storage @@ -148,6 +147,12 @@ def __hash__(self): """Makes the object hashable if its contents are hashable.""" return hash(tuple(sorted(self.items()))) + def __eq__(self, other): + """Compares two Config objects for equality based on their contents.""" + if not isinstance(other, Config): + return NotImplemented + return dict(self.items()) == dict(other.items()) + def get_config_from_file(config_path: str) -> Config: """ diff --git a/nemo/collections/speechlm2/modules/rvq_ear_tts_model.py b/nemo/collections/speechlm2/modules/rvq_ear_tts_model.py index 606ce7a86856..cd8a4a40d072 100644 --- a/nemo/collections/speechlm2/modules/rvq_ear_tts_model.py +++ b/nemo/collections/speechlm2/modules/rvq_ear_tts_model.py @@ -376,8 +376,8 @@ class RVQEARTTSConfig(Config): codebook_size: int = 1024 num_quantizers: int = 72 context_hidden_size: int = 4096 - cas_config: CASConfig | None = field(default_factory=lambda: CASConfig()) - mog_head_config: MoGHeadConfig = field(default_factory=lambda: MoGHeadConfig()) + cas_config: CASConfig | None = field(default_factory=CASConfig) + mog_head_config: MoGHeadConfig = field(default_factory=MoGHeadConfig) # extra parameters used for compatibility with S2S disable_eos_prediction: bool = False @@ -714,7 +714,6 @@ def dist(self, mus: Tensor, mu: Tensor) -> Tensor: else: low_mat_sq = self.low_mat.transpose(-1, -2) @ self.low_mat x, y = mus, mu - b, t, n, d_l = x.size() wx_sq = ( x * torch.einsum( diff --git a/nemo/collections/speechlm2/modules/rvq_ear_tts_vae.py b/nemo/collections/speechlm2/modules/rvq_ear_tts_vae.py index 78906cdc27f2..09a2ca95f947 100644 --- a/nemo/collections/speechlm2/modules/rvq_ear_tts_vae.py +++ b/nemo/collections/speechlm2/modules/rvq_ear_tts_vae.py @@ -677,12 +677,6 @@ def __init__( def log_std(self) -> Tensor: return torch.log(self._variance_list[-1]()) * 0.5 - @overload - def encode(self, z: Tensor, return_z_q: Literal[False]) -> list[Tensor]: ... - - @overload - def encode(self, z: Tensor, return_z_q: Literal[True]) -> tuple[list[Tensor], Tensor]: ... - def encode(self, z: Tensor, return_z_q: bool = False) -> list[Tensor] | tuple[list[Tensor], Tensor]: r = z ids_sel = [] From 89297365775bd4b7a7ec4bd8f2ac47c11175e16d Mon Sep 17 00:00:00 2001 From: Edresson Casanova Date: Thu, 13 Nov 2025 09:33:18 -0800 Subject: [PATCH 015/102] Fix merge issues Signed-off-by: Edresson Casanova --- nemo/collections/speechlm2/data/duplex_ear_tts_dataset.py | 4 ---- nemo/collections/speechlm2/modules/rvq_ear_tts_vae.py | 2 +- 2 files changed, 1 insertion(+), 5 deletions(-) diff --git a/nemo/collections/speechlm2/data/duplex_ear_tts_dataset.py b/nemo/collections/speechlm2/data/duplex_ear_tts_dataset.py index 8a2a92360cb7..f5158e3333fc 100644 --- a/nemo/collections/speechlm2/data/duplex_ear_tts_dataset.py +++ b/nemo/collections/speechlm2/data/duplex_ear_tts_dataset.py @@ -288,10 +288,6 @@ def __getitem__(self, cuts: CutSet) -> dict: source_audio = F.pad(source_audio, (0, extra_frames)) source_audio_lens = source_audio_lens + extra_frames -<<<<<<< HEAD -======= - # Add audio and text prompts ->>>>>>> e5cc7fb09 (Fix code scanning issues) text_pad_id = get_pad_id(self.tokenizer) input_text_tokens_ = [] source_tokens_ = [] diff --git a/nemo/collections/speechlm2/modules/rvq_ear_tts_vae.py b/nemo/collections/speechlm2/modules/rvq_ear_tts_vae.py index 09a2ca95f947..ec6d1b7ae63e 100644 --- a/nemo/collections/speechlm2/modules/rvq_ear_tts_vae.py +++ b/nemo/collections/speechlm2/modules/rvq_ear_tts_vae.py @@ -17,7 +17,7 @@ from collections.abc import Callable from contextlib import contextmanager from dataclasses import dataclass, field -from typing import Any, Concatenate, Literal, overload +from typing import Any, Concatenate # Third-party import torch From 2f6ae33fe258a8364014599a35fd6fd9546c482c Mon Sep 17 00:00:00 2001 From: Edresson Date: Thu, 13 Nov 2025 17:35:26 +0000 Subject: [PATCH 016/102] Apply isort and black reformatting Signed-off-by: Edresson Signed-off-by: Edresson Casanova --- nemo/collections/speechlm2/modules/ear_tts_commons.py | 1 + 1 file changed, 1 insertion(+) diff --git a/nemo/collections/speechlm2/modules/ear_tts_commons.py b/nemo/collections/speechlm2/modules/ear_tts_commons.py index b1720782c5c9..b2e996ff5eb0 100644 --- a/nemo/collections/speechlm2/modules/ear_tts_commons.py +++ b/nemo/collections/speechlm2/modules/ear_tts_commons.py @@ -43,6 +43,7 @@ # Configuration Class and Utilities # ============================================================================== + class Config(MutableMapping): """ A dictionary-like configuration class that uses attributes for storage From 3297d0c43b8915978b4b4ce6b0a2a66f98a73d37 Mon Sep 17 00:00:00 2001 From: Edresson Casanova Date: Thu, 13 Nov 2025 10:53:19 -0800 Subject: [PATCH 017/102] Remove EARTTS configs and use directly DictConfig Signed-off-by: Edresson Casanova --- .../speechlm2/models/duplex_ear_tts.py | 12 +- .../speechlm2/modules/ear_tts_commons.py | 139 ++---------------- .../speechlm2/modules/rvq_ear_tts_model.py | 88 ++--------- .../speechlm2/modules/rvq_ear_tts_vae.py | 38 +---- 4 files changed, 38 insertions(+), 239 deletions(-) diff --git a/nemo/collections/speechlm2/models/duplex_ear_tts.py b/nemo/collections/speechlm2/models/duplex_ear_tts.py index b9287e63f39f..eebd1e37a9bb 100644 --- a/nemo/collections/speechlm2/models/duplex_ear_tts.py +++ b/nemo/collections/speechlm2/models/duplex_ear_tts.py @@ -42,8 +42,8 @@ from nemo.collections.common.tokenizers import AutoTokenizer from nemo.collections.speechlm2.data.utils import get_pad_id from nemo.collections.speechlm2.modules.ear_tts_commons import SCRIPT_PLACEHOLDER -from nemo.collections.speechlm2.modules.rvq_ear_tts_model import RVQEARTTSConfig, RVQEARTTSModel -from nemo.collections.speechlm2.modules.rvq_ear_tts_vae import RVQVAEConfig, RVQVAEModel +from nemo.collections.speechlm2.modules.rvq_ear_tts_model import RVQEARTTSModel +from nemo.collections.speechlm2.modules.rvq_ear_tts_vae import RVQVAEModel from nemo.collections.speechlm2.parts.hf_hub import HFHubMixin from nemo.collections.speechlm2.parts.metrics.asr_bleu import ASRBLEU from nemo.collections.speechlm2.parts.metrics.intelligibility import Intelligibility @@ -321,11 +321,11 @@ def setup_rvq_audio_codec(model): with fp32_precision(): if model.cfg.get("pretrained_ae_dir", None): model.audio_codec = ( - RVQVAEModel.from_pretrained(model.cfg.pretrained_ae_dir, strict=False).eval().to(model.device) + RVQVAEModel.from_pretrained(model.cfg.pretrained_ae_dir, cfg=DictConfig(model.cfg.codec_config) if model.cfg.get("codec_config", None) else None, strict=False).eval().to(model.device) ) else: # init codec from config - model.audio_codec = RVQVAEModel(RVQVAEConfig(**model.cfg.codec_config)) + model.audio_codec = RVQVAEModel(DictConfig(model.cfg.codec_config)) for p in model.audio_codec.parameters(): p.requires_grad = False @@ -537,11 +537,11 @@ def _load_tts_model(self, cfg) -> nn.Module: """Load TTS model for RVQ-EAR-TTS.""" if self.cfg.get("pretrained_tts_model", None): self.tts_model = RVQEARTTSModel.from_pretrained( - cfg.pretrained_tts_model, RVQEARTTSConfig(**cfg.tts_config), strict=False + cfg.pretrained_tts_model, DictConfig(cfg.tts_config), strict=False ) else: # start the model from scratch - self.tts_model = RVQEARTTSModel(RVQEARTTSConfig(**cfg.tts_config)) + self.tts_model = RVQEARTTSModel(DictConfig(cfg.tts_config)) setup_audio_codec(self) diff --git a/nemo/collections/speechlm2/modules/ear_tts_commons.py b/nemo/collections/speechlm2/modules/ear_tts_commons.py index b2e996ff5eb0..f0a53ee92bf4 100644 --- a/nemo/collections/speechlm2/modules/ear_tts_commons.py +++ b/nemo/collections/speechlm2/modules/ear_tts_commons.py @@ -21,13 +21,14 @@ import shutil import subprocess import sys -from collections.abc import Mapping, MutableMapping +from collections.abc import Mapping from typing import Any from safetensors import safe_open from torch import nn from nemo.utils import logging +from omegaconf import DictConfig # ============================================================================== # Contants @@ -42,120 +43,7 @@ # ============================================================================== # Configuration Class and Utilities # ============================================================================== - - -class Config(MutableMapping): - """ - A dictionary-like configuration class that uses attributes for storage - and supports both attribute and item-style access. - - This class inherits from `collections.abc.MutableMapping` and stores all - key-value pairs as instance attributes in its internal `__dict__`. - - Nested dictionaries are recursively converted into Config objects upon being set. - """ - - def __init__(self, **kwargs): - """ - Initializes the Config object from keyword arguments. - """ - # __setattr__ will handle the recursive conversion for each item - for key, value in kwargs.items(): - setattr(self, key, value) - - def to_dict(self): - """ - Recursively converts the Config object back into a standard dictionary. - - Returns: - dict: A standard dictionary representation of the configuration. - """ - result = {} - for key, value in self.items(): - if isinstance(value, Config): - # If the value is a Config object, recursively call to_dict() - result[key] = value.to_dict() - else: - result[key] = value - return result - - def to_json(self, indent=2): - """ - Serializes the configuration object to a formatted JSON string. - - Args: - indent (int, optional): The indentation level for the JSON output. - Defaults to 2. - - Returns: - str: The configuration as a JSON-formatted string. - """ - # Leverage the to_dict() method for clean serialization - return json.dumps(self.to_dict(), indent=indent) - - # --- Core MutableMapping Methods --- - - def __setattr__(self, key, value): - """ - Sets an attribute. Recursively converts dicts to Config objects. - This is the primary method for adding/modifying data. - """ - if isinstance(value, Mapping): - value = Config(**value) - # Use object's __setattr__ to avoid infinite recursion - object.__setattr__(self, key, value) - - def __setitem__(self, key, value): - """Allows setting items using dictionary syntax (e.g., `config['key'] = value`).""" - setattr(self, key, value) - - def __getattr__(self, key): - """Allows accessing items as attributes (e.g., `config.key`).""" - # This method is only called for attributes that don't already exist. - raise AttributeError(f"'{self.__class__.__name__}' object has no attribute '{key}'") - - def __getitem__(self, key): - """Allows accessing items using dictionary syntax (e.g., `config['key']`).""" - try: - return getattr(self, key) - except AttributeError as e: - # Convert AttributeError to KeyError for dict-like behavior - raise KeyError(key) from e - - def __delitem__(self, key): - """Allows deleting items using dictionary syntax (e.g., `del config['key']`).""" - try: - delattr(self, key) - except AttributeError as e: - # Convert AttributeError to KeyError for dict-like behavior - raise KeyError(key) from e - - def __iter__(self): - """Returns an iterator over the keys (attributes) of the object.""" - return iter(self.__dict__) - - def __len__(self): - """Returns the number of items (attributes) in the object.""" - return len(self.__dict__) - - # --- Utility Methods --- - - def __repr__(self): - """Returns an informative string representation of the Config object.""" - return f"{self.__class__.__name__}({self.__dict__!r})" - - def __hash__(self): - """Makes the object hashable if its contents are hashable.""" - return hash(tuple(sorted(self.items()))) - - def __eq__(self, other): - """Compares two Config objects for equality based on their contents.""" - if not isinstance(other, Config): - return NotImplemented - return dict(self.items()) == dict(other.items()) - - -def get_config_from_file(config_path: str) -> Config: +def get_config_from_file(config_path: str) -> DictConfig: """ Loads a configuration from a JSON or Python file. @@ -197,11 +85,11 @@ def get_config_from_file(config_path: str) -> Config: ), f"Python config file must define a `{PYTHON_CONFIG_GETTER_NAME}` function." config = getattr(config_module, PYTHON_CONFIG_GETTER_NAME)(py_config_name) assert isinstance(config, Mapping), f"`{PYTHON_CONFIG_GETTER_NAME}` must return a dictionary-like object." - cfg = Config(**config) + cfg = DictConfig(config) return cfg -def get_config() -> Config: +def get_config() -> DictConfig: """ Parses command-line arguments to load the main configuration for a training run. @@ -244,7 +132,7 @@ def get_config() -> Config: return cfg -def get_config_from_dir(workdir_path: str) -> Config: +def get_config_from_dir(workdir_path: str) -> DictConfig: """ A simple utility to load the configuration directly from a work directory. @@ -252,7 +140,7 @@ def get_config_from_dir(workdir_path: str) -> Config: workdir_path (str): The path to the work directory containing a `config.json`. Returns: - Config: The loaded configuration object. + DictConfig: The loaded configuration object. """ config_save_path = os.path.join(workdir_path, CONFIG_NAME) cfg = get_config_from_file(config_save_path) @@ -264,9 +152,8 @@ def get_config_from_dir(workdir_path: str) -> Config: # Base Model Classes # ============================================================================== - class PreTrainedModel(nn.Module): - config_class = Config + config_class = DictConfig """ A base class for models to handle loading from pretrained checkpoints. @@ -276,18 +163,18 @@ class PreTrainedModel(nn.Module): like Hugging Face's Transformers. Args: - config (Config | dict[str, Any]): A configuration object containing model hyperparameters. + config (DictConfig | dict[str, Any]): A configuration object containing model hyperparameters. """ - def __init__(self, config: Config | dict[str, Any], *args, **kwargs): + def __init__(self, config: DictConfig | dict[str, Any], *args, **kwargs): super().__init__() - self.config = config if isinstance(config, self.config_class) else self.config_class(**config) + self.config = config if isinstance(config, self.config_class) else self.config_class(config) @classmethod def from_pretrained( cls, pretrained_dir: str, - cfg: Config | dict[str, Any] | None = None, + cfg: DictConfig | dict[str, Any] | None = None, checkpoint_regex: str = "checkpoint_*/ema.safetensors", strict: bool = False, **model_kwargs, @@ -303,7 +190,7 @@ def from_pretrained( cls (type): The model class to instantiate. pretrained_dir (str): The directory containing the pretrained model config and checkpoint files. - cfg (Config | dict[str, Any] | None, optional): An optional config object to override + cfg (DictConfig | dict[str, Any] | None, optional): An optional config object to override the loaded config. Defaults to None. checkpoint_regex (str, optional): A regex pattern to find the checkpoint file. Defaults to "checkpoint_*/ema.safetensors". diff --git a/nemo/collections/speechlm2/modules/rvq_ear_tts_model.py b/nemo/collections/speechlm2/modules/rvq_ear_tts_model.py index cd8a4a40d072..561c24f4a93b 100644 --- a/nemo/collections/speechlm2/modules/rvq_ear_tts_model.py +++ b/nemo/collections/speechlm2/modules/rvq_ear_tts_model.py @@ -15,7 +15,7 @@ import json import math import os -from dataclasses import dataclass, field, fields +from dataclasses import dataclass, fields from typing import Any import torch @@ -25,10 +25,11 @@ from transformers import AutoConfig, AutoModel, AutoModelForTextEncoding, AutoTokenizer, Cache from transformers.generation.logits_process import TopKLogitsWarper, TopPLogitsWarper -from nemo.collections.speechlm2.modules.ear_tts_commons import Config, PreTrainedModel +from nemo.collections.speechlm2.modules.ear_tts_commons import PreTrainedModel from nemo.collections.speechlm2.parts.precision import fp32_precision from nemo.collections.speechlm2.parts.pretrained import set_model_dict_for_partial_init from nemo.utils import logging +from omegaconf import DictConfig, OmegaConf # ============================================================================== # MLP module and Norm @@ -339,66 +340,6 @@ def get_mask( return sequence_mask(num_to_keep.view(-1), depth).view_as(code_mask) -@dataclass -class CASConfig(Config): - pretrained_tokenizer_name: str = "meta-llama/Llama-3.1-8B-Instruct" - vocab_dir: str | None = None - - # transformer backbone - backbone_type: str | None = "t5gemma" - backbone_model_class: str | None = None - backbone_config_class: str | None = None - backbone_config: Config | None = None - - -@dataclass -class MoGHeadConfig(Config): - intermediate_size: int = 4608 - num_layers: int = 3 - low_rank: int | None = 64 - num_predictions: int = 1024 - min_log_std: float = -4.0 - eps: float = 1e-6 - - -@dataclass -class RVQEARTTSConfig(Config): - model_type = "rvq_ear_tts" - - # transformer backbone - backbone_type: str | None = "gemma3_text" - backbone_model_class: str | None = None - backbone_config_class: str | None = None - backbone_config: Config | None = None - - # model specific configs - latent_size: int = 512 - codebook_size: int = 1024 - num_quantizers: int = 72 - context_hidden_size: int = 4096 - cas_config: CASConfig | None = field(default_factory=CASConfig) - mog_head_config: MoGHeadConfig = field(default_factory=MoGHeadConfig) - - # extra parameters used for compatibility with S2S - disable_eos_prediction: bool = False - use_subword_flag_emb: bool = True - use_bos_eos_emb: bool = True - pretrained_text_name: str | None = None - use_gated_fusion_for_text_audio: bool = True - - p_uncond: float = 0.1 - label_smoothing: float = 0.01 - max_training_rate: float = 0.8 - quantizer_dropout: float = 0.5 - random_target_masking: bool = False - exponent: float = 3.0 - - def __post_init__(self): - if self.cas_config is not None: - self.cas_config = CASConfig(**self.cas_config) - self.mog_head_config = MoGHeadConfig(**self.mog_head_config) - - # ============================================================================== # Model and Vocabulary Utilities # ============================================================================== @@ -923,7 +864,7 @@ class CharAwareSubwordEncoder(nn.Module): backbone_type (str | None): The type of backbone model from Hugging Face (e.g., "t5gemma"). backbone_model_class (str | None): The class name of the backbone model if not using AutoModel. backbone_config_class (str | None): The class name of the backbone config. - backbone_config (Config | None): A configuration for the backbone model. + backbone_config (DictConfig | None): A configuration for the backbone model. """ def __init__( @@ -934,7 +875,7 @@ def __init__( backbone_type: str | None = "t5gemma", backbone_model_class: str | None = None, backbone_config_class: str | None = None, - backbone_config: Config | None = None, + backbone_config: DictConfig | None = None, use_subword_flag_emb: bool = True, use_bos_eos_emb: bool = True, use_cumulative_word_emb: bool = False, @@ -953,13 +894,13 @@ def __init__( # 2. Initialize the backbone model if backbone_type: - config = AutoConfig.for_model(backbone_type, **(backbone_config.to_dict() if backbone_config else {})) + config = AutoConfig.for_model(backbone_type, **(OmegaConf.to_container(backbone_config, resolve=True) if backbone_config else {})) self.backbone = AutoModelForTextEncoding.from_config(config) else: assert backbone_model_class and backbone_config_class config_class = getattr(transformers, backbone_config_class) model_class = getattr(transformers, backbone_model_class) - config = config_class(**(backbone_config.to_dict() if backbone_config else {})) + config = config_class(**(OmegaConf.to_container(backbone_config, resolve=True) if backbone_config else {})) self.backbone = model_class(config) self.hidden_size = self.backbone.get_input_embeddings().weight.size(-1) @@ -1101,15 +1042,14 @@ class RVQEARTTSModel(PreTrainedModel): autoregressive inference. Args: - config (RVQEARTTSConfig | dict[str, Any]): The configuration object for the model. + config (DictConfig | dict[str, Any]): The configuration object for the model. """ - - config_class: type[Config] = RVQEARTTSConfig rvq_embs: Tensor - def __init__(self, config: RVQEARTTSConfig | dict[str, Any]): + def __init__(self, config: DictConfig | dict[str, Any]): super().__init__(config) + # Backbone module if self.config.get("pretrained_text_name", None): # Load pretrained backbone from huggingface @@ -1118,16 +1058,16 @@ def __init__(self, config: RVQEARTTSConfig | dict[str, Any]): llm = load_pretrained_hf(self.config.pretrained_text_name, pretrained_weights=True).train() self.backbone = llm.model # fetch PretrainedBaseModel from model "ForCausalLM" else: - if self.config.backbone_type is None: - assert self.config.backbone_model_class is not None and self.config.backbone_config_class is not None + if self.config.get("backbone_type", None) is None: + assert self.config.get("backbone_model_class", None) is not None and self.config.get("backbone_config_class", None) is not None backbone_config = getattr(transformers, self.config.backbone_config_class)( - **(self.config.backbone_config.to_dict() if self.config.backbone_config else {}), + **(OmegaConf.to_container(self.config.backbone_config, resolve=True) if self.config.backbone_config else {}), ) self.backbone = getattr(transformers, self.config.backbone_model_class)(backbone_config) else: backbone_config = AutoConfig.for_model( self.config.backbone_type, - **(self.config.backbone_config.to_dict() if self.config.backbone_config else {}), + **(OmegaConf.to_container(self.config.backbone_config, resolve=True) if self.config.backbone_config else {}), ) self.backbone = AutoModel.from_config(backbone_config) diff --git a/nemo/collections/speechlm2/modules/rvq_ear_tts_vae.py b/nemo/collections/speechlm2/modules/rvq_ear_tts_vae.py index ec6d1b7ae63e..522420462955 100644 --- a/nemo/collections/speechlm2/modules/rvq_ear_tts_vae.py +++ b/nemo/collections/speechlm2/modules/rvq_ear_tts_vae.py @@ -16,7 +16,6 @@ import math from collections.abc import Callable from contextlib import contextmanager -from dataclasses import dataclass, field from typing import Any, Concatenate # Third-party @@ -26,8 +25,8 @@ from torchaudio import functional as ta_F # Project -from nemo.collections.speechlm2.modules.ear_tts_commons import Config, PreTrainedModel - +from nemo.collections.speechlm2.modules.ear_tts_commons import PreTrainedModel +from omegaconf import DictConfig @contextmanager def disable_tf32(): @@ -38,33 +37,6 @@ def disable_tf32(): finally: torch.backends.cudnn.allow_tf32 = prev - -@dataclass -class RVQVAEConfig(Config): - model_type: str = "rvqvae" - - # model specific configs - latent_size: int = 512 - wav_to_token_ratio: int = field(init=False) - - n_fft: int = 32 - hop_length: int = 8 - base_hidden_size: int = 384 - channel_mult: tuple[int, ...] = (1, 2, 4) - rates: tuple[int, ...] = (8, 8, 8) - num_blocks: int = 3 - kernel_size: int = 7 - groups: int = 1 - - # quantization - codebook_size: int = 1024 - num_quantizers: int = 72 - quantizer_dropout: float = 0.5 - - def __post_init__(self): - self.wav_to_token_ratio = self.hop_length * math.prod(self.rates) - - # ============================================================================== # Utility Functions # ============================================================================== @@ -908,12 +880,12 @@ class RVQVAEModel(PreTrainedModel): It consists of an encoder, a quantizer, and a decoder. Args: - config (RVQVAEConfig | dict[str, Any]): A configuration object with model hyperparameters. + config (DictConfig | dict[str, Any]): A configuration object with model hyperparameters. """ - config_class: type[Config] = RVQVAEConfig + config_class: type[DictConfig] = DictConfig - def __init__(self, config: RVQVAEConfig | dict[str, Any]): + def __init__(self, config: DictConfig | dict[str, Any]): super().__init__(config) self.encoder = Wav2Latent( From c58ebd0c64f088159972a20674312e8ea8602c7e Mon Sep 17 00:00:00 2001 From: Edresson Date: Thu, 13 Nov 2025 18:54:56 +0000 Subject: [PATCH 018/102] Apply isort and black reformatting Signed-off-by: Edresson Signed-off-by: Edresson Casanova --- .../speechlm2/models/duplex_ear_tts.py | 8 +++++- .../speechlm2/modules/ear_tts_commons.py | 3 ++- .../speechlm2/modules/rvq_ear_tts_model.py | 25 ++++++++++++++----- .../speechlm2/modules/rvq_ear_tts_vae.py | 4 ++- 4 files changed, 31 insertions(+), 9 deletions(-) diff --git a/nemo/collections/speechlm2/models/duplex_ear_tts.py b/nemo/collections/speechlm2/models/duplex_ear_tts.py index eebd1e37a9bb..64d9f17ba1e1 100644 --- a/nemo/collections/speechlm2/models/duplex_ear_tts.py +++ b/nemo/collections/speechlm2/models/duplex_ear_tts.py @@ -321,7 +321,13 @@ def setup_rvq_audio_codec(model): with fp32_precision(): if model.cfg.get("pretrained_ae_dir", None): model.audio_codec = ( - RVQVAEModel.from_pretrained(model.cfg.pretrained_ae_dir, cfg=DictConfig(model.cfg.codec_config) if model.cfg.get("codec_config", None) else None, strict=False).eval().to(model.device) + RVQVAEModel.from_pretrained( + model.cfg.pretrained_ae_dir, + cfg=DictConfig(model.cfg.codec_config) if model.cfg.get("codec_config", None) else None, + strict=False, + ) + .eval() + .to(model.device) ) else: # init codec from config diff --git a/nemo/collections/speechlm2/modules/ear_tts_commons.py b/nemo/collections/speechlm2/modules/ear_tts_commons.py index f0a53ee92bf4..25331d583481 100644 --- a/nemo/collections/speechlm2/modules/ear_tts_commons.py +++ b/nemo/collections/speechlm2/modules/ear_tts_commons.py @@ -24,11 +24,11 @@ from collections.abc import Mapping from typing import Any +from omegaconf import DictConfig from safetensors import safe_open from torch import nn from nemo.utils import logging -from omegaconf import DictConfig # ============================================================================== # Contants @@ -152,6 +152,7 @@ def get_config_from_dir(workdir_path: str) -> DictConfig: # Base Model Classes # ============================================================================== + class PreTrainedModel(nn.Module): config_class = DictConfig diff --git a/nemo/collections/speechlm2/modules/rvq_ear_tts_model.py b/nemo/collections/speechlm2/modules/rvq_ear_tts_model.py index 561c24f4a93b..a94f580fc2ac 100644 --- a/nemo/collections/speechlm2/modules/rvq_ear_tts_model.py +++ b/nemo/collections/speechlm2/modules/rvq_ear_tts_model.py @@ -20,6 +20,7 @@ import torch import transformers +from omegaconf import DictConfig, OmegaConf from torch import Tensor, nn from torch.nn import functional as F from transformers import AutoConfig, AutoModel, AutoModelForTextEncoding, AutoTokenizer, Cache @@ -29,7 +30,6 @@ from nemo.collections.speechlm2.parts.precision import fp32_precision from nemo.collections.speechlm2.parts.pretrained import set_model_dict_for_partial_init from nemo.utils import logging -from omegaconf import DictConfig, OmegaConf # ============================================================================== # MLP module and Norm @@ -894,7 +894,9 @@ def __init__( # 2. Initialize the backbone model if backbone_type: - config = AutoConfig.for_model(backbone_type, **(OmegaConf.to_container(backbone_config, resolve=True) if backbone_config else {})) + config = AutoConfig.for_model( + backbone_type, **(OmegaConf.to_container(backbone_config, resolve=True) if backbone_config else {}) + ) self.backbone = AutoModelForTextEncoding.from_config(config) else: assert backbone_model_class and backbone_config_class @@ -1044,12 +1046,12 @@ class RVQEARTTSModel(PreTrainedModel): Args: config (DictConfig | dict[str, Any]): The configuration object for the model. """ + rvq_embs: Tensor def __init__(self, config: DictConfig | dict[str, Any]): super().__init__(config) - # Backbone module if self.config.get("pretrained_text_name", None): # Load pretrained backbone from huggingface @@ -1059,15 +1061,26 @@ def __init__(self, config: DictConfig | dict[str, Any]): self.backbone = llm.model # fetch PretrainedBaseModel from model "ForCausalLM" else: if self.config.get("backbone_type", None) is None: - assert self.config.get("backbone_model_class", None) is not None and self.config.get("backbone_config_class", None) is not None + assert ( + self.config.get("backbone_model_class", None) is not None + and self.config.get("backbone_config_class", None) is not None + ) backbone_config = getattr(transformers, self.config.backbone_config_class)( - **(OmegaConf.to_container(self.config.backbone_config, resolve=True) if self.config.backbone_config else {}), + **( + OmegaConf.to_container(self.config.backbone_config, resolve=True) + if self.config.backbone_config + else {} + ), ) self.backbone = getattr(transformers, self.config.backbone_model_class)(backbone_config) else: backbone_config = AutoConfig.for_model( self.config.backbone_type, - **(OmegaConf.to_container(self.config.backbone_config, resolve=True) if self.config.backbone_config else {}), + **( + OmegaConf.to_container(self.config.backbone_config, resolve=True) + if self.config.backbone_config + else {} + ), ) self.backbone = AutoModel.from_config(backbone_config) diff --git a/nemo/collections/speechlm2/modules/rvq_ear_tts_vae.py b/nemo/collections/speechlm2/modules/rvq_ear_tts_vae.py index 522420462955..e412b82691df 100644 --- a/nemo/collections/speechlm2/modules/rvq_ear_tts_vae.py +++ b/nemo/collections/speechlm2/modules/rvq_ear_tts_vae.py @@ -20,13 +20,14 @@ # Third-party import torch +from omegaconf import DictConfig from torch import Tensor, nn from torch.nn import functional as F from torchaudio import functional as ta_F # Project from nemo.collections.speechlm2.modules.ear_tts_commons import PreTrainedModel -from omegaconf import DictConfig + @contextmanager def disable_tf32(): @@ -37,6 +38,7 @@ def disable_tf32(): finally: torch.backends.cudnn.allow_tf32 = prev + # ============================================================================== # Utility Functions # ============================================================================== From 13a39989ce44ad8774e11d353e1eb6687bd2439d Mon Sep 17 00:00:00 2001 From: Edresson Casanova Date: Thu, 13 Nov 2025 11:53:31 -0800 Subject: [PATCH 019/102] Update Signed-off-by: Edresson Casanova --- .../{duplex_eartts_train_infer.py => duplex_eartts_infer.py} | 0 1 file changed, 0 insertions(+), 0 deletions(-) rename examples/speechlm2/{duplex_eartts_train_infer.py => duplex_eartts_infer.py} (100%) diff --git a/examples/speechlm2/duplex_eartts_train_infer.py b/examples/speechlm2/duplex_eartts_infer.py similarity index 100% rename from examples/speechlm2/duplex_eartts_train_infer.py rename to examples/speechlm2/duplex_eartts_infer.py From 9e8a721bac173bf4b2f5ea9ba6401f4979d10370 Mon Sep 17 00:00:00 2001 From: Edresson Casanova Date: Fri, 14 Nov 2025 04:01:50 -0800 Subject: [PATCH 020/102] Implement EARTTS unit tests Signed-off-by: Edresson Casanova --- nemo/collections/speechlm2/models/__init__.py | 2 + .../speechlm2/models/duplex_ear_tts.py | 3 +- .../speechlm2/test_duplex_eartts.py | 381 ++++++++++++++++++ 3 files changed, 385 insertions(+), 1 deletion(-) create mode 100644 tests/collections/speechlm2/test_duplex_eartts.py diff --git a/nemo/collections/speechlm2/models/__init__.py b/nemo/collections/speechlm2/models/__init__.py index 144ea7774a6a..f45a484664db 100644 --- a/nemo/collections/speechlm2/models/__init__.py +++ b/nemo/collections/speechlm2/models/__init__.py @@ -13,10 +13,12 @@ # limitations under the License. from .duplex_s2s_model import DuplexS2SModel from .duplex_s2s_speech_decoder_model import DuplexS2SSpeechDecoderModel +from .duplex_ear_tts import DuplexEARTTS from .salm import SALM __all__ = [ 'DuplexS2SModel', 'DuplexS2SSpeechDecoderModel', + 'DuplexEARTTS', 'SALM', ] diff --git a/nemo/collections/speechlm2/models/duplex_ear_tts.py b/nemo/collections/speechlm2/models/duplex_ear_tts.py index 64d9f17ba1e1..bf22aa54b761 100644 --- a/nemo/collections/speechlm2/models/duplex_ear_tts.py +++ b/nemo/collections/speechlm2/models/duplex_ear_tts.py @@ -428,7 +428,7 @@ def __init__(self, cfg: dict) -> None: self.save_hyperparameters() # convert dict to config cfg = DictConfig(cfg) - self.trainer_config = cfg.trainer + self.trainer_config = cfg.get("trainer", None) self.data_cfg = cfg.data self.cfg = cfg.model self.target_sample_rate = cfg.data.target_sample_rate @@ -995,6 +995,7 @@ def run_evaluation_one_batch(self, name, dataset_batch, use_dataloader_init=Fals if ( not self.model_16_precision_safe and self.cfg.get("ensures_16_safe", False) + and self.trainer_config is not None and str(self.trainer_config.precision) != str(32) ): # ToDo: move it to a method diff --git a/tests/collections/speechlm2/test_duplex_eartts.py b/tests/collections/speechlm2/test_duplex_eartts.py new file mode 100644 index 000000000000..4dcf49593a9c --- /dev/null +++ b/tests/collections/speechlm2/test_duplex_eartts.py @@ -0,0 +1,381 @@ +# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import pytest +import torch +from lhotse import CutSet, SupervisionSegment +from lhotse.testing.dummies import dummy_cut, dummy_recording + +from nemo.collections.common.data.utils import move_data_to_device +from nemo.collections.speechlm2.data import DuplexEARTTSDataset +from nemo.collections.speechlm2.models import DuplexEARTTS + +if torch.cuda.is_available(): + torch.set_default_device('cuda') + + +test_eartts_config = { + "model": { + "pretrained_lm_name": "nvidia/NVIDIA-Nemotron-Nano-9B-v2", + "pretrained_ae_dir": None, + "pretrained_tts_model": None, + "scoring_asr": "stt_en_fastconformer_transducer_large", + + "freeze_params": [ + r"^audio_codec\..+$", # Keep audio codec frozen as it only provides supervision for training. + r"^embed_tokens\..+$", # Keep embed_tokens frozen as done in eartts + ], + + "prevent_freeze_params": [], + + "audio_save_path": "", + + "inference_guidance_scale": 0.5, + "inference_noise_scale": 0.8, + "inference_top_p_or_k": 0.8, + "inference_guidance_enabled": False, + + "subword_mask_exactly_as_eartts": False, + "context_hidden_mask_exactly_as_eartts": False, + + "optimizer": { + "_target_": "torch.optim.AdamW", + "lr": 4e-5, + "betas": [0.9, 0.98], + "weight_decay": 0, + "foreach": True, + }, + + "lr_scheduler": { + "_target_": "nemo.core.optim.lr_scheduler.InverseSquareRootAnnealing", + "warmup_steps": 2500, + "min_lr": 1e-6, + "max_steps": 100_000_000, + }, + + "codec_config": { + "latent_size": 512, + "n_fft": 16, + "hop_length": 4, + "base_hidden_size": 384, + "channel_mult": [1, 2, 4], + "rates": [7, 7, 9], + "num_blocks": 3, + "kernel_size": 7, + "groups": 1, + "codebook_size": 1024, + "num_quantizers": 31, + "wav_to_token_ratio": 1764, + }, + + "tts_config": { + "use_gated_fusion_for_text_audio": True, + "disable_eos_prediction": True, + "use_bos_eos_emb": True, + "use_subword_flag_emb": True, + "num_delay_speech_tokens": 2, + + "backbone_type": "gemma3_text", + "backbone_model_class": None, + "backbone_config_class": None, + "backbone_config": { + "hidden_size": 1152, + "intermediate_size": 4608, + "num_hidden_layers": 1, + "num_attention_heads": 16, + "num_key_value_heads": 16, + "head_dim": 72, + "attention_dropout": 0.1, + "use_cache": False, + }, + + "latent_size": 512, + "codebook_size": 1024, + "num_quantizers": 31, + "context_hidden_size": None, + + "cas_config": { + "pretrained_tokenizer_name": "nvidia/NVIDIA-Nemotron-Nano-9B-v2", + "backbone_type": "t5gemma", + "backbone_model_class": None, + "backbone_config_class": None, + "backbone_config": { + "is_encoder_decoder": False, + "encoder": { + "hidden_size": 1152, + "intermediate_size": 4608, + "num_hidden_layers": 1, + "num_attention_heads": 16, + "num_key_value_heads": 16, + "head_dim": 72, + "use_cache": False, + "attention_dropout": 0.1, + }, + }, + }, + + "mog_head_config": { + "intermediate_size": 4608, + "num_layers": 3, + "low_rank": 64, + "num_predictions": 1024, + "min_log_std": -4.0, + "eps": 1e-6, + }, + + "p_uncond": 0.1, + "label_smoothing": 0.01, + "max_training_rate": 0.8, + "quantizer_dropout": 0.5, + "random_target_masking": False, + "exponent": 3.0, + }, + }, + + "trainer": { + "devices": -1, + "accelerator": "gpu", + "num_nodes": 1, + "precision": 32, + "logger": False, + "enable_checkpointing": False, + "use_distributed_sampler": False, + "max_steps": 100_000_000, + "val_check_interval": 1000, + "limit_train_batches": "${trainer.val_check_interval}", + "limit_val_batches": 2, + "log_every_n_steps": 20, + "num_sanity_val_steps": 0, + "gradient_clip_val": 1.0, + "accumulate_grad_batches": 1, + "strategy": { + "_target_": "lightning.pytorch.strategies.DDPStrategy", + "gradient_as_bucket_view": True, + "find_unused_parameters": True, + }, + }, + + "data": { + "add_text_bos_and_eos_in_each_turn": True, + "add_audio_prompt_after_description": True, + "audio_prompt_duration": 3.0, + "frame_length": 0.08, + "source_sample_rate": 22050, + "target_sample_rate": 22050, + "input_roles": ["user", "User"], + "output_roles": ["agent", "Assistant", "assistant", "Agent"], + }, + + "exp_manager": { + "exp_dir": None, + "explicit_log_dir": "", + "name": "eartts", + "create_tensorboard_logger": False, + "create_checkpoint_callback": True, + "use_datetime_version": True, + "max_time_per_run": "00:03:50:00", + "resume_from_checkpoint": None, + "resume_if_exists": True, + "resume_ignore_no_checkpoint": True, + "create_wandb_logger": True, + "wandb_logger_kwargs": { + "name": "duplex_eartts_test", + "project": "duplex_eartts", + "resume": True, + }, + }, +} + + +@pytest.fixture(scope="session") +def model(): + model = DuplexEARTTS(test_eartts_config) + if torch.cuda.is_available(): + model.to("cuda") + return model + + +@pytest.fixture(scope="session") +def dataset(model): + return DuplexEARTTSDataset( + model.tokenizer, + add_text_bos_and_eos_in_each_turn=True, + add_audio_prompt_after_description=True, + audio_prompt_duration=3.0, + frame_length=0.08, + source_sample_rate=22050, + target_sample_rate=22050, + input_roles=["user", "User"], + output_roles=["agent", "Assistant", "assistant", "Agent"], + ) + + +@pytest.fixture(scope="session") +def training_cutset_batch(): + cut = dummy_cut(0, recording=dummy_recording(0, with_data=True)) + cut.target_audio = dummy_recording(1, with_data=True) + cut.supervisions = [ + SupervisionSegment( + id=cut.id, + recording_id=cut.recording_id, + start=0, + duration=0.1, + text='hi', + speaker="user", + ), + SupervisionSegment( + id=cut.id, + recording_id=cut.recording_id, + start=0.3, + duration=0.1, + text='hello', + speaker="assistant", + ), + SupervisionSegment( + id=cut.id, + recording_id=cut.recording_id, + start=0.5, + duration=0.1, + text='ok', + speaker="user", + ), + SupervisionSegment( + id=cut.id, + recording_id=cut.recording_id, + start=0.6, + duration=0.4, + text='okay', + speaker="assistant", + ), + ] + return CutSet([cut]) + + +def test_eartts_dataset(dataset, training_cutset_batch): + print(training_cutset_batch) + batch = dataset[training_cutset_batch] + # Keys that must be present in batch + expected_keys = { + "sample_id", + "non_prompt_mask", + "desc_mask", + "desc_lens", + "desc_plus_audio_prompt_lens", + "aligned_attention_mask", + "aligned_position_ids", + "source_audio", + "source_audio_lens", + "target_audio", + "target_audio_lens", + "input_text_tokens", + "target_token_lens", + "source_tokens", + "source_token_lens", + "target_texts", + "speaker_reference_audio", + "speaker_reference_audio_lens", + "formatter", + } + + # --- Presence + tensor sanity checks --- + for key in expected_keys: + assert key in batch, f"Missing key: {key}" + + # Tensor-only keys + tensor_keys = [ + "non_prompt_mask", + "desc_mask", + "aligned_attention_mask", + "aligned_position_ids", + "source_audio", + "source_audio_lens", + "target_audio", + "target_audio_lens", + "input_text_tokens", + "target_token_lens", + "source_tokens", + "source_token_lens", + "speaker_reference_audio", + "speaker_reference_audio_lens", + ] + + for key in tensor_keys: + assert torch.is_tensor(batch[key]), f"{key} must be a tensor" + + # --- Shape/value checks similar to the original test --- + + # Audio shapes (you can adjust if needed) + assert batch["source_audio"].shape == (1, 89082) + assert batch["target_audio"].shape == (1, 89082) + + # Target text consistency + assert batch["target_texts"] == ["hello okay"] + + # Token checks (same content as your old test) + assert batch["source_tokens"].tolist() == [[ 2, 12, 12, 12, 12, 12, 12, 12, 12, 12, 12, 12, + 12, 12, 12, 12, 12, 12, 12, 12, 12, 12, 12, 12, + 12, 12, 12, 12, 12, 12, 12, 12, 12, 12, 12, 12, + 12, 2, 1, 2, 12, 12, 12, 12, 1, 1662, 2, 12, + 12, 12, 12 + ]] + + assert batch["input_text_tokens"].tolist() == [[ 2, 12, 12, 12, 12, 12, 12, 12, 12, 12, 12, 12, + 12, 12, 12, 12, 12, 12, 12, 12, 12, 12, 12, 12, + 12, 12, 12, 12, 12, 12, 12, 12, 12, 12, 12, 12, + 12, 2, 12, 12, 12, 12, 1, 2, 12, 12, 1, 1662, + 1417, 12, 12]] + + # Check formatter + assert batch["formatter"] == ["s2s_duplex"] + + +def test_eartts_training_step(model, dataset, training_cutset_batch): + model.train() + model.on_train_epoch_start() + batch = dataset[training_cutset_batch] + batch = move_data_to_device(batch, device=model.device) + results = model.training_step(batch, batch_idx=0) + assert torch.is_tensor(results["loss"]) + assert not torch.isnan(results["loss"]) + assert results["loss"] > 0 + + +def test_eartts_validation_step(model, dataset, training_cutset_batch): + model.eval() + model.on_validation_epoch_start() + batch = dataset[training_cutset_batch] + batch = move_data_to_device(batch, device=model.device) + results = model.validation_step({"dummy_val_set": batch}, batch_idx=0) + assert results is None # no return value + + +def test_eartts_offline_generation(model): + # generate random subword_ids + subword_ids = torch.ones(2, 10).long() + + # set init inputs and get it + model.set_init_inputs( + speaker_audio=torch.randn(1, 22050), + speaker_audio_lens=torch.tensor([22050]), + ) + init_inputs = model.get_init_inputs(B=subword_ids.size(0)) + model.eval() + gen_audio, gen_audio_len = model.offline_inference( + next_subword_ids=subword_ids, + init_inputs=init_inputs, + ) + + assert gen_audio.shape == (2, 17640) + assert gen_audio_len[0] == gen_audio.size(-1) + assert gen_audio.dtype == torch.float32 From 1c6af8def3699eed4d28c5c84501e1ed61be19ae Mon Sep 17 00:00:00 2001 From: Edresson Date: Fri, 14 Nov 2025 12:02:40 +0000 Subject: [PATCH 021/102] Apply isort and black reformatting Signed-off-by: Edresson Signed-off-by: Edresson Casanova --- nemo/collections/speechlm2/models/__init__.py | 2 +- .../speechlm2/test_duplex_eartts.py | 138 ++++++++++++++---- 2 files changed, 111 insertions(+), 29 deletions(-) diff --git a/nemo/collections/speechlm2/models/__init__.py b/nemo/collections/speechlm2/models/__init__.py index f45a484664db..6fc06b6527ac 100644 --- a/nemo/collections/speechlm2/models/__init__.py +++ b/nemo/collections/speechlm2/models/__init__.py @@ -11,9 +11,9 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +from .duplex_ear_tts import DuplexEARTTS from .duplex_s2s_model import DuplexS2SModel from .duplex_s2s_speech_decoder_model import DuplexS2SSpeechDecoderModel -from .duplex_ear_tts import DuplexEARTTS from .salm import SALM __all__ = [ diff --git a/tests/collections/speechlm2/test_duplex_eartts.py b/tests/collections/speechlm2/test_duplex_eartts.py index 4dcf49593a9c..7389102d8b0c 100644 --- a/tests/collections/speechlm2/test_duplex_eartts.py +++ b/tests/collections/speechlm2/test_duplex_eartts.py @@ -31,24 +31,18 @@ "pretrained_ae_dir": None, "pretrained_tts_model": None, "scoring_asr": "stt_en_fastconformer_transducer_large", - "freeze_params": [ r"^audio_codec\..+$", # Keep audio codec frozen as it only provides supervision for training. r"^embed_tokens\..+$", # Keep embed_tokens frozen as done in eartts ], - "prevent_freeze_params": [], - "audio_save_path": "", - "inference_guidance_scale": 0.5, "inference_noise_scale": 0.8, "inference_top_p_or_k": 0.8, "inference_guidance_enabled": False, - "subword_mask_exactly_as_eartts": False, "context_hidden_mask_exactly_as_eartts": False, - "optimizer": { "_target_": "torch.optim.AdamW", "lr": 4e-5, @@ -56,14 +50,12 @@ "weight_decay": 0, "foreach": True, }, - "lr_scheduler": { "_target_": "nemo.core.optim.lr_scheduler.InverseSquareRootAnnealing", "warmup_steps": 2500, "min_lr": 1e-6, "max_steps": 100_000_000, }, - "codec_config": { "latent_size": 512, "n_fft": 16, @@ -78,14 +70,12 @@ "num_quantizers": 31, "wav_to_token_ratio": 1764, }, - "tts_config": { "use_gated_fusion_for_text_audio": True, "disable_eos_prediction": True, "use_bos_eos_emb": True, "use_subword_flag_emb": True, "num_delay_speech_tokens": 2, - "backbone_type": "gemma3_text", "backbone_model_class": None, "backbone_config_class": None, @@ -99,12 +89,10 @@ "attention_dropout": 0.1, "use_cache": False, }, - "latent_size": 512, "codebook_size": 1024, "num_quantizers": 31, "context_hidden_size": None, - "cas_config": { "pretrained_tokenizer_name": "nvidia/NVIDIA-Nemotron-Nano-9B-v2", "backbone_type": "t5gemma", @@ -124,7 +112,6 @@ }, }, }, - "mog_head_config": { "intermediate_size": 4608, "num_layers": 3, @@ -133,7 +120,6 @@ "min_log_std": -4.0, "eps": 1e-6, }, - "p_uncond": 0.1, "label_smoothing": 0.01, "max_training_rate": 0.8, @@ -142,7 +128,6 @@ "exponent": 3.0, }, }, - "trainer": { "devices": -1, "accelerator": "gpu", @@ -165,7 +150,6 @@ "find_unused_parameters": True, }, }, - "data": { "add_text_bos_and_eos_in_each_turn": True, "add_audio_prompt_after_description": True, @@ -176,7 +160,6 @@ "input_roles": ["user", "User"], "output_roles": ["agent", "Assistant", "assistant", "Agent"], }, - "exp_manager": { "exp_dir": None, "explicit_log_dir": "", @@ -323,18 +306,117 @@ def test_eartts_dataset(dataset, training_cutset_batch): assert batch["target_texts"] == ["hello okay"] # Token checks (same content as your old test) - assert batch["source_tokens"].tolist() == [[ 2, 12, 12, 12, 12, 12, 12, 12, 12, 12, 12, 12, - 12, 12, 12, 12, 12, 12, 12, 12, 12, 12, 12, 12, - 12, 12, 12, 12, 12, 12, 12, 12, 12, 12, 12, 12, - 12, 2, 1, 2, 12, 12, 12, 12, 1, 1662, 2, 12, - 12, 12, 12 - ]] + assert batch["source_tokens"].tolist() == [ + [ + 2, + 12, + 12, + 12, + 12, + 12, + 12, + 12, + 12, + 12, + 12, + 12, + 12, + 12, + 12, + 12, + 12, + 12, + 12, + 12, + 12, + 12, + 12, + 12, + 12, + 12, + 12, + 12, + 12, + 12, + 12, + 12, + 12, + 12, + 12, + 12, + 12, + 2, + 1, + 2, + 12, + 12, + 12, + 12, + 1, + 1662, + 2, + 12, + 12, + 12, + 12, + ] + ] - assert batch["input_text_tokens"].tolist() == [[ 2, 12, 12, 12, 12, 12, 12, 12, 12, 12, 12, 12, - 12, 12, 12, 12, 12, 12, 12, 12, 12, 12, 12, 12, - 12, 12, 12, 12, 12, 12, 12, 12, 12, 12, 12, 12, - 12, 2, 12, 12, 12, 12, 1, 2, 12, 12, 1, 1662, - 1417, 12, 12]] + assert batch["input_text_tokens"].tolist() == [ + [ + 2, + 12, + 12, + 12, + 12, + 12, + 12, + 12, + 12, + 12, + 12, + 12, + 12, + 12, + 12, + 12, + 12, + 12, + 12, + 12, + 12, + 12, + 12, + 12, + 12, + 12, + 12, + 12, + 12, + 12, + 12, + 12, + 12, + 12, + 12, + 12, + 12, + 2, + 12, + 12, + 12, + 12, + 1, + 2, + 12, + 12, + 1, + 1662, + 1417, + 12, + 12, + ] + ] # Check formatter assert batch["formatter"] == ["s2s_duplex"] From ea050e8cdf72b152f4f44e5b695619cedcc6da01 Mon Sep 17 00:00:00 2001 From: Edresson Casanova Date: Fri, 14 Nov 2025 09:48:18 -0800 Subject: [PATCH 022/102] Add incremental decoding and unit test for it Signed-off-by: Edresson Casanova --- .../speechlm2/data/duplex_ear_tts_dataset.py | 1 + .../speechlm2/models/duplex_ear_tts.py | 197 ++++++++++++------ .../speechlm2/test_duplex_eartts.py | 26 ++- 3 files changed, 163 insertions(+), 61 deletions(-) diff --git a/nemo/collections/speechlm2/data/duplex_ear_tts_dataset.py b/nemo/collections/speechlm2/data/duplex_ear_tts_dataset.py index f5158e3333fc..27d9db748712 100644 --- a/nemo/collections/speechlm2/data/duplex_ear_tts_dataset.py +++ b/nemo/collections/speechlm2/data/duplex_ear_tts_dataset.py @@ -297,6 +297,7 @@ def __getitem__(self, cuts: CutSet) -> dict: desc_plus_audio_prompt_lens = [] # for each sample in the batch for i in range(input_text_tokens.size(0)): + # ToDo: Consider remove the prompt description, given that NanoV2 does not support it and curently it is only a single eos text token desc_tokens_ids = self.generate_prompt_description(device=input_text_tokens[i].device).squeeze(0) if self.add_audio_prompt_after_description: prompt_audio_size = int( diff --git a/nemo/collections/speechlm2/models/duplex_ear_tts.py b/nemo/collections/speechlm2/models/duplex_ear_tts.py index bf22aa54b761..b736e09dd6d3 100644 --- a/nemo/collections/speechlm2/models/duplex_ear_tts.py +++ b/nemo/collections/speechlm2/models/duplex_ear_tts.py @@ -1229,7 +1229,7 @@ def validation_step(self, batch: dict, batch_idx: int): ): if not os.path.isfile(inference_speaker_reference): continue - print("Generating sample for speaker refernce:", inference_speaker_reference) + new_dataset_batch = copy.deepcopy(dataset_batch) # Get only the file name ref_name = os.path.basename(inference_speaker_reference) @@ -1243,6 +1243,7 @@ def validation_step(self, batch: dict, batch_idx: int): new_dataset_batch["speaker_reference_audio"] = speaker_audio new_dataset_batch["speaker_reference_audio_lens"] = speaker_audio_lens self.run_evaluation_one_batch(name, new_dataset_batch, use_dataloader_init=False) + # run inference for a custom speaker reference elif self.cfg.get("inference_speaker_reference", None): new_dataset_batch = copy.deepcopy(dataset_batch) @@ -1254,6 +1255,7 @@ def validation_step(self, batch: dict, batch_idx: int): new_dataset_batch["speaker_reference_audio"] = speaker_audio new_dataset_batch["speaker_reference_audio_lens"] = speaker_audio_lens self.run_evaluation_one_batch(name, new_dataset_batch, use_dataloader_init=False) + # run inference using dataloader speaker references else: self.run_evaluation_one_batch(name, dataset_batch, use_dataloader_init=False) @@ -1355,7 +1357,8 @@ def set_init_inputs(self, speaker_audio, speaker_audio_lens, system_prompt=None, with fp32_precision(): prompt_audio_text_pad_size = int(prompt_audio_size // self.target_samples_per_frame) - # get description tokens + # get description tokens + # ToDo: Consider remove the prompt description, given that NanoV2 does not support it and curently it is only a single eos text token desc_tokens_ids = self.get_system_prompt(system_prompt=system_prompt, user_prompt=user_prompt) # create a padding tensor @@ -1490,31 +1493,123 @@ def get_init_inputs( return init_inputs + @torch.no_grad() + def infer_codes_one_step(self, current_subword_id, prev_subword_id, current_subword_mask, prev_audio_tokens, past_key_values, guidance_enabled=True, generation_config=None, ignore_eos_flag_stop=True): + + if self.cfg.tts_config.context_hidden_size is not None: + # get context_hidden_state it is always one step behind current_subword_id + # for the first step uses the last step from warmup + context_hidden_state = self.embed_tokens(prev_subword_id) + else: + context_hidden_state = None + + # force silence as next token + if self.cfg.get('inference_force_speech_silence_on_eos', True): + silence_codes = self.codec_silence_tokens.view(1, 1, -1).expand(prev_audio_tokens.shape) + prev_audio_tokens = torch.where( + current_subword_id.unsqueeze(-1) == self.text_eos_id, + silence_codes, # silence + prev_audio_tokens, # keep original + ) + + # get subword_ids + inputs = { + "code": prev_audio_tokens, + "context_hidden_state": context_hidden_state, + "subword_ids": current_subword_id, + "subword_mask": current_subword_mask, + "past_key_values": past_key_values, + "use_cache": True, + "guidance_enabled": guidance_enabled, + "generation_config": generation_config, + "ignore_eos_flag_stop": ignore_eos_flag_stop, + } + + outputs = self.tts_model(**inputs) + + return outputs["codes"], outputs["past_key_values"] + + @torch.no_grad() + def decode_one_audio_step(self, gen_audio_codes_history, number_prev_tokens=None): + with fp32_precision(), torch.no_grad(): + if number_prev_tokens: + gen_audio_codes_history = gen_audio_codes_history[:, -number_prev_tokens:] + + gen_audio_codes_history = replace_control_speech_codes(gen_audio_codes_history, self._control_codes, self.codec_silence_tokens) + gen_audio_codes_lens = torch.tensor([gen_audio_codes_history.size(1)]*gen_audio_codes_history.size(0), device=self.device) + audio_pred, audio_len = self.audio_codec.decode(gen_audio_codes_history, gen_audio_codes_lens) + + # return only the current/lastest audio chunk + audio_pred_cur_step = audio_pred.squeeze(1)[:, -self.audio_codec.config.wav_to_token_ratio:] + audio_len[:] = self.audio_codec.config.wav_to_token_ratio + return audio_pred_cur_step, audio_len + + @torch.no_grad() def offline_inference( self, next_subword_ids: torch.Tensor, + init_inputs: dict, formatter: str = "", guidance_enabled: bool = True, generation_config: dict = None, - init_inputs: dict = None, + incremental_audio_decoding: bool = False, ) -> dict[str, torch.Tensor]: """ - Autoregressive prediction. + Runs offline autoregressive inference for the Duplex EAR-TTS speech decoder. + + This method performs **text-to-speech (TTS)** generation: given subword/text + tokens and prompt-initialization states, it autoregressively generates + audio codec tokens and decodes them into a waveform. Args: - input_signal: a batch of waveforms with shape (B, T) with source sampling rate. - input_signal_lens: example lengths as number of samples of shape (B,). - decode_audio: bool, whether to decode audio codes to waveform. + next_subword_ids (torch.Tensor): + Conditioning subword/text token IDs for the speech decoder. + Shape: (B, T_text). + + init_inputs (dict): + Dictionary of prompt-dependent initial states produced by + ``get_init_inputs()``. May include: + + • "code" — initial audio tokens (e.g., prompt audio) + • "audio_mask" — mask for prompt audio positions + • "context_hidden_state" — decoder hidden state at t = 0 + • "subword_ids" — prompt text tokens + • "subword_mask" — mask for prompt text + • "non_prompt_mask" — mask marking positions to be generated + + ``get_init_inputs()`` automatically expands batch-1 buffers to + batch size B. + + formatter (str, optional): + Optional formatter identifier used to customize the prompt structure. + + guidance_enabled (bool, optional): + Whether classifier-free guidance (CFG) is enabled. + If enabled and ``generation_config`` is ``None``, guidance parameters + are taken from ``_get_generation_config()``. + + generation_config (dict, optional): + Settings controlling autoregressive generation, including sampling + strategy, noise scale, refinement iterations, and EOS rules. + If ``None``, defaults are taken from + ``_get_generation_config(guidance_enabled)``. + + incremental_audio_decoding (bool, optional): + If True, codec-to-waveform decoding is performed incrementally during + autoregressive generation. + If False, waveform decoding occurs only after all audio tokens are produced. Returns: - A dict with keys: - * "text": generated text, de-tokenized to strings, properly skipping text_pad_id; list of length B. - * "tokens_text": generated text tokens of shape (B, T2). - * "tokens_audio": generated audio codes of shape (B, T2, K) where `K=num_codebooks`. - * "tokens_len" output lengths as number of tokens of shape (B,). - * "audio": generated waveform of shape (B, T3) (`decode_audio=True`). - * "audio_len" output lengths as number of waveform samples of shape (B,) (when `decode_audio=True`). + dict[str, torch.Tensor]: + Contains: + + • **"audio"**: + Generated waveform of shape ``(B, T_audio)``, obtained via + ``audio_pred.squeeze(1)``. + + • **"audio_len"**: + Length of each generated waveform in samples, shape ``(B,)``. """ B = next_subword_ids.size(0) @@ -1541,71 +1636,53 @@ def offline_inference( gen_audio_codes = torch.zeros( B, max_steps, self.tts_model.config.num_quantizers, device=self.device, dtype=torch.long ) - + # init subwork as all ones subword_mask = torch.ones(B, max_steps, device=self.device, dtype=torch.bool) # get first context subword_id, that is the last subword_ids from the warmup first_context_subword_id = init_inputs["subword_ids"][:, -1].unsqueeze(-1) - for i in range(max_steps - 1): + # initialize variables used to save the output audio + audio_pred = None + audio_pred_len = torch.zeros(B, device=self.device, dtype=torch.long) + + for i in range(max_steps): step_start = time.time() # current subword id is always seem current_subword_id = next_subword_ids[:, i].unsqueeze(-1) - - if self.cfg.tts_config.context_hidden_size is not None: - # get context_hidden_state it is always one step behind current_subword_id - # for the first step uses the last step from warmup - if i == 0: - context_subword_id = first_context_subword_id - else: - context_subword_id = next_subword_ids[:, i - 1].unsqueeze(-1) - - context_hidden_state = self.embed_tokens(context_subword_id) + + if i == 0: + prev_subword_id = first_context_subword_id else: - context_hidden_state = None + prev_subword_id = next_subword_ids[:, i - 1].unsqueeze(-1) # create subword_mask current_subword_mask = subword_mask[:, i].unsqueeze(-1) - # get subword_ids - inputs = { - "code": code, - "context_hidden_state": context_hidden_state, - "subword_ids": current_subword_id, - "subword_mask": current_subword_mask, - "past_key_values": past_key_values, - "use_cache": True, - "guidance_enabled": guidance_enabled, - "generation_config": generation_config, - "ignore_eos_flag_stop": True, - } + code, past_key_values = self.infer_codes_one_step(current_subword_id=current_subword_id, prev_subword_id=prev_subword_id, current_subword_mask=current_subword_mask, prev_audio_tokens=code, past_key_values=past_key_values, guidance_enabled=guidance_enabled, generation_config=generation_config, ignore_eos_flag_stop=True) - outputs = self.tts_model(**inputs) - - code = outputs["codes"] - past_key_values = outputs["past_key_values"] - # ToDo: check why it is -1 - gen_audio_codes[:, i - 1] = code.squeeze(1) - - # force silence as next token - if self.cfg.get('inference_force_speech_silence_on_eos', True): - silence_codes = self.codec_silence_tokens.view(1, 1, -1).expand(code.shape) - code = torch.where( - current_subword_id.unsqueeze(-1) == self.text_eos_id, - silence_codes, # silence - code, # keep original - ) + # cache audio tokens + gen_audio_codes[:, i] = code.squeeze(1) + if incremental_audio_decoding: + audio_pred_i, audio_pred_i_len = self.decode_one_audio_step(gen_audio_codes[:, :i+1], number_prev_tokens=self.cfg.get("inference_codec_decoding_prev_tokens_number", None)) + if audio_pred is None: + audio_pred = audio_pred_i + else: + audio_pred = torch.cat([audio_pred, audio_pred_i], dim=1) + audio_pred_len += audio_pred_i_len + step_time = time.time() - step_start logging.info(f"Autoregressive inference step: {i} of {max_steps} take around {step_time}s") - gen_audio_codes_lens = torch.tensor([gen_audio_codes.shape[1]] * gen_audio_codes.shape[0]).to(self.device) - # decode audio. Note that it is not necessary because the prompt is removed, so no special token should be on the output, but lets do it for safety - gen_audio_codes = replace_control_speech_codes(gen_audio_codes, self._control_codes, self.codec_silence_tokens) - with fp32_precision(), torch.no_grad(): - audio_pred, audio_len = self.audio_codec.decode(gen_audio_codes, gen_audio_codes_lens) + if not incremental_audio_decoding: + gen_audio_codes_lens = torch.tensor([gen_audio_codes.shape[1]] * gen_audio_codes.shape[0]).to(self.device) + # decode audio. Note that it is not necessary because the prompt is removed, so no special token should be on the output, but lets do it for safety + gen_audio_codes = replace_control_speech_codes(gen_audio_codes, self._control_codes, self.codec_silence_tokens) + with fp32_precision(), torch.no_grad(): + audio_pred, audio_pred_len = self.audio_codec.decode(gen_audio_codes, gen_audio_codes_lens) - return audio_pred.squeeze(1), audio_len + return audio_pred.squeeze(1), audio_pred_len def backward(self, *args, **kwargs): with loss_parallel(): diff --git a/tests/collections/speechlm2/test_duplex_eartts.py b/tests/collections/speechlm2/test_duplex_eartts.py index 7389102d8b0c..10860fb988e2 100644 --- a/tests/collections/speechlm2/test_duplex_eartts.py +++ b/tests/collections/speechlm2/test_duplex_eartts.py @@ -443,6 +443,7 @@ def test_eartts_validation_step(model, dataset, training_cutset_batch): def test_eartts_offline_generation(model): + model.eval() # generate random subword_ids subword_ids = torch.ones(2, 10).long() @@ -452,11 +453,34 @@ def test_eartts_offline_generation(model): speaker_audio_lens=torch.tensor([22050]), ) init_inputs = model.get_init_inputs(B=subword_ids.size(0)) - model.eval() + torch.manual_seed(42) + if torch.cuda.is_available(): + torch.cuda.manual_seed_all(42) + gen_audio, gen_audio_len = model.offline_inference( next_subword_ids=subword_ids, init_inputs=init_inputs, ) + torch.manual_seed(42) + if torch.cuda.is_available(): + torch.cuda.manual_seed_all(42) + + gen_audio_inc, gen_audio_len_inc = model.offline_inference( + next_subword_ids=subword_ids, + init_inputs=init_inputs, + incremental_audio_decoding=True + ) + + assert torch.equal(gen_audio_len, gen_audio_len_inc), \ + "Audio lengths differ between incremental and non-incremental decoding." + + # compare waveform + torch.testing.assert_close( + gen_audio, + gen_audio_inc, + atol=1e-1, + rtol=0, + ) assert gen_audio.shape == (2, 17640) assert gen_audio_len[0] == gen_audio.size(-1) From 814bc9eb7130c1d4efe499ee67ca9aef2d584381 Mon Sep 17 00:00:00 2001 From: Edresson Date: Fri, 14 Nov 2025 17:49:11 +0000 Subject: [PATCH 023/102] Apply isort and black reformatting Signed-off-by: Edresson Signed-off-by: Edresson Casanova --- .../speechlm2/models/duplex_ear_tts.py | 73 +++++++++++++------ .../speechlm2/test_duplex_eartts.py | 11 ++- 2 files changed, 55 insertions(+), 29 deletions(-) diff --git a/nemo/collections/speechlm2/models/duplex_ear_tts.py b/nemo/collections/speechlm2/models/duplex_ear_tts.py index b736e09dd6d3..a295f0551927 100644 --- a/nemo/collections/speechlm2/models/duplex_ear_tts.py +++ b/nemo/collections/speechlm2/models/duplex_ear_tts.py @@ -1357,7 +1357,7 @@ def set_init_inputs(self, speaker_audio, speaker_audio_lens, system_prompt=None, with fp32_precision(): prompt_audio_text_pad_size = int(prompt_audio_size // self.target_samples_per_frame) - # get description tokens + # get description tokens # ToDo: Consider remove the prompt description, given that NanoV2 does not support it and curently it is only a single eos text token desc_tokens_ids = self.get_system_prompt(system_prompt=system_prompt, user_prompt=user_prompt) @@ -1494,7 +1494,17 @@ def get_init_inputs( return init_inputs @torch.no_grad() - def infer_codes_one_step(self, current_subword_id, prev_subword_id, current_subword_mask, prev_audio_tokens, past_key_values, guidance_enabled=True, generation_config=None, ignore_eos_flag_stop=True): + def infer_codes_one_step( + self, + current_subword_id, + prev_subword_id, + current_subword_mask, + prev_audio_tokens, + past_key_values, + guidance_enabled=True, + generation_config=None, + ignore_eos_flag_stop=True, + ): if self.cfg.tts_config.context_hidden_size is not None: # get context_hidden_state it is always one step behind current_subword_id @@ -1534,17 +1544,20 @@ def decode_one_audio_step(self, gen_audio_codes_history, number_prev_tokens=None with fp32_precision(), torch.no_grad(): if number_prev_tokens: gen_audio_codes_history = gen_audio_codes_history[:, -number_prev_tokens:] - - gen_audio_codes_history = replace_control_speech_codes(gen_audio_codes_history, self._control_codes, self.codec_silence_tokens) - gen_audio_codes_lens = torch.tensor([gen_audio_codes_history.size(1)]*gen_audio_codes_history.size(0), device=self.device) + + gen_audio_codes_history = replace_control_speech_codes( + gen_audio_codes_history, self._control_codes, self.codec_silence_tokens + ) + gen_audio_codes_lens = torch.tensor( + [gen_audio_codes_history.size(1)] * gen_audio_codes_history.size(0), device=self.device + ) audio_pred, audio_len = self.audio_codec.decode(gen_audio_codes_history, gen_audio_codes_lens) # return only the current/lastest audio chunk - audio_pred_cur_step = audio_pred.squeeze(1)[:, -self.audio_codec.config.wav_to_token_ratio:] + audio_pred_cur_step = audio_pred.squeeze(1)[:, -self.audio_codec.config.wav_to_token_ratio :] audio_len[:] = self.audio_codec.config.wav_to_token_ratio return audio_pred_cur_step, audio_len - @torch.no_grad() def offline_inference( self, @@ -1571,12 +1584,12 @@ def offline_inference( Dictionary of prompt-dependent initial states produced by ``get_init_inputs()``. May include: - • "code" — initial audio tokens (e.g., prompt audio) - • "audio_mask" — mask for prompt audio positions - • "context_hidden_state" — decoder hidden state at t = 0 - • "subword_ids" — prompt text tokens - • "subword_mask" — mask for prompt text - • "non_prompt_mask" — mask marking positions to be generated + • "code" — initial audio tokens (e.g., prompt audio) + • "audio_mask" — mask for prompt audio positions + • "context_hidden_state" — decoder hidden state at t = 0 + • "subword_ids" — prompt text tokens + • "subword_mask" — mask for prompt text + • "non_prompt_mask" — mask marking positions to be generated ``get_init_inputs()`` automatically expands batch-1 buffers to batch size B. @@ -1597,18 +1610,18 @@ def offline_inference( incremental_audio_decoding (bool, optional): If True, codec-to-waveform decoding is performed incrementally during - autoregressive generation. + autoregressive generation. If False, waveform decoding occurs only after all audio tokens are produced. Returns: dict[str, torch.Tensor]: Contains: - • **"audio"**: + • **"audio"**: Generated waveform of shape ``(B, T_audio)``, obtained via ``audio_pred.squeeze(1)``. - • **"audio_len"**: + • **"audio_len"**: Length of each generated waveform in samples, shape ``(B,)``. """ B = next_subword_ids.size(0) @@ -1636,7 +1649,7 @@ def offline_inference( gen_audio_codes = torch.zeros( B, max_steps, self.tts_model.config.num_quantizers, device=self.device, dtype=torch.long ) - + # init subwork as all ones subword_mask = torch.ones(B, max_steps, device=self.device, dtype=torch.bool) # get first context subword_id, that is the last subword_ids from the warmup @@ -1650,7 +1663,7 @@ def offline_inference( step_start = time.time() # current subword id is always seem current_subword_id = next_subword_ids[:, i].unsqueeze(-1) - + if i == 0: prev_subword_id = first_context_subword_id else: @@ -1659,26 +1672,40 @@ def offline_inference( # create subword_mask current_subword_mask = subword_mask[:, i].unsqueeze(-1) - code, past_key_values = self.infer_codes_one_step(current_subword_id=current_subword_id, prev_subword_id=prev_subword_id, current_subword_mask=current_subword_mask, prev_audio_tokens=code, past_key_values=past_key_values, guidance_enabled=guidance_enabled, generation_config=generation_config, ignore_eos_flag_stop=True) + code, past_key_values = self.infer_codes_one_step( + current_subword_id=current_subword_id, + prev_subword_id=prev_subword_id, + current_subword_mask=current_subword_mask, + prev_audio_tokens=code, + past_key_values=past_key_values, + guidance_enabled=guidance_enabled, + generation_config=generation_config, + ignore_eos_flag_stop=True, + ) # cache audio tokens gen_audio_codes[:, i] = code.squeeze(1) if incremental_audio_decoding: - audio_pred_i, audio_pred_i_len = self.decode_one_audio_step(gen_audio_codes[:, :i+1], number_prev_tokens=self.cfg.get("inference_codec_decoding_prev_tokens_number", None)) + audio_pred_i, audio_pred_i_len = self.decode_one_audio_step( + gen_audio_codes[:, : i + 1], + number_prev_tokens=self.cfg.get("inference_codec_decoding_prev_tokens_number", None), + ) if audio_pred is None: audio_pred = audio_pred_i else: audio_pred = torch.cat([audio_pred, audio_pred_i], dim=1) - audio_pred_len += audio_pred_i_len - + audio_pred_len += audio_pred_i_len + step_time = time.time() - step_start logging.info(f"Autoregressive inference step: {i} of {max_steps} take around {step_time}s") if not incremental_audio_decoding: gen_audio_codes_lens = torch.tensor([gen_audio_codes.shape[1]] * gen_audio_codes.shape[0]).to(self.device) # decode audio. Note that it is not necessary because the prompt is removed, so no special token should be on the output, but lets do it for safety - gen_audio_codes = replace_control_speech_codes(gen_audio_codes, self._control_codes, self.codec_silence_tokens) + gen_audio_codes = replace_control_speech_codes( + gen_audio_codes, self._control_codes, self.codec_silence_tokens + ) with fp32_precision(), torch.no_grad(): audio_pred, audio_pred_len = self.audio_codec.decode(gen_audio_codes, gen_audio_codes_lens) diff --git a/tests/collections/speechlm2/test_duplex_eartts.py b/tests/collections/speechlm2/test_duplex_eartts.py index 10860fb988e2..7f8ca884d06f 100644 --- a/tests/collections/speechlm2/test_duplex_eartts.py +++ b/tests/collections/speechlm2/test_duplex_eartts.py @@ -466,19 +466,18 @@ def test_eartts_offline_generation(model): torch.cuda.manual_seed_all(42) gen_audio_inc, gen_audio_len_inc = model.offline_inference( - next_subword_ids=subword_ids, - init_inputs=init_inputs, - incremental_audio_decoding=True + next_subword_ids=subword_ids, init_inputs=init_inputs, incremental_audio_decoding=True ) - assert torch.equal(gen_audio_len, gen_audio_len_inc), \ - "Audio lengths differ between incremental and non-incremental decoding." + assert torch.equal( + gen_audio_len, gen_audio_len_inc + ), "Audio lengths differ between incremental and non-incremental decoding." # compare waveform torch.testing.assert_close( gen_audio, gen_audio_inc, - atol=1e-1, + atol=1e-1, rtol=0, ) From 5c4f33b5292c863245026388a7ef2a27a3d952bf Mon Sep 17 00:00:00 2001 From: Edresson Casanova Date: Fri, 14 Nov 2025 10:19:41 -0800 Subject: [PATCH 024/102] Add option to run codec in bf16 to speedup Signed-off-by: Edresson Casanova --- .../speechlm2/models/duplex_ear_tts.py | 38 ++++++++++--------- .../speechlm2/test_duplex_eartts.py | 1 + 2 files changed, 21 insertions(+), 18 deletions(-) diff --git a/nemo/collections/speechlm2/models/duplex_ear_tts.py b/nemo/collections/speechlm2/models/duplex_ear_tts.py index a295f0551927..ceb0432f4a0a 100644 --- a/nemo/collections/speechlm2/models/duplex_ear_tts.py +++ b/nemo/collections/speechlm2/models/duplex_ear_tts.py @@ -65,16 +65,16 @@ def maybe_to(x, dtype): @contextmanager -def ensures_16_precision(mixed_dtype): +def ensures_target_precision(target_dtype): """ Workaround for precision related issues when training with bf16-true PyTorch Lightning precision setting. In bf16-true, PTL changes PyTorch's default dtype, which may break implicit assumptions for some models. This context manager restores default float32 precision and runs the computation in float32 autocast context. """ default_dtype = torch.get_default_dtype() - torch.set_default_dtype(mixed_dtype) + torch.set_default_dtype(target_dtype) try: - with torch.amp.autocast(device_type="cuda" if torch.cuda.is_available() else "cpu", dtype=mixed_dtype): + with torch.amp.autocast(device_type="cuda" if torch.cuda.is_available() else "cpu", dtype=target_dtype): yield finally: torch.set_default_dtype(default_dtype) @@ -175,7 +175,7 @@ def new_forward(*args, **kwargs): for k, v in kwargs.items() } # with torch.cuda.amp.autocast(enabled=True, dtype=mixed_dtype): - with ensures_16_precision(mixed_dtype): + with ensures_target_precision(mixed_dtype): return module._original_forward(*new_args, **new_kwargs) module.forward = new_forward @@ -315,10 +315,10 @@ def setup_rvq_audio_codec(model): Includes a workaround for PTL auto-downcasting the codec model to bf16 with bf16-true precision. """ - if hasattr(model, "audio_codec") and next(model.audio_codec.parameters()).dtype == torch.float: + if hasattr(model, "audio_codec") and next(model.audio_codec.parameters()).dtype == model.audio_codec_run_dtype: return # skip if already set up and has the right dtype - with fp32_precision(): + with ensures_target_precision(model.audio_codec_run_dtype): if model.cfg.get("pretrained_ae_dir", None): model.audio_codec = ( RVQVAEModel.from_pretrained( @@ -448,6 +448,8 @@ def __init__(self, cfg: dict) -> None: # delete llm because we use it only to get the embbeding tokens del self.language_model + self.audio_codec_run_dtype = getattr(torch, self.cfg.get("audio_codec_run_dtype", "bfloat16"), torch.float32) + # instanciate eartts model and codec self._load_tts_model(self.cfg) self._codebook_size = self.tts_model.config.codebook_size @@ -495,7 +497,7 @@ def get_codec_silence_frame_last_one(self): audio_len = torch.tensor([audio.size(-1)]).long() audio, audio_len = self.pad_audio_to_factor(audio, audio_len, self.target_samples_per_frame) - with fp32_precision(), torch.no_grad(): + with ensures_target_precision(self.audio_codec_run_dtype), torch.no_grad(): sil_codes, sil_codes_lens = self.audio_codec.encode(audio.unsqueeze(1), audio_len) return sil_codes[0, -1] @@ -507,7 +509,7 @@ def get_codec_silence_frame(self): audio_len = torch.tensor([audio.size(-1)]).long() audio, audio_len = self.pad_audio_to_factor(audio, audio_len, self.target_samples_per_frame) - with fp32_precision(), torch.no_grad(): + with ensures_target_precision(self.audio_codec_run_dtype), torch.no_grad(): sil_codes, _ = self.audio_codec.encode(audio.unsqueeze(1), audio_len) # [1, T, C] sil_codes = sil_codes[0] # [T, C] @@ -692,10 +694,10 @@ def prepare_inputs(self, batch: dict): aligned_position_ids = batch["aligned_position_ids"] # extract target audio codes - with fp32_precision(), torch.no_grad(): - target_audio, target_audio_lens = self.pad_audio_to_factor( - target_audio, target_audio_lens, self.target_samples_per_frame, 1 - ) + target_audio, target_audio_lens = self.pad_audio_to_factor( + target_audio, target_audio_lens, self.target_samples_per_frame, 1 + ) + with ensures_target_precision(self.audio_codec_run_dtype), torch.no_grad(): target_codes, target_codes_lens = self.audio_codec.encode(target_audio.unsqueeze(1), target_audio_lens) # ToDo: consider use the source audio @@ -707,8 +709,8 @@ def prepare_inputs(self, batch: dict): source_audio_lens = (source_audio_lens * (self.target_sample_rate/self.source_sample_rate)).to(lengths.dtype) # ToDo: Add a transformer encoder to help the model to better extract contextual information, replace the code bellow with it # extract embedding for context audios - with fp32_precision(), torch.no_grad(): - source_audio, source_audio_lens = self.pad_audio_to_factor(source_audio, source_audio_lens, self.target_samples_per_frame, 1) + + with ensures_target_precision(self.audio_codec_run_dtype), torch.no_grad(): source_codes, source_codes_lens = self.audio_codec.encode( source_audio.unsqueeze(1), source_audio_lens ) @@ -915,7 +917,7 @@ def get_teacher_force_inference_audio(self, batch, guidance_enabled=True): tf_audio_codes_pred = replace_control_speech_codes( tf_audio_codes_pred, self._control_codes, self.codec_silence_tokens ) - with fp32_precision(), torch.no_grad(): + with ensures_target_precision(self.audio_codec_run_dtype), torch.no_grad(): audio_pred, audio_len = self.audio_codec.decode(tf_audio_codes_pred, inputs["output_lens"]) return audio_pred.squeeze(1), audio_len @@ -1395,7 +1397,7 @@ def set_init_inputs(self, speaker_audio, speaker_audio_lens, system_prompt=None, target_audio_len = torch.tensor( [target_audio.size(-1)] * target_audio.size(0), dtype=torch.long, device=self.device ) - with fp32_precision(), torch.no_grad(): + with ensures_target_precision(self.audio_codec_run_dtype), torch.no_grad(): code, _ = self.audio_codec.encode(target_audio.unsqueeze(1), target_audio_len) # get context hidden @@ -1541,7 +1543,7 @@ def infer_codes_one_step( @torch.no_grad() def decode_one_audio_step(self, gen_audio_codes_history, number_prev_tokens=None): - with fp32_precision(), torch.no_grad(): + with ensures_target_precision(self.audio_codec_run_dtype), torch.no_grad(): if number_prev_tokens: gen_audio_codes_history = gen_audio_codes_history[:, -number_prev_tokens:] @@ -1706,7 +1708,7 @@ def offline_inference( gen_audio_codes = replace_control_speech_codes( gen_audio_codes, self._control_codes, self.codec_silence_tokens ) - with fp32_precision(), torch.no_grad(): + with ensures_target_precision(self.audio_codec_run_dtype), torch.no_grad(): audio_pred, audio_pred_len = self.audio_codec.decode(gen_audio_codes, gen_audio_codes_lens) return audio_pred.squeeze(1), audio_pred_len diff --git a/tests/collections/speechlm2/test_duplex_eartts.py b/tests/collections/speechlm2/test_duplex_eartts.py index 7f8ca884d06f..511d7a021951 100644 --- a/tests/collections/speechlm2/test_duplex_eartts.py +++ b/tests/collections/speechlm2/test_duplex_eartts.py @@ -35,6 +35,7 @@ r"^audio_codec\..+$", # Keep audio codec frozen as it only provides supervision for training. r"^embed_tokens\..+$", # Keep embed_tokens frozen as done in eartts ], + "audio_codec_run_dtype": "float32", "prevent_freeze_params": [], "audio_save_path": "", "inference_guidance_scale": 0.5, From e1cf14494b745430534c4cfc5983d2539cc104ad Mon Sep 17 00:00:00 2001 From: Edresson Casanova Date: Fri, 14 Nov 2025 11:55:21 -0800 Subject: [PATCH 025/102] Add docs Signed-off-by: Edresson Casanova --- .../speechlm2/models/duplex_ear_tts.py | 210 +++++++++++------- 1 file changed, 129 insertions(+), 81 deletions(-) diff --git a/nemo/collections/speechlm2/models/duplex_ear_tts.py b/nemo/collections/speechlm2/models/duplex_ear_tts.py index ceb0432f4a0a..2b936e75ee72 100644 --- a/nemo/collections/speechlm2/models/duplex_ear_tts.py +++ b/nemo/collections/speechlm2/models/duplex_ear_tts.py @@ -65,16 +65,16 @@ def maybe_to(x, dtype): @contextmanager -def ensures_target_precision(target_dtype): +def ensures_16_precision(mixed_dtype): """ Workaround for precision related issues when training with bf16-true PyTorch Lightning precision setting. In bf16-true, PTL changes PyTorch's default dtype, which may break implicit assumptions for some models. This context manager restores default float32 precision and runs the computation in float32 autocast context. """ default_dtype = torch.get_default_dtype() - torch.set_default_dtype(target_dtype) + torch.set_default_dtype(mixed_dtype) try: - with torch.amp.autocast(device_type="cuda" if torch.cuda.is_available() else "cpu", dtype=target_dtype): + with torch.amp.autocast(device_type="cuda" if torch.cuda.is_available() else "cpu", dtype=mixed_dtype): yield finally: torch.set_default_dtype(default_dtype) @@ -175,7 +175,7 @@ def new_forward(*args, **kwargs): for k, v in kwargs.items() } # with torch.cuda.amp.autocast(enabled=True, dtype=mixed_dtype): - with ensures_target_precision(mixed_dtype): + with ensures_16_precision(mixed_dtype): return module._original_forward(*new_args, **new_kwargs) module.forward = new_forward @@ -315,10 +315,10 @@ def setup_rvq_audio_codec(model): Includes a workaround for PTL auto-downcasting the codec model to bf16 with bf16-true precision. """ - if hasattr(model, "audio_codec") and next(model.audio_codec.parameters()).dtype == model.audio_codec_run_dtype: + if hasattr(model, "audio_codec") and next(model.audio_codec.parameters()).dtype == torch.float: return # skip if already set up and has the right dtype - with ensures_target_precision(model.audio_codec_run_dtype): + with fp32_precision(): if model.cfg.get("pretrained_ae_dir", None): model.audio_codec = ( RVQVAEModel.from_pretrained( @@ -448,8 +448,6 @@ def __init__(self, cfg: dict) -> None: # delete llm because we use it only to get the embbeding tokens del self.language_model - self.audio_codec_run_dtype = getattr(torch, self.cfg.get("audio_codec_run_dtype", "bfloat16"), torch.float32) - # instanciate eartts model and codec self._load_tts_model(self.cfg) self._codebook_size = self.tts_model.config.codebook_size @@ -497,7 +495,7 @@ def get_codec_silence_frame_last_one(self): audio_len = torch.tensor([audio.size(-1)]).long() audio, audio_len = self.pad_audio_to_factor(audio, audio_len, self.target_samples_per_frame) - with ensures_target_precision(self.audio_codec_run_dtype), torch.no_grad(): + with fp32_precision(), torch.no_grad(): sil_codes, sil_codes_lens = self.audio_codec.encode(audio.unsqueeze(1), audio_len) return sil_codes[0, -1] @@ -509,7 +507,7 @@ def get_codec_silence_frame(self): audio_len = torch.tensor([audio.size(-1)]).long() audio, audio_len = self.pad_audio_to_factor(audio, audio_len, self.target_samples_per_frame) - with ensures_target_precision(self.audio_codec_run_dtype), torch.no_grad(): + with fp32_precision(), torch.no_grad(): sil_codes, _ = self.audio_codec.encode(audio.unsqueeze(1), audio_len) # [1, T, C] sil_codes = sil_codes[0] # [T, C] @@ -564,6 +562,15 @@ def _load_language_model(self, cfg): return language_model def init_model_from_another_checkpoint(self, checkpoint_path): + """ + Loads model weights and config from another checkpoint file, supporting .nemo and PyTorch formats. + + Args: + checkpoint_path (str): Path to checkpoint file. + + Returns: + None. The model is updated in-place. + """ if checkpoint_path is not None: if '.nemo' in checkpoint_path: with tempfile.TemporaryDirectory() as tmpdir: @@ -623,16 +630,6 @@ def text_vocab_size(self): def text_bos_id(self) -> int: return self.tokenizer.bos_id - @property - def text_zstts_task_id(self) -> int: - return self.tokenizer.text_to_ids("<|box_start|>") # uses <|box_start|> special token as zstts task id token - - @property - def text_cont_task_id(self) -> int: - return self.tokenizer.text_to_ids( - "<|object_ref_start|>" - ) # uses <|object_ref_start|> special token as cont task id token - @property def text_eos_id(self) -> int: return self.tokenizer.eos_id @@ -680,7 +677,9 @@ def pad_audio_to_factor(self, audio, audio_len, samples_per_frame, downsampling_ return padded_audio, padded_len def prepare_inputs(self, batch: dict): - """ """ + """ + Prepare inputs, extracting audio tokens and padding if needed. + """ # check if audios has the same batch size assert batch["source_audio"].size(0) == batch["target_audio"].size(0) assert batch["speaker_reference_audio"].size(0) == batch["target_audio"].size(0) @@ -694,10 +693,10 @@ def prepare_inputs(self, batch: dict): aligned_position_ids = batch["aligned_position_ids"] # extract target audio codes - target_audio, target_audio_lens = self.pad_audio_to_factor( - target_audio, target_audio_lens, self.target_samples_per_frame, 1 - ) - with ensures_target_precision(self.audio_codec_run_dtype), torch.no_grad(): + with fp32_precision(), torch.no_grad(): + target_audio, target_audio_lens = self.pad_audio_to_factor( + target_audio, target_audio_lens, self.target_samples_per_frame, 1 + ) target_codes, target_codes_lens = self.audio_codec.encode(target_audio.unsqueeze(1), target_audio_lens) # ToDo: consider use the source audio @@ -709,8 +708,8 @@ def prepare_inputs(self, batch: dict): source_audio_lens = (source_audio_lens * (self.target_sample_rate/self.source_sample_rate)).to(lengths.dtype) # ToDo: Add a transformer encoder to help the model to better extract contextual information, replace the code bellow with it # extract embedding for context audios - - with ensures_target_precision(self.audio_codec_run_dtype), torch.no_grad(): + with fp32_precision(), torch.no_grad(): + source_audio, source_audio_lens = self.pad_audio_to_factor(source_audio, source_audio_lens, self.target_samples_per_frame, 1) source_codes, source_codes_lens = self.audio_codec.encode( source_audio.unsqueeze(1), source_audio_lens ) @@ -917,7 +916,7 @@ def get_teacher_force_inference_audio(self, batch, guidance_enabled=True): tf_audio_codes_pred = replace_control_speech_codes( tf_audio_codes_pred, self._control_codes, self.codec_silence_tokens ) - with ensures_target_precision(self.audio_codec_run_dtype), torch.no_grad(): + with fp32_precision(), torch.no_grad(): audio_pred, audio_len = self.audio_codec.decode(tf_audio_codes_pred, inputs["output_lens"]) return audio_pred.squeeze(1), audio_len @@ -990,6 +989,17 @@ def offline_inference_with_custom_sentences( return audio, audio_len, speaker_audio, speaker_audio_lens def run_evaluation_one_batch(self, name, dataset_batch, use_dataloader_init=False): + """ + Runs evaluation and scoring for a single data batch, logging metrics and updating result buffers. + + Args: + name (str): Name/id for the batch (for logging). + dataset_batch (dict): Batch of data inputs, supports batched text/audio/etc. + use_dataloader_init (bool, optional): If True, use dataloader initialization for prompts. + + Returns: + None. Outputs are logged and stored in result buffers. + """ results = {} inputs = self.prepare_inputs(dataset_batch) @@ -1272,6 +1282,16 @@ def test_step(self, *args, **kwargs): return self.validation_step(*args, **kwargs) def get_system_prompt(self, system_prompt=None, user_prompt=None): + """ + Constructs a prompt message pair (system, user, assistant) formatted for chat inference. + + Args: + system_prompt (str, optional): System message describing conversational policy. + user_prompt (str, optional): User message/content. + + Returns: + torch.Tensor: Tokenized prompt IDs, shape (1, T). + """ messages = [] if system_prompt is None: system_prompt = ( @@ -1317,6 +1337,18 @@ def get_system_prompt(self, system_prompt=None, user_prompt=None): return input_ids def set_init_inputs(self, speaker_audio, speaker_audio_lens, system_prompt=None, user_prompt=None): + """ + Registers and prepares initial input buffers for text/audio prompt and context, to warm up AR inference. + + Args: + speaker_audio (torch.Tensor): Batch of prompt audio, (B, T). + speaker_audio_lens (torch.Tensor): Lengths for each sample in speaker_audio, (B,). + system_prompt (str, optional): System prompt for context. + user_prompt (str, optional): User message for context. + + Returns: + dict: Dictionary of input tensors to be passed to inference, with registered buffers. + """ # compute prompt audio size and slice it with fp32_precision(): # compute the exact number of samples for the prompt duration @@ -1359,7 +1391,7 @@ def set_init_inputs(self, speaker_audio, speaker_audio_lens, system_prompt=None, with fp32_precision(): prompt_audio_text_pad_size = int(prompt_audio_size // self.target_samples_per_frame) - # get description tokens + # get description tokens # ToDo: Consider remove the prompt description, given that NanoV2 does not support it and curently it is only a single eos text token desc_tokens_ids = self.get_system_prompt(system_prompt=system_prompt, user_prompt=user_prompt) @@ -1397,7 +1429,7 @@ def set_init_inputs(self, speaker_audio, speaker_audio_lens, system_prompt=None, target_audio_len = torch.tensor( [target_audio.size(-1)] * target_audio.size(0), dtype=torch.long, device=self.device ) - with ensures_target_precision(self.audio_codec_run_dtype), torch.no_grad(): + with fp32_precision(), torch.no_grad(): code, _ = self.audio_codec.encode(target_audio.unsqueeze(1), target_audio_len) # get context hidden @@ -1467,6 +1499,19 @@ def get_init_inputs( "non_prompt_mask", ], ): + """ + Returns a dictionary of initial inputs for inference, using registered buffers. + + Args: + B (int): Required batch size. + init_inputs_names (List[str], optional): Names of input buffers to fetch. + + Returns: + dict: Each key is name from init_inputs_names, and value is tensor of appropriate shape (B, ...). + + Notes: + Expands batch-1 buffers to B if necessary. + """ if init_inputs_names is None: init_inputs_names = [ "code", @@ -1496,17 +1541,25 @@ def get_init_inputs( return init_inputs @torch.no_grad() - def infer_codes_one_step( - self, - current_subword_id, - prev_subword_id, - current_subword_mask, - prev_audio_tokens, - past_key_values, - guidance_enabled=True, - generation_config=None, - ignore_eos_flag_stop=True, - ): + def infer_codes_one_step(self, current_subword_id, prev_subword_id, current_subword_mask, prev_audio_tokens, past_key_values, guidance_enabled=True, generation_config=None, ignore_eos_flag_stop=True): + """ + Runs a single autoregressive prediction step to infer audio codec codes. + + Args: + current_subword_id (torch.Tensor): Current text token IDs, shape (B, 1). + prev_subword_id (torch.Tensor): Previous text token IDs, shape (B, 1). + current_subword_mask (torch.Tensor): Current mask, shape (B, 1). + prev_audio_tokens (torch.Tensor): Previously generated audio tokens, shape (B, 1, C). + past_key_values: Key-value cache for transformer decoder state. + guidance_enabled (bool, optional): Enables classifier-free guidance. + generation_config (dict, optional): Generation hyperparameters. + ignore_eos_flag_stop (bool): If True, ignore EOS flag for stopping. + + Returns: + Tuple[torch.Tensor, Any]: + - Predicted audio codec token(s), shape (B, 1, C) + - Updated past_key_values for the next step. + """ if self.cfg.tts_config.context_hidden_size is not None: # get context_hidden_state it is always one step behind current_subword_id @@ -1543,23 +1596,32 @@ def infer_codes_one_step( @torch.no_grad() def decode_one_audio_step(self, gen_audio_codes_history, number_prev_tokens=None): - with ensures_target_precision(self.audio_codec_run_dtype), torch.no_grad(): + """ + Decodes one step of generated audio codec tokens to raw waveform. + + Args: + gen_audio_codes_history (torch.Tensor): Audio tokens history, shape (B, T, C). + number_prev_tokens (int, optional): Number of previous tokens to decode, for incremental decoding. + + Returns: + Tuple[torch.Tensor, torch.Tensor]: + - audio_pred_cur_step: Latest decoded waveform chunk, shape (B, wav_to_token_ratio). + - audio_len: Lengths (number of samples), shape (B,). + """ + with fp32_precision(), torch.no_grad(): if number_prev_tokens: gen_audio_codes_history = gen_audio_codes_history[:, -number_prev_tokens:] - - gen_audio_codes_history = replace_control_speech_codes( - gen_audio_codes_history, self._control_codes, self.codec_silence_tokens - ) - gen_audio_codes_lens = torch.tensor( - [gen_audio_codes_history.size(1)] * gen_audio_codes_history.size(0), device=self.device - ) + + gen_audio_codes_history = replace_control_speech_codes(gen_audio_codes_history, self._control_codes, self.codec_silence_tokens) + gen_audio_codes_lens = torch.tensor([gen_audio_codes_history.size(1)]*gen_audio_codes_history.size(0), device=self.device) audio_pred, audio_len = self.audio_codec.decode(gen_audio_codes_history, gen_audio_codes_lens) # return only the current/lastest audio chunk - audio_pred_cur_step = audio_pred.squeeze(1)[:, -self.audio_codec.config.wav_to_token_ratio :] + audio_pred_cur_step = audio_pred.squeeze(1)[:, -self.audio_codec.config.wav_to_token_ratio:] audio_len[:] = self.audio_codec.config.wav_to_token_ratio return audio_pred_cur_step, audio_len + @torch.no_grad() def offline_inference( self, @@ -1586,12 +1648,12 @@ def offline_inference( Dictionary of prompt-dependent initial states produced by ``get_init_inputs()``. May include: - • "code" — initial audio tokens (e.g., prompt audio) - • "audio_mask" — mask for prompt audio positions - • "context_hidden_state" — decoder hidden state at t = 0 - • "subword_ids" — prompt text tokens - • "subword_mask" — mask for prompt text - • "non_prompt_mask" — mask marking positions to be generated + • "code" — initial audio tokens (e.g., prompt audio) + • "audio_mask" — mask for prompt audio positions + • "context_hidden_state" — decoder hidden state at t = 0 + • "subword_ids" — prompt text tokens + • "subword_mask" — mask for prompt text + • "non_prompt_mask" — mask marking positions to be generated ``get_init_inputs()`` automatically expands batch-1 buffers to batch size B. @@ -1612,18 +1674,18 @@ def offline_inference( incremental_audio_decoding (bool, optional): If True, codec-to-waveform decoding is performed incrementally during - autoregressive generation. + autoregressive generation. If False, waveform decoding occurs only after all audio tokens are produced. Returns: dict[str, torch.Tensor]: Contains: - • **"audio"**: + • **"audio"**: Generated waveform of shape ``(B, T_audio)``, obtained via ``audio_pred.squeeze(1)``. - • **"audio_len"**: + • **"audio_len"**: Length of each generated waveform in samples, shape ``(B,)``. """ B = next_subword_ids.size(0) @@ -1651,7 +1713,7 @@ def offline_inference( gen_audio_codes = torch.zeros( B, max_steps, self.tts_model.config.num_quantizers, device=self.device, dtype=torch.long ) - + # init subwork as all ones subword_mask = torch.ones(B, max_steps, device=self.device, dtype=torch.bool) # get first context subword_id, that is the last subword_ids from the warmup @@ -1665,7 +1727,7 @@ def offline_inference( step_start = time.time() # current subword id is always seem current_subword_id = next_subword_ids[:, i].unsqueeze(-1) - + if i == 0: prev_subword_id = first_context_subword_id else: @@ -1674,41 +1736,27 @@ def offline_inference( # create subword_mask current_subword_mask = subword_mask[:, i].unsqueeze(-1) - code, past_key_values = self.infer_codes_one_step( - current_subword_id=current_subword_id, - prev_subword_id=prev_subword_id, - current_subword_mask=current_subword_mask, - prev_audio_tokens=code, - past_key_values=past_key_values, - guidance_enabled=guidance_enabled, - generation_config=generation_config, - ignore_eos_flag_stop=True, - ) + code, past_key_values = self.infer_codes_one_step(current_subword_id=current_subword_id, prev_subword_id=prev_subword_id, current_subword_mask=current_subword_mask, prev_audio_tokens=code, past_key_values=past_key_values, guidance_enabled=guidance_enabled, generation_config=generation_config, ignore_eos_flag_stop=True) # cache audio tokens gen_audio_codes[:, i] = code.squeeze(1) if incremental_audio_decoding: - audio_pred_i, audio_pred_i_len = self.decode_one_audio_step( - gen_audio_codes[:, : i + 1], - number_prev_tokens=self.cfg.get("inference_codec_decoding_prev_tokens_number", None), - ) + audio_pred_i, audio_pred_i_len = self.decode_one_audio_step(gen_audio_codes[:, :i+1], number_prev_tokens=self.cfg.get("inference_codec_decoding_prev_tokens_number", None)) if audio_pred is None: audio_pred = audio_pred_i else: audio_pred = torch.cat([audio_pred, audio_pred_i], dim=1) - audio_pred_len += audio_pred_i_len - + audio_pred_len += audio_pred_i_len + step_time = time.time() - step_start logging.info(f"Autoregressive inference step: {i} of {max_steps} take around {step_time}s") if not incremental_audio_decoding: gen_audio_codes_lens = torch.tensor([gen_audio_codes.shape[1]] * gen_audio_codes.shape[0]).to(self.device) # decode audio. Note that it is not necessary because the prompt is removed, so no special token should be on the output, but lets do it for safety - gen_audio_codes = replace_control_speech_codes( - gen_audio_codes, self._control_codes, self.codec_silence_tokens - ) - with ensures_target_precision(self.audio_codec_run_dtype), torch.no_grad(): + gen_audio_codes = replace_control_speech_codes(gen_audio_codes, self._control_codes, self.codec_silence_tokens) + with fp32_precision(), torch.no_grad(): audio_pred, audio_pred_len = self.audio_codec.decode(gen_audio_codes, gen_audio_codes_lens) return audio_pred.squeeze(1), audio_pred_len From b30bbdb24c4e8d57e7fa86b994f88b92e39da364 Mon Sep 17 00:00:00 2001 From: Edresson Date: Fri, 14 Nov 2025 19:56:11 +0000 Subject: [PATCH 026/102] Apply isort and black reformatting Signed-off-by: Edresson Signed-off-by: Edresson Casanova --- .../speechlm2/models/duplex_ear_tts.py | 231 ++++++++++-------- 1 file changed, 129 insertions(+), 102 deletions(-) diff --git a/nemo/collections/speechlm2/models/duplex_ear_tts.py b/nemo/collections/speechlm2/models/duplex_ear_tts.py index 2b936e75ee72..a4630f55758f 100644 --- a/nemo/collections/speechlm2/models/duplex_ear_tts.py +++ b/nemo/collections/speechlm2/models/duplex_ear_tts.py @@ -563,14 +563,14 @@ def _load_language_model(self, cfg): def init_model_from_another_checkpoint(self, checkpoint_path): """ - Loads model weights and config from another checkpoint file, supporting .nemo and PyTorch formats. - - Args: - checkpoint_path (str): Path to checkpoint file. - - Returns: - None. The model is updated in-place. - """ + Loads model weights and config from another checkpoint file, supporting .nemo and PyTorch formats. + + Args: + checkpoint_path (str): Path to checkpoint file. + + Returns: + None. The model is updated in-place. + """ if checkpoint_path is not None: if '.nemo' in checkpoint_path: with tempfile.TemporaryDirectory() as tmpdir: @@ -677,7 +677,7 @@ def pad_audio_to_factor(self, audio, audio_len, samples_per_frame, downsampling_ return padded_audio, padded_len def prepare_inputs(self, batch: dict): - """ + """ Prepare inputs, extracting audio tokens and padding if needed. """ # check if audios has the same batch size @@ -990,16 +990,16 @@ def offline_inference_with_custom_sentences( def run_evaluation_one_batch(self, name, dataset_batch, use_dataloader_init=False): """ - Runs evaluation and scoring for a single data batch, logging metrics and updating result buffers. - - Args: - name (str): Name/id for the batch (for logging). - dataset_batch (dict): Batch of data inputs, supports batched text/audio/etc. - use_dataloader_init (bool, optional): If True, use dataloader initialization for prompts. - - Returns: - None. Outputs are logged and stored in result buffers. - """ + Runs evaluation and scoring for a single data batch, logging metrics and updating result buffers. + + Args: + name (str): Name/id for the batch (for logging). + dataset_batch (dict): Batch of data inputs, supports batched text/audio/etc. + use_dataloader_init (bool, optional): If True, use dataloader initialization for prompts. + + Returns: + None. Outputs are logged and stored in result buffers. + """ results = {} inputs = self.prepare_inputs(dataset_batch) @@ -1283,15 +1283,15 @@ def test_step(self, *args, **kwargs): def get_system_prompt(self, system_prompt=None, user_prompt=None): """ - Constructs a prompt message pair (system, user, assistant) formatted for chat inference. - - Args: - system_prompt (str, optional): System message describing conversational policy. - user_prompt (str, optional): User message/content. - - Returns: - torch.Tensor: Tokenized prompt IDs, shape (1, T). - """ + Constructs a prompt message pair (system, user, assistant) formatted for chat inference. + + Args: + system_prompt (str, optional): System message describing conversational policy. + user_prompt (str, optional): User message/content. + + Returns: + torch.Tensor: Tokenized prompt IDs, shape (1, T). + """ messages = [] if system_prompt is None: system_prompt = ( @@ -1338,17 +1338,17 @@ def get_system_prompt(self, system_prompt=None, user_prompt=None): def set_init_inputs(self, speaker_audio, speaker_audio_lens, system_prompt=None, user_prompt=None): """ - Registers and prepares initial input buffers for text/audio prompt and context, to warm up AR inference. - - Args: - speaker_audio (torch.Tensor): Batch of prompt audio, (B, T). - speaker_audio_lens (torch.Tensor): Lengths for each sample in speaker_audio, (B,). - system_prompt (str, optional): System prompt for context. - user_prompt (str, optional): User message for context. - - Returns: - dict: Dictionary of input tensors to be passed to inference, with registered buffers. - """ + Registers and prepares initial input buffers for text/audio prompt and context, to warm up AR inference. + + Args: + speaker_audio (torch.Tensor): Batch of prompt audio, (B, T). + speaker_audio_lens (torch.Tensor): Lengths for each sample in speaker_audio, (B,). + system_prompt (str, optional): System prompt for context. + user_prompt (str, optional): User message for context. + + Returns: + dict: Dictionary of input tensors to be passed to inference, with registered buffers. + """ # compute prompt audio size and slice it with fp32_precision(): # compute the exact number of samples for the prompt duration @@ -1391,7 +1391,7 @@ def set_init_inputs(self, speaker_audio, speaker_audio_lens, system_prompt=None, with fp32_precision(): prompt_audio_text_pad_size = int(prompt_audio_size // self.target_samples_per_frame) - # get description tokens + # get description tokens # ToDo: Consider remove the prompt description, given that NanoV2 does not support it and curently it is only a single eos text token desc_tokens_ids = self.get_system_prompt(system_prompt=system_prompt, user_prompt=user_prompt) @@ -1500,18 +1500,18 @@ def get_init_inputs( ], ): """ - Returns a dictionary of initial inputs for inference, using registered buffers. - - Args: - B (int): Required batch size. - init_inputs_names (List[str], optional): Names of input buffers to fetch. - - Returns: - dict: Each key is name from init_inputs_names, and value is tensor of appropriate shape (B, ...). - - Notes: - Expands batch-1 buffers to B if necessary. - """ + Returns a dictionary of initial inputs for inference, using registered buffers. + + Args: + B (int): Required batch size. + init_inputs_names (List[str], optional): Names of input buffers to fetch. + + Returns: + dict: Each key is name from init_inputs_names, and value is tensor of appropriate shape (B, ...). + + Notes: + Expands batch-1 buffers to B if necessary. + """ if init_inputs_names is None: init_inputs_names = [ "code", @@ -1541,25 +1541,35 @@ def get_init_inputs( return init_inputs @torch.no_grad() - def infer_codes_one_step(self, current_subword_id, prev_subword_id, current_subword_mask, prev_audio_tokens, past_key_values, guidance_enabled=True, generation_config=None, ignore_eos_flag_stop=True): + def infer_codes_one_step( + self, + current_subword_id, + prev_subword_id, + current_subword_mask, + prev_audio_tokens, + past_key_values, + guidance_enabled=True, + generation_config=None, + ignore_eos_flag_stop=True, + ): + """ + Runs a single autoregressive prediction step to infer audio codec codes. + + Args: + current_subword_id (torch.Tensor): Current text token IDs, shape (B, 1). + prev_subword_id (torch.Tensor): Previous text token IDs, shape (B, 1). + current_subword_mask (torch.Tensor): Current mask, shape (B, 1). + prev_audio_tokens (torch.Tensor): Previously generated audio tokens, shape (B, 1, C). + past_key_values: Key-value cache for transformer decoder state. + guidance_enabled (bool, optional): Enables classifier-free guidance. + generation_config (dict, optional): Generation hyperparameters. + ignore_eos_flag_stop (bool): If True, ignore EOS flag for stopping. + + Returns: + Tuple[torch.Tensor, Any]: + - Predicted audio codec token(s), shape (B, 1, C) + - Updated past_key_values for the next step. """ - Runs a single autoregressive prediction step to infer audio codec codes. - - Args: - current_subword_id (torch.Tensor): Current text token IDs, shape (B, 1). - prev_subword_id (torch.Tensor): Previous text token IDs, shape (B, 1). - current_subword_mask (torch.Tensor): Current mask, shape (B, 1). - prev_audio_tokens (torch.Tensor): Previously generated audio tokens, shape (B, 1, C). - past_key_values: Key-value cache for transformer decoder state. - guidance_enabled (bool, optional): Enables classifier-free guidance. - generation_config (dict, optional): Generation hyperparameters. - ignore_eos_flag_stop (bool): If True, ignore EOS flag for stopping. - - Returns: - Tuple[torch.Tensor, Any]: - - Predicted audio codec token(s), shape (B, 1, C) - - Updated past_key_values for the next step. - """ if self.cfg.tts_config.context_hidden_size is not None: # get context_hidden_state it is always one step behind current_subword_id @@ -1597,31 +1607,34 @@ def infer_codes_one_step(self, current_subword_id, prev_subword_id, current_subw @torch.no_grad() def decode_one_audio_step(self, gen_audio_codes_history, number_prev_tokens=None): """ - Decodes one step of generated audio codec tokens to raw waveform. - - Args: - gen_audio_codes_history (torch.Tensor): Audio tokens history, shape (B, T, C). - number_prev_tokens (int, optional): Number of previous tokens to decode, for incremental decoding. - - Returns: - Tuple[torch.Tensor, torch.Tensor]: - - audio_pred_cur_step: Latest decoded waveform chunk, shape (B, wav_to_token_ratio). - - audio_len: Lengths (number of samples), shape (B,). - """ + Decodes one step of generated audio codec tokens to raw waveform. + + Args: + gen_audio_codes_history (torch.Tensor): Audio tokens history, shape (B, T, C). + number_prev_tokens (int, optional): Number of previous tokens to decode, for incremental decoding. + + Returns: + Tuple[torch.Tensor, torch.Tensor]: + - audio_pred_cur_step: Latest decoded waveform chunk, shape (B, wav_to_token_ratio). + - audio_len: Lengths (number of samples), shape (B,). + """ with fp32_precision(), torch.no_grad(): if number_prev_tokens: gen_audio_codes_history = gen_audio_codes_history[:, -number_prev_tokens:] - - gen_audio_codes_history = replace_control_speech_codes(gen_audio_codes_history, self._control_codes, self.codec_silence_tokens) - gen_audio_codes_lens = torch.tensor([gen_audio_codes_history.size(1)]*gen_audio_codes_history.size(0), device=self.device) + + gen_audio_codes_history = replace_control_speech_codes( + gen_audio_codes_history, self._control_codes, self.codec_silence_tokens + ) + gen_audio_codes_lens = torch.tensor( + [gen_audio_codes_history.size(1)] * gen_audio_codes_history.size(0), device=self.device + ) audio_pred, audio_len = self.audio_codec.decode(gen_audio_codes_history, gen_audio_codes_lens) # return only the current/lastest audio chunk - audio_pred_cur_step = audio_pred.squeeze(1)[:, -self.audio_codec.config.wav_to_token_ratio:] + audio_pred_cur_step = audio_pred.squeeze(1)[:, -self.audio_codec.config.wav_to_token_ratio :] audio_len[:] = self.audio_codec.config.wav_to_token_ratio return audio_pred_cur_step, audio_len - @torch.no_grad() def offline_inference( self, @@ -1648,12 +1661,12 @@ def offline_inference( Dictionary of prompt-dependent initial states produced by ``get_init_inputs()``. May include: - • "code" — initial audio tokens (e.g., prompt audio) - • "audio_mask" — mask for prompt audio positions - • "context_hidden_state" — decoder hidden state at t = 0 - • "subword_ids" — prompt text tokens - • "subword_mask" — mask for prompt text - • "non_prompt_mask" — mask marking positions to be generated + • "code" — initial audio tokens (e.g., prompt audio) + • "audio_mask" — mask for prompt audio positions + • "context_hidden_state" — decoder hidden state at t = 0 + • "subword_ids" — prompt text tokens + • "subword_mask" — mask for prompt text + • "non_prompt_mask" — mask marking positions to be generated ``get_init_inputs()`` automatically expands batch-1 buffers to batch size B. @@ -1674,18 +1687,18 @@ def offline_inference( incremental_audio_decoding (bool, optional): If True, codec-to-waveform decoding is performed incrementally during - autoregressive generation. + autoregressive generation. If False, waveform decoding occurs only after all audio tokens are produced. Returns: dict[str, torch.Tensor]: Contains: - • **"audio"**: + • **"audio"**: Generated waveform of shape ``(B, T_audio)``, obtained via ``audio_pred.squeeze(1)``. - • **"audio_len"**: + • **"audio_len"**: Length of each generated waveform in samples, shape ``(B,)``. """ B = next_subword_ids.size(0) @@ -1713,7 +1726,7 @@ def offline_inference( gen_audio_codes = torch.zeros( B, max_steps, self.tts_model.config.num_quantizers, device=self.device, dtype=torch.long ) - + # init subwork as all ones subword_mask = torch.ones(B, max_steps, device=self.device, dtype=torch.bool) # get first context subword_id, that is the last subword_ids from the warmup @@ -1727,7 +1740,7 @@ def offline_inference( step_start = time.time() # current subword id is always seem current_subword_id = next_subword_ids[:, i].unsqueeze(-1) - + if i == 0: prev_subword_id = first_context_subword_id else: @@ -1736,26 +1749,40 @@ def offline_inference( # create subword_mask current_subword_mask = subword_mask[:, i].unsqueeze(-1) - code, past_key_values = self.infer_codes_one_step(current_subword_id=current_subword_id, prev_subword_id=prev_subword_id, current_subword_mask=current_subword_mask, prev_audio_tokens=code, past_key_values=past_key_values, guidance_enabled=guidance_enabled, generation_config=generation_config, ignore_eos_flag_stop=True) + code, past_key_values = self.infer_codes_one_step( + current_subword_id=current_subword_id, + prev_subword_id=prev_subword_id, + current_subword_mask=current_subword_mask, + prev_audio_tokens=code, + past_key_values=past_key_values, + guidance_enabled=guidance_enabled, + generation_config=generation_config, + ignore_eos_flag_stop=True, + ) # cache audio tokens gen_audio_codes[:, i] = code.squeeze(1) if incremental_audio_decoding: - audio_pred_i, audio_pred_i_len = self.decode_one_audio_step(gen_audio_codes[:, :i+1], number_prev_tokens=self.cfg.get("inference_codec_decoding_prev_tokens_number", None)) + audio_pred_i, audio_pred_i_len = self.decode_one_audio_step( + gen_audio_codes[:, : i + 1], + number_prev_tokens=self.cfg.get("inference_codec_decoding_prev_tokens_number", None), + ) if audio_pred is None: audio_pred = audio_pred_i else: audio_pred = torch.cat([audio_pred, audio_pred_i], dim=1) - audio_pred_len += audio_pred_i_len - + audio_pred_len += audio_pred_i_len + step_time = time.time() - step_start logging.info(f"Autoregressive inference step: {i} of {max_steps} take around {step_time}s") if not incremental_audio_decoding: gen_audio_codes_lens = torch.tensor([gen_audio_codes.shape[1]] * gen_audio_codes.shape[0]).to(self.device) # decode audio. Note that it is not necessary because the prompt is removed, so no special token should be on the output, but lets do it for safety - gen_audio_codes = replace_control_speech_codes(gen_audio_codes, self._control_codes, self.codec_silence_tokens) + gen_audio_codes = replace_control_speech_codes( + gen_audio_codes, self._control_codes, self.codec_silence_tokens + ) with fp32_precision(), torch.no_grad(): audio_pred, audio_pred_len = self.audio_codec.decode(gen_audio_codes, gen_audio_codes_lens) From 433823e1b425a71344f322d5b9b7c6762d4fa1c4 Mon Sep 17 00:00:00 2001 From: Edresson Casanova Date: Fri, 14 Nov 2025 12:07:42 -0800 Subject: [PATCH 027/102] rename codec context manager precision function Signed-off-by: Edresson Casanova --- .../speechlm2/models/duplex_ear_tts.py | 37 ++++++++++--------- 1 file changed, 20 insertions(+), 17 deletions(-) diff --git a/nemo/collections/speechlm2/models/duplex_ear_tts.py b/nemo/collections/speechlm2/models/duplex_ear_tts.py index a4630f55758f..3eec2c149252 100644 --- a/nemo/collections/speechlm2/models/duplex_ear_tts.py +++ b/nemo/collections/speechlm2/models/duplex_ear_tts.py @@ -65,16 +65,16 @@ def maybe_to(x, dtype): @contextmanager -def ensures_16_precision(mixed_dtype): +def ensures_target_precision(target_dtype): """ Workaround for precision related issues when training with bf16-true PyTorch Lightning precision setting. In bf16-true, PTL changes PyTorch's default dtype, which may break implicit assumptions for some models. This context manager restores default float32 precision and runs the computation in float32 autocast context. """ default_dtype = torch.get_default_dtype() - torch.set_default_dtype(mixed_dtype) + torch.set_default_dtype(target_dtype) try: - with torch.amp.autocast(device_type="cuda" if torch.cuda.is_available() else "cpu", dtype=mixed_dtype): + with torch.amp.autocast(device_type="cuda" if torch.cuda.is_available() else "cpu", dtype=target_dtype): yield finally: torch.set_default_dtype(default_dtype) @@ -175,7 +175,7 @@ def new_forward(*args, **kwargs): for k, v in kwargs.items() } # with torch.cuda.amp.autocast(enabled=True, dtype=mixed_dtype): - with ensures_16_precision(mixed_dtype): + with ensures_target_precision(mixed_dtype): return module._original_forward(*new_args, **new_kwargs) module.forward = new_forward @@ -315,10 +315,10 @@ def setup_rvq_audio_codec(model): Includes a workaround for PTL auto-downcasting the codec model to bf16 with bf16-true precision. """ - if hasattr(model, "audio_codec") and next(model.audio_codec.parameters()).dtype == torch.float: + if hasattr(model, "audio_codec") and next(model.audio_codec.parameters()).dtype == model.audio_codec_run_dtype: return # skip if already set up and has the right dtype - with fp32_precision(): + with ensures_target_precision(model.audio_codec_run_dtype): if model.cfg.get("pretrained_ae_dir", None): model.audio_codec = ( RVQVAEModel.from_pretrained( @@ -448,6 +448,9 @@ def __init__(self, cfg: dict) -> None: # delete llm because we use it only to get the embbeding tokens del self.language_model + # get codec run precision + self.audio_codec_run_dtype = getattr(torch, self.cfg.get("audio_codec_run_dtype", "bfloat16"), torch.float32) + # instanciate eartts model and codec self._load_tts_model(self.cfg) self._codebook_size = self.tts_model.config.codebook_size @@ -495,7 +498,7 @@ def get_codec_silence_frame_last_one(self): audio_len = torch.tensor([audio.size(-1)]).long() audio, audio_len = self.pad_audio_to_factor(audio, audio_len, self.target_samples_per_frame) - with fp32_precision(), torch.no_grad(): + with ensures_target_precision(self.audio_codec_run_dtype), torch.no_grad(): sil_codes, sil_codes_lens = self.audio_codec.encode(audio.unsqueeze(1), audio_len) return sil_codes[0, -1] @@ -507,7 +510,7 @@ def get_codec_silence_frame(self): audio_len = torch.tensor([audio.size(-1)]).long() audio, audio_len = self.pad_audio_to_factor(audio, audio_len, self.target_samples_per_frame) - with fp32_precision(), torch.no_grad(): + with ensures_target_precision(self.audio_codec_run_dtype), torch.no_grad(): sil_codes, _ = self.audio_codec.encode(audio.unsqueeze(1), audio_len) # [1, T, C] sil_codes = sil_codes[0] # [T, C] @@ -693,10 +696,10 @@ def prepare_inputs(self, batch: dict): aligned_position_ids = batch["aligned_position_ids"] # extract target audio codes - with fp32_precision(), torch.no_grad(): - target_audio, target_audio_lens = self.pad_audio_to_factor( - target_audio, target_audio_lens, self.target_samples_per_frame, 1 - ) + target_audio, target_audio_lens = self.pad_audio_to_factor( + target_audio, target_audio_lens, self.target_samples_per_frame, 1 + ) + with ensures_target_precision(self.audio_codec_run_dtype), torch.no_grad(): target_codes, target_codes_lens = self.audio_codec.encode(target_audio.unsqueeze(1), target_audio_lens) # ToDo: consider use the source audio @@ -708,8 +711,8 @@ def prepare_inputs(self, batch: dict): source_audio_lens = (source_audio_lens * (self.target_sample_rate/self.source_sample_rate)).to(lengths.dtype) # ToDo: Add a transformer encoder to help the model to better extract contextual information, replace the code bellow with it # extract embedding for context audios - with fp32_precision(), torch.no_grad(): - source_audio, source_audio_lens = self.pad_audio_to_factor(source_audio, source_audio_lens, self.target_samples_per_frame, 1) + source_audio, source_audio_lens = self.pad_audio_to_factor(source_audio, source_audio_lens, self.target_samples_per_frame, 1) + with ensures_target_precision(self.audio_codec_run_dtype), torch.no_grad(): source_codes, source_codes_lens = self.audio_codec.encode( source_audio.unsqueeze(1), source_audio_lens ) @@ -916,7 +919,7 @@ def get_teacher_force_inference_audio(self, batch, guidance_enabled=True): tf_audio_codes_pred = replace_control_speech_codes( tf_audio_codes_pred, self._control_codes, self.codec_silence_tokens ) - with fp32_precision(), torch.no_grad(): + with ensures_target_precision(self.audio_codec_run_dtype), torch.no_grad(): audio_pred, audio_len = self.audio_codec.decode(tf_audio_codes_pred, inputs["output_lens"]) return audio_pred.squeeze(1), audio_len @@ -1429,7 +1432,7 @@ def set_init_inputs(self, speaker_audio, speaker_audio_lens, system_prompt=None, target_audio_len = torch.tensor( [target_audio.size(-1)] * target_audio.size(0), dtype=torch.long, device=self.device ) - with fp32_precision(), torch.no_grad(): + with ensures_target_precision(self.audio_codec_run_dtype), torch.no_grad(): code, _ = self.audio_codec.encode(target_audio.unsqueeze(1), target_audio_len) # get context hidden @@ -1783,7 +1786,7 @@ def offline_inference( gen_audio_codes = replace_control_speech_codes( gen_audio_codes, self._control_codes, self.codec_silence_tokens ) - with fp32_precision(), torch.no_grad(): + with ensures_target_precision(self.audio_codec_run_dtype), torch.no_grad(): audio_pred, audio_pred_len = self.audio_codec.decode(gen_audio_codes, gen_audio_codes_lens) return audio_pred.squeeze(1), audio_pred_len From d9e74ab865d750fe80dafc77ccc66e9cd8689cc5 Mon Sep 17 00:00:00 2001 From: Edresson Casanova Date: Fri, 14 Nov 2025 12:19:52 -0800 Subject: [PATCH 028/102] Rename init_model_from_another_checkpoint to restore_from_pretrained_checkpoint Signed-off-by: Edresson Casanova --- nemo/collections/speechlm2/models/duplex_ear_tts.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/nemo/collections/speechlm2/models/duplex_ear_tts.py b/nemo/collections/speechlm2/models/duplex_ear_tts.py index 3eec2c149252..d0e08e4ff102 100644 --- a/nemo/collections/speechlm2/models/duplex_ear_tts.py +++ b/nemo/collections/speechlm2/models/duplex_ear_tts.py @@ -491,7 +491,7 @@ def __init__(self, cfg: dict) -> None: self._use_fsdp = False self._use_tp = False if self.cfg.get("pretrained_model", None): - self.init_model_from_another_checkpoint(self.cfg.pretrained_model) + self.restore_from_pretrained_checkpoint(self.cfg.pretrained_model) def get_codec_silence_frame_last_one(self): audio = torch.zeros(1, 10 * self.target_sample_rate).float().to(self.device) @@ -564,9 +564,9 @@ def _load_language_model(self, cfg): language_model = None return language_model - def init_model_from_another_checkpoint(self, checkpoint_path): + def restore_from_pretrained_checkpoint(self, checkpoint_path): """ - Loads model weights and config from another checkpoint file, supporting .nemo and PyTorch formats. + Loads model weights a pretrained checkpoint file, supporting partial loading from .nemo and PyTorch formats. Args: checkpoint_path (str): Path to checkpoint file. From 89fcb32fe3e0d986b91ce9ccc0f72fbddb16bb75 Mon Sep 17 00:00:00 2001 From: Edresson Casanova Date: Sun, 16 Nov 2025 09:34:44 -0800 Subject: [PATCH 029/102] Replace torchaudio with librosa Signed-off-by: Edresson Casanova --- .../speechlm2/models/duplex_ear_tts.py | 27 ++++++++++++++++--- 1 file changed, 23 insertions(+), 4 deletions(-) diff --git a/nemo/collections/speechlm2/models/duplex_ear_tts.py b/nemo/collections/speechlm2/models/duplex_ear_tts.py index d0e08e4ff102..a313be0d173c 100644 --- a/nemo/collections/speechlm2/models/duplex_ear_tts.py +++ b/nemo/collections/speechlm2/models/duplex_ear_tts.py @@ -22,7 +22,7 @@ import torch import torch.nn as nn import torch.nn.functional as F -import torchaudio +import librosa from lightning import LightningModule from omegaconf import DictConfig from peft import PeftModel @@ -55,6 +55,25 @@ from nemo.core.neural_types import AudioSignal, LabelsType, LengthsType, NeuralType from nemo.utils import logging +def load_audio_librosa(path, sr=None): + """ + Load audio using librosa with torchaudio-like behavior. + + Returns: + audio_tensor: torch.FloatTensor of shape [channels, time] + sr: sampling rate + """ + # Load with librosa (preserve original sampling rate) + audio, sr = librosa.load(path, sr=sr, mono=False) + + # Ensure shape is [channels, time] + if audio.ndim == 1: + # Mono: (time,) -> (1, time) + audio = audio[None, :] + + # Convert to torch float32 (torchaudio behavior) + audio_tensor = torch.from_numpy(audio).float() + return audio_tensor, sr def maybe_to(x, dtype): if x is None: @@ -940,7 +959,7 @@ def offline_inference_with_custom_sentences( # ToDo: split it in multiples batches to support long list of sentences B = len(test_sentences) # load and get speaker reference - speaker_audio, sr = torchaudio.load(inference_speaker_reference) + speaker_audio, sr = load_audio_librosa(inference_speaker_reference) speaker_audio = resample(speaker_audio, sr, self.target_sample_rate) speaker_audio = speaker_audio.repeat(B, 1).to(self.device) # lengths -> [B] @@ -1250,7 +1269,7 @@ def validation_step(self, batch: dict, batch_idx: int): ref_name = os.path.basename(inference_speaker_reference) # Append to each sample_id new_dataset_batch['sample_id'] = [f"{sid}_{ref_name}" for sid in dataset_batch['sample_id']] - speaker_audio, sr = torchaudio.load(inference_speaker_reference) + speaker_audio, sr = load_audio_librosa(inference_speaker_reference) speaker_audio = resample(speaker_audio, sr, self.target_sample_rate) speaker_audio = speaker_audio.repeat(B, 1).to(self.device) # lengths -> [B] @@ -1262,7 +1281,7 @@ def validation_step(self, batch: dict, batch_idx: int): # run inference for a custom speaker reference elif self.cfg.get("inference_speaker_reference", None): new_dataset_batch = copy.deepcopy(dataset_batch) - speaker_audio, sr = torchaudio.load(inference_speaker_reference) + speaker_audio, sr = load_audio_librosa(inference_speaker_reference) speaker_audio = resample(speaker_audio, sr, self.target_sample_rate) speaker_audio = speaker_audio.repeat(B, 1).to(self.device) # lengths -> [B] From 10c95ff7e5a23adb4844aa6039718aa1925a3c1c Mon Sep 17 00:00:00 2001 From: Edresson Date: Sun, 16 Nov 2025 17:35:28 +0000 Subject: [PATCH 030/102] Apply isort and black reformatting Signed-off-by: Edresson Signed-off-by: Edresson Casanova --- nemo/collections/speechlm2/models/duplex_ear_tts.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/nemo/collections/speechlm2/models/duplex_ear_tts.py b/nemo/collections/speechlm2/models/duplex_ear_tts.py index a313be0d173c..eb6cca3b2d81 100644 --- a/nemo/collections/speechlm2/models/duplex_ear_tts.py +++ b/nemo/collections/speechlm2/models/duplex_ear_tts.py @@ -19,10 +19,10 @@ from collections import Counter from contextlib import contextmanager +import librosa import torch import torch.nn as nn import torch.nn.functional as F -import librosa from lightning import LightningModule from omegaconf import DictConfig from peft import PeftModel @@ -55,10 +55,11 @@ from nemo.core.neural_types import AudioSignal, LabelsType, LengthsType, NeuralType from nemo.utils import logging + def load_audio_librosa(path, sr=None): """ Load audio using librosa with torchaudio-like behavior. - + Returns: audio_tensor: torch.FloatTensor of shape [channels, time] sr: sampling rate @@ -75,6 +76,7 @@ def load_audio_librosa(path, sr=None): audio_tensor = torch.from_numpy(audio).float() return audio_tensor, sr + def maybe_to(x, dtype): if x is None: return None From 4a451d011424dad7e6dde0ad36acd95b1c73e7b7 Mon Sep 17 00:00:00 2001 From: Edresson Casanova Date: Mon, 17 Nov 2025 03:30:34 -0800 Subject: [PATCH 031/102] Replace torchaudio with librosa on codec Signed-off-by: Edresson Casanova --- .../speechlm2/modules/rvq_ear_tts_vae.py | 21 ++++++++++--------- 1 file changed, 11 insertions(+), 10 deletions(-) diff --git a/nemo/collections/speechlm2/modules/rvq_ear_tts_vae.py b/nemo/collections/speechlm2/modules/rvq_ear_tts_vae.py index e412b82691df..192012739b3d 100644 --- a/nemo/collections/speechlm2/modules/rvq_ear_tts_vae.py +++ b/nemo/collections/speechlm2/modules/rvq_ear_tts_vae.py @@ -23,7 +23,7 @@ from omegaconf import DictConfig from torch import Tensor, nn from torch.nn import functional as F -from torchaudio import functional as ta_F +import librosa # Project from nemo.collections.speechlm2.modules.ear_tts_commons import PreTrainedModel @@ -318,7 +318,7 @@ def get_fbanks( are only computed once for a given set of parameters, improving efficiency when the function is called multiple times with the same arguments. - Note: This implementation only supports Mel filterbanks via torchaudio. + Note: This implementation only supports Mel filterbanks via librosa. Args: sample_rate (int): The sample rate of the audio. @@ -336,16 +336,17 @@ def get_fbanks( Tensor: The Mel filterbank matrix. Shape: [n_mels, n_fft // 2 + 1] """ - # Generate Mel filterbanks using torchaudio's functional API - fb = ta_F.melscale_fbanks( - n_freqs=n_fft // 2 + 1, - f_min=f_min, - f_max=f_max, + # Generate Mel filterbanks using librosa's functional API + fb = librosa.filters.mel( + sr=sample_rate, + n_fft=n_fft, n_mels=n_mels, - sample_rate=sample_rate, + fmin=f_min, + fmax=f_max, norm=norm, - mel_scale=mel_scale, - ).T # Transpose to get the shape [n_mels, n_freqs] + htk=(mel_scale == "htk"), + ) # [n_mels, n_freqs] + fb = torch.from_numpy(fb).float() return fb From ff662edc2924f9a6e5cde92804498b471a7a16d8 Mon Sep 17 00:00:00 2001 From: Edresson Date: Mon, 17 Nov 2025 11:31:28 +0000 Subject: [PATCH 032/102] Apply isort and black reformatting Signed-off-by: Edresson Signed-off-by: Edresson Casanova --- nemo/collections/speechlm2/modules/rvq_ear_tts_vae.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/nemo/collections/speechlm2/modules/rvq_ear_tts_vae.py b/nemo/collections/speechlm2/modules/rvq_ear_tts_vae.py index 192012739b3d..df0a6356b172 100644 --- a/nemo/collections/speechlm2/modules/rvq_ear_tts_vae.py +++ b/nemo/collections/speechlm2/modules/rvq_ear_tts_vae.py @@ -18,12 +18,13 @@ from contextlib import contextmanager from typing import Any, Concatenate +import librosa + # Third-party import torch from omegaconf import DictConfig from torch import Tensor, nn from torch.nn import functional as F -import librosa # Project from nemo.collections.speechlm2.modules.ear_tts_commons import PreTrainedModel @@ -345,7 +346,7 @@ def get_fbanks( fmax=f_max, norm=norm, htk=(mel_scale == "htk"), - ) # [n_mels, n_freqs] + ) # [n_mels, n_freqs] fb = torch.from_numpy(fb).float() return fb From 36b09f810b681749fc96469945434bb31966038f Mon Sep 17 00:00:00 2001 From: Edresson Casanova Date: Mon, 17 Nov 2025 05:12:38 -0800 Subject: [PATCH 033/102] Add sensitive_layers parameter Signed-off-by: Edresson Casanova --- .../speechlm2/models/duplex_ear_tts.py | 272 ++++++++++-------- 1 file changed, 158 insertions(+), 114 deletions(-) diff --git a/nemo/collections/speechlm2/models/duplex_ear_tts.py b/nemo/collections/speechlm2/models/duplex_ear_tts.py index eb6cca3b2d81..fefdc116db05 100644 --- a/nemo/collections/speechlm2/models/duplex_ear_tts.py +++ b/nemo/collections/speechlm2/models/duplex_ear_tts.py @@ -101,18 +101,26 @@ def ensures_target_precision(target_dtype): torch.set_default_dtype(default_dtype) -def make_tts_model_mixed_precision_definite( - model, inputs, mixed_dtype=torch.bfloat16, bf16_min=1e-2, bf16_max=1e2, safety_factor=1.0 -): - safe_min = bf16_min * safety_factor - safe_max = bf16_max * safety_factor +def collect_activation_stats(model: nn.Module, inputs: dict) -> dict: + """ + Collect per-layer activation statistics (min and max) for Linear, LayerNorm, and Embedding modules. - # 1️⃣ Collect activation stats in FP32 - model_fp32 = copy.deepcopy(model).eval().to(torch.float32) + This performs a forward pass in FP32 and registers hooks to record + the min and max values of each layer's output. These statistics are + used to decide which layers are safe for mixed precision. + + Args: + model (nn.Module): Model to analyze. + inputs (dict): Input arguments for the model forward pass. + + Returns: + dict: Mapping from layer names to activation stats: + {"layer_name": {"min": value, "max": value}} + """ stats = {} hooks = [] - def _activation_hook(name): + def _make_hook(name: str): def hook(_, __, out): if isinstance(out, tuple): out = out[0] @@ -121,12 +129,14 @@ def hook(_, __, out): return hook - for name, module in model_fp32.named_modules(): + # Register hooks + for name, module in model.named_modules(): if isinstance(module, (nn.Linear, nn.LayerNorm, nn.Embedding)): - hooks.append(module.register_forward_hook(_activation_hook(name))) + hooks.append(module.register_forward_hook(_make_hook(name))) + # Forward pass with torch.no_grad(): - _ = model_fp32( + _ = model( code=inputs["code"], audio_mask=maybe_to(inputs["audio_mask"], torch.float32), attention_mask=maybe_to(inputs["attention_mask"], torch.float32), @@ -137,124 +147,141 @@ def hook(_, __, out): non_prompt_mask=maybe_to(inputs["non_prompt_mask"], torch.float32), ) + # Remove hooks for h in hooks: h.remove() - # 2️⃣ Patch model for mixed precision with safe propagation - model_patched = copy.deepcopy(model).eval() - bf16_layers, fp32_layers = [], [] + return stats + - all_modules = list(model_patched.named_modules()) +def classify_precision_layers( + model: nn.Module, stats: dict, safe_min: float, safe_max: float +) -> list: + """ + Determine which layers must remain FP32 for numerical stability. - # flag to propagate FP32 to next safe layers + Sensitive layers (LayerNorm, Embedding, or Linear layers with out-of-range activations) + are forced to FP32. FP32 can propagate to the next safe layer to prevent instability. + + Args: + model (nn.Module): Model to classify. + stats (dict): Activation statistics from `collect_activation_stats`. + safe_min (float): Minimum threshold for safe activations. + safe_max (float): Maximum threshold for safe activations. + + Returns: + list: Names of layers that should remain FP32. + """ + fp32_layers = [] propagate_fp32 = False - for idx, (name, module) in enumerate(all_modules): + for name, module in model.named_modules(): if name not in stats: continue + mn, mx = stats[name]["min"], stats[name]["max"] - safe = abs(mn) < safe_max and abs(mx) < safe_max and not (abs(mn) < safe_min and abs(mx) < safe_min) + safe_range = abs(mn) < safe_max and abs(mx) < safe_max + not_tiny = not (abs(mn) < safe_min and abs(mx) < safe_min) + safe = safe_range and not_tiny - is_sensitive = False - if isinstance(module, (nn.LayerNorm, nn.Embedding)): + # Determine if layer is FP32-sensitive + is_sensitive = isinstance(module, (nn.LayerNorm, nn.Embedding)) + if isinstance(module, nn.Linear) and not safe: is_sensitive = True - elif isinstance(module, nn.Linear): - if not safe: - is_sensitive = True - # mark this layer if is_sensitive: - if name not in fp32_layers: - fp32_layers.append(name) - propagate_fp32 = True # propagate FP32 to next layers if safe + fp32_layers.append(name) + propagate_fp32 = True + elif propagate_fp32: + # Propagate FP32 to next safe layer + fp32_layers.append(name) + propagate_fp32 = False + + return fp32_layers + + +def wrap_module_precision(module: nn.Module, force_fp32: bool, mixed_dtype=torch.bfloat16): + """ + Wrap a module's forward to enforce mixed precision or FP32. + + Args: + module (nn.Module): Module to wrap. + force_fp32 (bool): If True, module runs in FP32. + mixed_dtype (torch.dtype): Target dtype for mixed precision layers. + """ + if hasattr(module, "_original_forward"): + return + + module._original_forward = module.forward + + def new_forward(*args, **kwargs): + if force_fp32: + with fp32_precision(): + return module._original_forward(*args, **kwargs) else: - if propagate_fp32: - # next layer is safe but preceded by FP32-sensitive -> still FP32 - fp32_layers.append(name) - propagate_fp32 = False # stop propagation after one safe layer - else: - # layer itself is safe and no FP32 propagation -> use BF16/FP16 - if isinstance(module, nn.Linear): - bf16_layers.append(name) + new_args = tuple( + a.to(mixed_dtype) if isinstance(a, torch.Tensor) and a.is_floating_point() else a + for a in args + ) + new_kwargs = { + k: v.to(mixed_dtype) if isinstance(v, torch.Tensor) and v.is_floating_point() else v + for k, v in kwargs.items() + } + with ensures_target_precision(mixed_dtype): + return module._original_forward(*new_args, **new_kwargs) - # 3️⃣ Wrap forwards to enforce precision - def wrap_forward(module, is_fp32_sensitive): - if hasattr(module, "_original_forward"): - return - module._original_forward = module.forward + module.forward = new_forward - def new_forward(*args, **kwargs): - if is_fp32_sensitive: - with fp32_precision(): - return module._original_forward(*args, **kwargs) - else: - new_args = tuple( - a.to(mixed_dtype) if isinstance(a, torch.Tensor) and a.is_floating_point() else a for a in args - ) - new_kwargs = { - k: v.to(mixed_dtype) if isinstance(v, torch.Tensor) and v.is_floating_point() else v - for k, v in kwargs.items() - } - # with torch.cuda.amp.autocast(enabled=True, dtype=mixed_dtype): - with ensures_target_precision(mixed_dtype): - return module._original_forward(*new_args, **new_kwargs) - module.forward = new_forward +def find_sensitive_layers( + model: nn.Module, + inputs: dict, + bf16_min: float = 1e-2, + bf16_max: float = 1e2, + safety_factor: float = 1.0, +) -> list: + """ + Identify FP32-sensitive layers for a TTS model. - for name, module in model_patched.named_modules(): - if isinstance(module, (nn.Linear, nn.LayerNorm, nn.Embedding)): - wrap_forward(module, name in fp32_layers) - - # 4️⃣ Count actual running dtype - running_dtypes = Counter() - hook_handles = [] - - def dtype_counter_hook(module, inputs, outputs): - for x in inputs: - if isinstance(x, torch.Tensor): - running_dtypes[str(x.dtype)] += 1 - outputs_list = outputs if isinstance(outputs, (tuple, list)) else [outputs] - for x in outputs_list: - if isinstance(x, torch.Tensor): - running_dtypes[str(x.dtype)] += 1 - - for name, module in model_patched.named_modules(): - if isinstance(module, (nn.Linear, nn.LayerNorm, nn.Embedding)): - hook_handles.append(module.register_forward_hook(dtype_counter_hook)) + Steps: + 1. Run FP32 forward pass to collect activation stats. + 2. Classify layers that must remain FP32. - with torch.no_grad(): - _ = model_patched( - code=inputs["code"], - audio_mask=maybe_to(inputs["audio_mask"], torch.float32), - attention_mask=maybe_to(inputs["attention_mask"], torch.float32), - position_ids=inputs["position_ids"], - context_hidden_state=maybe_to(inputs["context_hidden_state"], torch.float32), - subword_ids=inputs["subword_ids"], - subword_mask=maybe_to(inputs["subword_mask"], torch.float32), - non_prompt_mask=maybe_to(inputs["non_prompt_mask"], torch.float32), - ) + Args: + model (nn.Module): TTS model. + inputs (dict): Inputs for forward pass. + bf16_min (float): Minimum safe activation for BF16. + bf16_max (float): Maximum safe activation for BF16. + safety_factor (float): Safety factor for thresholds. - for h in hook_handles: - h.remove() + Returns: + list: Names of FP32-sensitive layers. + """ + safe_min = bf16_min * safety_factor + safe_max = bf16_max * safety_factor - num_bf16_fp16 = running_dtypes.get("torch.bfloat16", 0) + running_dtypes.get("torch.float16", 0) - num_fp32 = running_dtypes.get("torch.float32", 0) + # FP32 reference forward + model_fp32 = copy.deepcopy(model).eval().to(torch.float32) + stats = collect_activation_stats(model_fp32, inputs) - summary = { - "bf16_layers": bf16_layers, - "fp32_layers": fp32_layers, - "num_bf16_fp16": num_bf16_fp16, - "num_fp32": num_fp32, - "stats": stats, - "safe_min": safe_min, - "safe_max": safe_max, - "safety_factor": safety_factor, - } + # Identify FP32 layers + model_patched = copy.deepcopy(model).eval() + fp32_layers = classify_precision_layers(model_patched, stats, safe_min, safe_max) - print("Num. BF16/FP16 candidate layers:", len(bf16_layers)) - print("Num. FP32 layers (sensitive + propagated):", len(fp32_layers)) + # Count total relevant layers + total_layers = sum( + 1 + for _, module in model.named_modules() + if isinstance(module, (nn.Linear, nn.LayerNorm, nn.Embedding)) + ) + half_precision_layers = total_layers - len(fp32_layers) - return model_patched, summary + print( + f"Total sensitive layers (FP32): {len(fp32_layers)}, " + f"Half precision layers: {half_precision_layers}" + ) + + return fp32_layers def generate_multiturn_speaking_mask(input_ids: torch.Tensor, bos_token_id: int = 0, eos_token_id: int = 1): @@ -524,7 +551,6 @@ def get_codec_silence_frame_last_one(self): return sil_codes[0, -1] def get_codec_silence_frame(self): - from collections import Counter # Generate long zero waveform (silence) audio = torch.zeros(1, 10 * self.target_sample_rate).float().to(self.device) @@ -1012,6 +1038,24 @@ def offline_inference_with_custom_sentences( ) return audio, audio_len, speaker_audio, speaker_audio_lens + def apply_mixed_precision_wrapping_on_tts_model( + self, fp32_layers: list, mixed_dtype=torch.bfloat16 + ): + """ + Apply mixed precision to TTS model layers, keeping FP32 layers intact. + + Args: + fp32_layers (list): Names of layers to keep FP32. + mixed_dtype (torch.dtype): Target dtype for mixed precision layers. + """ + logging.info( + f"Converting TTS model to mixed precision. FP32 layers: {fp32_layers}" + ) + for name, module in self.tts_model.named_modules(): + if isinstance(module, (nn.Linear, nn.LayerNorm, nn.Embedding)): + force_fp32 = name in fp32_layers + wrap_module_precision(module, force_fp32, mixed_dtype) + def run_evaluation_one_batch(self, name, dataset_batch, use_dataloader_init=False): """ Runs evaluation and scoring for a single data batch, logging metrics and updating result buffers. @@ -1034,16 +1078,16 @@ def run_evaluation_one_batch(self, name, dataset_batch, use_dataloader_init=Fals and self.trainer_config is not None and str(self.trainer_config.precision) != str(32) ): - # ToDo: move it to a method - self.tts_model, summary = make_tts_model_mixed_precision_definite( - self.tts_model, - inputs, - safety_factor=1.0, - mixed_dtype=torch.float16 if str(self.trainer_config.precision) == str(16) else torch.bfloat16, - ) - # self.tts_model, summary = make_tts_model_mixed_precision_safe(self.tts_model, inputs, safety_factor=1.0) + if self.cfg.get("sensitive_layers", None): + self.apply_mixed_precision_wrapping_on_tts_model(self.cfg.sensitive_layers, mixed_dtype=torch.float16 if str(self.trainer_config.precision) == str(16) else torch.bfloat16,) + else: + sensitive_layers = find_sensitive_layers( + self.tts_model, + inputs, + safety_factor=1.0, + ) + self.apply_mixed_precision_wrapping_on_tts_model(sensitive_layers, mixed_dtype=torch.float16 if str(self.trainer_config.precision) == str(16) else torch.bfloat16,) self.model_16_precision_safe = True - print("Current FP32 layers:", summary["fp32_layers"]) results["audio_tf"], results["audio_tf_len"] = self.get_teacher_force_inference_audio(dataset_batch) if use_dataloader_init: From 54b5418846dc7bf103b8e9b8b3f9c3a6a2d93d7b Mon Sep 17 00:00:00 2001 From: Edresson Date: Mon, 17 Nov 2025 13:13:33 +0000 Subject: [PATCH 034/102] Apply isort and black reformatting Signed-off-by: Edresson Signed-off-by: Edresson Casanova --- .../speechlm2/models/duplex_ear_tts.py | 34 ++++++++----------- 1 file changed, 14 insertions(+), 20 deletions(-) diff --git a/nemo/collections/speechlm2/models/duplex_ear_tts.py b/nemo/collections/speechlm2/models/duplex_ear_tts.py index fefdc116db05..488a5e398e6a 100644 --- a/nemo/collections/speechlm2/models/duplex_ear_tts.py +++ b/nemo/collections/speechlm2/models/duplex_ear_tts.py @@ -154,9 +154,7 @@ def hook(_, __, out): return stats -def classify_precision_layers( - model: nn.Module, stats: dict, safe_min: float, safe_max: float -) -> list: +def classify_precision_layers(model: nn.Module, stats: dict, safe_min: float, safe_max: float) -> list: """ Determine which layers must remain FP32 for numerical stability. @@ -220,8 +218,7 @@ def new_forward(*args, **kwargs): return module._original_forward(*args, **kwargs) else: new_args = tuple( - a.to(mixed_dtype) if isinstance(a, torch.Tensor) and a.is_floating_point() else a - for a in args + a.to(mixed_dtype) if isinstance(a, torch.Tensor) and a.is_floating_point() else a for a in args ) new_kwargs = { k: v.to(mixed_dtype) if isinstance(v, torch.Tensor) and v.is_floating_point() else v @@ -270,16 +267,11 @@ def find_sensitive_layers( # Count total relevant layers total_layers = sum( - 1 - for _, module in model.named_modules() - if isinstance(module, (nn.Linear, nn.LayerNorm, nn.Embedding)) + 1 for _, module in model.named_modules() if isinstance(module, (nn.Linear, nn.LayerNorm, nn.Embedding)) ) half_precision_layers = total_layers - len(fp32_layers) - print( - f"Total sensitive layers (FP32): {len(fp32_layers)}, " - f"Half precision layers: {half_precision_layers}" - ) + print(f"Total sensitive layers (FP32): {len(fp32_layers)}, " f"Half precision layers: {half_precision_layers}") return fp32_layers @@ -1038,9 +1030,7 @@ def offline_inference_with_custom_sentences( ) return audio, audio_len, speaker_audio, speaker_audio_lens - def apply_mixed_precision_wrapping_on_tts_model( - self, fp32_layers: list, mixed_dtype=torch.bfloat16 - ): + def apply_mixed_precision_wrapping_on_tts_model(self, fp32_layers: list, mixed_dtype=torch.bfloat16): """ Apply mixed precision to TTS model layers, keeping FP32 layers intact. @@ -1048,9 +1038,7 @@ def apply_mixed_precision_wrapping_on_tts_model( fp32_layers (list): Names of layers to keep FP32. mixed_dtype (torch.dtype): Target dtype for mixed precision layers. """ - logging.info( - f"Converting TTS model to mixed precision. FP32 layers: {fp32_layers}" - ) + logging.info(f"Converting TTS model to mixed precision. FP32 layers: {fp32_layers}") for name, module in self.tts_model.named_modules(): if isinstance(module, (nn.Linear, nn.LayerNorm, nn.Embedding)): force_fp32 = name in fp32_layers @@ -1079,14 +1067,20 @@ def run_evaluation_one_batch(self, name, dataset_batch, use_dataloader_init=Fals and str(self.trainer_config.precision) != str(32) ): if self.cfg.get("sensitive_layers", None): - self.apply_mixed_precision_wrapping_on_tts_model(self.cfg.sensitive_layers, mixed_dtype=torch.float16 if str(self.trainer_config.precision) == str(16) else torch.bfloat16,) + self.apply_mixed_precision_wrapping_on_tts_model( + self.cfg.sensitive_layers, + mixed_dtype=torch.float16 if str(self.trainer_config.precision) == str(16) else torch.bfloat16, + ) else: sensitive_layers = find_sensitive_layers( self.tts_model, inputs, safety_factor=1.0, ) - self.apply_mixed_precision_wrapping_on_tts_model(sensitive_layers, mixed_dtype=torch.float16 if str(self.trainer_config.precision) == str(16) else torch.bfloat16,) + self.apply_mixed_precision_wrapping_on_tts_model( + sensitive_layers, + mixed_dtype=torch.float16 if str(self.trainer_config.precision) == str(16) else torch.bfloat16, + ) self.model_16_precision_safe = True results["audio_tf"], results["audio_tf_len"] = self.get_teacher_force_inference_audio(dataset_batch) From 71be84ab6f78802060f658915aa1d84784dc8710 Mon Sep 17 00:00:00 2001 From: Edresson Casanova Date: Mon, 17 Nov 2025 07:28:38 -0800 Subject: [PATCH 035/102] Remove torchaudio from metrics Signed-off-by: Edresson Casanova --- .../speechlm2/models/duplex_ear_tts.py | 2 +- .../speechlm2/parts/metrics/results_logger.py | 19 ++++++++++--------- 2 files changed, 11 insertions(+), 10 deletions(-) diff --git a/nemo/collections/speechlm2/models/duplex_ear_tts.py b/nemo/collections/speechlm2/models/duplex_ear_tts.py index 488a5e398e6a..130892f7a99f 100644 --- a/nemo/collections/speechlm2/models/duplex_ear_tts.py +++ b/nemo/collections/speechlm2/models/duplex_ear_tts.py @@ -1306,7 +1306,7 @@ def validation_step(self, batch: dict, batch_idx: int): new_dataset_batch = copy.deepcopy(dataset_batch) # Get only the file name - ref_name = os.path.basename(inference_speaker_reference) + ref_name = os.path.basename(inference_speaker_reference).split(".")[0] # Append to each sample_id new_dataset_batch['sample_id'] = [f"{sid}_{ref_name}" for sid in dataset_batch['sample_id']] speaker_audio, sr = load_audio_librosa(inference_speaker_reference) diff --git a/nemo/collections/speechlm2/parts/metrics/results_logger.py b/nemo/collections/speechlm2/parts/metrics/results_logger.py index 4d083ac22345..b49855a86dda 100644 --- a/nemo/collections/speechlm2/parts/metrics/results_logger.py +++ b/nemo/collections/speechlm2/parts/metrics/results_logger.py @@ -15,12 +15,12 @@ import os import shutil +import soundfile as sf import torch -import torchaudio +from nemo.collections.audio.parts.utils.resampling import resample from nemo.utils import logging - def safe_remove_path(path): shutil.rmtree(path, ignore_errors=True) @@ -66,7 +66,7 @@ def merge_and_save_audio( ) -> None: # if user_audio is None ignore it if user_audio is not None: - user_audio = torchaudio.functional.resample(user_audio.float(), user_audio_sr, pred_audio_sr) + user_audio = resample(user_audio.float(), user_audio_sr, pred_audio_sr) T1, T2 = pred_audio.shape[0], user_audio.shape[0] max_len = max(T1, T2) pred_audio_padded = torch.nn.functional.pad(pred_audio, (0, max_len - T1), mode='constant', value=0) @@ -85,7 +85,8 @@ def merge_and_save_audio( combined_wav = pred_audio.unsqueeze(0).detach().cpu() # save audio - torchaudio.save(out_audio_path, combined_wav, pred_audio_sr) + os.makedirs(os.path.dirname(out_audio_path), exist_ok=True) + sf.write(out_audio_path, combined_wav.numpy().astype('float32').T, pred_audio_sr) logging.info(f"Audio saved at: {out_audio_path}") def update( @@ -152,18 +153,18 @@ def update( ) # (B, T, repeat_factor) eou_pred_wav = eou_pred_wav.view(1, -1) # (B, T * repeat_factor) eou_pred_wav = eou_pred_wav.float() * 0.8 # make 1 audible and keep 0 as total silence - torchaudio.save(out_audio_path_eou, eou_pred_wav.squeeze().unsqueeze(0).detach().cpu(), pred_audio_sr) + sf.write(out_audio_path_eou, eou_pred_wav.squeeze().unsqueeze(0).detach().cpu().numpy().astype('float32').T, pred_audio_sr) if pre_audio_trimmed is not None: out_audio_path_trimmed = os.path.join(self.audio_save_path, f"{name}_{sample_id}_pred_trimmed.wav") - torchaudio.save( - out_audio_path_trimmed, pre_audio_trimmed[i].squeeze().unsqueeze(0).detach().cpu(), pred_audio_sr + sf.write( + out_audio_path_trimmed, pre_audio_trimmed[i].squeeze().unsqueeze(0).detach().cpu().numpy().astype('float32').T, pred_audio_sr ) if reference_audio is not None: out_audio_path_ref = os.path.join(self.audio_save_path, f"{name}_{sample_id}_spk_reference.wav") - torchaudio.save( - out_audio_path_ref, reference_audio[i].squeeze().unsqueeze(0).detach().cpu(), pred_audio_sr + sf.write( + out_audio_path_ref, reference_audio[i].squeeze().unsqueeze(0).detach().cpu().numpy().astype('float32').T, pred_audio_sr ) # cache metadata From 65c084d7ca7d8c58d99d3a4d851af394669989ac Mon Sep 17 00:00:00 2001 From: Edresson Casanova Date: Mon, 17 Nov 2025 07:29:38 -0800 Subject: [PATCH 036/102] Update lhotse formmaters Signed-off-by: Edresson Casanova --- nemo/collections/common/data/lhotse/cutset.py | 226 ++++++++---------- 1 file changed, 100 insertions(+), 126 deletions(-) diff --git a/nemo/collections/common/data/lhotse/cutset.py b/nemo/collections/common/data/lhotse/cutset.py index 9b40974106b9..228d0230349c 100644 --- a/nemo/collections/common/data/lhotse/cutset.py +++ b/nemo/collections/common/data/lhotse/cutset.py @@ -540,23 +540,40 @@ def cut_to_conversation( ) + @data_type_parser(["s2s_duplex_overlap_as_s2s_duplex"]) -def read_s2s_duplex_overlap_as_s2s_duplex(config) -> tuple[CutSet, bool]: - def filter_cuts_starting_with_agent(cuts: CutSet, agent_roles=("agent", "assistant", "Assistant")) -> CutSet: - def filter_cut_fn(cut): - # sort supervisions by start - cut.supervisions = sorted(cut.supervisions, key=lambda s: s.start) - if len(cut.supervisions): - return cut.supervisions[0].speaker not in agent_roles - else: - return False # filter emptly supervisions +def read_s2s_duplex_overlap_as_s2s_duplex(config) -> Tuple[CutSet, bool]: + """ + Convert a CutSet with overlapping agent/user segments into a standard S2S duplex format. + + Args: + config: Dictionary containing parser options: + - move_agent_text_back_by (float): Time offset to shift agent text back. + - filter_samples_starting_with_agent (bool): Whether to remove samples starting with agent. + - agent_roles (List[str]): Roles considered as agent. + + Returns: + Tuple[CutSet, bool]: Converted cuts and a flag indicating if the data was tarred. + """ + move_agent_text_back_by = config.get("move_agent_text_back_by", 0) + filter_samples_starting_with_agent = config.get("filter_samples_starting_with_agent", False) + agent_roles = config.get("agent_roles", ["agent", "Assistant", "assistant"]) - return cuts.filter(filter_cut_fn) + cuts, is_tarred = read_cutset_from_config(config) + + def filter_cuts_starting_with_agent_fn(cuts: CutSet, agent_roles: Tuple[str, ...]) -> CutSet: + """Remove cuts where the first supervision belongs to an agent role.""" + def _filter_fn(cut: Cut) -> bool: + if not cut.supervisions: + return False + cut.supervisions = sorted(cut.supervisions, key=lambda s: s.start) + return cut.supervisions[0].speaker not in agent_roles + return cuts.filter(_filter_fn) - def convert_overlap_cut(cut): - agent_segments = [] - for seg in cut.agent_segments: - ss = SupervisionSegment( + def convert_overlap_cut_fn(cut: Cut) -> Cut: + """Convert agent/user overlapping segments into sequential SupervisionSegments.""" + agent_segments = [ + SupervisionSegment( id=cut.id, recording_id=cut.id, start=seg["start"] - move_agent_text_back_by, @@ -564,11 +581,11 @@ def convert_overlap_cut(cut): text=seg["text"], speaker="agent", ) - agent_segments.append(ss) + for seg in cut.agent_segments + ] - user_segments = [] - for seg in cut.user_segments: - ss = SupervisionSegment( + user_segments = [ + SupervisionSegment( id=cut.id, recording_id=cut.id, start=seg["start"], @@ -576,176 +593,133 @@ def convert_overlap_cut(cut): text=seg["text"], speaker="user", ) - user_segments.append(ss) + for seg in cut.user_segments + ] cut.supervisions = sorted(agent_segments + user_segments, key=lambda s: s.start) cut.formatter = "s2s_duplex_overlap_as_s2s_duplex" return cut - # load lhotse cuts - cuts, is_tarred = read_cutset_from_config(config) - move_agent_text_back_by = config.get("move_agent_text_back_by", 0) - filter_samples_starting_with_agent = config.get("filter_samples_starting_with_agent", False) - agent_roles = config.get("agent_roles", ["agent", "Assistant", "assistant"]) - - # convert cuts - cuts = cuts.map(convert_overlap_cut) - - # Filter cuts where the first supervision is agent + cuts = cuts.map(convert_overlap_cut_fn) if filter_samples_starting_with_agent: - cuts = filter_cuts_starting_with_agent(cuts, agent_roles) + cuts = filter_cuts_starting_with_agent_fn(cuts, tuple(agent_roles)) return cuts, is_tarred @data_type_parser(["lhotse_magpietts_data_as_continuation"]) -def read_lhotse_magpietts_data_as_continuation(config) -> tuple[CutSet, bool]: +def read_lhotse_magpietts_data_as_continuation(config) -> Tuple[CutSet, bool]: + """ + Convert MagpieTTS dataset cuts to a continuation format suitable for S2S training. + + Args: + config: Dictionary containing parser options: + - add_extra_end_silence (bool): Whether to add extra silence at the end. + - extra_end_silence_range (List[float]): Range of extra silence duration. + - max_cer (float): Maximum allowed character error rate. + - min_context_speaker_similarity (float): Minimum similarity score. + - target_speaker (str, optional): Target speaker filter. + - sample_rate (int): Audio sample rate for resampling. + + Returns: + Tuple[CutSet, bool]: Converted cuts and a flag indicating if data was tarred. + """ + cuts, is_tarred = read_cutset_from_config(config) + + add_extra_end_sil = config.get("add_extra_end_silence", False) + extra_end_silence_range = config.get("extra_end_silence_range", [0.5, 6.0]) + sample_rate = config.get("sample_rate", 22050) + + max_cer = config.get("max_cer", 0.03) + min_context_speaker_similarity = config.get("min_context_speaker_similarity", 0.6) + target_speaker = config.get("target_speaker", None) + keep_flag = "pass" + def create_recording_from_array(samples: np.ndarray, sampling_rate: int, recording_id: str) -> Recording: + """Convert a numpy array into a Lhotse Recording object.""" with io.BytesIO() as buffer: sf.write(buffer, samples.T, samplerate=sampling_rate, format='WAV') buffer.seek(0) return Recording.from_bytes(buffer.read(), recording_id=recording_id) - def convert_lhotse_magpietts_data_as_cont(cut): - # create a copy of agent supervision and original duration + def convert_cut_fn(cut: Cut) -> Cut: + """Convert a single cut into the continuation format.""" orig_agent_sup = fastcopy(cut.supervisions[0]) - target_audio_org_dur = cut.target_audio.duration + target_audio_orig_dur = cut.target_audio.duration - # Resample both to match sample_rate + # Resample audios cut.target_audio = cut.target_audio.resample(sample_rate) cut.context_audio = cut.context_audio.resample(sample_rate) - - # Compute total duration total_duration = cut.target_audio.duration - # Convert target_audio (Recording) into MonoCut so we can pad it + # Prepare MonoCuts cut_target = MonoCut( id=f"{cut.id}_target", start=0.0, - duration=cut.target_audio.duration, + duration=total_duration, channel=0, recording=cut.target_audio, supervisions=[], ) - # create silence audio - num_samples = int(total_duration * sample_rate) - zero_audio = np.zeros((1, num_samples), dtype=np.float32) - source_recording = create_recording_from_array( - zero_audio, - sampling_rate=sample_rate, - recording_id=f"{cut.id}_source", - ) + zero_audio = np.zeros((1, int(total_duration * sample_rate)), dtype=np.float32) + source_recording = create_recording_from_array(zero_audio, sample_rate, recording_id=f"{cut.id}_source") cut_source = MonoCut( id=f"{cut.id}_source", start=0.0, - duration=cut.target_audio.duration, + duration=total_duration, channel=0, recording=source_recording, supervisions=[], ) - # Save both to memory + # Save to memory cut_source = cut_source.move_to_memory(audio_format='wav') cut_target = cut_target.move_to_memory(audio_format='wav') - # user starts on zeros with dummy text - user_sup = fastcopy( - orig_agent_sup, - start=0.0, - duration=0.08, # keep only on frame to the user - speaker="user", - text="dummy text", - ) - # agent starts when user turn finish and has target_audio_dur - agent_sup = fastcopy( - orig_agent_sup, - start=0.0, - duration=target_audio_org_dur - 0.08, - speaker="agent", - ) + # Create user and agent supervisions + user_sup = fastcopy(orig_agent_sup, start=0.0, duration=0.08, speaker="user", text="dummy text") + agent_sup = fastcopy(orig_agent_sup, start=0.0, duration=target_audio_orig_dur - 0.08, speaker="agent") - # Add extra sil in the end of the audio to force the model to produce silence if it receives zeros and the was all processed - if ADD_EXTRA_END_SIL: - sil_duration = random.uniform(*SILENCE_RANGE) - # pad audios + # Optionally add extra silence + if add_extra_end_sil: + sil_duration = random.uniform(*extra_end_silence_range) cut_target = cut_target.pad(duration=total_duration + sil_duration, direction="right") cut_source = cut_source.pad(duration=total_duration + sil_duration, direction="right") - # Save both to memory cut_source = cut_source.to_mono().move_to_memory(audio_format='wav') cut_target = cut_target.to_mono().move_to_memory(audio_format='wav') - agent_sup.duration = ( - agent_sup.duration + sil_duration + 1.0 - ) # added here 1.0 seconds to not have text EOS for this dataset to avoid conflicts with S2S, text EOS is the interruption token on duplex - user_sup.duration = user_sup.duration + sil_duration + agent_sup.duration += sil_duration + 1.0 + user_sup.duration += sil_duration # Assemble final cut cut_source.supervisions = [user_sup, agent_sup] - cut_source.recording = cut_source.recording # remains the resampled context_audio cut_source.target_audio = cut_target.recording cut_source.duration = cut_target.duration - cut_source.formatter = "lhotse_magpietts_data_as_continuation" cut_source.context_audio = cut.context_audio - return cut_source + cut_source.formatter = "lhotse_magpietts_data_as_continuation" - def filter_cer(example): - if isinstance(example, Cut) and len(example.supervisions) > 0 and example.supervisions[0].has_custom("cer"): - return example.supervisions[0].cer <= MAX_CER - else: - return True - - def filter_val_flag(example): - if ( - isinstance(example, Cut) - and example.has_custom("validation_status") - and example.validation_status != KEEP_FLAG - ): - return False - else: - return True - - def filter_secs(example): - if ( - isinstance(example, Cut) - and len(example.supervisions) > 0 - and example.supervisions[0].has_custom("context_speaker_similarity") - ): - return example.supervisions[0].context_speaker_similarity >= MIN_SECS - else: - return True + return cut_source - def filter_target_speaker(example): - if isinstance(example, Cut) and len(example.supervisions) > 0 and TARGET_SPEAKER is not None: - return TARGET_SPEAKER in example.supervisions[0].speaker - else: - return True + # Filters + def filter_cer_fn(cut: Cut) -> bool: + return len(cut.supervisions) == 0 or not cut.supervisions[0].has_custom("cer") or cut.supervisions[0].cer <= max_cer - # load lhotse cuts - cuts, is_tarred = read_cutset_from_config(config) + def filter_val_flag_fn(cut: Cut) -> bool: + return not cut.has_custom("validation_status") or cut.validation_status == keep_flag - ADD_EXTRA_END_SIL = config.get("add_extra_end_silence", False) - SILENCE_RANGE = config.get("extra_end_silence_range", [0.5, 6.0]) + def filter_secs_fn(cut: Cut) -> bool: + return len(cut.supervisions) == 0 or not cut.supervisions[0].has_custom("context_speaker_similarity") or cut.supervisions[0].context_speaker_similarity >= min_context_speaker_similarity - # load prompt cut - sample_rate = 22050 + def filter_target_speaker_fn(cut: Cut) -> bool: + return len(cut.supervisions) == 0 or target_speaker is None or target_speaker in cut.supervisions[0].speaker - # filter dataset - MAX_CER = config.get("max_cer", 0.03) - cuts = cuts.filter(filter_cer) - # filter invalid samples - KEEP_FLAG = "pass" - cuts = cuts.filter(filter_val_flag) - # filter based on context speaker similarity - MIN_SECS = config.get("min_context_speaker_similarity", 0.6) - cuts = cuts.filter(filter_secs) + # Apply filters + cuts = cuts.filter(filter_cer_fn).filter(filter_val_flag_fn).filter(filter_secs_fn).filter(filter_target_speaker_fn) - # filter speaker - TARGET_SPEAKER = config.get("target_speaker", None) - cuts = cuts.filter(filter_target_speaker) + # Convert cuts + cuts = cuts.map(convert_cut_fn) - # convert cuts - cuts = cuts.map(convert_lhotse_magpietts_data_as_cont) return cuts, is_tarred From 42729680c76b69db604958cf50fd2efb17a23cf4 Mon Sep 17 00:00:00 2001 From: Edresson Date: Mon, 17 Nov 2025 15:32:30 +0000 Subject: [PATCH 037/102] Apply isort and black reformatting Signed-off-by: Edresson Signed-off-by: Edresson Casanova --- nemo/collections/common/data/lhotse/cutset.py | 19 +++++++++++++++---- .../speechlm2/parts/metrics/results_logger.py | 15 ++++++++++++--- 2 files changed, 27 insertions(+), 7 deletions(-) diff --git a/nemo/collections/common/data/lhotse/cutset.py b/nemo/collections/common/data/lhotse/cutset.py index 228d0230349c..22b5b453ed4b 100644 --- a/nemo/collections/common/data/lhotse/cutset.py +++ b/nemo/collections/common/data/lhotse/cutset.py @@ -540,7 +540,6 @@ def cut_to_conversation( ) - @data_type_parser(["s2s_duplex_overlap_as_s2s_duplex"]) def read_s2s_duplex_overlap_as_s2s_duplex(config) -> Tuple[CutSet, bool]: """ @@ -563,11 +562,13 @@ def read_s2s_duplex_overlap_as_s2s_duplex(config) -> Tuple[CutSet, bool]: def filter_cuts_starting_with_agent_fn(cuts: CutSet, agent_roles: Tuple[str, ...]) -> CutSet: """Remove cuts where the first supervision belongs to an agent role.""" + def _filter_fn(cut: Cut) -> bool: if not cut.supervisions: return False cut.supervisions = sorted(cut.supervisions, key=lambda s: s.start) return cut.supervisions[0].speaker not in agent_roles + return cuts.filter(_filter_fn) def convert_overlap_cut_fn(cut: Cut) -> Cut: @@ -703,19 +704,29 @@ def convert_cut_fn(cut: Cut) -> Cut: # Filters def filter_cer_fn(cut: Cut) -> bool: - return len(cut.supervisions) == 0 or not cut.supervisions[0].has_custom("cer") or cut.supervisions[0].cer <= max_cer + return ( + len(cut.supervisions) == 0 + or not cut.supervisions[0].has_custom("cer") + or cut.supervisions[0].cer <= max_cer + ) def filter_val_flag_fn(cut: Cut) -> bool: return not cut.has_custom("validation_status") or cut.validation_status == keep_flag def filter_secs_fn(cut: Cut) -> bool: - return len(cut.supervisions) == 0 or not cut.supervisions[0].has_custom("context_speaker_similarity") or cut.supervisions[0].context_speaker_similarity >= min_context_speaker_similarity + return ( + len(cut.supervisions) == 0 + or not cut.supervisions[0].has_custom("context_speaker_similarity") + or cut.supervisions[0].context_speaker_similarity >= min_context_speaker_similarity + ) def filter_target_speaker_fn(cut: Cut) -> bool: return len(cut.supervisions) == 0 or target_speaker is None or target_speaker in cut.supervisions[0].speaker # Apply filters - cuts = cuts.filter(filter_cer_fn).filter(filter_val_flag_fn).filter(filter_secs_fn).filter(filter_target_speaker_fn) + cuts = ( + cuts.filter(filter_cer_fn).filter(filter_val_flag_fn).filter(filter_secs_fn).filter(filter_target_speaker_fn) + ) # Convert cuts cuts = cuts.map(convert_cut_fn) diff --git a/nemo/collections/speechlm2/parts/metrics/results_logger.py b/nemo/collections/speechlm2/parts/metrics/results_logger.py index b49855a86dda..0caed8bb4601 100644 --- a/nemo/collections/speechlm2/parts/metrics/results_logger.py +++ b/nemo/collections/speechlm2/parts/metrics/results_logger.py @@ -21,6 +21,7 @@ from nemo.collections.audio.parts.utils.resampling import resample from nemo.utils import logging + def safe_remove_path(path): shutil.rmtree(path, ignore_errors=True) @@ -153,18 +154,26 @@ def update( ) # (B, T, repeat_factor) eou_pred_wav = eou_pred_wav.view(1, -1) # (B, T * repeat_factor) eou_pred_wav = eou_pred_wav.float() * 0.8 # make 1 audible and keep 0 as total silence - sf.write(out_audio_path_eou, eou_pred_wav.squeeze().unsqueeze(0).detach().cpu().numpy().astype('float32').T, pred_audio_sr) + sf.write( + out_audio_path_eou, + eou_pred_wav.squeeze().unsqueeze(0).detach().cpu().numpy().astype('float32').T, + pred_audio_sr, + ) if pre_audio_trimmed is not None: out_audio_path_trimmed = os.path.join(self.audio_save_path, f"{name}_{sample_id}_pred_trimmed.wav") sf.write( - out_audio_path_trimmed, pre_audio_trimmed[i].squeeze().unsqueeze(0).detach().cpu().numpy().astype('float32').T, pred_audio_sr + out_audio_path_trimmed, + pre_audio_trimmed[i].squeeze().unsqueeze(0).detach().cpu().numpy().astype('float32').T, + pred_audio_sr, ) if reference_audio is not None: out_audio_path_ref = os.path.join(self.audio_save_path, f"{name}_{sample_id}_spk_reference.wav") sf.write( - out_audio_path_ref, reference_audio[i].squeeze().unsqueeze(0).detach().cpu().numpy().astype('float32').T, pred_audio_sr + out_audio_path_ref, + reference_audio[i].squeeze().unsqueeze(0).detach().cpu().numpy().astype('float32').T, + pred_audio_sr, ) # cache metadata From 7b71f419945c5d4a3862ecc548241c8d51db2517 Mon Sep 17 00:00:00 2001 From: Edresson Casanova Date: Mon, 17 Nov 2025 07:37:08 -0800 Subject: [PATCH 038/102] Docs for set_model_dict_for_partial_init Signed-off-by: Edresson Casanova --- .../collections/speechlm2/parts/pretrained.py | 43 ++++++++++++++++--- 1 file changed, 37 insertions(+), 6 deletions(-) diff --git a/nemo/collections/speechlm2/parts/pretrained.py b/nemo/collections/speechlm2/parts/pretrained.py index b8042718bd50..59ec80fdcdd8 100644 --- a/nemo/collections/speechlm2/parts/pretrained.py +++ b/nemo/collections/speechlm2/parts/pretrained.py @@ -25,6 +25,7 @@ from nemo.collections.speechlm2.parts.precision import fp32_precision from nemo.collections.tts.models import AudioCodecModel from nemo.utils import logging +from typing import Dict, Any def load_pretrained_nemo(cls, model_path_or_name: str): @@ -102,15 +103,45 @@ def setup_speech_encoder(model: torch.nn.Module, pretrained_weights: bool = True model.perception = AudioPerceptionModule(model.cfg.perception).train() -def set_model_dict_for_partial_init(pretrained_dict, model_dict): - # 1. filter out different size layers +def set_model_dict_for_partial_init(pretrained_dict: Dict[str, torch.Tensor], + model_dict: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]: + """ + Partially initialize a model's state dictionary with a pretrained state dictionary. + This function safely copies compatible layers from a pretrained model into a new model, + ignoring layers with mismatched shapes or missing keys. + + Steps: + 1. Remove layers from the pretrained dictionary if their shape does not match the target model. + 2. Keep only keys that exist in the target model. + 3. Update the model dictionary with the filtered pretrained weights. + + Args: + pretrained_dict (Dict[str, torch.Tensor]): + The state dictionary of the pretrained model. + model_dict (Dict[str, torch.Tensor]): + The state dictionary of the target model to be partially initialized. + + Returns: + Dict[str, torch.Tensor]: + The updated model state dictionary with compatible layers loaded from the pretrained dictionary. + + Example: + >>> model_dict = model.state_dict() + >>> pretrained_dict = torch.load("pretrained_model.pt") + >>> model_dict = set_model_dict_for_partial_init(pretrained_dict, model_dict) + >>> model.load_state_dict(model_dict) + """ + # 1. Remove layers where pretrained shape differs from model shape for k, v in list(pretrained_dict.items()): if k in model_dict and hasattr(model_dict[k], "numel") and v.numel() != model_dict[k].numel(): del pretrained_dict[k] - logging.info(" | > Layer with shape mismatach in the model definition: {}".format(k)) - # 2. filter out unnecessary keys + logging.info(f" | > Layer with shape mismatch in the model definition: {k}") + + # 2. Keep only keys that exist in the target model pretrained_dict = {k: v for k, v in pretrained_dict.items() if k in model_dict} - # 3. overwrite entries in the existing state dict + + # 3. Update model dictionary with filtered pretrained layers model_dict.update(pretrained_dict) - logging.info(" | > {} / {} layers are restored.".format(len(pretrained_dict), len(model_dict))) + logging.info(f" | > {len(pretrained_dict)} / {len(model_dict)} layers are restored.") + return model_dict From e6d597713a1caa0f90d62bbbf6edd5e479cfa111 Mon Sep 17 00:00:00 2001 From: Edresson Date: Mon, 17 Nov 2025 15:38:04 +0000 Subject: [PATCH 039/102] Apply isort and black reformatting Signed-off-by: Edresson Signed-off-by: Edresson Casanova --- nemo/collections/speechlm2/parts/pretrained.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/nemo/collections/speechlm2/parts/pretrained.py b/nemo/collections/speechlm2/parts/pretrained.py index 59ec80fdcdd8..ba2fa886b2e6 100644 --- a/nemo/collections/speechlm2/parts/pretrained.py +++ b/nemo/collections/speechlm2/parts/pretrained.py @@ -13,6 +13,7 @@ # limitations under the License. from contextlib import contextmanager from pathlib import Path +from typing import Any, Dict import torch from omegaconf import open_dict @@ -21,11 +22,9 @@ from nemo.collections.asr.models import ASRModel from nemo.collections.speechlm2.modules import AudioPerceptionModule - from nemo.collections.speechlm2.parts.precision import fp32_precision from nemo.collections.tts.models import AudioCodecModel from nemo.utils import logging -from typing import Dict, Any def load_pretrained_nemo(cls, model_path_or_name: str): @@ -103,10 +102,11 @@ def setup_speech_encoder(model: torch.nn.Module, pretrained_weights: bool = True model.perception = AudioPerceptionModule(model.cfg.perception).train() -def set_model_dict_for_partial_init(pretrained_dict: Dict[str, torch.Tensor], - model_dict: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]: +def set_model_dict_for_partial_init( + pretrained_dict: Dict[str, torch.Tensor], model_dict: Dict[str, torch.Tensor] +) -> Dict[str, torch.Tensor]: """ - Partially initialize a model's state dictionary with a pretrained state dictionary. + Partially initialize a model's state dictionary with a pretrained state dictionary. This function safely copies compatible layers from a pretrained model into a new model, ignoring layers with mismatched shapes or missing keys. From 41ce58a3a61d88d4cb9b97d8f29584aa090a4eda Mon Sep 17 00:00:00 2001 From: Edresson Casanova Date: Mon, 17 Nov 2025 07:46:26 -0800 Subject: [PATCH 040/102] Make codec sil tokens a buffer Signed-off-by: Edresson Casanova --- nemo/collections/speechlm2/models/duplex_ear_tts.py | 3 ++- nemo/collections/speechlm2/parts/pretrained.py | 1 + 2 files changed, 3 insertions(+), 1 deletion(-) diff --git a/nemo/collections/speechlm2/models/duplex_ear_tts.py b/nemo/collections/speechlm2/models/duplex_ear_tts.py index 130892f7a99f..6a8401c9df5f 100644 --- a/nemo/collections/speechlm2/models/duplex_ear_tts.py +++ b/nemo/collections/speechlm2/models/duplex_ear_tts.py @@ -502,7 +502,8 @@ def __init__(self, cfg: dict) -> None: self.source_samples_per_frame = int(self.source_sample_rate // self.source_fps) # get codec silence tokens - self.codec_silence_tokens = self.get_codec_silence_frame() + codec_silence_tokens = self.get_codec_silence_frame() + self.register_buffer("codec_silence_tokens", codec_silence_tokens) # Load tokenizer self.tokenizer = AutoTokenizer( diff --git a/nemo/collections/speechlm2/parts/pretrained.py b/nemo/collections/speechlm2/parts/pretrained.py index ba2fa886b2e6..8861e9704e1f 100644 --- a/nemo/collections/speechlm2/parts/pretrained.py +++ b/nemo/collections/speechlm2/parts/pretrained.py @@ -25,6 +25,7 @@ from nemo.collections.speechlm2.parts.precision import fp32_precision from nemo.collections.tts.models import AudioCodecModel from nemo.utils import logging +from typing import Dict def load_pretrained_nemo(cls, model_path_or_name: str): From 2aa4e3b904ef15c92879de770a63e6f5deb2370b Mon Sep 17 00:00:00 2001 From: Edresson Date: Mon, 17 Nov 2025 15:48:02 +0000 Subject: [PATCH 041/102] Apply isort and black reformatting Signed-off-by: Edresson Signed-off-by: Edresson Casanova --- nemo/collections/speechlm2/parts/pretrained.py | 1 - 1 file changed, 1 deletion(-) diff --git a/nemo/collections/speechlm2/parts/pretrained.py b/nemo/collections/speechlm2/parts/pretrained.py index 8861e9704e1f..ba2fa886b2e6 100644 --- a/nemo/collections/speechlm2/parts/pretrained.py +++ b/nemo/collections/speechlm2/parts/pretrained.py @@ -25,7 +25,6 @@ from nemo.collections.speechlm2.parts.precision import fp32_precision from nemo.collections.tts.models import AudioCodecModel from nemo.utils import logging -from typing import Dict def load_pretrained_nemo(cls, model_path_or_name: str): From 1c3d1a6a6d6c0d09b356e1347eeaa46bd1f1b80a Mon Sep 17 00:00:00 2001 From: Edresson Casanova Date: Mon, 17 Nov 2025 07:51:35 -0800 Subject: [PATCH 042/102] Fix lint Signed-off-by: Edresson Casanova --- nemo/collections/speechlm2/parts/pretrained.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/nemo/collections/speechlm2/parts/pretrained.py b/nemo/collections/speechlm2/parts/pretrained.py index ba2fa886b2e6..8aa3645fbeff 100644 --- a/nemo/collections/speechlm2/parts/pretrained.py +++ b/nemo/collections/speechlm2/parts/pretrained.py @@ -13,7 +13,7 @@ # limitations under the License. from contextlib import contextmanager from pathlib import Path -from typing import Any, Dict +from typing import Dict import torch from omegaconf import open_dict From fbf853f667cd094224d0a2e318f1c9d829375246 Mon Sep 17 00:00:00 2001 From: Edresson Casanova Date: Mon, 17 Nov 2025 09:22:45 -0800 Subject: [PATCH 043/102] Add CER/WER metrics unit test Signed-off-by: Edresson Casanova --- .../speechlm2/parts/metrics/__init__.py | 2 ++ .../speechlm2/parts/metrics/intelligibility.py | 2 +- .../collections/speechlm2/test_duplex_eartts.py | 1 - tests/collections/speechlm2/test_metrics.py | 16 +++++++++++----- 4 files changed, 14 insertions(+), 7 deletions(-) diff --git a/nemo/collections/speechlm2/parts/metrics/__init__.py b/nemo/collections/speechlm2/parts/metrics/__init__.py index ca40c4cff5cd..73e6e050f731 100644 --- a/nemo/collections/speechlm2/parts/metrics/__init__.py +++ b/nemo/collections/speechlm2/parts/metrics/__init__.py @@ -15,10 +15,12 @@ from .bleu import BLEU from .results_logger import ResultsLogger from .token_accuracy import TokenAccuracy +from .intelligibility import Intelligibility __all__ = [ 'ASRBLEU', 'BLEU', 'TokenAccuracy', 'ResultsLogger', + 'Intelligibility', ] diff --git a/nemo/collections/speechlm2/parts/metrics/intelligibility.py b/nemo/collections/speechlm2/parts/metrics/intelligibility.py index 8e00f71b6625..15823dc86647 100644 --- a/nemo/collections/speechlm2/parts/metrics/intelligibility.py +++ b/nemo/collections/speechlm2/parts/metrics/intelligibility.py @@ -74,7 +74,7 @@ def update( if self.asr is None and not self.reuse_asr_hyps: self.reset() - if pred_audio_lens is None: + if pred_audio_lens is None and pred_audio is not None: pred_audio_lens = [pred_audio.shape[1]] * pred_audio.shape[0] with fp32_precision(): diff --git a/tests/collections/speechlm2/test_duplex_eartts.py b/tests/collections/speechlm2/test_duplex_eartts.py index 511d7a021951..556ad8882d94 100644 --- a/tests/collections/speechlm2/test_duplex_eartts.py +++ b/tests/collections/speechlm2/test_duplex_eartts.py @@ -247,7 +247,6 @@ def training_cutset_batch(): def test_eartts_dataset(dataset, training_cutset_batch): - print(training_cutset_batch) batch = dataset[training_cutset_batch] # Keys that must be present in batch expected_keys = { diff --git a/tests/collections/speechlm2/test_metrics.py b/tests/collections/speechlm2/test_metrics.py index 4972a7418fa2..9ae67f1b43c5 100644 --- a/tests/collections/speechlm2/test_metrics.py +++ b/tests/collections/speechlm2/test_metrics.py @@ -13,7 +13,7 @@ # limitations under the License. import pytest -from nemo.collections.speechlm2.parts.metrics import BLEU, WER +from nemo.collections.speechlm2.parts.metrics import BLEU, Intelligibility def test_bleu(): @@ -34,19 +34,25 @@ def test_bleu(): assert ans["txt_bleu"] == 50.0 # average across datasets -def test_wer(): - metric = WER(verbose=False) +def test_intelligibility(): + metric = Intelligibility(pretrained_asr=None, verbose=False, reuse_asr_hyps=True) metric.update( name="dataset_1", refs=["a b c d e f g h i j k l", "m n o p r s t u v"], - hyps=["a b c d e f g h i j k l", "m n o p r s t u v"], + asr_hyps=["a b c d e f g h i j k l", "m n o p r s t u v"], + pred_audio=None, ) metric.update( name="dataset_2", refs=["a b c"], - hyps=["a b d"], + asr_hyps=["a b d"], + pred_audio=None, ) ans = metric.compute() + # wer assert ans["wer_dataset_1"] == 0.0 assert ans["wer_dataset_2"] == 1 / 3 assert ans["wer"] == 1 / 6 # average across datasets + # cer + assert ans["cer_dataset_1"] == 0.0 + assert ans["cer_dataset_2"] == 1 / 5 From 4fc268db45c73cb2a36b6ac117d8557a72f564bf Mon Sep 17 00:00:00 2001 From: Edresson Date: Mon, 17 Nov 2025 17:23:40 +0000 Subject: [PATCH 044/102] Apply isort and black reformatting Signed-off-by: Edresson Signed-off-by: Edresson Casanova --- nemo/collections/speechlm2/parts/metrics/__init__.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/nemo/collections/speechlm2/parts/metrics/__init__.py b/nemo/collections/speechlm2/parts/metrics/__init__.py index 73e6e050f731..8758fe0f4dd4 100644 --- a/nemo/collections/speechlm2/parts/metrics/__init__.py +++ b/nemo/collections/speechlm2/parts/metrics/__init__.py @@ -13,9 +13,9 @@ # limitations under the License. from .asr_bleu import ASRBLEU from .bleu import BLEU +from .intelligibility import Intelligibility from .results_logger import ResultsLogger from .token_accuracy import TokenAccuracy -from .intelligibility import Intelligibility __all__ = [ 'ASRBLEU', From a124eab32f7c814ffeee470d6338e336ed483b10 Mon Sep 17 00:00:00 2001 From: Edresson Casanova Date: Mon, 17 Nov 2025 10:38:42 -0800 Subject: [PATCH 045/102] Disable triton if cuda is not available Signed-off-by: Edresson Casanova --- .../speechlm2/modules/rvq_ear_tts_model.py | 97 ++++++++++++++++++- 1 file changed, 95 insertions(+), 2 deletions(-) diff --git a/nemo/collections/speechlm2/modules/rvq_ear_tts_model.py b/nemo/collections/speechlm2/modules/rvq_ear_tts_model.py index a94f580fc2ac..71e37875db6c 100644 --- a/nemo/collections/speechlm2/modules/rvq_ear_tts_model.py +++ b/nemo/collections/speechlm2/modules/rvq_ear_tts_model.py @@ -95,11 +95,19 @@ def forward(self, x: Tensor) -> Tensor: # ============================================================================== # Triton-accelerated and Fallback Functions # ============================================================================== + +TRITON_IMPORTED = False try: - # Attempt to import Triton for optimized GPU kernels import triton import triton.language as tl + TRITON_IMPORTED = True +except ImportError: + TRITON_IMPORTED = False + +USE_TRITON = TRITON_IMPORTED and torch.cuda.is_available() +logging.info("Triton available & CUDA detected. Using Triton kernel for batch_matmul.") +if USE_TRITON: @triton.jit def batch_matmul_kernel( x_ptr, # Pointer to input tensor x: [batch_size, d_in] @@ -190,7 +198,93 @@ def batch_matmul_triton(x, w, y, BLOCK_SIZE_DIN: int = 16, BLOCK_SIZE_DOUT: int batch_matmul = batch_matmul_triton logging.info("Triton is available. Using optimized Triton kernel for batch_matmul.") + +try: + import triton + import triton.language as tl + TRITON_IMPORTED = True except ImportError: + TRITON_IMPORTED = False + +USE_TRITON = TRITON_IMPORTED and torch.cuda.is_available() + +if USE_TRITON: + logging.info("Triton+CUDA detected. Using Triton kernel for batch_matmul.") + + @triton.jit + def batch_matmul_kernel( + x_ptr, + w_ptr, + y_ptr, + result_ptr, + b, + d_in, + d_out, + n, + BLOCK_SIZE_DIN: tl.constexpr, + BLOCK_SIZE_DOUT: tl.constexpr, + ): + batch_id = tl.program_id(axis=0) + dout_block_id = tl.program_id(axis=1) + + if batch_id >= b: + return + + idx = tl.load(y_ptr + batch_id) + + x_offset = x_ptr + batch_id * d_in + w_offset = w_ptr + idx * d_out * d_in + + dout_offsets = dout_block_id * BLOCK_SIZE_DOUT + tl.arange(0, BLOCK_SIZE_DOUT) + dout_mask = dout_offsets < d_out + + result_block = tl.zeros([BLOCK_SIZE_DOUT], dtype=tl.float32) + + for din_start in range(0, d_in, BLOCK_SIZE_DIN): + din_offsets = din_start + tl.arange(0, BLOCK_SIZE_DIN) + din_mask = din_offsets < d_in + + x_i = tl.load(x_offset + din_offsets, mask=din_mask, other=0.0) + + w_i_block = tl.load( + w_offset + dout_offsets[:, None] * d_in + din_offsets[None, :], + mask=(dout_mask[:, None] & din_mask[None, :]), + other=0.0, + ) + + result_block += tl.sum(w_i_block * x_i[None, :], axis=1) + + result_offset = result_ptr + batch_id * d_out + dout_offsets + tl.store(result_offset, result_block, mask=dout_mask) + + def batch_matmul_triton(x, w, y, BLOCK_SIZE_DIN: int = 16, BLOCK_SIZE_DOUT: int = 64): + assert x.is_contiguous() and w.is_contiguous() and y.is_contiguous() + + b, d_in = x.shape + n, d_out, _ = w.shape + result = torch.empty(b, d_out, device=x.device, dtype=torch.float32) + + batch_matmul_kernel[ + lambda meta: (b, triton.cdiv(d_out, meta["BLOCK_SIZE_DOUT"])) + ]( + x.float(), + w.float(), + y, + result, + b, + d_in, + d_out, + n, + BLOCK_SIZE_DIN=BLOCK_SIZE_DIN, + BLOCK_SIZE_DOUT=BLOCK_SIZE_DOUT, + ) + + return result.to(dtype=x.dtype) + + batch_matmul = batch_matmul_triton + +else: + logging.info("Using PyTorch fallback (Triton unavailable or no CUDA).") # Fallback to PyTorch implementation if Triton is not available def batch_matmul_pytorch(x: Tensor, w: Tensor, y: Tensor, *args, **kwargs) -> Tensor: """ @@ -213,7 +307,6 @@ def batch_matmul_pytorch(x: Tensor, w: Tensor, y: Tensor, *args, **kwargs) -> Te return torch.bmm(w[y], x.unsqueeze(2)).squeeze(2) batch_matmul = batch_matmul_pytorch - logging.info("Triton is not available. Using PyTorch fallback for batch_matmul.") # ============================================================================== From 31f13f980323010207272efaad5b41ef83999113 Mon Sep 17 00:00:00 2001 From: Edresson Date: Mon, 17 Nov 2025 18:39:26 +0000 Subject: [PATCH 046/102] Apply isort and black reformatting Signed-off-by: Edresson Signed-off-by: Edresson Casanova --- nemo/collections/speechlm2/modules/rvq_ear_tts_model.py | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/nemo/collections/speechlm2/modules/rvq_ear_tts_model.py b/nemo/collections/speechlm2/modules/rvq_ear_tts_model.py index 71e37875db6c..53265c75942a 100644 --- a/nemo/collections/speechlm2/modules/rvq_ear_tts_model.py +++ b/nemo/collections/speechlm2/modules/rvq_ear_tts_model.py @@ -100,6 +100,7 @@ def forward(self, x: Tensor) -> Tensor: try: import triton import triton.language as tl + TRITON_IMPORTED = True except ImportError: TRITON_IMPORTED = False @@ -108,6 +109,7 @@ def forward(self, x: Tensor) -> Tensor: logging.info("Triton available & CUDA detected. Using Triton kernel for batch_matmul.") if USE_TRITON: + @triton.jit def batch_matmul_kernel( x_ptr, # Pointer to input tensor x: [batch_size, d_in] @@ -202,6 +204,7 @@ def batch_matmul_triton(x, w, y, BLOCK_SIZE_DIN: int = 16, BLOCK_SIZE_DOUT: int try: import triton import triton.language as tl + TRITON_IMPORTED = True except ImportError: TRITON_IMPORTED = False @@ -264,9 +267,7 @@ def batch_matmul_triton(x, w, y, BLOCK_SIZE_DIN: int = 16, BLOCK_SIZE_DOUT: int n, d_out, _ = w.shape result = torch.empty(b, d_out, device=x.device, dtype=torch.float32) - batch_matmul_kernel[ - lambda meta: (b, triton.cdiv(d_out, meta["BLOCK_SIZE_DOUT"])) - ]( + batch_matmul_kernel[lambda meta: (b, triton.cdiv(d_out, meta["BLOCK_SIZE_DOUT"]))]( x.float(), w.float(), y, @@ -285,6 +286,7 @@ def batch_matmul_triton(x, w, y, BLOCK_SIZE_DIN: int = 16, BLOCK_SIZE_DOUT: int else: logging.info("Using PyTorch fallback (Triton unavailable or no CUDA).") + # Fallback to PyTorch implementation if Triton is not available def batch_matmul_pytorch(x: Tensor, w: Tensor, y: Tensor, *args, **kwargs) -> Tensor: """ From d57812562ceced3435b2786d60963c17ff7e6c53 Mon Sep 17 00:00:00 2001 From: Edresson Casanova Date: Mon, 17 Nov 2025 13:35:14 -0800 Subject: [PATCH 047/102] Update codec run dtype Signed-off-by: Edresson Casanova --- nemo/collections/speechlm2/models/duplex_ear_tts.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/nemo/collections/speechlm2/models/duplex_ear_tts.py b/nemo/collections/speechlm2/models/duplex_ear_tts.py index 6a8401c9df5f..f37db91cbbd7 100644 --- a/nemo/collections/speechlm2/models/duplex_ear_tts.py +++ b/nemo/collections/speechlm2/models/duplex_ear_tts.py @@ -489,7 +489,7 @@ def __init__(self, cfg: dict) -> None: del self.language_model # get codec run precision - self.audio_codec_run_dtype = getattr(torch, self.cfg.get("audio_codec_run_dtype", "bfloat16"), torch.float32) + self.audio_codec_run_dtype = getattr(torch, self.cfg.get("audio_codec_run_dtype", "float32"), torch.float32) # instanciate eartts model and codec self._load_tts_model(self.cfg) From 853d9488921febd279d0e1499b4aed31429d05e8 Mon Sep 17 00:00:00 2001 From: Edresson Casanova Date: Fri, 21 Nov 2025 06:44:41 -0800 Subject: [PATCH 048/102] Add EOS dropout Signed-off-by: Edresson Casanova --- .../speechlm2/models/duplex_ear_tts.py | 21 +++++++++++++++++++ 1 file changed, 21 insertions(+) diff --git a/nemo/collections/speechlm2/models/duplex_ear_tts.py b/nemo/collections/speechlm2/models/duplex_ear_tts.py index f37db91cbbd7..75fbdad9d94e 100644 --- a/nemo/collections/speechlm2/models/duplex_ear_tts.py +++ b/nemo/collections/speechlm2/models/duplex_ear_tts.py @@ -801,6 +801,27 @@ def pad_or_truncate(x, pad_value=0): # set the extra self.speech_pad_id at first 1 position in non_prompt_mask target_codes_aligned[row_idx, pos] = self.speech_pad_id + # EOS dropout to make the model more robust + if self.training and self.cfg.get("text_eos_dropout_prob", 0.0) > 0: + # Mask EOS positions + eos_mask = (input_text_tokens == self.text_eos_id) + + # Random dropout only on EOS positions + dropout_mask = ( + torch.rand(eos_mask.sum(), device=input_text_tokens.device) < self.cfg.text_eos_dropout_prob + ) + + # Scatter dropout decisions into [B, T] + full_dropout_mask = torch.zeros_like(input_text_tokens, dtype=torch.bool) + full_dropout_mask[eos_mask] = dropout_mask + + # Replace dropped EOS with PAD + input_text_tokens = torch.where( + full_dropout_mask, + torch.full_like(input_text_tokens, self.text_pad_id), + input_text_tokens + ) + # shift text tokens subword_ids = F.pad(input_text_tokens[:, 1:], [0, 1]) # note that we are using a text mask where we are ignoring the desc + audio prompt but we are keeping 1 until the audio ends to support duplex From 53569adbf5c3b71d616efe1deb8ce898efd81dda Mon Sep 17 00:00:00 2001 From: Edresson Date: Fri, 21 Nov 2025 14:45:38 +0000 Subject: [PATCH 049/102] Apply isort and black reformatting Signed-off-by: Edresson Signed-off-by: Edresson Casanova --- nemo/collections/speechlm2/models/duplex_ear_tts.py | 10 +++------- 1 file changed, 3 insertions(+), 7 deletions(-) diff --git a/nemo/collections/speechlm2/models/duplex_ear_tts.py b/nemo/collections/speechlm2/models/duplex_ear_tts.py index 75fbdad9d94e..053da8642331 100644 --- a/nemo/collections/speechlm2/models/duplex_ear_tts.py +++ b/nemo/collections/speechlm2/models/duplex_ear_tts.py @@ -804,12 +804,10 @@ def pad_or_truncate(x, pad_value=0): # EOS dropout to make the model more robust if self.training and self.cfg.get("text_eos_dropout_prob", 0.0) > 0: # Mask EOS positions - eos_mask = (input_text_tokens == self.text_eos_id) + eos_mask = input_text_tokens == self.text_eos_id # Random dropout only on EOS positions - dropout_mask = ( - torch.rand(eos_mask.sum(), device=input_text_tokens.device) < self.cfg.text_eos_dropout_prob - ) + dropout_mask = torch.rand(eos_mask.sum(), device=input_text_tokens.device) < self.cfg.text_eos_dropout_prob # Scatter dropout decisions into [B, T] full_dropout_mask = torch.zeros_like(input_text_tokens, dtype=torch.bool) @@ -817,9 +815,7 @@ def pad_or_truncate(x, pad_value=0): # Replace dropped EOS with PAD input_text_tokens = torch.where( - full_dropout_mask, - torch.full_like(input_text_tokens, self.text_pad_id), - input_text_tokens + full_dropout_mask, torch.full_like(input_text_tokens, self.text_pad_id), input_text_tokens ) # shift text tokens From 545217570fdb428b40629341ef40c5e750e49510 Mon Sep 17 00:00:00 2001 From: Edresson Casanova Date: Mon, 24 Nov 2025 06:29:50 -0800 Subject: [PATCH 050/102] Fix Bleu Signed-off-by: Edresson Casanova --- nemo/collections/speechlm2/parts/metrics/bleu.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/nemo/collections/speechlm2/parts/metrics/bleu.py b/nemo/collections/speechlm2/parts/metrics/bleu.py index 8479f3ec6346..34f6da088a4d 100644 --- a/nemo/collections/speechlm2/parts/metrics/bleu.py +++ b/nemo/collections/speechlm2/parts/metrics/bleu.py @@ -40,6 +40,8 @@ def __init__(self, normalize: bool = True, normalizer=None, verbose: bool = True self._hyps = defaultdict(list) def reset(self): + self._refs.clear() + self._hyps.clear() return self def update(self, name: str, refs: list[str], hyps: list[str]) -> None: @@ -56,8 +58,7 @@ def compute(self) -> dict[str, torch.Tensor]: metric = torch.tensor(sacrebleu.corpus_bleu(self._hyps[name], [self._refs[name]]).score) corpus_metric[f"txt_bleu_{name}"] = metric corpus_metric["txt_bleu"] = torch.stack(list(corpus_metric.values())).mean() - self._refs.clear() - self._hyps.clear() + self.reset() return corpus_metric From db2bdf90b9b5dadea340ee359fca10d8aa3c07f1 Mon Sep 17 00:00:00 2001 From: Edresson Casanova Date: Mon, 24 Nov 2025 06:51:19 -0800 Subject: [PATCH 051/102] Cleanup tests useless comments Signed-off-by: Edresson Casanova --- tests/collections/speechlm2/test_duplex_eartts.py | 10 +--------- 1 file changed, 1 insertion(+), 9 deletions(-) diff --git a/tests/collections/speechlm2/test_duplex_eartts.py b/tests/collections/speechlm2/test_duplex_eartts.py index 556ad8882d94..364ee0c6edb2 100644 --- a/tests/collections/speechlm2/test_duplex_eartts.py +++ b/tests/collections/speechlm2/test_duplex_eartts.py @@ -248,7 +248,6 @@ def training_cutset_batch(): def test_eartts_dataset(dataset, training_cutset_batch): batch = dataset[training_cutset_batch] - # Keys that must be present in batch expected_keys = { "sample_id", "non_prompt_mask", @@ -271,11 +270,9 @@ def test_eartts_dataset(dataset, training_cutset_batch): "formatter", } - # --- Presence + tensor sanity checks --- for key in expected_keys: assert key in batch, f"Missing key: {key}" - # Tensor-only keys tensor_keys = [ "non_prompt_mask", "desc_mask", @@ -296,16 +293,11 @@ def test_eartts_dataset(dataset, training_cutset_batch): for key in tensor_keys: assert torch.is_tensor(batch[key]), f"{key} must be a tensor" - # --- Shape/value checks similar to the original test --- - - # Audio shapes (you can adjust if needed) assert batch["source_audio"].shape == (1, 89082) assert batch["target_audio"].shape == (1, 89082) - # Target text consistency + # Check target text consistency assert batch["target_texts"] == ["hello okay"] - - # Token checks (same content as your old test) assert batch["source_tokens"].tolist() == [ [ 2, From a802460660a9a38c4a623268e5616967fdf4d1b3 Mon Sep 17 00:00:00 2001 From: Edresson Casanova Date: Mon, 24 Nov 2025 09:44:37 -0800 Subject: [PATCH 052/102] Rename EARTTS files Signed-off-by: Edresson Casanova --- nemo/collections/speechlm2/models/duplex_ear_tts.py | 10 ++++++---- .../modules/{rvq_ear_tts_model.py => ear_tts_model.py} | 0 .../{rvq_ear_tts_vae.py => ear_tts_vae_codec.py} | 0 3 files changed, 6 insertions(+), 4 deletions(-) rename nemo/collections/speechlm2/modules/{rvq_ear_tts_model.py => ear_tts_model.py} (100%) rename nemo/collections/speechlm2/modules/{rvq_ear_tts_vae.py => ear_tts_vae_codec.py} (100%) diff --git a/nemo/collections/speechlm2/models/duplex_ear_tts.py b/nemo/collections/speechlm2/models/duplex_ear_tts.py index 053da8642331..9544ee03a4ca 100644 --- a/nemo/collections/speechlm2/models/duplex_ear_tts.py +++ b/nemo/collections/speechlm2/models/duplex_ear_tts.py @@ -42,8 +42,8 @@ from nemo.collections.common.tokenizers import AutoTokenizer from nemo.collections.speechlm2.data.utils import get_pad_id from nemo.collections.speechlm2.modules.ear_tts_commons import SCRIPT_PLACEHOLDER -from nemo.collections.speechlm2.modules.rvq_ear_tts_model import RVQEARTTSModel -from nemo.collections.speechlm2.modules.rvq_ear_tts_vae import RVQVAEModel +from nemo.collections.speechlm2.modules.ear_tts_model import RVQEARTTSModel +from nemo.collections.speechlm2.modules.ear_tts_vae_codec import RVQVAEModel from nemo.collections.speechlm2.parts.hf_hub import HFHubMixin from nemo.collections.speechlm2.parts.metrics.asr_bleu import ASRBLEU from nemo.collections.speechlm2.parts.metrics.intelligibility import Intelligibility @@ -1313,9 +1313,11 @@ def validation_step(self, batch: dict, batch_idx: int): for name, dataset_batch in batch.items(): if dataset_batch is None: continue # some dataset is exhausted + + B = len(dataset_batch['sample_id']) + # run inference for multiples references if self.cfg.get("inference_speaker_reference_path", None): - B = len(dataset_batch['sample_id']) for inference_speaker_reference in glob.glob( os.path.join(self.cfg.inference_speaker_reference_path, "**"), recursive=True ): @@ -1339,7 +1341,7 @@ def validation_step(self, batch: dict, batch_idx: int): # run inference for a custom speaker reference elif self.cfg.get("inference_speaker_reference", None): new_dataset_batch = copy.deepcopy(dataset_batch) - speaker_audio, sr = load_audio_librosa(inference_speaker_reference) + speaker_audio, sr = load_audio_librosa(self.cfg.inference_speaker_reference) speaker_audio = resample(speaker_audio, sr, self.target_sample_rate) speaker_audio = speaker_audio.repeat(B, 1).to(self.device) # lengths -> [B] diff --git a/nemo/collections/speechlm2/modules/rvq_ear_tts_model.py b/nemo/collections/speechlm2/modules/ear_tts_model.py similarity index 100% rename from nemo/collections/speechlm2/modules/rvq_ear_tts_model.py rename to nemo/collections/speechlm2/modules/ear_tts_model.py diff --git a/nemo/collections/speechlm2/modules/rvq_ear_tts_vae.py b/nemo/collections/speechlm2/modules/ear_tts_vae_codec.py similarity index 100% rename from nemo/collections/speechlm2/modules/rvq_ear_tts_vae.py rename to nemo/collections/speechlm2/modules/ear_tts_vae_codec.py From da2cb92849405c57a0f141ab8bff64df00f6c76c Mon Sep 17 00:00:00 2001 From: Edresson Casanova Date: Tue, 25 Nov 2025 06:11:14 -0800 Subject: [PATCH 053/102] Update EARTTS dataset docs Signed-off-by: Edresson Casanova --- nemo/collections/common/data/lhotse/cutset.py | 3 +- .../speechlm2/data/duplex_ear_tts_dataset.py | 53 ++++++++++++++++--- 2 files changed, 48 insertions(+), 8 deletions(-) diff --git a/nemo/collections/common/data/lhotse/cutset.py b/nemo/collections/common/data/lhotse/cutset.py index 22b5b453ed4b..fee53b653762 100644 --- a/nemo/collections/common/data/lhotse/cutset.py +++ b/nemo/collections/common/data/lhotse/cutset.py @@ -611,7 +611,8 @@ def convert_overlap_cut_fn(cut: Cut) -> Cut: @data_type_parser(["lhotse_magpietts_data_as_continuation"]) def read_lhotse_magpietts_data_as_continuation(config) -> Tuple[CutSet, bool]: """ - Convert MagpieTTS dataset cuts to a continuation format suitable for S2S training. + Convert MagpieTTS dataset cuts into the Duplex S2S format, with optional + `context_audio` that can be used as a speaker reference. Args: config: Dictionary containing parser options: diff --git a/nemo/collections/speechlm2/data/duplex_ear_tts_dataset.py b/nemo/collections/speechlm2/data/duplex_ear_tts_dataset.py index 27d9db748712..68ee93fc6206 100644 --- a/nemo/collections/speechlm2/data/duplex_ear_tts_dataset.py +++ b/nemo/collections/speechlm2/data/duplex_ear_tts_dataset.py @@ -140,20 +140,59 @@ class DuplexEARTTSDataset(torch.utils.data.Dataset): output_roles (list[str], optional): List of speaker roles (cut.supervisions[:].speaker) to consider as outputs. Defaults to ["agent"]. - Returns: + p_drop_description (float, optional): + Probability of dropping text descriptions. Default: `0.0`. + + add_text_bos_and_eos_in_each_turn (bool, optional): + If True, each conversational turn from any speaker is explicitly delimited + with BOS and EOS tokens in the text stream. + Default: `True`. + + add_audio_prompt_after_description (bool, optional): + If True, an optional audio/speaker prompt is appended after the description. + Default: `True`. + + audio_prompt_duration (float, optional): + Duration (in seconds) of the audio prompt appended when + `add_audio_prompt_after_description=True`. Default: `3.0`. + + num_delay_speech_tokens (int, optional): + Number of PAD tokens to insert before speech tokens to artificially + delay the start of speech. Default: `0`. + + Returns: A dictionary with the following keys: + - sample_id: List of sample IDs for each cut in the batch [B] + + - non_prompt_mask: Bool tensor [B, T] marking positions that are not part of the prompt + - desc_mask: Bool tensor [B, T] marking positions belonging to the description + - desc_lens: Tensor of description text lengths [B] + - desc_plus_audio_prompt_lens: Tensor of description + audio prompt lengths [B] + + - aligned_attention_mask: Bool tensor [B, T] used by alignment-aware transformer models + - aligned_position_ids: Tensor of position indices aligned to audio frames [B, T] + - source_audio: Tensor of source waveform samples [B, T] - source_audio_lens: Tensor of source audio lengths [B] + - target_audio: Tensor of target waveform samples [B, T] - target_audio_lens: Tensor of target audio lengths [B] - - input_text_tokens: Tensor of target text tokens [B, T], with special tokens (BOS/EOS/PAD) - at positions aligned with audio frames + + - input_text_tokens: Tensor of frame-aligned input text tokens [B, T], + including BOS/EOS/PAD when enabled - target_token_lens: Tensor of target token sequence lengths [B] - - source_tokens: Tensor of source text tokens [B, T], with special tokens (BOS/EOS/PAD) - at positions aligned with audio frames + + - source_tokens: Tensor of frame-aligned source text tokens [B, T], + including BOS/EOS/PAD - source_token_lens: Tensor of source token sequence lengths [B] + - target_texts: List of full target texts joined from output_roles supervisions [B] + - speaker_reference_audio: Tensor of optional speaker reference waveform samples [B, T] + - speaker_reference_audio_lens: Tensor of speaker reference audio lengths [B] + + - formatter: List indicating the formatter to use for each cut (default "s2s_duplex") [B] + Notes: - The dataset ensures frame-level alignment between audio and text by inserting tokens at specific frame positions based on the timing of supervision segments. @@ -174,8 +213,8 @@ def __init__( input_roles: list[str] = None, output_roles: list[str] = None, p_drop_description: float = 0.0, - add_text_bos_and_eos_in_each_turn: bool = False, - add_audio_prompt_after_description: bool = False, + add_text_bos_and_eos_in_each_turn: bool = True, + add_audio_prompt_after_description: bool = True, audio_prompt_duration: float = 3.0, num_delay_speech_tokens: int = 0, ): From 8ca4db4abd72110a699ffc4414622041af4faa28 Mon Sep 17 00:00:00 2001 From: Edresson Casanova Date: Tue, 25 Nov 2025 06:18:38 -0800 Subject: [PATCH 054/102] Remove mixed precision fns Signed-off-by: Edresson Casanova --- .../speechlm2/models/duplex_ear_tts.py | 231 ------------------ 1 file changed, 231 deletions(-) diff --git a/nemo/collections/speechlm2/models/duplex_ear_tts.py b/nemo/collections/speechlm2/models/duplex_ear_tts.py index 9544ee03a4ca..606ff74814ed 100644 --- a/nemo/collections/speechlm2/models/duplex_ear_tts.py +++ b/nemo/collections/speechlm2/models/duplex_ear_tts.py @@ -101,181 +101,6 @@ def ensures_target_precision(target_dtype): torch.set_default_dtype(default_dtype) -def collect_activation_stats(model: nn.Module, inputs: dict) -> dict: - """ - Collect per-layer activation statistics (min and max) for Linear, LayerNorm, and Embedding modules. - - This performs a forward pass in FP32 and registers hooks to record - the min and max values of each layer's output. These statistics are - used to decide which layers are safe for mixed precision. - - Args: - model (nn.Module): Model to analyze. - inputs (dict): Input arguments for the model forward pass. - - Returns: - dict: Mapping from layer names to activation stats: - {"layer_name": {"min": value, "max": value}} - """ - stats = {} - hooks = [] - - def _make_hook(name: str): - def hook(_, __, out): - if isinstance(out, tuple): - out = out[0] - if torch.is_tensor(out): - stats[name] = {"min": float(out.detach().min()), "max": float(out.detach().max())} - - return hook - - # Register hooks - for name, module in model.named_modules(): - if isinstance(module, (nn.Linear, nn.LayerNorm, nn.Embedding)): - hooks.append(module.register_forward_hook(_make_hook(name))) - - # Forward pass - with torch.no_grad(): - _ = model( - code=inputs["code"], - audio_mask=maybe_to(inputs["audio_mask"], torch.float32), - attention_mask=maybe_to(inputs["attention_mask"], torch.float32), - position_ids=inputs["position_ids"], - context_hidden_state=maybe_to(inputs["context_hidden_state"], torch.float32), - subword_ids=inputs["subword_ids"], - subword_mask=maybe_to(inputs["subword_mask"], torch.float32), - non_prompt_mask=maybe_to(inputs["non_prompt_mask"], torch.float32), - ) - - # Remove hooks - for h in hooks: - h.remove() - - return stats - - -def classify_precision_layers(model: nn.Module, stats: dict, safe_min: float, safe_max: float) -> list: - """ - Determine which layers must remain FP32 for numerical stability. - - Sensitive layers (LayerNorm, Embedding, or Linear layers with out-of-range activations) - are forced to FP32. FP32 can propagate to the next safe layer to prevent instability. - - Args: - model (nn.Module): Model to classify. - stats (dict): Activation statistics from `collect_activation_stats`. - safe_min (float): Minimum threshold for safe activations. - safe_max (float): Maximum threshold for safe activations. - - Returns: - list: Names of layers that should remain FP32. - """ - fp32_layers = [] - propagate_fp32 = False - - for name, module in model.named_modules(): - if name not in stats: - continue - - mn, mx = stats[name]["min"], stats[name]["max"] - safe_range = abs(mn) < safe_max and abs(mx) < safe_max - not_tiny = not (abs(mn) < safe_min and abs(mx) < safe_min) - safe = safe_range and not_tiny - - # Determine if layer is FP32-sensitive - is_sensitive = isinstance(module, (nn.LayerNorm, nn.Embedding)) - if isinstance(module, nn.Linear) and not safe: - is_sensitive = True - - if is_sensitive: - fp32_layers.append(name) - propagate_fp32 = True - elif propagate_fp32: - # Propagate FP32 to next safe layer - fp32_layers.append(name) - propagate_fp32 = False - - return fp32_layers - - -def wrap_module_precision(module: nn.Module, force_fp32: bool, mixed_dtype=torch.bfloat16): - """ - Wrap a module's forward to enforce mixed precision or FP32. - - Args: - module (nn.Module): Module to wrap. - force_fp32 (bool): If True, module runs in FP32. - mixed_dtype (torch.dtype): Target dtype for mixed precision layers. - """ - if hasattr(module, "_original_forward"): - return - - module._original_forward = module.forward - - def new_forward(*args, **kwargs): - if force_fp32: - with fp32_precision(): - return module._original_forward(*args, **kwargs) - else: - new_args = tuple( - a.to(mixed_dtype) if isinstance(a, torch.Tensor) and a.is_floating_point() else a for a in args - ) - new_kwargs = { - k: v.to(mixed_dtype) if isinstance(v, torch.Tensor) and v.is_floating_point() else v - for k, v in kwargs.items() - } - with ensures_target_precision(mixed_dtype): - return module._original_forward(*new_args, **new_kwargs) - - module.forward = new_forward - - -def find_sensitive_layers( - model: nn.Module, - inputs: dict, - bf16_min: float = 1e-2, - bf16_max: float = 1e2, - safety_factor: float = 1.0, -) -> list: - """ - Identify FP32-sensitive layers for a TTS model. - - Steps: - 1. Run FP32 forward pass to collect activation stats. - 2. Classify layers that must remain FP32. - - Args: - model (nn.Module): TTS model. - inputs (dict): Inputs for forward pass. - bf16_min (float): Minimum safe activation for BF16. - bf16_max (float): Maximum safe activation for BF16. - safety_factor (float): Safety factor for thresholds. - - Returns: - list: Names of FP32-sensitive layers. - """ - safe_min = bf16_min * safety_factor - safe_max = bf16_max * safety_factor - - # FP32 reference forward - model_fp32 = copy.deepcopy(model).eval().to(torch.float32) - stats = collect_activation_stats(model_fp32, inputs) - - # Identify FP32 layers - model_patched = copy.deepcopy(model).eval() - fp32_layers = classify_precision_layers(model_patched, stats, safe_min, safe_max) - - # Count total relevant layers - total_layers = sum( - 1 for _, module in model.named_modules() if isinstance(module, (nn.Linear, nn.LayerNorm, nn.Embedding)) - ) - half_precision_layers = total_layers - len(fp32_layers) - - print(f"Total sensitive layers (FP32): {len(fp32_layers)}, " f"Half precision layers: {half_precision_layers}") - - return fp32_layers - - def generate_multiturn_speaking_mask(input_ids: torch.Tensor, bos_token_id: int = 0, eos_token_id: int = 1): """ Efficient, batched speaking mask generator that marks 1 between and pairs. @@ -474,7 +299,6 @@ def __init__(self, cfg: dict) -> None: self.target_sample_rate = cfg.data.target_sample_rate self.source_sample_rate = cfg.data.source_sample_rate self.normalize_text = cfg.data.get("normalize_text", False) - self.model_16_precision_safe = None self.validation_save_path = os.path.join(cfg.exp_manager.explicit_log_dir, "validation_logs") @@ -742,23 +566,6 @@ def prepare_inputs(self, batch: dict): with ensures_target_precision(self.audio_codec_run_dtype), torch.no_grad(): target_codes, target_codes_lens = self.audio_codec.encode(target_audio.unsqueeze(1), target_audio_lens) - # ToDo: consider use the source audio - """ - # resample source audio if needed - if self.source_sample_rate != self.target_sample_rate: - source_audio = resample(source_audio, self.source_sample_rate, self.target_sample_rate) - with fp32_precision(): - source_audio_lens = (source_audio_lens * (self.target_sample_rate/self.source_sample_rate)).to(lengths.dtype) - # ToDo: Add a transformer encoder to help the model to better extract contextual information, replace the code bellow with it - # extract embedding for context audios - source_audio, source_audio_lens = self.pad_audio_to_factor(source_audio, source_audio_lens, self.target_samples_per_frame, 1) - with ensures_target_precision(self.audio_codec_run_dtype), torch.no_grad(): - source_codes, source_codes_lens = self.audio_codec.encode( - source_audio.unsqueeze(1), source_audio_lens - ) - source_codes = source_codes.transpose(1, 2) # (B, K, T) -> (B, T, K) - """ - with fp32_precision(): target_len = target_codes.shape[1] @@ -1048,20 +855,6 @@ def offline_inference_with_custom_sentences( ) return audio, audio_len, speaker_audio, speaker_audio_lens - def apply_mixed_precision_wrapping_on_tts_model(self, fp32_layers: list, mixed_dtype=torch.bfloat16): - """ - Apply mixed precision to TTS model layers, keeping FP32 layers intact. - - Args: - fp32_layers (list): Names of layers to keep FP32. - mixed_dtype (torch.dtype): Target dtype for mixed precision layers. - """ - logging.info(f"Converting TTS model to mixed precision. FP32 layers: {fp32_layers}") - for name, module in self.tts_model.named_modules(): - if isinstance(module, (nn.Linear, nn.LayerNorm, nn.Embedding)): - force_fp32 = name in fp32_layers - wrap_module_precision(module, force_fp32, mixed_dtype) - def run_evaluation_one_batch(self, name, dataset_batch, use_dataloader_init=False): """ Runs evaluation and scoring for a single data batch, logging metrics and updating result buffers. @@ -1077,30 +870,6 @@ def run_evaluation_one_batch(self, name, dataset_batch, use_dataloader_init=Fals results = {} inputs = self.prepare_inputs(dataset_batch) - # first evaluation, make the model bf16 safe - if ( - not self.model_16_precision_safe - and self.cfg.get("ensures_16_safe", False) - and self.trainer_config is not None - and str(self.trainer_config.precision) != str(32) - ): - if self.cfg.get("sensitive_layers", None): - self.apply_mixed_precision_wrapping_on_tts_model( - self.cfg.sensitive_layers, - mixed_dtype=torch.float16 if str(self.trainer_config.precision) == str(16) else torch.bfloat16, - ) - else: - sensitive_layers = find_sensitive_layers( - self.tts_model, - inputs, - safety_factor=1.0, - ) - self.apply_mixed_precision_wrapping_on_tts_model( - sensitive_layers, - mixed_dtype=torch.float16 if str(self.trainer_config.precision) == str(16) else torch.bfloat16, - ) - self.model_16_precision_safe = True - results["audio_tf"], results["audio_tf_len"] = self.get_teacher_force_inference_audio(dataset_batch) if use_dataloader_init: # cut it on prompt From 45659e4c22b7ad69299b07a3b8a12660956706c0 Mon Sep 17 00:00:00 2001 From: Edresson Casanova Date: Tue, 25 Nov 2025 08:52:44 -0800 Subject: [PATCH 055/102] Remove system prompt Signed-off-by: Edresson Casanova --- .../speechlm2/data/duplex_ear_tts_dataset.py | 81 +++++-------------- .../speechlm2/models/duplex_ear_tts.py | 78 ++---------------- 2 files changed, 29 insertions(+), 130 deletions(-) diff --git a/nemo/collections/speechlm2/data/duplex_ear_tts_dataset.py b/nemo/collections/speechlm2/data/duplex_ear_tts_dataset.py index 68ee93fc6206..76d4de5a17e3 100644 --- a/nemo/collections/speechlm2/data/duplex_ear_tts_dataset.py +++ b/nemo/collections/speechlm2/data/duplex_ear_tts_dataset.py @@ -233,40 +233,6 @@ def __init__( assert tokenizer.bos is not None, "BOS support in the tokenizer is required for S2S models." assert tokenizer.eos is not None, "EOS support in the tokenizer is required for S2S models." - def generate_prompt_description(self, device): - messages = [] - if random.random() > self.p_drop_description: - # ToDo: add extra system prompts - system_prompt = ( - "You engage in conversation with the user. When delivering your response as speech, " - "if the user provides a description such as emotions, scene details, " - "or speaker style, you adjust your speaking style accordingly when delivering the response. " - "However, this description should influence only the delivery of your response, not its content. " - "Your response should remain independent of any stylistic instructions." - ) - messages.append({"role": "system", "content": system_prompt}) - else: - messages.append({"role": "system", "content": ""}) - - # given that descriptions are currently not supported, only added the user prompt - # ToDo: add extra user prompts or completly remove it as it is not used in NanoV2 - user_prompt = "Can you tell me something interesting?" - messages.append({"role": "user", "content": user_prompt}) - messages.append({"role": "assistant", "content": SCRIPT_PLACEHOLDER}) - non_script_list = self.tokenizer.apply_chat_template( - messages, - tokenize=False, - add_generation_prompt=False, - ).split(SCRIPT_PLACEHOLDER + self.tokenizer.eos_token)[:-1] - - input_ids = [] - for i, non_script in enumerate(non_script_list): - desc_ids = self.tokenizer.text_to_ids(non_script) - input_ids.extend(desc_ids) - - input_ids = torch.tensor(input_ids, dtype=torch.long, device=device).view(1, -1) - return input_ids - def __getitem__(self, cuts: CutSet) -> dict: cuts = cuts.transform_text(_strip_timestamps) source_audio, source_audio_lens = collate_audio(cuts.resample(self.source_sample_rate)) @@ -336,8 +302,9 @@ def __getitem__(self, cuts: CutSet) -> dict: desc_plus_audio_prompt_lens = [] # for each sample in the batch for i in range(input_text_tokens.size(0)): - # ToDo: Consider remove the prompt description, given that NanoV2 does not support it and curently it is only a single eos text token - desc_tokens_ids = self.generate_prompt_description(device=input_text_tokens[i].device).squeeze(0) + # create a eos token tensor + initial_text_frame_id = torch.tensor([self.tokenizer.eos], dtype=torch.long, device=input_text_tokens[i].device) + if self.add_audio_prompt_after_description: prompt_audio_size = int( ((self.audio_prompt_duration * self.target_sample_rate) // target_samples_per_frame) @@ -360,64 +327,58 @@ def __getitem__(self, cuts: CutSet) -> dict: # set last prompt frame with eos in text channel prompt_audio_text_pad[-1] = self.tokenizer.eos - # Add eos to simulate the end of a turn as in EAR-TTS inference - desc_tokens_ids = torch.cat( - [ - desc_tokens_ids, - torch.tensor([self.tokenizer.eos], dtype=desc_tokens_ids.dtype, device=desc_tokens_ids.device), - ] - ) - # Add padding equivalent to the audio prompt size in number of tokens + # Prepend an initial text EOS token followed by padding tokens that match + # the number of audio-prompt frames (in text-token units) and input_text_tokens new_input_text_tokens = torch.cat( [ - desc_tokens_ids.to(input_text_tokens.dtype), + initial_text_frame_id.to(input_text_tokens.dtype), prompt_audio_text_pad.to(input_text_tokens.dtype), input_text_tokens[i], ] ) # append to list and update lens input_text_tokens_.append(new_input_text_tokens) - target_token_lens[i] = target_token_lens[i] + len(desc_tokens_ids) + prompt_audio_text_pad_size + target_token_lens[i] = target_token_lens[i] + len(initial_text_frame_id) + prompt_audio_text_pad_size # add description to source text tokens - source_tokens_.append(torch.cat([desc_tokens_ids, prompt_audio_text_pad, source_tokens[i]])) - source_token_lens[i] = source_token_lens[i] + len(desc_tokens_ids) + prompt_audio_text_pad_size + source_tokens_.append(torch.cat([initial_text_frame_id, prompt_audio_text_pad, source_tokens[i]])) + source_token_lens[i] = source_token_lens[i] + len(initial_text_frame_id) + prompt_audio_text_pad_size # add silence in the source audio while the prompt is being processed - pad_size = (len(desc_tokens_ids) * source_samples_per_frame) + prompt_audio.size(1) + pad_size = (len(initial_text_frame_id) * source_samples_per_frame) + prompt_audio.size(1) pad_audio = torch.zeros(pad_size, device=source_audio.device, dtype=source_audio.dtype) source_audio_.append(torch.cat([pad_audio, source_audio[i]])) source_audio_lens[i] = source_audio_lens[i] + pad_size # add silence in the target audio while the prompt is being processed - pad_size = len(desc_tokens_ids) * target_samples_per_frame + pad_size = len(initial_text_frame_id) * target_samples_per_frame pad_audio = torch.zeros(pad_size, device=target_audio.device, dtype=target_audio.dtype) target_audio_.append(torch.cat([pad_audio, prompt_audio[i], target_audio[i]])) target_audio_lens[i] = target_audio_lens[i] + pad_size + prompt_audio.size(1) # desc duration - desc_lens.append(len(desc_tokens_ids)) + desc_lens.append(len(initial_text_frame_id)) desc_plus_audio_prompt_lens.append( - len(desc_tokens_ids) + prompt_audio_text_pad_size - 1 + len(initial_text_frame_id) + prompt_audio_text_pad_size - 1 ) # -1 due the shift done in subword_ids else: # add description to target text tokens - input_text_tokens_.append(torch.cat([desc_tokens_ids, input_text_tokens[i]])) - target_token_lens[i] = target_token_lens[i] + len(desc_tokens_ids) + input_text_tokens_.append(torch.cat([initial_text_frame_id, input_text_tokens[i]])) + target_token_lens[i] = target_token_lens[i] + len(initial_text_frame_id) # add description to source text tokens - source_tokens_.append(torch.cat([desc_tokens_ids, source_tokens[i]])) - source_token_lens[i] = source_token_lens[i] + len(desc_tokens_ids) + source_tokens_.append(torch.cat([initial_text_frame_id, source_tokens[i]])) + source_token_lens[i] = source_token_lens[i] + len(initial_text_frame_id) # add silence in the source audio while the prompt is being processed - pad_size = len(desc_tokens_ids) * source_samples_per_frame + pad_size = len(initial_text_frame_id) * source_samples_per_frame pad_audio = torch.zeros(pad_size, device=source_audio.device, dtype=source_audio.dtype) source_audio_.append(torch.cat([pad_audio, source_audio[i]])) source_audio_lens[i] = source_audio_lens[i] + pad_size # add silence in the target audio while the prompt is being processed - pad_size = len(desc_tokens_ids) * target_samples_per_frame + pad_size = len(initial_text_frame_id) * target_samples_per_frame pad_audio = torch.zeros(pad_size, device=target_audio.device, dtype=target_audio.dtype) target_audio_.append(torch.cat([pad_audio, target_audio[i]])) target_audio_lens[i] = target_audio_lens[i] + pad_size # des duration - desc_lens.append(len(desc_tokens_ids)) - desc_plus_audio_prompt_lens.append(len(desc_tokens_ids)) + desc_lens.append(len(initial_text_frame_id)) + desc_plus_audio_prompt_lens.append(len(initial_text_frame_id)) # collate tensors input_text_tokens = collate_vectors(input_text_tokens_, padding_value=text_pad_id) diff --git a/nemo/collections/speechlm2/models/duplex_ear_tts.py b/nemo/collections/speechlm2/models/duplex_ear_tts.py index 606ff74814ed..5e6d743a463f 100644 --- a/nemo/collections/speechlm2/models/duplex_ear_tts.py +++ b/nemo/collections/speechlm2/models/duplex_ear_tts.py @@ -1132,61 +1132,6 @@ def on_test_epoch_end(self) -> None: def test_step(self, *args, **kwargs): return self.validation_step(*args, **kwargs) - def get_system_prompt(self, system_prompt=None, user_prompt=None): - """ - Constructs a prompt message pair (system, user, assistant) formatted for chat inference. - - Args: - system_prompt (str, optional): System message describing conversational policy. - user_prompt (str, optional): User message/content. - - Returns: - torch.Tensor: Tokenized prompt IDs, shape (1, T). - """ - messages = [] - if system_prompt is None: - system_prompt = ( - "You engage in conversation with the user. When delivering your response as speech, " - "if the user provides a description such as emotions, scene details, " - "or speaker style, you adjust your speaking style accordingly when delivering the response. " - "However, this description should influence only the delivery of your response, not its content. " - "Your response should remain independent of any stylistic instructions." - ) - messages.append({"role": "system", "content": system_prompt}) - - # ToDo: implement dataloading support for descriptions - """for desc in example["descriptions"]: - user_prompt = "" - if random.random() > self.p_drop_description and desc: - user_prompt += f"```\n{desc}\n```" - if random.random() > self.p_drop_description: - if user_prompt: - user_prompt += "\n\n" - user_prompt += self.rng.choice(self.user_prompts) - if user_prompt: - messages.append({"role": "user", "content": user_prompt}) - messages.append({"role": "assistant", "content": SCRIPT_PLACEHOLDER}) - """ - - # given that descriptions are currently not supported, only added the user prompt - if user_prompt is None: - user_prompt = "Can you tell me something interesting?" - messages.append({"role": "user", "content": user_prompt}) - messages.append({"role": "assistant", "content": SCRIPT_PLACEHOLDER}) - non_script_list = self.tokenizer.apply_chat_template( - messages, - tokenize=False, - add_generation_prompt=False, - ).split(SCRIPT_PLACEHOLDER + self.tokenizer.eos_token)[:-1] - - input_ids = [] - for i, non_script in enumerate(non_script_list): - desc_ids = self.tokenizer.text_to_ids(non_script) - input_ids.extend(desc_ids) - - input_ids = torch.tensor(input_ids, dtype=torch.long, device=self.device).view(1, -1) - return input_ids - def set_init_inputs(self, speaker_audio, speaker_audio_lens, system_prompt=None, user_prompt=None): """ Registers and prepares initial input buffers for text/audio prompt and context, to warm up AR inference. @@ -1242,30 +1187,23 @@ def set_init_inputs(self, speaker_audio, speaker_audio_lens, system_prompt=None, with fp32_precision(): prompt_audio_text_pad_size = int(prompt_audio_size // self.target_samples_per_frame) - # get description tokens - # ToDo: Consider remove the prompt description, given that NanoV2 does not support it and curently it is only a single eos text token - desc_tokens_ids = self.get_system_prompt(system_prompt=system_prompt, user_prompt=user_prompt) + # create a eos token id + first_text_frame = torch.tensor([self.tokenizer.eos], dtype=torch.long, device=self.device) # create a padding tensor prompt_audio_text_pad = ( - torch.ones(prompt_audio_text_pad_size, device=self.device, dtype=desc_tokens_ids.dtype) * self.text_pad_id + torch.ones(prompt_audio_text_pad_size, device=self.device, dtype=first_text_frame.dtype) * self.text_pad_id ) prompt_audio_text_pad[-1] = self.tokenizer.eos - # Add eos to simulate the end of a turn as in EAR-TTS inference - desc_tokens_ids = torch.cat( - [ - desc_tokens_ids.squeeze(), - torch.tensor([self.tokenizer.eos], dtype=desc_tokens_ids.dtype, device=desc_tokens_ids.device), - ] - ) - # Add padding equivalent to the audio prompt size in number of tokens + # Prepend an initial text EOS token followed by padding tokens that match + # the number of audio-prompt frames (in text-token units). input_text_tokens = torch.cat( - [desc_tokens_ids.to(desc_tokens_ids.dtype), prompt_audio_text_pad.to(desc_tokens_ids.dtype)] + [first_text_frame, prompt_audio_text_pad.to(first_text_frame.dtype)] ) # create pad audio for the description - pad_size = desc_tokens_ids.size(-1) * self.target_samples_per_frame + pad_size = first_text_frame.size(-1) * self.target_samples_per_frame pad_audio = ( torch.zeros(pad_size, device=prompt_audio.device, dtype=prompt_audio.dtype) .unsqueeze(0) @@ -1301,7 +1239,7 @@ def set_init_inputs(self, speaker_audio, speaker_audio_lens, system_prompt=None, ) # desc mask is all zeros except the description desc_mask = torch.zeros_like(input_text_tokens) - desc_mask[:, : desc_tokens_ids.size(-1)] = 1 + desc_mask[:, : first_text_frame.size(-1)] = 1 if not self.cfg.get("disable_speech_pad", False): # add special tokens on audio codes From 544eee0db95279b8fe0f4d0fa26ead6f0cbcd8eb Mon Sep 17 00:00:00 2001 From: Edresson Date: Tue, 25 Nov 2025 16:53:32 +0000 Subject: [PATCH 056/102] Apply isort and black reformatting Signed-off-by: Edresson Signed-off-by: Edresson Casanova --- nemo/collections/speechlm2/data/duplex_ear_tts_dataset.py | 4 +++- nemo/collections/speechlm2/models/duplex_ear_tts.py | 4 +--- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/nemo/collections/speechlm2/data/duplex_ear_tts_dataset.py b/nemo/collections/speechlm2/data/duplex_ear_tts_dataset.py index 76d4de5a17e3..9f3f55499f9b 100644 --- a/nemo/collections/speechlm2/data/duplex_ear_tts_dataset.py +++ b/nemo/collections/speechlm2/data/duplex_ear_tts_dataset.py @@ -303,7 +303,9 @@ def __getitem__(self, cuts: CutSet) -> dict: # for each sample in the batch for i in range(input_text_tokens.size(0)): # create a eos token tensor - initial_text_frame_id = torch.tensor([self.tokenizer.eos], dtype=torch.long, device=input_text_tokens[i].device) + initial_text_frame_id = torch.tensor( + [self.tokenizer.eos], dtype=torch.long, device=input_text_tokens[i].device + ) if self.add_audio_prompt_after_description: prompt_audio_size = int( diff --git a/nemo/collections/speechlm2/models/duplex_ear_tts.py b/nemo/collections/speechlm2/models/duplex_ear_tts.py index 5e6d743a463f..3a692215db64 100644 --- a/nemo/collections/speechlm2/models/duplex_ear_tts.py +++ b/nemo/collections/speechlm2/models/duplex_ear_tts.py @@ -1198,9 +1198,7 @@ def set_init_inputs(self, speaker_audio, speaker_audio_lens, system_prompt=None, # Prepend an initial text EOS token followed by padding tokens that match # the number of audio-prompt frames (in text-token units). - input_text_tokens = torch.cat( - [first_text_frame, prompt_audio_text_pad.to(first_text_frame.dtype)] - ) + input_text_tokens = torch.cat([first_text_frame, prompt_audio_text_pad.to(first_text_frame.dtype)]) # create pad audio for the description pad_size = first_text_frame.size(-1) * self.target_samples_per_frame From d32f4ef1c4a64d28e34d353acc7c6cd9d57c1b4a Mon Sep 17 00:00:00 2001 From: Edresson Casanova Date: Tue, 25 Nov 2025 09:05:11 -0800 Subject: [PATCH 057/102] Remove custom test sentence inference logic from the model Signed-off-by: Edresson Casanova --- .../speechlm2/data/duplex_ear_tts_dataset.py | 1 - .../speechlm2/models/duplex_ear_tts.py | 126 +++--------------- 2 files changed, 21 insertions(+), 106 deletions(-) diff --git a/nemo/collections/speechlm2/data/duplex_ear_tts_dataset.py b/nemo/collections/speechlm2/data/duplex_ear_tts_dataset.py index 9f3f55499f9b..6cc11271c6ca 100644 --- a/nemo/collections/speechlm2/data/duplex_ear_tts_dataset.py +++ b/nemo/collections/speechlm2/data/duplex_ear_tts_dataset.py @@ -24,7 +24,6 @@ from nemo.collections.common.tokenizers import TokenizerSpec from nemo.collections.speechlm2.data.utils import get_pad_id -from nemo.collections.speechlm2.modules.ear_tts_commons import SCRIPT_PLACEHOLDER from nemo.utils import logging diff --git a/nemo/collections/speechlm2/models/duplex_ear_tts.py b/nemo/collections/speechlm2/models/duplex_ear_tts.py index 3a692215db64..05efc342988a 100644 --- a/nemo/collections/speechlm2/models/duplex_ear_tts.py +++ b/nemo/collections/speechlm2/models/duplex_ear_tts.py @@ -12,7 +12,6 @@ # See the License for the specific language governing permissions and # limitations under the License. import copy -import glob import os import tempfile import time @@ -41,7 +40,6 @@ from nemo.collections.common.parts.nlp_overrides import NLPSaveRestoreConnector from nemo.collections.common.tokenizers import AutoTokenizer from nemo.collections.speechlm2.data.utils import get_pad_id -from nemo.collections.speechlm2.modules.ear_tts_commons import SCRIPT_PLACEHOLDER from nemo.collections.speechlm2.modules.ear_tts_model import RVQEARTTSModel from nemo.collections.speechlm2.modules.ear_tts_vae_codec import RVQVAEModel from nemo.collections.speechlm2.parts.hf_hub import HFHubMixin @@ -1019,109 +1017,27 @@ def run_evaluation_one_batch(self, name, dataset_batch, use_dataloader_init=Fals ) def validation_step(self, batch: dict, batch_idx: int): - if self.cfg.get("test_sentences", None) and self.cfg.get("inference_speaker_reference", None): - for name in self.cfg.test_sentences.keys(): - logging.info(f"Generating {name} custom sentences.") - test_sentences = self.cfg.test_sentences[name] - results = {} - results["audio"], results["audio_len"], speaker_audio, speaker_audio_lens = ( - self.offline_inference_with_custom_sentences(test_sentences, self.cfg.inference_speaker_reference) - ) - with fp32_precision(): # resample is fragile to bfloat16 default dtype - metric_audio_pred = results["audio"] - metric_audio_pred_lens = results["audio_len"] - - # resample audio to the asr sampling rate - metric_audio_pred = resample(metric_audio_pred, self.target_sample_rate, 16000) - metric_audio_pred_lens = (metric_audio_pred_lens / self.target_sample_rate * 16000).to(torch.long) - - asr_hyps = self.asr_bleu.update( - name=name, - refs=test_sentences, - pred_audio=metric_audio_pred, - pred_audio_lens=metric_audio_pred_lens, - ) - - self.intelligibility.update( - name=name, - refs=test_sentences, - pred_audio=metric_audio_pred, - pred_audio_lens=metric_audio_pred_lens, - asr_hyps=asr_hyps, - ) - - self.secs.update( - name=name, - target_audio=resample(speaker_audio, self.target_sample_rate, 16000), - target_audio_lens=(speaker_audio_lens / self.target_sample_rate * 16000).to(torch.long), - pred_audio=resample(results["audio"], self.target_sample_rate, 16000), - pred_audio_lens=(results["audio_len"] / self.target_sample_rate * 16000).to(torch.long), - ) - - self.results_logger.update( - name=name, - refs=test_sentences, - hyps=test_sentences, - asr_hyps=asr_hyps, - samples_id=[str(i) for i in range(len(test_sentences))], - pred_audio=results["audio"].float(), - pred_audio_tf=None, - pre_audio_trimmed=None, - reference_audio=speaker_audio.float(), - target_audio=None, - pred_audio_sr=self.target_sample_rate, - user_audio=None, - user_audio_sr=None, - eou_pred=None, - fps=self.target_fps, - results=None, - tokenizer=self.tokenizer, - ) - - else: - for name, dataset_batch in batch.items(): - if dataset_batch is None: - continue # some dataset is exhausted - - B = len(dataset_batch['sample_id']) - - # run inference for multiples references - if self.cfg.get("inference_speaker_reference_path", None): - for inference_speaker_reference in glob.glob( - os.path.join(self.cfg.inference_speaker_reference_path, "**"), recursive=True - ): - if not os.path.isfile(inference_speaker_reference): - continue - - new_dataset_batch = copy.deepcopy(dataset_batch) - # Get only the file name - ref_name = os.path.basename(inference_speaker_reference).split(".")[0] - # Append to each sample_id - new_dataset_batch['sample_id'] = [f"{sid}_{ref_name}" for sid in dataset_batch['sample_id']] - speaker_audio, sr = load_audio_librosa(inference_speaker_reference) - speaker_audio = resample(speaker_audio, sr, self.target_sample_rate) - speaker_audio = speaker_audio.repeat(B, 1).to(self.device) - # lengths -> [B] - speaker_audio_lens = torch.tensor([speaker_audio.size(1)], device=self.device).long().repeat(B) - new_dataset_batch["speaker_reference_audio"] = speaker_audio - new_dataset_batch["speaker_reference_audio_lens"] = speaker_audio_lens - self.run_evaluation_one_batch(name, new_dataset_batch, use_dataloader_init=False) - - # run inference for a custom speaker reference - elif self.cfg.get("inference_speaker_reference", None): - new_dataset_batch = copy.deepcopy(dataset_batch) - speaker_audio, sr = load_audio_librosa(self.cfg.inference_speaker_reference) - speaker_audio = resample(speaker_audio, sr, self.target_sample_rate) - speaker_audio = speaker_audio.repeat(B, 1).to(self.device) - # lengths -> [B] - speaker_audio_lens = torch.tensor([speaker_audio.size(1)], device=self.device).long().repeat(B) - new_dataset_batch["speaker_reference_audio"] = speaker_audio - new_dataset_batch["speaker_reference_audio_lens"] = speaker_audio_lens - self.run_evaluation_one_batch(name, new_dataset_batch, use_dataloader_init=False) - - # run inference using dataloader speaker references - else: - self.run_evaluation_one_batch(name, dataset_batch, use_dataloader_init=False) + for name, dataset_batch in batch.items(): + if dataset_batch is None: + continue # some dataset is exhausted + + B = len(dataset_batch['sample_id']) + + # run inference for a custom speaker reference + if self.cfg.get("inference_speaker_reference", None): + new_dataset_batch = copy.deepcopy(dataset_batch) + speaker_audio, sr = load_audio_librosa(self.cfg.inference_speaker_reference) + speaker_audio = resample(speaker_audio, sr, self.target_sample_rate) + speaker_audio = speaker_audio.repeat(B, 1).to(self.device) + # lengths -> [B] + speaker_audio_lens = torch.tensor([speaker_audio.size(1)], device=self.device).long().repeat(B) + new_dataset_batch["speaker_reference_audio"] = speaker_audio + new_dataset_batch["speaker_reference_audio_lens"] = speaker_audio_lens + self.run_evaluation_one_batch(name, new_dataset_batch) + + # run inference using dataloader speaker references + else: + self.run_evaluation_one_batch(name, dataset_batch, use_dataloader_init=False) def on_test_epoch_start(self) -> None: return self.on_validation_epoch_start() From e12b192c91816a26608b314d3c2ee88a46abb89b Mon Sep 17 00:00:00 2001 From: Edresson Casanova Date: Tue, 25 Nov 2025 09:06:19 -0800 Subject: [PATCH 058/102] Remove .nemo file loading support Signed-off-by: Edresson Casanova --- nemo/collections/speechlm2/models/duplex_ear_tts.py | 11 +---------- 1 file changed, 1 insertion(+), 10 deletions(-) diff --git a/nemo/collections/speechlm2/models/duplex_ear_tts.py b/nemo/collections/speechlm2/models/duplex_ear_tts.py index 05efc342988a..f4dc158cce1a 100644 --- a/nemo/collections/speechlm2/models/duplex_ear_tts.py +++ b/nemo/collections/speechlm2/models/duplex_ear_tts.py @@ -13,7 +13,6 @@ # limitations under the License. import copy import os -import tempfile import time from collections import Counter from contextlib import contextmanager @@ -37,7 +36,6 @@ ) from nemo.collections.audio.parts.utils.resampling import resample -from nemo.collections.common.parts.nlp_overrides import NLPSaveRestoreConnector from nemo.collections.common.tokenizers import AutoTokenizer from nemo.collections.speechlm2.data.utils import get_pad_id from nemo.collections.speechlm2.modules.ear_tts_model import RVQEARTTSModel @@ -437,14 +435,7 @@ def restore_from_pretrained_checkpoint(self, checkpoint_path): None. The model is updated in-place. """ if checkpoint_path is not None: - if '.nemo' in checkpoint_path: - with tempfile.TemporaryDirectory() as tmpdir: - NLPSaveRestoreConnector._unpack_nemo_file(checkpoint_path, tmpdir) - checkpoint_path = f"{tmpdir}/model_weights.ckpt" - checkpoint_state = torch.load(checkpoint_path, map_location='cpu') - else: - checkpoint_state = torch.load(checkpoint_path, weights_only=False, map_location='cpu')['state_dict'] - + checkpoint_state = torch.load(checkpoint_path, weights_only=False, map_location='cpu')['state_dict'] checkpoint_state = set_model_dict_for_partial_init(checkpoint_state, self.state_dict()) if self.cfg.get("rescale_pretrained_weights", None): From 1cfd4f3c1fc78b31a4ccb72b30c1b177d507f431 Mon Sep 17 00:00:00 2001 From: Edresson Casanova Date: Tue, 25 Nov 2025 09:31:56 -0800 Subject: [PATCH 059/102] Rename speaker_reference with audio_prompt Signed-off-by: Edresson Casanova --- .../speechlm2/data/duplex_ear_tts_dataset.py | 58 ++++++------- .../speechlm2/models/duplex_ear_tts.py | 85 +++---------------- .../speechlm2/test_duplex_eartts.py | 8 +- 3 files changed, 43 insertions(+), 108 deletions(-) diff --git a/nemo/collections/speechlm2/data/duplex_ear_tts_dataset.py b/nemo/collections/speechlm2/data/duplex_ear_tts_dataset.py index 6cc11271c6ca..e29021d1345b 100644 --- a/nemo/collections/speechlm2/data/duplex_ear_tts_dataset.py +++ b/nemo/collections/speechlm2/data/duplex_ear_tts_dataset.py @@ -187,8 +187,8 @@ class DuplexEARTTSDataset(torch.utils.data.Dataset): - target_texts: List of full target texts joined from output_roles supervisions [B] - - speaker_reference_audio: Tensor of optional speaker reference waveform samples [B, T] - - speaker_reference_audio_lens: Tensor of speaker reference audio lengths [B] + - audio_prompt: Tensor of optional speaker reference waveform samples [B, T] + - audio_prompt_lens: Tensor of speaker reference audio lengths [B] - formatter: List indicating the formatter to use for each cut (default "s2s_duplex") [B] @@ -255,19 +255,19 @@ def __getitem__(self, cuts: CutSet) -> dict: # if context audio is available use it, otherwise use a random agent turn as speaker reference if hasattr(cuts[0], "context_audio"): - speaker_reference_audio = [] - speaker_reference_audio_lens = [] + audio_prompt = [] + audio_prompt_lens = [] for cut in cuts: ref_audio = torch.tensor(cut.context_audio.resample(self.target_sample_rate).load_audio()).float() ref_audio_len = torch.tensor(ref_audio.shape[1]).long() - speaker_reference_audio.append(ref_audio.squeeze(0)) - speaker_reference_audio_lens.append(ref_audio_len) + audio_prompt.append(ref_audio.squeeze(0)) + audio_prompt_lens.append(ref_audio_len) - speaker_reference_audio = collate_vectors(speaker_reference_audio, padding_value=0).float() - speaker_reference_audio_lens = torch.tensor(speaker_reference_audio_lens).long() + audio_prompt = collate_vectors(audio_prompt, padding_value=0).float() + audio_prompt_lens = torch.tensor(audio_prompt_lens).long() else: # extract target speaker reference from a random audio audio - speaker_reference_audio, speaker_reference_audio_lens = collate_random_turn_audio( + audio_prompt, audio_prompt_lens = collate_random_turn_audio( cuts.resample(self.target_sample_rate), roles=self.output_roles, recording_field="target_audio" ) @@ -302,7 +302,7 @@ def __getitem__(self, cuts: CutSet) -> dict: # for each sample in the batch for i in range(input_text_tokens.size(0)): # create a eos token tensor - initial_text_frame_id = torch.tensor( + first_text_frame = torch.tensor( [self.tokenizer.eos], dtype=torch.long, device=input_text_tokens[i].device ) @@ -312,7 +312,7 @@ def __getitem__(self, cuts: CutSet) -> dict: * target_samples_per_frame ) prompt_audio = sample_audio_segments_repeat( - speaker_reference_audio, speaker_reference_audio_lens, prompt_audio_size, sample=True + audio_prompt, audio_prompt_lens, prompt_audio_size, sample=True ) # add a silence in the end to smooth the transition between prompt and audio tokens, keep one extra pad token due shift on subword_ids prompt_audio[:, -int(target_samples_per_frame * 2) :] = 0 @@ -332,54 +332,54 @@ def __getitem__(self, cuts: CutSet) -> dict: # the number of audio-prompt frames (in text-token units) and input_text_tokens new_input_text_tokens = torch.cat( [ - initial_text_frame_id.to(input_text_tokens.dtype), + first_text_frame.to(input_text_tokens.dtype), prompt_audio_text_pad.to(input_text_tokens.dtype), input_text_tokens[i], ] ) # append to list and update lens input_text_tokens_.append(new_input_text_tokens) - target_token_lens[i] = target_token_lens[i] + len(initial_text_frame_id) + prompt_audio_text_pad_size + target_token_lens[i] = target_token_lens[i] + len(first_text_frame) + prompt_audio_text_pad_size # add description to source text tokens - source_tokens_.append(torch.cat([initial_text_frame_id, prompt_audio_text_pad, source_tokens[i]])) - source_token_lens[i] = source_token_lens[i] + len(initial_text_frame_id) + prompt_audio_text_pad_size + source_tokens_.append(torch.cat([first_text_frame, prompt_audio_text_pad, source_tokens[i]])) + source_token_lens[i] = source_token_lens[i] + len(first_text_frame) + prompt_audio_text_pad_size # add silence in the source audio while the prompt is being processed - pad_size = (len(initial_text_frame_id) * source_samples_per_frame) + prompt_audio.size(1) + pad_size = (len(first_text_frame) * source_samples_per_frame) + prompt_audio.size(1) pad_audio = torch.zeros(pad_size, device=source_audio.device, dtype=source_audio.dtype) source_audio_.append(torch.cat([pad_audio, source_audio[i]])) source_audio_lens[i] = source_audio_lens[i] + pad_size # add silence in the target audio while the prompt is being processed - pad_size = len(initial_text_frame_id) * target_samples_per_frame + pad_size = len(first_text_frame) * target_samples_per_frame pad_audio = torch.zeros(pad_size, device=target_audio.device, dtype=target_audio.dtype) target_audio_.append(torch.cat([pad_audio, prompt_audio[i], target_audio[i]])) target_audio_lens[i] = target_audio_lens[i] + pad_size + prompt_audio.size(1) # desc duration - desc_lens.append(len(initial_text_frame_id)) + desc_lens.append(len(first_text_frame)) desc_plus_audio_prompt_lens.append( - len(initial_text_frame_id) + prompt_audio_text_pad_size - 1 + len(first_text_frame) + prompt_audio_text_pad_size - 1 ) # -1 due the shift done in subword_ids else: # add description to target text tokens - input_text_tokens_.append(torch.cat([initial_text_frame_id, input_text_tokens[i]])) - target_token_lens[i] = target_token_lens[i] + len(initial_text_frame_id) + input_text_tokens_.append(torch.cat([first_text_frame, input_text_tokens[i]])) + target_token_lens[i] = target_token_lens[i] + len(first_text_frame) # add description to source text tokens - source_tokens_.append(torch.cat([initial_text_frame_id, source_tokens[i]])) - source_token_lens[i] = source_token_lens[i] + len(initial_text_frame_id) + source_tokens_.append(torch.cat([first_text_frame, source_tokens[i]])) + source_token_lens[i] = source_token_lens[i] + len(first_text_frame) # add silence in the source audio while the prompt is being processed - pad_size = len(initial_text_frame_id) * source_samples_per_frame + pad_size = len(first_text_frame) * source_samples_per_frame pad_audio = torch.zeros(pad_size, device=source_audio.device, dtype=source_audio.dtype) source_audio_.append(torch.cat([pad_audio, source_audio[i]])) source_audio_lens[i] = source_audio_lens[i] + pad_size # add silence in the target audio while the prompt is being processed - pad_size = len(initial_text_frame_id) * target_samples_per_frame + pad_size = len(first_text_frame) * target_samples_per_frame pad_audio = torch.zeros(pad_size, device=target_audio.device, dtype=target_audio.dtype) target_audio_.append(torch.cat([pad_audio, target_audio[i]])) target_audio_lens[i] = target_audio_lens[i] + pad_size # des duration - desc_lens.append(len(initial_text_frame_id)) - desc_plus_audio_prompt_lens.append(len(initial_text_frame_id)) + desc_lens.append(len(first_text_frame)) + desc_plus_audio_prompt_lens.append(len(first_text_frame)) # collate tensors input_text_tokens = collate_vectors(input_text_tokens_, padding_value=text_pad_id) @@ -451,8 +451,8 @@ def __getitem__(self, cuts: CutSet) -> dict: "target_texts": [ " ".join(s.text for s in cut.supervisions if s.speaker in self.output_roles) for cut in cuts ], - "speaker_reference_audio": speaker_reference_audio, - "speaker_reference_audio_lens": speaker_reference_audio_lens, + "audio_prompt": audio_prompt, + "audio_prompt_lens": audio_prompt_lens, "formatter": [getattr(cut, "formatter", "s2s_duplex") for cut in cuts], } diff --git a/nemo/collections/speechlm2/models/duplex_ear_tts.py b/nemo/collections/speechlm2/models/duplex_ear_tts.py index f4dc158cce1a..2f1e3281bddb 100644 --- a/nemo/collections/speechlm2/models/duplex_ear_tts.py +++ b/nemo/collections/speechlm2/models/duplex_ear_tts.py @@ -330,18 +330,10 @@ def __init__(self, cfg: dict) -> None: self.cfg.pretrained_lm_name, use_fast=True, trust_remote_code=True ) # Note that we are using fast tokenizer - if 'Qwen2.5' in self.cfg.pretrained_lm_name: - # For Qwen, '<|im_start|>' is a common choice for a BOS token. - # You can check your tokenizer's vocabulary for the best candidate. - logging.warning("Tokenizer does not have a `bos_token`. Setting it to '<|im_start|>'.") - self.tokenizer.bos_token = '<|im_start|>' - self.tokenizer.eos_token = '<|im_end|>' - - elif 'Nemotron' in self.cfg.pretrained_lm_name: - # ====== NEMOTRON-SPECIFIC HANDLING ====== - self.tokenizer.bos_token = '' - self.tokenizer.eos_token = '' - self.tokenizer.pad_token = '' + # set tokenizer special tokens + self.tokenizer.bos_token = self.cfg.get("bos_token", '') + self.tokenizer.eos_token = self.cfg.get("eos_token", '') + self.tokenizer.pad_token = self.cfg.get("pad_token", '') # cached for quicker audio decoding self.register_buffer( @@ -538,7 +530,7 @@ def prepare_inputs(self, batch: dict): """ # check if audios has the same batch size assert batch["source_audio"].size(0) == batch["target_audio"].size(0) - assert batch["speaker_reference_audio"].size(0) == batch["target_audio"].size(0) + assert batch["audio_prompt"].size(0) == batch["target_audio"].size(0) target_audio = batch["target_audio"] target_audio_lens = batch["target_audio_lens"] @@ -787,63 +779,6 @@ def _get_generation_config(self, guidance_enabled: bool = False): "eos_threshold": -3.0, } - def offline_inference_with_custom_sentences( - self, test_sentences: torch.Tensor, inference_speaker_reference: torch.Tensor, speech_text_ratio: float = 3.5 - ): - # ToDo: split it in multiples batches to support long list of sentences - B = len(test_sentences) - # load and get speaker reference - speaker_audio, sr = load_audio_librosa(inference_speaker_reference) - speaker_audio = resample(speaker_audio, sr, self.target_sample_rate) - speaker_audio = speaker_audio.repeat(B, 1).to(self.device) - # lengths -> [B] - speaker_audio_lens = torch.tensor([speaker_audio.size(1)], device=self.device).long().repeat(B) - - # Tokenize sentences - tokenized = [ - torch.as_tensor( - [self.tokenizer.bos] + self.tokenizer.text_to_ids(text), dtype=torch.long, device=self.device - ) - for text in test_sentences - ] - - # Get max length and target length - max_len = max(len(t) for t in tokenized) - # Pad each to double length - target_len = int( - speech_text_ratio * max_len - ) # make text longer to ensures that we have enough steps for speech gen - next_subword_ids = torch.stack( - [ - torch.cat( - [ - torch.tensor( - [self.text_pad_id], dtype=torch.long, device=self.device - ), # shift right adding one padding token - t, - torch.full( - (target_len - len(t) - 1,), self.text_pad_id, dtype=torch.long, device=self.device - ), # remaining padding - ] - ) - for t in tokenized - ] - ) - - # set init inputs and get it - self.set_init_inputs( - speaker_audio=speaker_audio, - speaker_audio_lens=speaker_audio_lens, - ) - init_inputs = self.get_init_inputs(B=next_subword_ids.size(0)) - - audio, audio_len = self.offline_inference( - next_subword_ids=next_subword_ids, - guidance_enabled=self.cfg.get("inference_guidance_enabled", True), - init_inputs=init_inputs, - ) - return audio, audio_len, speaker_audio, speaker_audio_lens - def run_evaluation_one_batch(self, name, dataset_batch, use_dataloader_init=False): """ Runs evaluation and scoring for a single data batch, logging metrics and updating result buffers. @@ -882,8 +817,8 @@ def run_evaluation_one_batch(self, name, dataset_batch, use_dataloader_init=Fals else: # set init inputs and get it self.set_init_inputs( - speaker_audio=dataset_batch["speaker_reference_audio"], - speaker_audio_lens=dataset_batch["speaker_reference_audio_lens"], + speaker_audio=dataset_batch["audio_prompt"], + speaker_audio_lens=dataset_batch["audio_prompt_lens"], ) init_inputs = self.get_init_inputs(B=inputs["subword_ids"].size(0)) @@ -996,7 +931,7 @@ def run_evaluation_one_batch(self, name, dataset_batch, use_dataloader_init=Fals pred_audio=results["audio"].float(), pred_audio_tf=results["audio_tf"].float(), pre_audio_trimmed=None, - reference_audio=dataset_batch["speaker_reference_audio"].float(), + reference_audio=dataset_batch["audio_prompt"].float(), target_audio=target_audio_no_prompt.float(), pred_audio_sr=self.target_sample_rate, user_audio=dataset_batch["source_audio"].float(), @@ -1022,8 +957,8 @@ def validation_step(self, batch: dict, batch_idx: int): speaker_audio = speaker_audio.repeat(B, 1).to(self.device) # lengths -> [B] speaker_audio_lens = torch.tensor([speaker_audio.size(1)], device=self.device).long().repeat(B) - new_dataset_batch["speaker_reference_audio"] = speaker_audio - new_dataset_batch["speaker_reference_audio_lens"] = speaker_audio_lens + new_dataset_batch["audio_prompt"] = speaker_audio + new_dataset_batch["audio_prompt_lens"] = speaker_audio_lens self.run_evaluation_one_batch(name, new_dataset_batch) # run inference using dataloader speaker references diff --git a/tests/collections/speechlm2/test_duplex_eartts.py b/tests/collections/speechlm2/test_duplex_eartts.py index 364ee0c6edb2..e1b9fd1b7fa6 100644 --- a/tests/collections/speechlm2/test_duplex_eartts.py +++ b/tests/collections/speechlm2/test_duplex_eartts.py @@ -265,8 +265,8 @@ def test_eartts_dataset(dataset, training_cutset_batch): "source_tokens", "source_token_lens", "target_texts", - "speaker_reference_audio", - "speaker_reference_audio_lens", + "audio_prompt", + "audio_prompt_lens", "formatter", } @@ -286,8 +286,8 @@ def test_eartts_dataset(dataset, training_cutset_batch): "target_token_lens", "source_tokens", "source_token_lens", - "speaker_reference_audio", - "speaker_reference_audio_lens", + "audio_prompt", + "audio_prompt_lens", ] for key in tensor_keys: From 2caec09ad201092b07b89f2fa1a52954b586f9e8 Mon Sep 17 00:00:00 2001 From: Edresson Casanova Date: Tue, 25 Nov 2025 12:58:07 -0800 Subject: [PATCH 060/102] Modularize dataloader get_item Signed-off-by: Edresson Casanova --- examples/speechlm2/duplex_eartts_infer.py | 2 +- examples/speechlm2/duplex_eartts_train.py | 2 +- .../speechlm2/data/duplex_ear_tts_dataset.py | 524 +++++++++++++----- .../speechlm2/models/duplex_ear_tts.py | 32 +- .../speechlm2/test_duplex_eartts.py | 9 +- 5 files changed, 392 insertions(+), 177 deletions(-) diff --git a/examples/speechlm2/duplex_eartts_infer.py b/examples/speechlm2/duplex_eartts_infer.py index 8118e086950c..255336849817 100644 --- a/examples/speechlm2/duplex_eartts_infer.py +++ b/examples/speechlm2/duplex_eartts_infer.py @@ -50,7 +50,7 @@ def inference(cfg): input_roles=cfg.data.input_roles, output_roles=cfg.data.output_roles, add_text_bos_and_eos_in_each_turn=cfg.data.get("add_text_bos_and_eos_in_each_turn", True), - add_audio_prompt_after_description=cfg.data.add_audio_prompt_after_description, + add_audio_prompt=cfg.data.get("add_audio_prompt", True), audio_prompt_duration=cfg.data.audio_prompt_duration, num_delay_speech_tokens=cfg.model.get("num_delay_speech_tokens", 2), ) diff --git a/examples/speechlm2/duplex_eartts_train.py b/examples/speechlm2/duplex_eartts_train.py index c3ddcc072589..aaccfe6dcb72 100644 --- a/examples/speechlm2/duplex_eartts_train.py +++ b/examples/speechlm2/duplex_eartts_train.py @@ -48,7 +48,7 @@ def train(cfg): input_roles=cfg.data.input_roles, output_roles=cfg.data.output_roles, add_text_bos_and_eos_in_each_turn=cfg.data.get("add_text_bos_and_eos_in_each_turn", True), - add_audio_prompt_after_description=cfg.data.add_audio_prompt_after_description, + add_audio_prompt=cfg.data.get("add_audio_prompt", True), audio_prompt_duration=cfg.data.audio_prompt_duration, num_delay_speech_tokens=cfg.model.get("num_delay_speech_tokens", 2), ) diff --git a/nemo/collections/speechlm2/data/duplex_ear_tts_dataset.py b/nemo/collections/speechlm2/data/duplex_ear_tts_dataset.py index e29021d1345b..1e8aa6819090 100644 --- a/nemo/collections/speechlm2/data/duplex_ear_tts_dataset.py +++ b/nemo/collections/speechlm2/data/duplex_ear_tts_dataset.py @@ -147,13 +147,13 @@ class DuplexEARTTSDataset(torch.utils.data.Dataset): with BOS and EOS tokens in the text stream. Default: `True`. - add_audio_prompt_after_description (bool, optional): + add_audio_prompt (bool, optional): If True, an optional audio/speaker prompt is appended after the description. Default: `True`. audio_prompt_duration (float, optional): Duration (in seconds) of the audio prompt appended when - `add_audio_prompt_after_description=True`. Default: `3.0`. + `add_audio_prompt=True`. Default: `3.0`. num_delay_speech_tokens (int, optional): Number of PAD tokens to insert before speech tokens to artificially @@ -164,9 +164,7 @@ class DuplexEARTTSDataset(torch.utils.data.Dataset): - sample_id: List of sample IDs for each cut in the batch [B] - non_prompt_mask: Bool tensor [B, T] marking positions that are not part of the prompt - - desc_mask: Bool tensor [B, T] marking positions belonging to the description - - desc_lens: Tensor of description text lengths [B] - - desc_plus_audio_prompt_lens: Tensor of description + audio prompt lengths [B] + - prompt_lens: Tensor of description + audio prompt lengths [B] - aligned_attention_mask: Bool tensor [B, T] used by alignment-aware transformer models - aligned_position_ids: Tensor of position indices aligned to audio frames [B, T] @@ -213,7 +211,7 @@ def __init__( output_roles: list[str] = None, p_drop_description: float = 0.0, add_text_bos_and_eos_in_each_turn: bool = True, - add_audio_prompt_after_description: bool = True, + add_audio_prompt: bool = True, audio_prompt_duration: float = 3.0, num_delay_speech_tokens: int = 0, ): @@ -225,10 +223,16 @@ def __init__( self.output_roles = set(ifnone(output_roles, ["agent"])) self.p_drop_description = p_drop_description self.add_text_bos_and_eos_in_each_turn = add_text_bos_and_eos_in_each_turn - self.add_audio_prompt_after_description = add_audio_prompt_after_description + self.add_audio_prompt = add_audio_prompt self.audio_prompt_duration = audio_prompt_duration self.num_delay_speech_tokens = num_delay_speech_tokens + # compute source and target samples_per_frame + source_fps = self.source_sample_rate / (self.source_sample_rate * self.frame_length) + self.source_samples_per_frame = int(self.source_sample_rate // source_fps) + target_fps = self.target_sample_rate / (self.target_sample_rate * self.frame_length) + self.target_samples_per_frame = int(self.target_sample_rate // target_fps) + assert tokenizer.bos is not None, "BOS support in the tokenizer is required for S2S models." assert tokenizer.eos is not None, "EOS support in the tokenizer is required for S2S models." @@ -253,152 +257,49 @@ def __getitem__(self, cuts: CutSet) -> dict: add_text_bos_and_eos_in_each_turn=self.add_text_bos_and_eos_in_each_turn, ) - # if context audio is available use it, otherwise use a random agent turn as speaker reference - if hasattr(cuts[0], "context_audio"): - audio_prompt = [] - audio_prompt_lens = [] - for cut in cuts: - ref_audio = torch.tensor(cut.context_audio.resample(self.target_sample_rate).load_audio()).float() - ref_audio_len = torch.tensor(ref_audio.shape[1]).long() - audio_prompt.append(ref_audio.squeeze(0)) - audio_prompt_lens.append(ref_audio_len) - - audio_prompt = collate_vectors(audio_prompt, padding_value=0).float() - audio_prompt_lens = torch.tensor(audio_prompt_lens).long() - else: - # extract target speaker reference from a random audio audio - audio_prompt, audio_prompt_lens = collate_random_turn_audio( - cuts.resample(self.target_sample_rate), roles=self.output_roles, recording_field="target_audio" - ) + audio_prompt, audio_prompt_lens = get_audio_prompt(cuts, self.target_sample_rate, roles=self.output_roles, recording_field="target_audio") # ensures that input_text_tokens is not longer than its duration input_text_tokens = input_text_tokens[:, : target_token_lens.max()] - source_fps = self.source_sample_rate / (self.source_sample_rate * self.frame_length) - source_samples_per_frame = int(self.source_sample_rate // source_fps) - target_fps = self.target_sample_rate / (self.target_sample_rate * self.frame_length) - target_samples_per_frame = int(self.target_sample_rate // target_fps) - - # one is default and we add BOS on speech channel to ensures it, inside of the model class, so if we want bigger than that we can add padding in the audio here + # add speech channel delay if needed if self.num_delay_speech_tokens: - # compute the padding need in target audio for the number of delay tokens - extra_frames = int(self.num_delay_speech_tokens * target_samples_per_frame) - # left pad target audio to create the delay and make the model to predict silence while consuming self.num_delay_speech_tokens text tokens - target_audio = F.pad(target_audio, (extra_frames, 0)) - target_audio_lens = target_audio_lens + extra_frames - - # right pad the source audio to avoid size mismatch - extra_frames = int(self.num_delay_speech_tokens * source_samples_per_frame) - source_audio = F.pad(source_audio, (0, extra_frames)) - source_audio_lens = source_audio_lens + extra_frames - - text_pad_id = get_pad_id(self.tokenizer) - input_text_tokens_ = [] - source_tokens_ = [] - source_audio_ = [] - target_audio_ = [] - desc_lens = [] - desc_plus_audio_prompt_lens = [] - # for each sample in the batch - for i in range(input_text_tokens.size(0)): - # create a eos token tensor - first_text_frame = torch.tensor( - [self.tokenizer.eos], dtype=torch.long, device=input_text_tokens[i].device + source_audio, source_audio_lens, target_audio, target_audio_lens = add_speech_delay( + source_audio, + source_audio_lens, + target_audio, + target_audio_lens, + self.num_delay_speech_tokens, + self.target_samples_per_frame, + self.source_samples_per_frame, ) - if self.add_audio_prompt_after_description: - prompt_audio_size = int( - ((self.audio_prompt_duration * self.target_sample_rate) // target_samples_per_frame) - * target_samples_per_frame - ) - prompt_audio = sample_audio_segments_repeat( - audio_prompt, audio_prompt_lens, prompt_audio_size, sample=True - ) - # add a silence in the end to smooth the transition between prompt and audio tokens, keep one extra pad token due shift on subword_ids - prompt_audio[:, -int(target_samples_per_frame * 2) :] = 0 - - # create tensor to pad text channels with the same amount of frames added in audio channel (audio prompt) - prompt_audio_text_pad_size = prompt_audio_size // target_samples_per_frame - prompt_audio_text_pad = ( - torch.ones( - prompt_audio_text_pad_size, device=input_text_tokens.device, dtype=input_text_tokens.dtype - ) - * text_pad_id - ) - # set last prompt frame with eos in text channel - prompt_audio_text_pad[-1] = self.tokenizer.eos - - # Prepend an initial text EOS token followed by padding tokens that match - # the number of audio-prompt frames (in text-token units) and input_text_tokens - new_input_text_tokens = torch.cat( - [ - first_text_frame.to(input_text_tokens.dtype), - prompt_audio_text_pad.to(input_text_tokens.dtype), - input_text_tokens[i], - ] - ) - # append to list and update lens - input_text_tokens_.append(new_input_text_tokens) - target_token_lens[i] = target_token_lens[i] + len(first_text_frame) + prompt_audio_text_pad_size - - # add description to source text tokens - source_tokens_.append(torch.cat([first_text_frame, prompt_audio_text_pad, source_tokens[i]])) - source_token_lens[i] = source_token_lens[i] + len(first_text_frame) + prompt_audio_text_pad_size - # add silence in the source audio while the prompt is being processed - pad_size = (len(first_text_frame) * source_samples_per_frame) + prompt_audio.size(1) - pad_audio = torch.zeros(pad_size, device=source_audio.device, dtype=source_audio.dtype) - source_audio_.append(torch.cat([pad_audio, source_audio[i]])) - source_audio_lens[i] = source_audio_lens[i] + pad_size - # add silence in the target audio while the prompt is being processed - pad_size = len(first_text_frame) * target_samples_per_frame - pad_audio = torch.zeros(pad_size, device=target_audio.device, dtype=target_audio.dtype) - target_audio_.append(torch.cat([pad_audio, prompt_audio[i], target_audio[i]])) - target_audio_lens[i] = target_audio_lens[i] + pad_size + prompt_audio.size(1) - # desc duration - desc_lens.append(len(first_text_frame)) - desc_plus_audio_prompt_lens.append( - len(first_text_frame) + prompt_audio_text_pad_size - 1 - ) # -1 due the shift done in subword_ids - else: - # add description to target text tokens - input_text_tokens_.append(torch.cat([first_text_frame, input_text_tokens[i]])) - target_token_lens[i] = target_token_lens[i] + len(first_text_frame) - # add description to source text tokens - source_tokens_.append(torch.cat([first_text_frame, source_tokens[i]])) - source_token_lens[i] = source_token_lens[i] + len(first_text_frame) - # add silence in the source audio while the prompt is being processed - pad_size = len(first_text_frame) * source_samples_per_frame - pad_audio = torch.zeros(pad_size, device=source_audio.device, dtype=source_audio.dtype) - source_audio_.append(torch.cat([pad_audio, source_audio[i]])) - source_audio_lens[i] = source_audio_lens[i] + pad_size - # add silence in the target audio while the prompt is being processed - pad_size = len(first_text_frame) * target_samples_per_frame - pad_audio = torch.zeros(pad_size, device=target_audio.device, dtype=target_audio.dtype) - target_audio_.append(torch.cat([pad_audio, target_audio[i]])) - target_audio_lens[i] = target_audio_lens[i] + pad_size - - # des duration - desc_lens.append(len(first_text_frame)) - desc_plus_audio_prompt_lens.append(len(first_text_frame)) - - # collate tensors - input_text_tokens = collate_vectors(input_text_tokens_, padding_value=text_pad_id) - source_tokens = collate_vectors(source_tokens_, padding_value=text_pad_id) - source_audio = collate_vectors(source_audio_, padding_value=0) - target_audio = collate_vectors(target_audio_, padding_value=0) - - # recreate audio mask - non_desc_mask = get_mask_from_lengths(target_token_lens) - # ignore desc len in audio mask - for i, frame in enumerate(desc_lens): - non_desc_mask[i, :frame] = 0.0 - - # desc mask is totally the oposite of audio mask - desc_mask = ~non_desc_mask + ( + input_text_tokens, + target_token_lens, + source_tokens, + source_token_lens, + source_audio, + source_audio_lens, + target_audio, + target_audio_lens, + prompt_lens, + ) = self.maybe_add_audio_prompt( + input_text_tokens, + target_token_lens, + source_tokens, + source_token_lens, + target_audio, + target_audio_lens, + source_audio, + source_audio_lens, + audio_prompt, + audio_prompt_lens, + ) # create non_prompt_mask that should mask desc plus audio prompt if used non_prompt_mask = get_mask_from_lengths(target_token_lens) - for i, frame in enumerate(desc_plus_audio_prompt_lens): + for i, frame in enumerate(prompt_lens): non_prompt_mask[i, : frame - 1] = 0.0 max_len = max(target_token_lens) @@ -435,9 +336,7 @@ def __getitem__(self, cuts: CutSet) -> dict: return { "sample_id": [str(cut.id) for cut in cuts], "non_prompt_mask": non_prompt_mask.bool(), - "desc_mask": desc_mask.bool(), - "desc_lens": desc_lens, - "desc_plus_audio_prompt_lens": desc_plus_audio_prompt_lens, + "prompt_lens": prompt_lens, "aligned_attention_mask": aligned_attention_mask.bool(), "aligned_position_ids": aligned_position_ids, "source_audio": source_audio, @@ -456,6 +355,339 @@ def __getitem__(self, cuts: CutSet) -> dict: "formatter": [getattr(cut, "formatter", "s2s_duplex") for cut in cuts], } + def maybe_add_audio_prompt( + self, + input_text_tokens: torch.Tensor, + target_token_lens: torch.Tensor, + source_tokens: torch.Tensor, + source_token_lens: torch.Tensor, + target_audio: torch.Tensor, + target_audio_lens: torch.Tensor, + source_audio: torch.Tensor, + source_audio_lens: torch.Tensor, + audio_prompt: torch.Tensor, + audio_prompt_lens: torch.Tensor, + ): + """ + Prepend an audio-based speaker prompt and aligned text tokens to the Duplex S2S inputs. + + This method optionally injects a speaker-reference audio prompt at the beginning of each + sample in the batch. The prompt is inserted in the target-audio channel and aligned text + padding is inserted into the text-token streams (input text tokens and source tokens). + The corresponding silence padding is inserted into the source/target audio streams to + preserve frame-level alignment. + + Args: + input_text_tokens (torch.Tensor): + Tensor of input text tokens with shape [B, T_text]. + dtype: torch.long. + + target_token_lens (torch.Tensor): + Lengths of input_text_tokens per batch element (before padding). shape [B]. + + source_tokens (torch.Tensor): + Source-side text tokens, shape [B, T_src_text], dtype torch.long. + + source_token_lens (torch.Tensor): + Source text token lengths per batch element, shape [B]. + + target_audio (torch.Tensor): + Target-side audio waveforms, shape [B, T_audio], dtype torch.float32/float16. + + target_audio_lens (torch.Tensor): + Target audio lengths per batch element, shape [B]. + + source_audio (torch.Tensor): + Source-side audio waveforms, shape [B, T_audio], dtype torch.float32/float16. + + source_audio_lens (torch.Tensor): + Source audio lengths per batch element, shape [B]. + + audio_prompt (torch.Tensor): + Audio prompt waveforms to sample from, shape [B, T_prompt_audio]. + + audio_prompt_lens (torch.Tensor): + Valid lengths for audio_prompt per batch element. + + Returns: + Tuple containing: + input_text_tokens (torch.Tensor): + Updated text tokens with prepended prompt-aligned tokens. Shape [B, T']. + + target_token_lens (torch.Tensor): + Updated token lengths per batch element. + + source_tokens (torch.Tensor): + Updated source text tokens with prompt padding included. Shape [B, T']. + + source_token_lens (torch.Tensor): + Updated source token lengths per batch element. + + source_audio (torch.Tensor): + Updated source audio with silence padding. Shape [B, T_audio']. + + source_audio_lens (torch.Tensor): + Updated source audio lengths. + + target_audio (torch.Tensor): + Updated target audio with prompt audio and silence padding. Shape [B, T_audio']. + + target_audio_lens (torch.Tensor): + Updated target audio lengths. + + prompt_lens (list[int]): + Length (in text-token units) of the prompt region per batch item. + """ + + text_pad_id = get_pad_id(self.tokenizer) + + input_text_tokens_ = [] + source_tokens_ = [] + source_audio_ = [] + target_audio_ = [] + prompt_lens = [] + + for i in range(input_text_tokens.size(0)): + first_text_frame = torch.tensor( + [self.tokenizer.eos], + dtype=torch.long, + device=input_text_tokens.device, + ) + + if self.add_audio_prompt: + # Compute audio prompt duration in samples (rounded to frame boundaries) + prompt_audio_size = int( + ((self.audio_prompt_duration * self.target_sample_rate) // + self.target_samples_per_frame) * self.target_samples_per_frame + ) + + prompt_audio = sample_audio_segments_repeat( + audio_prompt, audio_prompt_lens, prompt_audio_size, sample=True + ) + + # Fade transition: add silence at the end of the prompt + prompt_audio[:, -int(self.target_samples_per_frame * 2):] = 0 + + # Number of text tokens to insert to align with prompt_audio frames + prompt_audio_text_pad_size = prompt_audio_size // self.target_samples_per_frame + + prompt_audio_text_pad = ( + torch.ones( + prompt_audio_text_pad_size, + device=input_text_tokens.device, + dtype=input_text_tokens.dtype, + ) * text_pad_id + ) + prompt_audio_text_pad[-1] = self.tokenizer.eos + + new_input_text_tokens = torch.cat( + [ + first_text_frame.to(input_text_tokens.dtype), + prompt_audio_text_pad, + input_text_tokens[i], + ] + ) + input_text_tokens_.append(new_input_text_tokens) + target_token_lens[i] += len(first_text_frame) + prompt_audio_text_pad_size + + new_source_tokens = torch.cat([first_text_frame, prompt_audio_text_pad, source_tokens[i]]) + source_tokens_.append(new_source_tokens) + source_token_lens[i] += len(first_text_frame) + prompt_audio_text_pad_size + + # Silence in source audio during prompt processing + pad_size_src = (len(first_text_frame) * self.source_samples_per_frame) + prompt_audio.size(1) + pad_audio_src = torch.zeros( + pad_size_src, + device=source_audio.device, + dtype=source_audio.dtype, + ) + source_audio_.append(torch.cat([pad_audio_src, source_audio[i]])) + source_audio_lens[i] += pad_size_src + + # Insert prompt audio in the target channel + pad_size_tgt = len(first_text_frame) * self.target_samples_per_frame + pad_audio_tgt = torch.zeros( + pad_size_tgt, + device=target_audio.device, + dtype=target_audio.dtype, + ) + target_audio_.append(torch.cat([pad_audio_tgt, prompt_audio[i], target_audio[i]])) + target_audio_lens[i] += pad_size_tgt + prompt_audio.size(1) + + prompt_lens.append(len(first_text_frame) + prompt_audio_text_pad_size - 1) + + else: + # Add only a single text-frame (EOS) as prompt + input_text_tokens_.append(torch.cat([first_text_frame, input_text_tokens[i]])) + target_token_lens[i] += len(first_text_frame) + + source_tokens_.append(torch.cat([first_text_frame, source_tokens[i]])) + source_token_lens[i] += len(first_text_frame) + + pad_size_src = len(first_text_frame) * self.source_samples_per_frame + pad_audio_src = torch.zeros( + pad_size_src, + device=source_audio.device, + dtype=source_audio.dtype, + ) + source_audio_.append(torch.cat([pad_audio_src, source_audio[i]])) + source_audio_lens[i] += pad_size_src + + pad_size_tgt = len(first_text_frame) * self.target_samples_per_frame + pad_audio_tgt = torch.zeros( + pad_size_tgt, + device=target_audio.device, + dtype=target_audio.dtype, + ) + target_audio_.append(torch.cat([pad_audio_tgt, target_audio[i]])) + target_audio_lens[i] += pad_size_tgt + + prompt_lens.append(len(first_text_frame)) + + input_text_tokens = collate_vectors(input_text_tokens_, padding_value=text_pad_id) + source_tokens = collate_vectors(source_tokens_, padding_value=text_pad_id) + source_audio = collate_vectors(source_audio_, padding_value=0) + target_audio = collate_vectors(target_audio_, padding_value=0) + + return ( + input_text_tokens, + target_token_lens, + source_tokens, + source_token_lens, + source_audio, + source_audio_lens, + target_audio, + target_audio_lens, + prompt_lens, + ) + + +def add_speech_delay( + source_audio: torch.Tensor, + source_audio_lens: torch.Tensor, + target_audio: torch.Tensor, + target_audio_lens: torch.Tensor, + num_delay_speech_tokens: int, + target_samples_per_frame: int, + source_samples_per_frame: int, +): + """ + Apply a speech delay by padding audio waveforms based on the number of delay speech tokens. + + Behavior: + - Target audio is *left padded* to force the model to predict initial silence. + - Source audio is *right padded* to maintain size consistency for attention alignment. + + Args: + source_audio (FloatTensor [B, T_src]): + Source/input audio waveforms. + + source_audio_lens (LongTensor [B]): + Lengths of source audio in samples. + + target_audio (FloatTensor [B, T_tgt]): + Target/output audio waveforms. + + target_audio_lens (LongTensor [B]): + Lengths of target audio in samples. + + num_delay_speech_tokens (int): + Number of delay tokens inserted on the text side. + + target_samples_per_frame (int): + Number of audio samples per frame for target audio. + + source_samples_per_frame (int): + Number of audio samples per frame for source audio. + + Returns: + Tuple containing: + - source_audio (FloatTensor [B, T_src + pad]) + - source_audio_lens (LongTensor [B]) + - target_audio (FloatTensor [B, T_tgt + pad]) + - target_audio_lens (LongTensor [B]) + """ + # Compute target-side left padding for the delay + extra_target_samples = int(num_delay_speech_tokens * target_samples_per_frame) + + # Left-pad target audio: forces model to generate silence initially + target_audio = F.pad(target_audio, (extra_target_samples, 0)) + target_audio_lens = target_audio_lens + extra_target_samples + + # Compute source-side right padding to maintain alignment + extra_source_samples = int(num_delay_speech_tokens * source_samples_per_frame) + + # Right-pad source audio: avoids mismatch when aligning source/target frames + source_audio = F.pad(source_audio, (0, extra_source_samples)) + source_audio_lens = source_audio_lens + extra_source_samples + + return source_audio, source_audio_lens, target_audio, target_audio_lens + +def get_audio_prompt( + cuts: CutSet, + target_sample_rate: int, + roles: set[str], + recording_field: str = "target_audio", +) -> tuple[torch.Tensor, torch.Tensor]: + """ + Retrieve an audio prompt for speaker conditioning. + + This function returns: + - context audio if available (per-cut), + - otherwise a random turn belonging to the desired speaker roles. + + Behavior: + 1. If `cut.context_audio` exists, use it as the reference. + 2. Otherwise, select a random audio turn from the same target-role speakers. + + Args: + cuts (CutSet): + Batch of cuts from which to extract reference audio. + + target_sample_rate (int): + Sample rate to which reference audio is resampled. + + roles (set[str]): + Set of speaker roles to sample from when selecting random turns. + + recording_field (str, optional): + Name of the audio field in the cut ("source_audio", "target_audio", etc.). + Used when sampling random reference turns. + + Returns: + (audio_prompt, audio_prompt_lens): + + - audio_prompt (FloatTensor [B, T]): + Padded batch of reference waveforms. + + - audio_prompt_lens (LongTensor [B]): + Lengths of each reference waveform before padding. + """ + # use provided context audio directly + if hasattr(cuts[0], "context_audio"): + audio_prompt = [] + audio_prompt_lens = [] + + for cut in cuts: + ref_audio = cut.context_audio.resample(target_sample_rate).load_audio() + ref_audio = torch.tensor(ref_audio).float() # shape: [1, T] + ref_audio_len = ref_audio.shape[1] + + audio_prompt.append(ref_audio.squeeze(0)) # [T] + audio_prompt_lens.append(ref_audio_len) + + audio_prompt = collate_vectors(audio_prompt, padding_value=0).float() + audio_prompt_lens = torch.tensor(audio_prompt_lens).long() + + else: + # sample a reference turn from the target-role speakers + audio_prompt, audio_prompt_lens = collate_random_turn_audio( + cuts.resample(target_sample_rate), + roles=roles, + recording_field=recording_field, + ) + + return audio_prompt, audio_prompt_lens def collate_random_turn_audio( cuts: CutSet, diff --git a/nemo/collections/speechlm2/models/duplex_ear_tts.py b/nemo/collections/speechlm2/models/duplex_ear_tts.py index 2f1e3281bddb..096335c1ab6c 100644 --- a/nemo/collections/speechlm2/models/duplex_ear_tts.py +++ b/nemo/collections/speechlm2/models/duplex_ear_tts.py @@ -535,7 +535,6 @@ def prepare_inputs(self, batch: dict): target_audio = batch["target_audio"] target_audio_lens = batch["target_audio_lens"] input_text_tokens = batch["input_text_tokens"] - desc_mask = batch["desc_mask"] non_prompt_mask = batch["non_prompt_mask"] aligned_attention_mask = batch["aligned_attention_mask"] aligned_position_ids = batch["aligned_position_ids"] @@ -561,7 +560,6 @@ def pad_or_truncate(x, pad_value=0): return x # leave others for now input_text_tokens = pad_or_truncate(input_text_tokens, pad_value=self.text_pad_id) - desc_mask = pad_or_truncate(desc_mask, pad_value=0) non_prompt_mask = pad_or_truncate(non_prompt_mask, pad_value=0) aligned_position_ids = pad_or_truncate(aligned_position_ids, pad_value=0) @@ -575,13 +573,9 @@ def pad_or_truncate(x, pad_value=0): elif L1 > new_len or L2 > new_len: aligned_attention_mask = aligned_attention_mask[:, :, :new_len, :new_len] - # ToDo: desc_mask is one for the end of the sequence, this is what cause the artifact issue in the end, fix it. - # set the pad token when there is desc - target_codes_aligned = torch.where( - desc_mask.unsqueeze(-1), # (B, T, 1) for broadcasting - torch.full_like(target_codes, self.speech_pad_id), # fill with pad id - target_codes, - ) + # set the pad token for the first BOS frame + target_codes_aligned = target_codes.clone() + target_codes_aligned[:, 0] = self.speech_pad_id # set special token in the last audio prompt (it will works as a BOS token) pos = non_prompt_mask.float().argmax(dim=1) # shape: [B] @@ -811,7 +805,7 @@ def run_evaluation_one_batch(self, name, dataset_batch, use_dataloader_init=Fals init_inputs[key] = torch.stack( [ init_inputs[key][i, :plen] - for i, plen in enumerate(dataset_batch["desc_plus_audio_prompt_lens"]) + for i, plen in enumerate(dataset_batch["prompt_lens"]) ] ) else: @@ -826,7 +820,7 @@ def run_evaluation_one_batch(self, name, dataset_batch, use_dataloader_init=Fals next_subword_ids = torch.stack( [ inputs["subword_ids"][i, plen:] # slice each element - for i, plen in enumerate(dataset_batch["desc_plus_audio_prompt_lens"]) + for i, plen in enumerate(dataset_batch["prompt_lens"]) ] ) @@ -849,7 +843,7 @@ def run_evaluation_one_batch(self, name, dataset_batch, use_dataloader_init=Fals ] target_audio_no_prompt_lens = dataset_batch["target_audio_lens"] - ( torch.tensor( - dataset_batch["desc_plus_audio_prompt_lens"], + dataset_batch["prompt_lens"], dtype=torch.long, device=dataset_batch["target_audio_lens"].device, ) @@ -1077,17 +1071,9 @@ def set_init_inputs(self, speaker_audio, speaker_audio_lens, system_prompt=None, subword_mask[:, -3:] = ( 1 # -3 because of the it start right after the first valid prompt token and it is shifted by 1 ) - # desc mask is all zeros except the description - desc_mask = torch.zeros_like(input_text_tokens) - desc_mask[:, : first_text_frame.size(-1)] = 1 - - if not self.cfg.get("disable_speech_pad", False): - # add special tokens on audio codes - code = torch.where( - desc_mask.unsqueeze(-1).bool(), # (B, T, 1) for broadcasting - torch.full_like(code, self.speech_pad_id), # fill with pad id - code, - ) + + # set the pad token for the first BOS frame + code[:, 0] = self.speech_pad_id # shift subword_ids subword_ids = F.pad(input_text_tokens[:, 1:], [0, 1], value=0.0) diff --git a/tests/collections/speechlm2/test_duplex_eartts.py b/tests/collections/speechlm2/test_duplex_eartts.py index e1b9fd1b7fa6..77056a899893 100644 --- a/tests/collections/speechlm2/test_duplex_eartts.py +++ b/tests/collections/speechlm2/test_duplex_eartts.py @@ -153,7 +153,7 @@ }, "data": { "add_text_bos_and_eos_in_each_turn": True, - "add_audio_prompt_after_description": True, + "add_audio_prompt": True, "audio_prompt_duration": 3.0, "frame_length": 0.08, "source_sample_rate": 22050, @@ -195,7 +195,7 @@ def dataset(model): return DuplexEARTTSDataset( model.tokenizer, add_text_bos_and_eos_in_each_turn=True, - add_audio_prompt_after_description=True, + add_audio_prompt=True, audio_prompt_duration=3.0, frame_length=0.08, source_sample_rate=22050, @@ -251,9 +251,7 @@ def test_eartts_dataset(dataset, training_cutset_batch): expected_keys = { "sample_id", "non_prompt_mask", - "desc_mask", - "desc_lens", - "desc_plus_audio_prompt_lens", + "prompt_lens", "aligned_attention_mask", "aligned_position_ids", "source_audio", @@ -275,7 +273,6 @@ def test_eartts_dataset(dataset, training_cutset_batch): tensor_keys = [ "non_prompt_mask", - "desc_mask", "aligned_attention_mask", "aligned_position_ids", "source_audio", From c7d7440c6d4d77d1143c99966f22ee4e70d990ea Mon Sep 17 00:00:00 2001 From: Edresson Date: Tue, 25 Nov 2025 20:58:49 +0000 Subject: [PATCH 061/102] Apply isort and black reformatting Signed-off-by: Edresson Signed-off-by: Edresson Casanova --- .../speechlm2/data/duplex_ear_tts_dataset.py | 15 ++++++++++----- .../speechlm2/models/duplex_ear_tts.py | 5 +---- 2 files changed, 11 insertions(+), 9 deletions(-) diff --git a/nemo/collections/speechlm2/data/duplex_ear_tts_dataset.py b/nemo/collections/speechlm2/data/duplex_ear_tts_dataset.py index 1e8aa6819090..c5790eab1405 100644 --- a/nemo/collections/speechlm2/data/duplex_ear_tts_dataset.py +++ b/nemo/collections/speechlm2/data/duplex_ear_tts_dataset.py @@ -257,7 +257,9 @@ def __getitem__(self, cuts: CutSet) -> dict: add_text_bos_and_eos_in_each_turn=self.add_text_bos_and_eos_in_each_turn, ) - audio_prompt, audio_prompt_lens = get_audio_prompt(cuts, self.target_sample_rate, roles=self.output_roles, recording_field="target_audio") + audio_prompt, audio_prompt_lens = get_audio_prompt( + cuts, self.target_sample_rate, roles=self.output_roles, recording_field="target_audio" + ) # ensures that input_text_tokens is not longer than its duration input_text_tokens = input_text_tokens[:, : target_token_lens.max()] @@ -457,8 +459,8 @@ def maybe_add_audio_prompt( if self.add_audio_prompt: # Compute audio prompt duration in samples (rounded to frame boundaries) prompt_audio_size = int( - ((self.audio_prompt_duration * self.target_sample_rate) // - self.target_samples_per_frame) * self.target_samples_per_frame + ((self.audio_prompt_duration * self.target_sample_rate) // self.target_samples_per_frame) + * self.target_samples_per_frame ) prompt_audio = sample_audio_segments_repeat( @@ -466,7 +468,7 @@ def maybe_add_audio_prompt( ) # Fade transition: add silence at the end of the prompt - prompt_audio[:, -int(self.target_samples_per_frame * 2):] = 0 + prompt_audio[:, -int(self.target_samples_per_frame * 2) :] = 0 # Number of text tokens to insert to align with prompt_audio frames prompt_audio_text_pad_size = prompt_audio_size // self.target_samples_per_frame @@ -476,7 +478,8 @@ def maybe_add_audio_prompt( prompt_audio_text_pad_size, device=input_text_tokens.device, dtype=input_text_tokens.dtype, - ) * text_pad_id + ) + * text_pad_id ) prompt_audio_text_pad[-1] = self.tokenizer.eos @@ -623,6 +626,7 @@ def add_speech_delay( return source_audio, source_audio_lens, target_audio, target_audio_lens + def get_audio_prompt( cuts: CutSet, target_sample_rate: int, @@ -689,6 +693,7 @@ def get_audio_prompt( return audio_prompt, audio_prompt_lens + def collate_random_turn_audio( cuts: CutSet, roles: set[str], diff --git a/nemo/collections/speechlm2/models/duplex_ear_tts.py b/nemo/collections/speechlm2/models/duplex_ear_tts.py index 096335c1ab6c..fea9b5a91b6d 100644 --- a/nemo/collections/speechlm2/models/duplex_ear_tts.py +++ b/nemo/collections/speechlm2/models/duplex_ear_tts.py @@ -803,10 +803,7 @@ def run_evaluation_one_batch(self, name, dataset_batch, use_dataloader_init=Fals for key in init_inputs: if init_inputs[key] is not None: init_inputs[key] = torch.stack( - [ - init_inputs[key][i, :plen] - for i, plen in enumerate(dataset_batch["prompt_lens"]) - ] + [init_inputs[key][i, :plen] for i, plen in enumerate(dataset_batch["prompt_lens"])] ) else: # set init inputs and get it From 96c3ddbbce02fac23c4f43857c9096bc26784bb4 Mon Sep 17 00:00:00 2001 From: Edresson Casanova Date: Wed, 26 Nov 2025 04:34:54 -0800 Subject: [PATCH 062/102] Update docs Signed-off-by: Edresson Casanova --- docs/source/speechlm2/datasets.rst | 1 + docs/source/speechlm2/models.rst | 24 +++++++++++++++++++ nemo/collections/common/data/lhotse/cutset.py | 2 +- .../speechlm2/data/duplex_ear_tts_dataset.py | 7 +++--- .../speechlm2/modules/ear_tts_model.py | 19 +++++++++++---- 5 files changed, 44 insertions(+), 9 deletions(-) diff --git a/docs/source/speechlm2/datasets.rst b/docs/source/speechlm2/datasets.rst index c31a51a28006..01feec9ec9eb 100644 --- a/docs/source/speechlm2/datasets.rst +++ b/docs/source/speechlm2/datasets.rst @@ -11,6 +11,7 @@ Duplex S2S models use the Lhotse framework for audio data management. The primar 1. **DuplexS2SDataset**: For general duplex speech-to-speech models 2. **SALMDataset**: Specifically for the Speech-Augmented Language Model (SALM), which processes speech+text and outputs text. +3. **DuplexEARTTSDataset**: Dataset for Duplex EARTTS model, extending DuplexS2SDataset with additional output fields for TTS, including audio prompting. It optionally prepends an audio prompt (speaker reference) to target_audio, which is used to initialize speaker conditioning in the EARTTS model. The dataset provides audio_prompt, audio_prompt_lens, non_prompt_mask, aligned_attention_mask, and aligned_position_ids, and supports custom speaker reference audio through the context_audio field, while preserving full compatibility with the original data format. DuplexS2S Dataset Structure ^^^^^^^^^^^^^^^^^^^^^^^^^^^ diff --git a/docs/source/speechlm2/models.rst b/docs/source/speechlm2/models.rst index 3e970c83a760..44ebaaf69520 100644 --- a/docs/source/speechlm2/models.rst +++ b/docs/source/speechlm2/models.rst @@ -8,6 +8,30 @@ Core Model Architectures The collection includes the following core model architectures: + +DuplexEARTTS +^^^^^^^^^^^^ + +DuplexEARTTS is a streaming text-to-speech model designed for duplex speech-to-speech systems. It focuses on low-latency, fully streamable speech generation by converting text tokens into audio representations in real time. + +The architecture is based on the Streaming TTS model proposed in `Audio Flamingo 3`_, with several extensions for duplex interaction: + +* **Gated fusion of text and audio representations**: (`GatedProjectedSumRMSNorm`), enabling better multimodal integration. +* **Subword-aware embeddings**: (`SubwordFlagEmbedding`) to improve pronunciation for words composed of multiple text tokens. +* **Custom BOS/EOS embeddings**: (`BOSEOSEmbedding`) for interruption-aware, multi-turn duplex generation. + + +Key components: + +* **RVQVAEModel**: An RVQ-based neural audio codec that compresses speech into discrete acoustic tokens using a convolutional encoder and reconstructs high-quality audio via a convolutional decoder. +* **RVQEARTTSModel**: A streaming speech generation model that predicts multiple RVQ codebooks in parallel using a Mixture-of-Gaussians (MoG) prediction head. It produces audio tokens autoregressively from text representations with minimal latency. + +DuplexEARTTS is particularly useful for: +* Duplex speech-to-speech systems requiring interruption-aware synthesis. +* Low-latency text-to-speech generation. +* Real-time conversational agents with streamed audio output. + + SALM (Speech-Augmented Language Model) ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ diff --git a/nemo/collections/common/data/lhotse/cutset.py b/nemo/collections/common/data/lhotse/cutset.py index fee53b653762..b8929c82b366 100644 --- a/nemo/collections/common/data/lhotse/cutset.py +++ b/nemo/collections/common/data/lhotse/cutset.py @@ -609,7 +609,7 @@ def convert_overlap_cut_fn(cut: Cut) -> Cut: @data_type_parser(["lhotse_magpietts_data_as_continuation"]) -def read_lhotse_magpietts_data_as_continuation(config) -> Tuple[CutSet, bool]: +def read_lhotse_magpietts_data_as_s2s_duplex(config) -> Tuple[CutSet, bool]: """ Convert MagpieTTS dataset cuts into the Duplex S2S format, with optional `context_audio` that can be used as a speaker reference. diff --git a/nemo/collections/speechlm2/data/duplex_ear_tts_dataset.py b/nemo/collections/speechlm2/data/duplex_ear_tts_dataset.py index c5790eab1405..845910aa6892 100644 --- a/nemo/collections/speechlm2/data/duplex_ear_tts_dataset.py +++ b/nemo/collections/speechlm2/data/duplex_ear_tts_dataset.py @@ -276,6 +276,7 @@ def __getitem__(self, cuts: CutSet) -> dict: self.source_samples_per_frame, ) + # add audio prompt if needed ( input_text_tokens, target_token_lens, @@ -376,8 +377,6 @@ def maybe_add_audio_prompt( This method optionally injects a speaker-reference audio prompt at the beginning of each sample in the batch. The prompt is inserted in the target-audio channel and aligned text padding is inserted into the text-token streams (input text tokens and source tokens). - The corresponding silence padding is inserted into the source/target audio streams to - preserve frame-level alignment. Args: input_text_tokens (torch.Tensor): @@ -467,7 +466,7 @@ def maybe_add_audio_prompt( audio_prompt, audio_prompt_lens, prompt_audio_size, sample=True ) - # Fade transition: add silence at the end of the prompt + # add silence at the end of the prompt prompt_audio[:, -int(self.target_samples_per_frame * 2) :] = 0 # Number of text tokens to insert to align with prompt_audio frames @@ -655,7 +654,7 @@ def get_audio_prompt( Set of speaker roles to sample from when selecting random turns. recording_field (str, optional): - Name of the audio field in the cut ("source_audio", "target_audio", etc.). + Name of the audio field in the cut ("recording", "target_audio", etc.). Used when sampling random reference turns. Returns: diff --git a/nemo/collections/speechlm2/modules/ear_tts_model.py b/nemo/collections/speechlm2/modules/ear_tts_model.py index 53265c75942a..d366e0dee45c 100644 --- a/nemo/collections/speechlm2/modules/ear_tts_model.py +++ b/nemo/collections/speechlm2/modules/ear_tts_model.py @@ -1132,11 +1132,22 @@ def forward(self, audio_emb, text_emb): class RVQEARTTSModel(PreTrainedModel): """ - The main RVQEARTTS model, which can be used for both training and inference. + Main RVQEARTTS model for training and inference. - This model integrates a character-aware text encoder and a MoG head with a - transformer backbone. It can be trained to predict audio codes or used for - autoregressive inference. + This model integrates a character-aware text encoder with a transformer backbone + and a Mixture-of-Gaussians (MoG) prediction head. + + The architecture is based on the Streaming TTS model proposed in + "Audio Flamingo 3" (https://arxiv.org/abs/2507.08128), with several improvements: + + 1. Gated fusion of text and audio representations + (`GatedProjectedSumRMSNorm`). + + 2. Subword-aware embeddings for improved pronunciation of multi-token words + (`SubwordFlagEmbedding`). + + 3. Custom BOS and EOS embeddings for duplex interaction support, + enabling interruption-aware generation (`BOSEOSEmbedding`). Args: config (DictConfig | dict[str, Any]): The configuration object for the model. From f9e106693215e61ab03fd191bc6c3a60df2b87e1 Mon Sep 17 00:00:00 2001 From: Edresson Casanova Date: Wed, 26 Nov 2025 04:44:15 -0800 Subject: [PATCH 063/102] Update EARTTS documentation Signed-off-by: Edresson Casanova --- docs/source/speechlm2/models.rst | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/docs/source/speechlm2/models.rst b/docs/source/speechlm2/models.rst index 44ebaaf69520..5439220b8520 100644 --- a/docs/source/speechlm2/models.rst +++ b/docs/source/speechlm2/models.rst @@ -102,6 +102,7 @@ Speech generation components convert text or token representations back into spe 1. **TransformerARSpeechDecoder**: An autoregressive transformer-based speech decoder 2. **Audio Codec Integration**: Works with audio codecs to generate natural speech from discrete tokens +3. **DuplexEARTTS**: A ready-to-use duplex text-to-speech model that supports user interruption via a special text interruption token. The model integrates an RVQ-based audio codec with a streaming speech generation module to enable low-latency, real-time synthesis. Implementation Details -------------------- @@ -224,6 +225,9 @@ All models in the speechlm2 collection can be instantiated from pretrained check # Load DuplexS2SSpeechDecoderModel decoder_model = slm.models.DuplexS2SSpeechDecoderModel.from_pretrained("path/to/checkpoint") + # Load DuplexEARTTS + decoder_model = slm.models.DuplexEARTTS.from_pretrained("path/to/checkpoint") + Model Configuration ----------------- From 12625bc20c2a885c2a3fa736e385e750af60e8f0 Mon Sep 17 00:00:00 2001 From: Edresson Casanova Date: Wed, 26 Nov 2025 06:13:13 -0800 Subject: [PATCH 064/102] Rename eval script Signed-off-by: Edresson Casanova --- ..._eartts_infer.py => duplex_eartts_eval.py} | 2 +- examples/speechlm2/duplex_eartts_train.py | 2 +- .../speechlm2/models/duplex_ear_tts.py | 18 +++--- .../speechlm2/modules/ear_tts_commons.py | 60 ------------------- 4 files changed, 13 insertions(+), 69 deletions(-) rename examples/speechlm2/{duplex_eartts_infer.py => duplex_eartts_eval.py} (97%) diff --git a/examples/speechlm2/duplex_eartts_infer.py b/examples/speechlm2/duplex_eartts_eval.py similarity index 97% rename from examples/speechlm2/duplex_eartts_infer.py rename to examples/speechlm2/duplex_eartts_eval.py index 255336849817..8ad5ba06acc4 100644 --- a/examples/speechlm2/duplex_eartts_infer.py +++ b/examples/speechlm2/duplex_eartts_eval.py @@ -51,7 +51,7 @@ def inference(cfg): output_roles=cfg.data.output_roles, add_text_bos_and_eos_in_each_turn=cfg.data.get("add_text_bos_and_eos_in_each_turn", True), add_audio_prompt=cfg.data.get("add_audio_prompt", True), - audio_prompt_duration=cfg.data.audio_prompt_duration, + audio_prompt_duration=cfg.data.get("audio_prompt_duration", 3), num_delay_speech_tokens=cfg.model.get("num_delay_speech_tokens", 2), ) datamodule = DataModule(cfg.data, tokenizer=model.tokenizer, dataset=dataset) diff --git a/examples/speechlm2/duplex_eartts_train.py b/examples/speechlm2/duplex_eartts_train.py index aaccfe6dcb72..8fd4da3bd5fd 100644 --- a/examples/speechlm2/duplex_eartts_train.py +++ b/examples/speechlm2/duplex_eartts_train.py @@ -49,7 +49,7 @@ def train(cfg): output_roles=cfg.data.output_roles, add_text_bos_and_eos_in_each_turn=cfg.data.get("add_text_bos_and_eos_in_each_turn", True), add_audio_prompt=cfg.data.get("add_audio_prompt", True), - audio_prompt_duration=cfg.data.audio_prompt_duration, + audio_prompt_duration=cfg.data.get("audio_prompt_duration", 3), num_delay_speech_tokens=cfg.model.get("num_delay_speech_tokens", 2), ) datamodule = DataModule(cfg.data, tokenizer=model.tokenizer, dataset=dataset) diff --git a/nemo/collections/speechlm2/models/duplex_ear_tts.py b/nemo/collections/speechlm2/models/duplex_ear_tts.py index fea9b5a91b6d..733be035ad1b 100644 --- a/nemo/collections/speechlm2/models/duplex_ear_tts.py +++ b/nemo/collections/speechlm2/models/duplex_ear_tts.py @@ -633,13 +633,7 @@ def pad_or_truncate(x, pad_value=0): "input_text_tokens": input_text_tokens, } - def training_step(self, batch: dict, batch_idx: int): - for m in (self.tts_model,): - if is_frozen(m): - m.eval() - - inputs = self.prepare_inputs(batch) - + def forward(self, inputs): tts_output = self.tts_model( code=inputs["code"], audio_mask=inputs["audio_mask"], @@ -650,6 +644,16 @@ def training_step(self, batch: dict, batch_idx: int): subword_mask=inputs["subword_mask"], non_prompt_mask=inputs["non_prompt_mask"], ) + return tts_output + + def training_step(self, batch: dict, batch_idx: int): + for m in (self.tts_model,): + if is_frozen(m): + m.eval() + + inputs = self.prepare_inputs(batch) + tts_output = self.forward(inputs) + loss_dict = {"lm_loss": tts_output.lm_loss, "c_loss": tts_output.c_loss, "k_loss": tts_output.k_loss} loss = sum(loss_dict.values()) diff --git a/nemo/collections/speechlm2/modules/ear_tts_commons.py b/nemo/collections/speechlm2/modules/ear_tts_commons.py index 25331d583481..59490cdfdcad 100644 --- a/nemo/collections/speechlm2/modules/ear_tts_commons.py +++ b/nemo/collections/speechlm2/modules/ear_tts_commons.py @@ -19,8 +19,6 @@ import os import re import shutil -import subprocess -import sys from collections.abc import Mapping from typing import Any @@ -279,64 +277,6 @@ def _get_weight_names(module): # IO and Checkpointing Utilities # ============================================================================== - -def check_git_hash() -> str | None: - """ - Retrieves the current git commit hash of the repository containing this file. - - This is useful for reproducibility, allowing you to track the exact version - of the code used for a particular experiment. - - Returns: - str | None: The git commit hash as a string if successful, otherwise None. - """ - - try: - # Get the directory where this script is located - source_sub_dir = os.path.dirname(os.path.realpath(__file__)) - # Execute the git command to get the current HEAD commit hash - git_hash = ( - subprocess.check_output(["git", "rev-parse", "HEAD"], cwd=source_sub_dir, stderr=subprocess.DEVNULL) - .decode(sys.stdout.encoding) - .strip() - ) - except (subprocess.CalledProcessError, FileNotFoundError): - # Handle cases where git is not installed or the directory is not a git repo - logging.warning( - "Could not retrieve git hash. This may be because the code is not in a git repository " - "or git is not installed. Git hash checking will be ignored." - ) - return None - return git_hash - - -def write_git_hash(workdir_path: str) -> None: - """ - Writes the current git hash to a file in a specified directory. - - If a hash file already exists, it compares the current hash with the saved one - and logs a warning if they differ. - - Args: - workdir_path (str): The path to the directory where the git hash file will be saved. - """ - git_hash = check_git_hash() - if git_hash is None: - return - - saved_git_hash_path = os.path.join(workdir_path, GIT_HASH_NAME) - if os.path.exists(saved_git_hash_path): - # If hash file exists, compare it with the current hash - with open(saved_git_hash_path) as f: - saved_git_hash = f.read().strip() - if saved_git_hash != git_hash: - logging.warning(f"Git hash has changed. Saved: {saved_git_hash[:8]}, Current: {git_hash[:8]}") - else: - # If no hash file exists, write the current hash - with open(saved_git_hash_path, "w") as f: - f.write(git_hash) - - def latest_checkpoint_path(dir_path: str, regex: str | None = None) -> str: """ Finds the path of the latest checkpoint file or directory in a directory. From 4c89963d46a6eec653497181d69bebf40ec11211 Mon Sep 17 00:00:00 2001 From: Edresson Date: Wed, 26 Nov 2025 14:14:07 +0000 Subject: [PATCH 065/102] Apply isort and black reformatting Signed-off-by: Edresson --- nemo/collections/speechlm2/modules/ear_tts_commons.py | 1 + 1 file changed, 1 insertion(+) diff --git a/nemo/collections/speechlm2/modules/ear_tts_commons.py b/nemo/collections/speechlm2/modules/ear_tts_commons.py index 59490cdfdcad..3507f9fd8b5a 100644 --- a/nemo/collections/speechlm2/modules/ear_tts_commons.py +++ b/nemo/collections/speechlm2/modules/ear_tts_commons.py @@ -277,6 +277,7 @@ def _get_weight_names(module): # IO and Checkpointing Utilities # ============================================================================== + def latest_checkpoint_path(dir_path: str, regex: str | None = None) -> str: """ Finds the path of the latest checkpoint file or directory in a directory. From 0d67a585405d87b5b2a7500b9403e1810a86fbe4 Mon Sep 17 00:00:00 2001 From: Edresson Casanova Date: Wed, 26 Nov 2025 12:15:53 -0800 Subject: [PATCH 066/102] Rename intellibility metric Signed-off-by: Edresson Casanova --- .../speechlm2/models/duplex_ear_tts.py | 20 ++++++++----------- .../speechlm2/parts/metrics/__init__.py | 2 +- .../{intelligibility.py => asr_cer_wer.py} | 0 3 files changed, 9 insertions(+), 13 deletions(-) rename nemo/collections/speechlm2/parts/metrics/{intelligibility.py => asr_cer_wer.py} (100%) diff --git a/nemo/collections/speechlm2/models/duplex_ear_tts.py b/nemo/collections/speechlm2/models/duplex_ear_tts.py index 733be035ad1b..80adf1140951 100644 --- a/nemo/collections/speechlm2/models/duplex_ear_tts.py +++ b/nemo/collections/speechlm2/models/duplex_ear_tts.py @@ -42,7 +42,7 @@ from nemo.collections.speechlm2.modules.ear_tts_vae_codec import RVQVAEModel from nemo.collections.speechlm2.parts.hf_hub import HFHubMixin from nemo.collections.speechlm2.parts.metrics.asr_bleu import ASRBLEU -from nemo.collections.speechlm2.parts.metrics.intelligibility import Intelligibility +from nemo.collections.speechlm2.parts.metrics.asr_cer_wer import Intelligibility from nemo.collections.speechlm2.parts.metrics.results_logger import ResultsLogger from nemo.collections.speechlm2.parts.metrics.secs import SECS from nemo.collections.speechlm2.parts.optim_setup import configure_optimizers, is_frozen @@ -633,7 +633,13 @@ def pad_or_truncate(x, pad_value=0): "input_text_tokens": input_text_tokens, } - def forward(self, inputs): + def training_step(self, batch: dict, batch_idx: int): + for m in (self.tts_model,): + if is_frozen(m): + m.eval() + + inputs = self.prepare_inputs(batch) + tts_output = self.tts_model( code=inputs["code"], audio_mask=inputs["audio_mask"], @@ -644,16 +650,6 @@ def forward(self, inputs): subword_mask=inputs["subword_mask"], non_prompt_mask=inputs["non_prompt_mask"], ) - return tts_output - - def training_step(self, batch: dict, batch_idx: int): - for m in (self.tts_model,): - if is_frozen(m): - m.eval() - - inputs = self.prepare_inputs(batch) - tts_output = self.forward(inputs) - loss_dict = {"lm_loss": tts_output.lm_loss, "c_loss": tts_output.c_loss, "k_loss": tts_output.k_loss} loss = sum(loss_dict.values()) diff --git a/nemo/collections/speechlm2/parts/metrics/__init__.py b/nemo/collections/speechlm2/parts/metrics/__init__.py index 8758fe0f4dd4..68312d3ca230 100644 --- a/nemo/collections/speechlm2/parts/metrics/__init__.py +++ b/nemo/collections/speechlm2/parts/metrics/__init__.py @@ -13,7 +13,7 @@ # limitations under the License. from .asr_bleu import ASRBLEU from .bleu import BLEU -from .intelligibility import Intelligibility +from .asr_cer_wer import Intelligibility from .results_logger import ResultsLogger from .token_accuracy import TokenAccuracy diff --git a/nemo/collections/speechlm2/parts/metrics/intelligibility.py b/nemo/collections/speechlm2/parts/metrics/asr_cer_wer.py similarity index 100% rename from nemo/collections/speechlm2/parts/metrics/intelligibility.py rename to nemo/collections/speechlm2/parts/metrics/asr_cer_wer.py From ddda743ec66dca56e9e40a470178d1ab0e4cd99e Mon Sep 17 00:00:00 2001 From: Edresson Date: Wed, 26 Nov 2025 20:16:48 +0000 Subject: [PATCH 067/102] Apply isort and black reformatting Signed-off-by: Edresson --- nemo/collections/speechlm2/parts/metrics/__init__.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/nemo/collections/speechlm2/parts/metrics/__init__.py b/nemo/collections/speechlm2/parts/metrics/__init__.py index 68312d3ca230..5a825a86d316 100644 --- a/nemo/collections/speechlm2/parts/metrics/__init__.py +++ b/nemo/collections/speechlm2/parts/metrics/__init__.py @@ -12,8 +12,8 @@ # See the License for the specific language governing permissions and # limitations under the License. from .asr_bleu import ASRBLEU -from .bleu import BLEU from .asr_cer_wer import Intelligibility +from .bleu import BLEU from .results_logger import ResultsLogger from .token_accuracy import TokenAccuracy From 9e2b376afdda4d07afd279da693927785f4a4c90 Mon Sep 17 00:00:00 2001 From: Edresson Casanova Date: Wed, 26 Nov 2025 12:26:22 -0800 Subject: [PATCH 068/102] Add WER metric back Signed-off-by: Edresson Casanova --- .../speechlm2/parts/metrics/__init__.py | 2 + .../speechlm2/parts/metrics/wer.py | 65 +++++++++++++++++++ tests/collections/speechlm2/test_metrics.py | 20 +++++- 3 files changed, 86 insertions(+), 1 deletion(-) create mode 100644 nemo/collections/speechlm2/parts/metrics/wer.py diff --git a/nemo/collections/speechlm2/parts/metrics/__init__.py b/nemo/collections/speechlm2/parts/metrics/__init__.py index 5a825a86d316..07cff31b5dd7 100644 --- a/nemo/collections/speechlm2/parts/metrics/__init__.py +++ b/nemo/collections/speechlm2/parts/metrics/__init__.py @@ -16,10 +16,12 @@ from .bleu import BLEU from .results_logger import ResultsLogger from .token_accuracy import TokenAccuracy +from .wer import WER __all__ = [ 'ASRBLEU', 'BLEU', + 'WER', 'TokenAccuracy', 'ResultsLogger', 'Intelligibility', diff --git a/nemo/collections/speechlm2/parts/metrics/wer.py b/nemo/collections/speechlm2/parts/metrics/wer.py new file mode 100644 index 000000000000..c5ea1fb3f5ca --- /dev/null +++ b/nemo/collections/speechlm2/parts/metrics/wer.py @@ -0,0 +1,65 @@ +# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from collections import defaultdict + +import torch +from whisper_normalizer.english import EnglishTextNormalizer + +from nemo.collections.asr.metrics.wer import word_error_rate +from nemo.utils import logging + + +class WER: + """ + Computes WER on text predictions. + By default, uses Whisper's EnglishTextNormalizer on hypotheses and references. + """ + + def __init__(self, normalize: bool = True, normalizer=None, verbose: bool = True): + self.verbose = verbose + if normalize: + if normalizer is None: + self.normalizer = EnglishTextNormalizer() + else: + self.normalizer = normalizer + else: + self.normalizer = _identity + + self._refs = defaultdict(list) + self._hyps = defaultdict(list) + + def reset(self): + self._refs.clear() + self._hyps.clear() + return self + + def update(self, name: str, refs: list[str], hyps: list[str]) -> None: + for ref, hyp in zip(refs, hyps): + self._refs[name].append(self.normalizer(ref)) + self._hyps[name].append(self.normalizer(hyp)) + if self.verbose and refs and hyps: + logging.info(f"[REF]\t{refs[0]}\n[HYP]\t{hyps[0]}") + + def compute(self) -> dict[str, torch.Tensor]: + corpus_metric = {} + for name in self._refs.keys(): + metric = torch.tensor(word_error_rate(self._hyps[name], self._refs[name])) + corpus_metric[f"wer_{name}"] = metric + corpus_metric["wer"] = torch.stack(list(corpus_metric.values())).mean() + self.reset() + return corpus_metric + + +def _identity(x): + return x \ No newline at end of file diff --git a/tests/collections/speechlm2/test_metrics.py b/tests/collections/speechlm2/test_metrics.py index 9ae67f1b43c5..9be859545587 100644 --- a/tests/collections/speechlm2/test_metrics.py +++ b/tests/collections/speechlm2/test_metrics.py @@ -13,7 +13,7 @@ # limitations under the License. import pytest -from nemo.collections.speechlm2.parts.metrics import BLEU, Intelligibility +from nemo.collections.speechlm2.parts.metrics import BLEU, WER, Intelligibility def test_bleu(): @@ -34,6 +34,24 @@ def test_bleu(): assert ans["txt_bleu"] == 50.0 # average across datasets +def test_wer(): + metric = WER(verbose=False) + metric.update( + name="dataset_1", + refs=["a b c d e f g h i j k l", "m n o p r s t u v"], + hyps=["a b c d e f g h i j k l", "m n o p r s t u v"], + ) + metric.update( + name="dataset_2", + refs=["a b c"], + hyps=["a b d"], + ) + ans = metric.compute() + assert ans["wer_dataset_1"] == 0.0 + assert ans["wer_dataset_2"] == 1 / 3 + assert ans["wer"] == 1 / 6 # average across datasets + + def test_intelligibility(): metric = Intelligibility(pretrained_asr=None, verbose=False, reuse_asr_hyps=True) metric.update( From 6951357d2cb7013f8548ed24000a6964b6a03769 Mon Sep 17 00:00:00 2001 From: Edresson Date: Wed, 26 Nov 2025 20:29:35 +0000 Subject: [PATCH 069/102] Apply isort and black reformatting Signed-off-by: Edresson --- nemo/collections/speechlm2/parts/metrics/wer.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/nemo/collections/speechlm2/parts/metrics/wer.py b/nemo/collections/speechlm2/parts/metrics/wer.py index c5ea1fb3f5ca..93fe1da6a34c 100644 --- a/nemo/collections/speechlm2/parts/metrics/wer.py +++ b/nemo/collections/speechlm2/parts/metrics/wer.py @@ -62,4 +62,4 @@ def compute(self) -> dict[str, torch.Tensor]: def _identity(x): - return x \ No newline at end of file + return x From fa33c5b449164e4de89834fc7047f8d7a32a2f63 Mon Sep 17 00:00:00 2001 From: Edresson Casanova Date: Wed, 26 Nov 2025 12:46:28 -0800 Subject: [PATCH 070/102] Do not share small embeddings Signed-off-by: Edresson Casanova --- .../speechlm2/models/duplex_ear_tts.py | 19 +------------------ 1 file changed, 1 insertion(+), 18 deletions(-) diff --git a/nemo/collections/speechlm2/models/duplex_ear_tts.py b/nemo/collections/speechlm2/models/duplex_ear_tts.py index 80adf1140951..6e2d62f22198 100644 --- a/nemo/collections/speechlm2/models/duplex_ear_tts.py +++ b/nemo/collections/speechlm2/models/duplex_ear_tts.py @@ -1413,21 +1413,7 @@ def oomptimizer_schema(self) -> dict: Return a typing schema for optimal batch size calibration for various sequence lengths using OOMptimizer. """ - return { - "cls": dict, - "inputs": [ - {"name": "source_audio", "type": NeuralType(("B", "T"), AudioSignal()), "seq_length": "input"}, - {"name": "source_audio_lens", "type": NeuralType(("B",), LengthsType()), "seq_length": "input"}, - {"name": "target_audio", "type": NeuralType(("B", "T"), AudioSignal()), "seq_length": "input"}, - {"name": "target_audio_lens", "type": NeuralType(("B",), LengthsType()), "seq_length": "input"}, - { - "name": "input_text_tokens", - "type": NeuralType(("B", "T"), LabelsType()), - "seq_length": "output", - "vocab_size": self.tokenizer.vocab_size, - }, - ], - } + raise NotImplementedError def configure_model(self) -> None: # TODO(pzelasko): refactor into separate module re-usable across models @@ -1522,13 +1508,10 @@ def configure_model(self) -> None: ) else: self.embed_text_tokens = fully_shard(self.embed_text_tokens, **fsdp_config) - # self.tts_model = fully_shard(self.tts_model, **fsdp_config) self.tts_model.mog_head = fully_shard(self.tts_model.mog_head, **fsdp_config) self.tts_model.embed_subword = fully_shard(self.tts_model.embed_subword, **fsdp_config) self.tts_model.embed_context = fully_shard(self.tts_model.embed_context, **fsdp_config) self.tts_model.embed_code = fully_shard(self.tts_model.embed_code, **fsdp_config) - self.tts_model.null_emb = fully_shard(self.tts_model.null_emb, **fsdp_config) - self.tts_model.bos_emb = fully_shard(self.tts_model.bos_emb, **fsdp_config) self.tts_model.lm_head = fully_shard(self.tts_model.lm_head, **fsdp_config) def load_state_dict(self, state_dict, strict: bool = True): From 9142d70159eb93597ef4275da4b9cb5bb9802d3f Mon Sep 17 00:00:00 2001 From: Edresson Casanova Date: Wed, 26 Nov 2025 12:48:43 -0800 Subject: [PATCH 071/102] Remove unused imports Signed-off-by: Edresson Casanova --- nemo/collections/speechlm2/models/duplex_ear_tts.py | 1 - 1 file changed, 1 deletion(-) diff --git a/nemo/collections/speechlm2/models/duplex_ear_tts.py b/nemo/collections/speechlm2/models/duplex_ear_tts.py index 6e2d62f22198..0410df4ff0bd 100644 --- a/nemo/collections/speechlm2/models/duplex_ear_tts.py +++ b/nemo/collections/speechlm2/models/duplex_ear_tts.py @@ -48,7 +48,6 @@ from nemo.collections.speechlm2.parts.optim_setup import configure_optimizers, is_frozen from nemo.collections.speechlm2.parts.precision import fp32_precision from nemo.collections.speechlm2.parts.pretrained import load_pretrained_hf, set_model_dict_for_partial_init -from nemo.core.neural_types import AudioSignal, LabelsType, LengthsType, NeuralType from nemo.utils import logging From fd517a01f43c63e8f2b9bda3ac4c215f86b6a41e Mon Sep 17 00:00:00 2001 From: Edresson Casanova Date: Thu, 27 Nov 2025 04:25:03 -0800 Subject: [PATCH 072/102] Reuse TTS get_mask_from_lengths Signed-off-by: Edresson Casanova --- .../speechlm2/data/duplex_ear_tts_dataset.py | 26 +---------------- .../speechlm2/models/duplex_ear_tts.py | 28 ------------------- 2 files changed, 1 insertion(+), 53 deletions(-) diff --git a/nemo/collections/speechlm2/data/duplex_ear_tts_dataset.py b/nemo/collections/speechlm2/data/duplex_ear_tts_dataset.py index 845910aa6892..ce4c878b1ea8 100644 --- a/nemo/collections/speechlm2/data/duplex_ear_tts_dataset.py +++ b/nemo/collections/speechlm2/data/duplex_ear_tts_dataset.py @@ -25,6 +25,7 @@ from nemo.collections.common.tokenizers import TokenizerSpec from nemo.collections.speechlm2.data.utils import get_pad_id from nemo.utils import logging +from nemo.collections.tts.parts.utils.helpers import get_mask_from_lengths def sample_audio_segments_repeat( @@ -84,31 +85,6 @@ def sample_audio_segments_repeat( return out -def get_mask_from_lengths( - lengths: torch.Tensor = None, - x: torch.Tensor = None, -) -> torch.Tensor: - """Constructs binary mask from a 1D torch tensor of input lengths - Args: - lengths: torch.tensor (torch.tensor): 1D tensor with lengths - x: torch.tensor = tensor to be used on, last dimension is for mask - Returns: - mask (torch.tensor): num_sequences x max_length binary tensor - """ - if lengths is None: - assert x is not None - return torch.ones(x.shape[-1], dtype=torch.bool, device=x.device) - else: - if x is None: - max_len = torch.max(lengths) - else: - max_len = x.shape[-1] - - ids = torch.arange(0, max_len, device=lengths.device, dtype=lengths.dtype) - mask = ids < lengths.unsqueeze(1) - return mask - - class DuplexEARTTSDataset(torch.utils.data.Dataset): """ A dataset for duplex speech-to-speech models that handles bidirectional conversations. diff --git a/nemo/collections/speechlm2/models/duplex_ear_tts.py b/nemo/collections/speechlm2/models/duplex_ear_tts.py index 0410df4ff0bd..c88277c6a1f9 100644 --- a/nemo/collections/speechlm2/models/duplex_ear_tts.py +++ b/nemo/collections/speechlm2/models/duplex_ear_tts.py @@ -140,34 +140,6 @@ def replace_control_speech_codes( return torch.where(torch.isin(speech_codes, control_codes), speech_codes[:, :1], speech_codes) -def get_mask_from_lengths( - lengths: torch.Tensor = None, x: torch.Tensor = None, pad_to_factor: int = None -) -> torch.Tensor: - """Constructs binary mask from a 1D torch tensor of input lengths - Args: - lengths: torch.tensor (torch.tensor): 1D tensor with lengths - x: torch.tensor = tensor to be used on, last dimension is for mask - Returns: - mask (torch.tensor): num_sequences x max_length binary tensor - """ - if lengths is None: - assert x is not None - return torch.ones(x.shape[-1], dtype=torch.bool, device=x.device) - else: - if x is None: - max_len = torch.max(lengths) - else: - max_len = x.shape[-1] - - if pad_to_factor is not None: - with fp32_precision(): - max_len = torch.ceil(max_len / pad_to_factor) * pad_to_factor - - ids = torch.arange(0, max_len, device=lengths.device, dtype=lengths.dtype) - mask = ids < lengths.unsqueeze(1) - return mask - - def setup_rvq_audio_codec(model): """ Sets up an ``AudioCodecModel``, initializing it from pretrained weights. From dc1901e5d4cb6759aecb71327ed152486e14deea Mon Sep 17 00:00:00 2001 From: Edresson Date: Thu, 27 Nov 2025 12:25:51 +0000 Subject: [PATCH 073/102] Apply isort and black reformatting Signed-off-by: Edresson --- nemo/collections/speechlm2/data/duplex_ear_tts_dataset.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/nemo/collections/speechlm2/data/duplex_ear_tts_dataset.py b/nemo/collections/speechlm2/data/duplex_ear_tts_dataset.py index ce4c878b1ea8..8d5d6ff78346 100644 --- a/nemo/collections/speechlm2/data/duplex_ear_tts_dataset.py +++ b/nemo/collections/speechlm2/data/duplex_ear_tts_dataset.py @@ -24,8 +24,8 @@ from nemo.collections.common.tokenizers import TokenizerSpec from nemo.collections.speechlm2.data.utils import get_pad_id -from nemo.utils import logging from nemo.collections.tts.parts.utils.helpers import get_mask_from_lengths +from nemo.utils import logging def sample_audio_segments_repeat( From 10df97f0ba5a2392c25f1e8861da20bb20b56dcd Mon Sep 17 00:00:00 2001 From: Edresson Casanova Date: Thu, 27 Nov 2025 04:29:50 -0800 Subject: [PATCH 074/102] Update rescale_state_dict Signed-off-by: Edresson Casanova --- nemo/collections/speechlm2/models/duplex_ear_tts.py | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/nemo/collections/speechlm2/models/duplex_ear_tts.py b/nemo/collections/speechlm2/models/duplex_ear_tts.py index c88277c6a1f9..1141870d1e84 100644 --- a/nemo/collections/speechlm2/models/duplex_ear_tts.py +++ b/nemo/collections/speechlm2/models/duplex_ear_tts.py @@ -181,7 +181,7 @@ def setup_audio_codec(self): def rescale_state_dict(state_dict, target_std=0.02, first_n_layers=None, layer_prefix="tts_model.backbone.layers."): """ - Rescale trainable weights in a state_dict for BF16 stability. + Rescale trainable weights in a state_dict for BF16/FP16 stability. Args: state_dict: PyTorch state_dict @@ -217,9 +217,9 @@ def rescale_state_dict(state_dict, target_std=0.02, first_n_layers=None, layer_p if not weight_tensors: if first_n_layers is not None: - print(f"⚠️ No weights found for first {first_n_layers} layers with prefix '{layer_prefix}'.") + logging.info(f"No weights found for first {first_n_layers} layers with prefix '{layer_prefix}'.") else: - print("⚠️ No weights found to rescale in state_dict.") + logging.info("No weights found to rescale in state_dict.") return state_dict # Compute global std across selected weights (on CPU) @@ -228,8 +228,8 @@ def rescale_state_dict(state_dict, target_std=0.02, first_n_layers=None, layer_p current_std = float(torch.std(flat)) scale = target_std / (current_std + 1e-8) - print( - f"📦 Rescaling state_dict " + logging.info( + f"Rescaling state_dict " f"{'(first N layers)' if first_n_layers else '(all layers)'}: " f"current std = {current_std:.6f}, target = {target_std}, scale = {scale:.6f}" ) @@ -246,7 +246,7 @@ def rescale_state_dict(state_dict, target_std=0.02, first_n_layers=None, layer_p else: new_state_dict[name] = param - print("✅ Done: weights rescaled.") + logging.info("Done: weights rescaled.") return new_state_dict From 8498ea7111f70d7da742471ebaa051e4a2f54533 Mon Sep 17 00:00:00 2001 From: Edresson Casanova Date: Thu, 27 Nov 2025 04:45:25 -0800 Subject: [PATCH 075/102] Add missing dataset docs Signed-off-by: Edresson Casanova --- .../speechlm2/data/duplex_ear_tts_dataset.py | 95 ++++++++++++++++++- 1 file changed, 92 insertions(+), 3 deletions(-) diff --git a/nemo/collections/speechlm2/data/duplex_ear_tts_dataset.py b/nemo/collections/speechlm2/data/duplex_ear_tts_dataset.py index 8d5d6ff78346..ffc51b2a5edf 100644 --- a/nemo/collections/speechlm2/data/duplex_ear_tts_dataset.py +++ b/nemo/collections/speechlm2/data/duplex_ear_tts_dataset.py @@ -634,11 +634,9 @@ def get_audio_prompt( Used when sampling random reference turns. Returns: - (audio_prompt, audio_prompt_lens): - + Tuple containing: - audio_prompt (FloatTensor [B, T]): Padded batch of reference waveforms. - - audio_prompt_lens (LongTensor [B]): Lengths of each reference waveform before padding. """ @@ -674,6 +672,32 @@ def collate_random_turn_audio( roles: set[str], recording_field: str = "target_audio", ) -> tuple[torch.Tensor, torch.Tensor]: + """ + Sample and collate reference audio from random speaker turns. + + For each cut in the batch, this function: + - selects a random supervision belonging to one of the specified roles, + - extracts the corresponding audio segment, + - collates all segments into a padded batch. + + Args: + cuts (CutSet): + Batch of cuts to sample from. + + roles (set[str]): + Set of speaker roles to consider when selecting random turns. + + recording_field (str, optional): + Name of the audio field to load from the cut + (e.g., "recording", "target_audio"). + + Returns: + Tuple containing: + - audio (FloatTensor [B, T]): + Padded batch of sampled reference waveforms. + - audio_lens (LongTensor [B]): + Lengths of each waveform before padding. + """ selected_turn_audios = [] selected_turn_audios_lens = [] for cut in cuts: @@ -701,6 +725,38 @@ def collate_token_channel( roles: set[str], add_text_bos_and_eos_in_each_turn: bool = True, ) -> tuple[torch.Tensor, torch.Tensor]: + """ + Build and collate token channels aligned to the audio frame grid. + + This function converts text supervisions into frame-level token + representations for each cut, pads them to a uniform length, and + returns both the padded tokens and their true lengths. + + Args: + cuts (CutSet): + Batch of cuts from which to extract token channels. + + tokenizer (TokenizerSpec): + Tokenizer used to convert text into token IDs. + + frame_length (Seconds): + Duration of a single audio frame, used to align text tokens + to audio frames. + + roles (set[str]): + Speaker roles whose text will be included in the token channel. + + add_text_bos_and_eos_in_each_turn (bool, optional): + Whether to insert BOS at the beginning and EOS at the end of + each speaking turn. + + Returns: + Tuple containing: + - tokens (LongTensor [B, T]): + Padded batch of frame-aligned token sequences. + - token_lens (LongTensor [B]): + Length of each sequence before padding. + """ pad_id = get_pad_id(tokenizer) tokens = [ build_token_channel( @@ -726,6 +782,39 @@ def build_token_channel( pad_id: int = -1, add_text_bos_and_eos_in_each_turn: bool = True, ) -> torch.Tensor: + """ + Build a frame-aligned token sequence for a single cut. + + This function maps speaking turns into a token channel aligned to the + audio frame grid. Tokens are inserted at frame positions corresponding + to supervision start times, with optional BOS and EOS insertion. + + Any region not covered by text is filled with `pad_id`. + + Args: + cut (Cut): + Input cut containing audio and supervisions. + + tokenizer (TokenizerSpec): + Tokenizer used to encode text into tokens. + + frame_length (Seconds): + Duration of one frame, used to align text to the audio grid. + + roles (set[str]): + Speaker roles whose text should be included. + + pad_id (int, optional): + Token ID used for padding empty frames. + + add_text_bos_and_eos_in_each_turn (bool, optional): + Whether to insert BOS before and EOS after each supervision. + + Returns: + tokens (LongTensor [T]): + Frame-aligned token sequence for the cut. + """ + diagnostic = f"Extra info: {cut.id=}" if getattr(cut, "shard_origin", None) is not None: diagnostic = f"{diagnostic} {cut.shard_origin=}" From b887a776652f2c409605e0c07cfeffe72b0d8492 Mon Sep 17 00:00:00 2001 From: Edresson Casanova Date: Thu, 27 Nov 2025 10:04:51 -0800 Subject: [PATCH 076/102] Remove Pretrained class and ear_tts_commons.py Signed-off-by: Edresson Casanova --- .../speechlm2/models/duplex_ear_tts.py | 38 +- .../speechlm2/modules/ear_tts_commons.py | 331 ------------------ .../speechlm2/modules/ear_tts_model.py | 8 +- .../speechlm2/modules/ear_tts_vae_codec.py | 10 +- .../collections/speechlm2/parts/pretrained.py | 29 +- 5 files changed, 51 insertions(+), 365 deletions(-) delete mode 100644 nemo/collections/speechlm2/modules/ear_tts_commons.py diff --git a/nemo/collections/speechlm2/models/duplex_ear_tts.py b/nemo/collections/speechlm2/models/duplex_ear_tts.py index 1141870d1e84..0e2e81694471 100644 --- a/nemo/collections/speechlm2/models/duplex_ear_tts.py +++ b/nemo/collections/speechlm2/models/duplex_ear_tts.py @@ -47,7 +47,7 @@ from nemo.collections.speechlm2.parts.metrics.secs import SECS from nemo.collections.speechlm2.parts.optim_setup import configure_optimizers, is_frozen from nemo.collections.speechlm2.parts.precision import fp32_precision -from nemo.collections.speechlm2.parts.pretrained import load_pretrained_hf, set_model_dict_for_partial_init +from nemo.collections.speechlm2.parts.pretrained import load_pretrained_hf, set_model_dict_for_partial_init, load_checkpoint from nemo.utils import logging @@ -151,19 +151,12 @@ def setup_rvq_audio_codec(model): return # skip if already set up and has the right dtype with ensures_target_precision(model.audio_codec_run_dtype): - if model.cfg.get("pretrained_ae_dir", None): - model.audio_codec = ( - RVQVAEModel.from_pretrained( - model.cfg.pretrained_ae_dir, - cfg=DictConfig(model.cfg.codec_config) if model.cfg.get("codec_config", None) else None, - strict=False, - ) - .eval() - .to(model.device) - ) - else: - # init codec from config - model.audio_codec = RVQVAEModel(DictConfig(model.cfg.codec_config)) + model.audio_codec = RVQVAEModel(DictConfig(model.cfg.codec_config)) + # load pretrained codec checkpoint + if model.cfg.get("pretrained_codec_model", None): + checkpoint_state = load_checkpoint(model.cfg.pretrained_codec_model) + checkpoint_state = set_model_dict_for_partial_init(checkpoint_state, model.audio_codec.state_dict()) + model.audio_codec.load_state_dict(checkpoint_state, strict=True) for p in model.audio_codec.parameters(): p.requires_grad = False @@ -367,13 +360,14 @@ def _load_embed_tokens(self, cfg) -> nn.Embedding: def _load_tts_model(self, cfg) -> nn.Module: """Load TTS model for RVQ-EAR-TTS.""" + # instanciate tts model + self.tts_model = RVQEARTTSModel(DictConfig(cfg.tts_config)) + + # load pretrained tts checkpoint if self.cfg.get("pretrained_tts_model", None): - self.tts_model = RVQEARTTSModel.from_pretrained( - cfg.pretrained_tts_model, DictConfig(cfg.tts_config), strict=False - ) - else: - # start the model from scratch - self.tts_model = RVQEARTTSModel(DictConfig(cfg.tts_config)) + checkpoint_state = load_checkpoint(self.cfg.pretrained_tts_model) + checkpoint_state = set_model_dict_for_partial_init(checkpoint_state, self.tts_model.state_dict()) + self.tts_model.load_state_dict(checkpoint_state, strict=True) setup_audio_codec(self) @@ -389,7 +383,7 @@ def _load_language_model(self, cfg): def restore_from_pretrained_checkpoint(self, checkpoint_path): """ - Loads model weights a pretrained checkpoint file, supporting partial loading from .nemo and PyTorch formats. + Loads model weights a pretrained checkpoint file, supporting partial loading from safetensor and PyTorch formats. Args: checkpoint_path (str): Path to checkpoint file. @@ -398,7 +392,7 @@ def restore_from_pretrained_checkpoint(self, checkpoint_path): None. The model is updated in-place. """ if checkpoint_path is not None: - checkpoint_state = torch.load(checkpoint_path, weights_only=False, map_location='cpu')['state_dict'] + checkpoint_state = load_checkpoint(checkpoint_path) checkpoint_state = set_model_dict_for_partial_init(checkpoint_state, self.state_dict()) if self.cfg.get("rescale_pretrained_weights", None): diff --git a/nemo/collections/speechlm2/modules/ear_tts_commons.py b/nemo/collections/speechlm2/modules/ear_tts_commons.py deleted file mode 100644 index 3507f9fd8b5a..000000000000 --- a/nemo/collections/speechlm2/modules/ear_tts_commons.py +++ /dev/null @@ -1,331 +0,0 @@ -# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import argparse -import glob -import importlib.machinery -import json -import os -import re -import shutil -from collections.abc import Mapping -from typing import Any - -from omegaconf import DictConfig -from safetensors import safe_open -from torch import nn - -from nemo.utils import logging - -# ============================================================================== -# Contants -# ============================================================================== -PYTHON_CONFIG_GETTER_NAME = "get_config" -CHECKPOINT_FORMAT = "checkpoint_{}/ema.safetensors" -CONFIG_NAME = "config.json" -GIT_HASH_NAME = "githash" -SCRIPT_PLACEHOLDER = "[[[<<>>]]]" - - -# ============================================================================== -# Configuration Class and Utilities -# ============================================================================== -def get_config_from_file(config_path: str) -> DictConfig: - """ - Loads a configuration from a JSON or Python file. - - - For JSON files (`*.json`), it parses the file directly. - - For Python files (`*.py`), it imports the file as a module and calls a - `get_config()` function within it. - - It also supports a special syntax `path/to/config.py:config_name` to select - a specific configuration from a Python file that returns a dictionary of configs. - - Args: - config_path (str): The path to the configuration file. - - Returns: - Config: The loaded configuration object. - - Raises: - AssertionError: If the file path is invalid, does not exist, or is not in - the expected format. - """ - - match = re.search(r".+\.((json)|(py)|(py:.+))$", config_path) - assert match, f"Only Python (*.py) or JSON (*.json) files are supported, but got {config_path}." - - py_config_name: str | None = None - if not (config_path.endswith(".py") or config_path.endswith(".json")): - config_path_split = config_path.split(":") - config_path = ":".join(config_path_split[:-1]) - py_config_name = config_path_split[-1] - - assert os.path.isfile(config_path), f"Configuration file not found at: {config_path}" - - if config_path.endswith(".json"): - with open(config_path) as f: - config = json.load(f) - else: - config_module = importlib.machinery.SourceFileLoader("_config", config_path).load_module() - assert hasattr( - config_module, PYTHON_CONFIG_GETTER_NAME - ), f"Python config file must define a `{PYTHON_CONFIG_GETTER_NAME}` function." - config = getattr(config_module, PYTHON_CONFIG_GETTER_NAME)(py_config_name) - assert isinstance(config, Mapping), f"`{PYTHON_CONFIG_GETTER_NAME}` must return a dictionary-like object." - cfg = DictConfig(config) - return cfg - - -def get_config() -> DictConfig: - """ - Parses command-line arguments to load the main configuration for a training run. - - This function implements a hierarchical configuration loading strategy: - 1. It checks if a `config.json` exists in the specified `--workdir`. If so, it loads it. - 2. If a `--config` argument is also provided, it uses that file to update the - configuration loaded from the work directory. - 3. If no config exists in the work directory, it requires the `--config` argument - to be provided as the base configuration. - - This allows for resuming training from a work directory while also being able to - override specific parameters for a new run. - - Returns: - Config: The final, consolidated configuration object. - """ - parser = argparse.ArgumentParser(description="Load training configuration.") - parser.add_argument("-c", "--config", type=str, default=None, help="Path to a Python or JSON configuration file.") - parser.add_argument( - "-w", "--workdir", type=str, required=True, help="Work directory to save logs and checkpoints." - ) - - args = parser.parse_args() - workdir_path = args.workdir - config_save_path = os.path.join(workdir_path, CONFIG_NAME) - - if os.path.exists(config_save_path): - logging.info(f"Resuming from work directory. Loading configuration from {config_save_path}.") - cfg = get_config_from_file(config_save_path) - if args.config and args.config != config_save_path: - logging.info(f"Updating loaded configuration with parameters from {args.config}.") - override_cfg = get_config_from_file(args.config) - cfg.update(override_cfg) - else: - assert args.config is not None, "A configuration file must be specified via `-c` or `--config` for a new run." - logging.info(f"Starting a new run. Loading configuration from {args.config}.") - cfg = get_config_from_file(args.config) - cfg.workdir_path = workdir_path - - return cfg - - -def get_config_from_dir(workdir_path: str) -> DictConfig: - """ - A simple utility to load the configuration directly from a work directory. - - Args: - workdir_path (str): The path to the work directory containing a `config.json`. - - Returns: - DictConfig: The loaded configuration object. - """ - config_save_path = os.path.join(workdir_path, CONFIG_NAME) - cfg = get_config_from_file(config_save_path) - cfg.workdir_path = workdir_path - return cfg - - -# ============================================================================== -# Base Model Classes -# ============================================================================== - - -class PreTrainedModel(nn.Module): - config_class = DictConfig - - """ - A base class for models to handle loading from pretrained checkpoints. - - This class provides a common interface for initializing a model and loading - weights from a saved checkpoint, following a pattern similar to libraries - like Hugging Face's Transformers. - - Args: - config (DictConfig | dict[str, Any]): A configuration object containing model hyperparameters. - """ - - def __init__(self, config: DictConfig | dict[str, Any], *args, **kwargs): - super().__init__() - self.config = config if isinstance(config, self.config_class) else self.config_class(config) - - @classmethod - def from_pretrained( - cls, - pretrained_dir: str, - cfg: DictConfig | dict[str, Any] | None = None, - checkpoint_regex: str = "checkpoint_*/ema.safetensors", - strict: bool = False, - **model_kwargs, - ) -> "PreTrainedModel": - """ - Loads a pretrained model from a directory. - - This method first loads the configuration file from the specified directory, - initializes the model with this configuration, and then loads the weights - from the latest checkpoint file found in that directory. - - Args: - cls (type): The model class to instantiate. - pretrained_dir (str): The directory containing the pretrained model - config and checkpoint files. - cfg (DictConfig | dict[str, Any] | None, optional): An optional config object to override - the loaded config. Defaults to None. - checkpoint_regex (str, optional): A regex pattern to find the checkpoint - file. Defaults to "checkpoint_*/ema.safetensors". - strict (bool, optional): Whether to strictly enforce that the keys in - the checkpoint match the keys of the model. - Defaults to False. - **model_kwargs: Additional keyword arguments to pass to the model's - constructor. - - Returns: - PreTrainedModel: An instance of the model with loaded weights. - """ - pretrained_cfg = get_config_from_dir(pretrained_dir).model - if cfg is not None: - pretrained_cfg.update(cfg) - logging.info(f"The loaded config of the pretrained model is updated to: {pretrained_cfg}") - model = cls( - pretrained_cfg, - **model_kwargs, - ) - model_state_dict = {} - with safe_open(latest_checkpoint_path(pretrained_dir, checkpoint_regex), framework="pt", device="cpu") as f: - for key in f.keys(): - model_state_dict[key] = f.get_tensor(key) - model.load_state_dict(model_state_dict, strict=strict) - return model - - def get_optimizer_param_groups(self, weight_decay: float = 0.0) -> list[dict]: - """ - Separates model parameters into two groups: one with weight decay and one without. - - This is a common practice in training deep learning models, where weight decay - is typically applied to the weights of linear and convolutional layers, but not - to biases or normalization layer parameters. - - Args: - weight_decay (float, optional): The weight decay value to apply to the - first group of parameters. Defaults to 0.0. - - Returns: - list[dict]: A list of two dictionaries, each suitable for an optimizer's - parameter groups. The first group has weight decay, and the - second does not. - """ - - def _get_weight_names(module): - """Recursively finds the names of all 'weight' parameters in conv/linear layers.""" - result = [] - is_weight_layer = isinstance( - module, - ( - nn.Linear - | nn.Conv1d - | nn.Conv2d - | nn.Conv3d - | nn.ConvTranspose1d - | nn.ConvTranspose2d - | nn.ConvTranspose3d - ), - ) - if is_weight_layer: - result.append("weight") - else: - for name, child in module.named_children(): - result += [f"{name}.{n}" for n in _get_weight_names(child)] - return result - - # Separate parameters - params_w_decay, params_wo_decay = [], [] - param_names_w_decay = set(_get_weight_names(self)) - - for n, p in self.named_parameters(): - if p.requires_grad: - if n in param_names_w_decay: - params_w_decay.append(p) - else: - params_wo_decay.append(p) - return [ - {"params": params_w_decay, "weight_decay": weight_decay}, - {"params": params_wo_decay, "weight_decay": 0.0}, - ] - - -# ============================================================================== -# IO and Checkpointing Utilities -# ============================================================================== - - -def latest_checkpoint_path(dir_path: str, regex: str | None = None) -> str: - """ - Finds the path of the latest checkpoint file or directory in a directory. - - The latest checkpoint is determined by sorting the filenames alphanumerically - and picking the last one. This assumes a naming convention like `checkpoint_1000.pt`, - `checkpoint_2000.pt`, etc. - - Args: - dir_path (str): The directory to search for checkpoints. - regex (str | None, optional): A glob pattern to match checkpoint files. If None, - a default pattern is used. Defaults to None. - - Returns: - str: The full path to the latest checkpoint file. - - Raises: - AssertionError: If no files matching the regex are found in the directory. - """ - if regex is None: - regex = CHECKPOINT_FORMAT.format("*") - - f_list = glob.glob(os.path.join(dir_path, regex)) - if not f_list: - raise FileNotFoundError(f"No checkpoint files or directories found in {dir_path} matching '{regex}'") - - # Sort files based on the integer values in their names - f_list.sort(key=lambda f: int("".join(filter(str.isdigit, f)))) - - latest_path = f_list[-1] - logging.info(f"Latest checkpoint '{os.path.relpath(latest_path, start=dir_path)}' found in '{dir_path}'.") - return latest_path - - -def manage_checkpoints(dir_path: str, max_checkpoints: int, regex: str | None = None): - """Keeps the most recent checkpoints and deletes older ones.""" - if regex is None: - regex = CHECKPOINT_FORMAT.format("*") - - checkpoints = glob.glob(os.path.join(dir_path, regex)) - - if len(checkpoints) > max_checkpoints: - # Sort files based on the integer values in their names - checkpoints.sort(key=lambda f: int("".join(filter(str.isdigit, f)))) - num_to_delete = len(checkpoints) - max_checkpoints - for old_checkpoint in checkpoints[:num_to_delete]: - logging.info(f"Deleting old checkpoint: {old_checkpoint}") - if os.path.isfile(old_checkpoint): - os.remove(old_checkpoint) - else: - shutil.rmtree(old_checkpoint) diff --git a/nemo/collections/speechlm2/modules/ear_tts_model.py b/nemo/collections/speechlm2/modules/ear_tts_model.py index d366e0dee45c..071316b073ad 100644 --- a/nemo/collections/speechlm2/modules/ear_tts_model.py +++ b/nemo/collections/speechlm2/modules/ear_tts_model.py @@ -26,7 +26,6 @@ from transformers import AutoConfig, AutoModel, AutoModelForTextEncoding, AutoTokenizer, Cache from transformers.generation.logits_process import TopKLogitsWarper, TopPLogitsWarper -from nemo.collections.speechlm2.modules.ear_tts_commons import PreTrainedModel from nemo.collections.speechlm2.parts.precision import fp32_precision from nemo.collections.speechlm2.parts.pretrained import set_model_dict_for_partial_init from nemo.utils import logging @@ -1130,7 +1129,7 @@ def forward(self, audio_emb, text_emb): return h -class RVQEARTTSModel(PreTrainedModel): +class RVQEARTTSModel(nn.Module): """ Main RVQEARTTS model for training and inference. @@ -1152,11 +1151,12 @@ class RVQEARTTSModel(PreTrainedModel): Args: config (DictConfig | dict[str, Any]): The configuration object for the model. """ - + config_class = DictConfig rvq_embs: Tensor def __init__(self, config: DictConfig | dict[str, Any]): - super().__init__(config) + super().__init__() + self.config = config # Backbone module if self.config.get("pretrained_text_name", None): diff --git a/nemo/collections/speechlm2/modules/ear_tts_vae_codec.py b/nemo/collections/speechlm2/modules/ear_tts_vae_codec.py index df0a6356b172..1d2aaa8ad44c 100644 --- a/nemo/collections/speechlm2/modules/ear_tts_vae_codec.py +++ b/nemo/collections/speechlm2/modules/ear_tts_vae_codec.py @@ -19,16 +19,11 @@ from typing import Any, Concatenate import librosa - -# Third-party import torch from omegaconf import DictConfig from torch import Tensor, nn from torch.nn import functional as F -# Project -from nemo.collections.speechlm2.modules.ear_tts_commons import PreTrainedModel - @contextmanager def disable_tf32(): @@ -875,7 +870,7 @@ def forward(self, x: Tensor, cache=None, flush: bool = False, constrain_value_ra return x -class RVQVAEModel(PreTrainedModel): +class RVQVAEModel(nn.Module): """ Residual Vector-Quantized Variational Autoencoder (RVQ-VAE) model. @@ -890,7 +885,8 @@ class RVQVAEModel(PreTrainedModel): config_class: type[DictConfig] = DictConfig def __init__(self, config: DictConfig | dict[str, Any]): - super().__init__(config) + super().__init__() + self.config = config self.encoder = Wav2Latent( latent_size=self.config.latent_size, diff --git a/nemo/collections/speechlm2/parts/pretrained.py b/nemo/collections/speechlm2/parts/pretrained.py index 8aa3645fbeff..de172265a657 100644 --- a/nemo/collections/speechlm2/parts/pretrained.py +++ b/nemo/collections/speechlm2/parts/pretrained.py @@ -25,6 +25,7 @@ from nemo.collections.speechlm2.parts.precision import fp32_precision from nemo.collections.tts.models import AudioCodecModel from nemo.utils import logging +from safetensors.torch import load_file def load_pretrained_nemo(cls, model_path_or_name: str): @@ -127,7 +128,7 @@ def set_model_dict_for_partial_init( Example: >>> model_dict = model.state_dict() - >>> pretrained_dict = torch.load("pretrained_model.pt") + >>> pretrained_dict = load_checkpoint("pretrained_model.ckpt") >>> model_dict = set_model_dict_for_partial_init(pretrained_dict, model_dict) >>> model.load_state_dict(model_dict) """ @@ -145,3 +146,29 @@ def set_model_dict_for_partial_init( logging.info(f" | > {len(pretrained_dict)} / {len(model_dict)} layers are restored.") return model_dict + + +def load_checkpoint(checkpoint_path): + """ + Load a model checkpoint from disk. + + Supports loading checkpoints stored in either PyTorch (`.ckpt`, `.pt`) or + SafeTensors (`.safetensors`) formats. All parameters are loaded onto CPU + regardless of the original device. + + Args: + checkpoint_path (str): + Path to the checkpoint file. If the filename contains `.safetensors`, + it is loaded using the SafeTensors backend; otherwise, it is assumed + to be a PyTorch checkpoint containing a `state_dict` field. + + Returns: + dict: + A state dictionary mapping parameter names to tensors. + """ + if ".safetensors" in checkpoint_path: + checkpoint_state = load_file(checkpoint_path, device="cpu") + else: + checkpoint_state = torch.load(checkpoint_path, weights_only=False, map_location="cpu")["state_dict"] + return checkpoint_state + From 926e17911fb78ee2eb8c613971e63646e3eded36 Mon Sep 17 00:00:00 2001 From: Edresson Date: Thu, 27 Nov 2025 18:05:39 +0000 Subject: [PATCH 077/102] Apply isort and black reformatting Signed-off-by: Edresson --- nemo/collections/speechlm2/models/duplex_ear_tts.py | 6 +++++- nemo/collections/speechlm2/modules/ear_tts_model.py | 1 + nemo/collections/speechlm2/parts/pretrained.py | 3 +-- 3 files changed, 7 insertions(+), 3 deletions(-) diff --git a/nemo/collections/speechlm2/models/duplex_ear_tts.py b/nemo/collections/speechlm2/models/duplex_ear_tts.py index 0e2e81694471..4b23e9809232 100644 --- a/nemo/collections/speechlm2/models/duplex_ear_tts.py +++ b/nemo/collections/speechlm2/models/duplex_ear_tts.py @@ -47,7 +47,11 @@ from nemo.collections.speechlm2.parts.metrics.secs import SECS from nemo.collections.speechlm2.parts.optim_setup import configure_optimizers, is_frozen from nemo.collections.speechlm2.parts.precision import fp32_precision -from nemo.collections.speechlm2.parts.pretrained import load_pretrained_hf, set_model_dict_for_partial_init, load_checkpoint +from nemo.collections.speechlm2.parts.pretrained import ( + load_checkpoint, + load_pretrained_hf, + set_model_dict_for_partial_init, +) from nemo.utils import logging diff --git a/nemo/collections/speechlm2/modules/ear_tts_model.py b/nemo/collections/speechlm2/modules/ear_tts_model.py index 071316b073ad..1e00b9a77cff 100644 --- a/nemo/collections/speechlm2/modules/ear_tts_model.py +++ b/nemo/collections/speechlm2/modules/ear_tts_model.py @@ -1151,6 +1151,7 @@ class RVQEARTTSModel(nn.Module): Args: config (DictConfig | dict[str, Any]): The configuration object for the model. """ + config_class = DictConfig rvq_embs: Tensor diff --git a/nemo/collections/speechlm2/parts/pretrained.py b/nemo/collections/speechlm2/parts/pretrained.py index de172265a657..c43da2a9747a 100644 --- a/nemo/collections/speechlm2/parts/pretrained.py +++ b/nemo/collections/speechlm2/parts/pretrained.py @@ -18,6 +18,7 @@ import torch from omegaconf import open_dict from peft import PeftModel +from safetensors.torch import load_file from transformers import AutoConfig, AutoModelForCausalLM from nemo.collections.asr.models import ASRModel @@ -25,7 +26,6 @@ from nemo.collections.speechlm2.parts.precision import fp32_precision from nemo.collections.tts.models import AudioCodecModel from nemo.utils import logging -from safetensors.torch import load_file def load_pretrained_nemo(cls, model_path_or_name: str): @@ -171,4 +171,3 @@ def load_checkpoint(checkpoint_path): else: checkpoint_state = torch.load(checkpoint_path, weights_only=False, map_location="cpu")["state_dict"] return checkpoint_state - From be7f6897d0033a9498795aee962a571ccf1fa568 Mon Sep 17 00:00:00 2001 From: Edresson Casanova Date: Thu, 27 Nov 2025 10:11:11 -0800 Subject: [PATCH 078/102] Add top-level comment on Logger Signed-off-by: Edresson Casanova --- .../speechlm2/parts/metrics/results_logger.py | 24 +++++++++++++++++++ 1 file changed, 24 insertions(+) diff --git a/nemo/collections/speechlm2/parts/metrics/results_logger.py b/nemo/collections/speechlm2/parts/metrics/results_logger.py index 0caed8bb4601..5bcdacf9e512 100644 --- a/nemo/collections/speechlm2/parts/metrics/results_logger.py +++ b/nemo/collections/speechlm2/parts/metrics/results_logger.py @@ -21,6 +21,30 @@ from nemo.collections.audio.parts.utils.resampling import resample from nemo.utils import logging +""" +Utilities for logging evaluation results of SpeechLM2 collection models. + +This file provides helper functionality for saving audio outputs and structured +metadata during evaluation or inference of Duplex speech-to-speech / TTS models. +It is primarily responsible for: + + - Writing predicted waveforms to disk. + - Merging user and model audio into multi-channel WAV files for analysis. + - Exporting metadata (reference text, predictions, ASR output) into JSONL format. + - Saving auxiliary debug artifacts such as: + * teacher-forced predictions, + * reference audio, + * trimmed outputs, + * end-of-utterance (EOU) probability signals. + +Unlike other files in this directory, which focus on metric evaluation, this module +is dedicated to persisting model outputs — including predicted audio samples and +their associated metadata — for later inspection and analysis. + +Key abstraction: + - `ResultsLogger`: A lightweight utility class that manages audio dumping + and metadata bookkeeping across inference batches. +""" def safe_remove_path(path): shutil.rmtree(path, ignore_errors=True) From 912ef24ee400ca712237a187566969428acad15f Mon Sep 17 00:00:00 2001 From: Edresson Date: Thu, 27 Nov 2025 18:12:08 +0000 Subject: [PATCH 079/102] Apply isort and black reformatting Signed-off-by: Edresson --- nemo/collections/speechlm2/parts/metrics/results_logger.py | 1 + 1 file changed, 1 insertion(+) diff --git a/nemo/collections/speechlm2/parts/metrics/results_logger.py b/nemo/collections/speechlm2/parts/metrics/results_logger.py index 5bcdacf9e512..c9f66212c137 100644 --- a/nemo/collections/speechlm2/parts/metrics/results_logger.py +++ b/nemo/collections/speechlm2/parts/metrics/results_logger.py @@ -46,6 +46,7 @@ and metadata bookkeeping across inference batches. """ + def safe_remove_path(path): shutil.rmtree(path, ignore_errors=True) From 7ef91dbd534c06ac69ae98357b246b3a82344f89 Mon Sep 17 00:00:00 2001 From: Edresson Casanova Date: Fri, 28 Nov 2025 10:46:38 -0800 Subject: [PATCH 080/102] Refactor checkpoint loading Signed-off-by: Edresson Casanova --- examples/speechlm2/duplex_eartts_eval.py | 10 ++- examples/speechlm2/duplex_eartts_train.py | 15 ++++ .../speechlm2/models/duplex_ear_tts.py | 76 +++++++++++-------- .../speechlm2/test_duplex_eartts.py | 1 + 4 files changed, 68 insertions(+), 34 deletions(-) diff --git a/examples/speechlm2/duplex_eartts_eval.py b/examples/speechlm2/duplex_eartts_eval.py index 8ad5ba06acc4..b55ad5219f26 100644 --- a/examples/speechlm2/duplex_eartts_eval.py +++ b/examples/speechlm2/duplex_eartts_eval.py @@ -38,9 +38,13 @@ def inference(cfg): OmegaConf.save(cfg, log_dir / "exp_config.yaml") with trainer.init_module(): - model = DuplexEARTTS(OmegaConf.to_container(cfg, resolve=True)) - - model.eval() + if cfg.get("checkpoint_path", None): + model = DuplexEARTTS.load_from_checkpoint( + cfg.checkpoint_path, + cfg=OmegaConf.to_container(cfg, resolve=True), + ) + else: + raise ValueError("For evaluation, you must provide `cfg.checkpoint_path`.") dataset = DuplexEARTTSDataset( tokenizer=model.tokenizer, diff --git a/examples/speechlm2/duplex_eartts_train.py b/examples/speechlm2/duplex_eartts_train.py index 8fd4da3bd5fd..f9ee99af492e 100644 --- a/examples/speechlm2/duplex_eartts_train.py +++ b/examples/speechlm2/duplex_eartts_train.py @@ -24,6 +24,11 @@ from nemo.utils.exp_manager import exp_manager from nemo.utils.trainer_utils import resolve_trainer_cfg +from nemo.collections.speechlm2.parts.pretrained import ( + load_checkpoint, + set_model_dict_for_partial_init, +) + torch.cuda.set_device(int(os.environ["LOCAL_RANK"])) @@ -40,6 +45,16 @@ def train(cfg): with trainer.init_module(): model = DuplexEARTTS(OmegaConf.to_container(cfg, resolve=True)) + # load pretrained tts checkpoint if available + if model.cfg.get("pretrained_tts_model", None): + checkpoint_state = load_checkpoint(model.cfg.pretrained_tts_model) + checkpoint_state = set_model_dict_for_partial_init(checkpoint_state, model.tts_model.state_dict()) + model.tts_model.load_state_dict(checkpoint_state, strict=True) + + # load pretrained checkpoint and rescale the weights if needed + if model.cfg.get("pretrained_model", None): + model.restore_from_pretrained_checkpoint(model.cfg.pretrained_model) + dataset = DuplexEARTTSDataset( tokenizer=model.tokenizer, frame_length=cfg.data.frame_length, diff --git a/nemo/collections/speechlm2/models/duplex_ear_tts.py b/nemo/collections/speechlm2/models/duplex_ear_tts.py index 4b23e9809232..5bff5668c64e 100644 --- a/nemo/collections/speechlm2/models/duplex_ear_tts.py +++ b/nemo/collections/speechlm2/models/duplex_ear_tts.py @@ -144,16 +144,42 @@ def replace_control_speech_codes( return torch.where(torch.isin(speech_codes, control_codes), speech_codes[:, :1], speech_codes) -def setup_rvq_audio_codec(model): +def ensures_codec_target_dtype(model): """ - Sets up an ``AudioCodecModel``, initializing it from pretrained weights. - The result is assigned to ``model.audio_codec`` attribute. + Ensures the audio codec is instantiated with the target dtype. + + This function checks whether `model.audio_codec` exists and whether its + parameters match `model.audio_codec_run_dtype`. If the codec is missing + or is running with the wrong dtype (e.g., due to PTL auto-downcasting), + the codec is reloaded by calling `setup_audio_codec()`. + + Intended to be called at runtime boundaries such as: + - `on_train_epoch_start` + - `on_validation_epoch_start` + + Args: + model: Model instance of DuplexEARTTS - Includes a workaround for PTL auto-downcasting the codec model to bf16 with bf16-true precision. """ if hasattr(model, "audio_codec") and next(model.audio_codec.parameters()).dtype == model.audio_codec_run_dtype: - return # skip if already set up and has the right dtype + return # already correct precision → no-op + + setup_audio_codec(model) + +def setup_audio_codec(model): + """ + Instantiates the RVQ audio codec and injects codec embeddings into the TTS model. + + This function is responsible only for: + - Instantiating the codec model (`RVQVAEModel`). + - Loading pretrained codec weights (if configured). + - Freezing codec parameters. + - Registering RVQ embeddings inside the TTS model via `set_rvq_embs`. + + Args: + model: Model instance of DuplexEARTTS + """ with ensures_target_precision(model.audio_codec_run_dtype): model.audio_codec = RVQVAEModel(DictConfig(model.cfg.codec_config)) # load pretrained codec checkpoint @@ -165,15 +191,13 @@ def setup_rvq_audio_codec(model): for p in model.audio_codec.parameters(): p.requires_grad = False + assert callable(model.tts_model.set_rvq_embs) -def setup_audio_codec(self): - setup_rvq_audio_codec(self) - assert callable(self.tts_model.set_rvq_embs) - self.tts_model.set_rvq_embs(torch.stack([x.detach() for x in self.audio_codec.prvq.mus_list], 0)) - self.tts_model.rvq_embs = self.tts_model.rvq_embs.to(next(self.tts_model.parameters()).dtype) + model.tts_model.set_rvq_embs(torch.stack([x.detach() for x in model.audio_codec.prvq.mus_list], 0)) + model.tts_model.rvq_embs = model.tts_model.rvq_embs.to(next(model.tts_model.parameters()).dtype) # compute target fps - self.target_fps = self.target_sample_rate / self.audio_codec.config.wav_to_token_ratio - self.target_samples_per_frame = self.audio_codec.config.wav_to_token_ratio + model.target_fps = model.target_sample_rate / model.audio_codec.config.wav_to_token_ratio + model.target_samples_per_frame = model.audio_codec.config.wav_to_token_ratio def rescale_state_dict(state_dict, target_std=0.02, first_n_layers=None, layer_prefix="tts_model.backbone.layers."): @@ -279,8 +303,11 @@ def __init__(self, cfg: dict) -> None: # get codec run precision self.audio_codec_run_dtype = getattr(torch, self.cfg.get("audio_codec_run_dtype", "float32"), torch.float32) - # instanciate eartts model and codec - self._load_tts_model(self.cfg) + # Instantiate TTS model + self.tts_model = RVQEARTTSModel(DictConfig(self.cfg.tts_config)) + # Load and initialize audio codec, and bind RVQ embeddings to the TTS model + setup_audio_codec(self) + self._codebook_size = self.tts_model.config.codebook_size # compute source fps @@ -311,8 +338,7 @@ def __init__(self, cfg: dict) -> None: self._use_fsdp = False self._use_tp = False - if self.cfg.get("pretrained_model", None): - self.restore_from_pretrained_checkpoint(self.cfg.pretrained_model) + def get_codec_silence_frame_last_one(self): audio = torch.zeros(1, 10 * self.target_sample_rate).float().to(self.device) @@ -362,19 +388,6 @@ def _load_embed_tokens(self, cfg) -> nn.Embedding: embed_tokens.load_state_dict(embed_tokens_state_dict) return embed_tokens - def _load_tts_model(self, cfg) -> nn.Module: - """Load TTS model for RVQ-EAR-TTS.""" - # instanciate tts model - self.tts_model = RVQEARTTSModel(DictConfig(cfg.tts_config)) - - # load pretrained tts checkpoint - if self.cfg.get("pretrained_tts_model", None): - checkpoint_state = load_checkpoint(self.cfg.pretrained_tts_model) - checkpoint_state = set_model_dict_for_partial_init(checkpoint_state, self.tts_model.state_dict()) - self.tts_model.load_state_dict(checkpoint_state, strict=True) - - setup_audio_codec(self) - def _load_language_model(self, cfg): """Load language model for RVQ-EAR-TTS.""" if cfg.pretrained_lm_name: @@ -640,7 +653,7 @@ def training_step(self, batch: dict, batch_idx: int): return ans def on_train_epoch_start(self) -> None: - setup_audio_codec(self) # potentially reloads the audio codec to make sure it's in fp32 + ensures_codec_target_dtype(self) # potentially reloads the audio codec to make sure it's in target codec precision def on_train_epoch_end(self) -> None: # log model stats to debug gradient weights issues @@ -688,7 +701,8 @@ def log_model_stats(self): self.log("weights/mean", weight_mean, on_epoch=True, sync_dist=True) def on_validation_epoch_start(self) -> None: - setup_audio_codec(self) + ensures_codec_target_dtype(self) # potentially reloads the audio codec to make sure it's in target codec precision + self.results_logger = ResultsLogger(self.validation_save_path).reset() self.asr_bleu = ASRBLEU(self.cfg.scoring_asr).reset() self.intelligibility = Intelligibility(self.cfg.scoring_asr, reuse_asr_hyps=True).reset() diff --git a/tests/collections/speechlm2/test_duplex_eartts.py b/tests/collections/speechlm2/test_duplex_eartts.py index 77056a899893..8381ce89de40 100644 --- a/tests/collections/speechlm2/test_duplex_eartts.py +++ b/tests/collections/speechlm2/test_duplex_eartts.py @@ -21,6 +21,7 @@ from nemo.collections.speechlm2.data import DuplexEARTTSDataset from nemo.collections.speechlm2.models import DuplexEARTTS + if torch.cuda.is_available(): torch.set_default_device('cuda') From be12a72cb1f010b4dcbe8b7ac424787f8120a27d Mon Sep 17 00:00:00 2001 From: Edresson Date: Fri, 28 Nov 2025 18:47:39 +0000 Subject: [PATCH 081/102] Apply isort and black reformatting Signed-off-by: Edresson --- examples/speechlm2/duplex_eartts_train.py | 9 ++------- nemo/collections/speechlm2/models/duplex_ear_tts.py | 9 ++++++--- 2 files changed, 8 insertions(+), 10 deletions(-) diff --git a/examples/speechlm2/duplex_eartts_train.py b/examples/speechlm2/duplex_eartts_train.py index f9ee99af492e..2229be46e144 100644 --- a/examples/speechlm2/duplex_eartts_train.py +++ b/examples/speechlm2/duplex_eartts_train.py @@ -18,17 +18,12 @@ from omegaconf import OmegaConf from nemo.collections.speechlm2 import DataModule, DuplexEARTTSDataset - from nemo.collections.speechlm2.models.duplex_ear_tts import DuplexEARTTS +from nemo.collections.speechlm2.parts.pretrained import load_checkpoint, set_model_dict_for_partial_init from nemo.core.config import hydra_runner from nemo.utils.exp_manager import exp_manager from nemo.utils.trainer_utils import resolve_trainer_cfg -from nemo.collections.speechlm2.parts.pretrained import ( - load_checkpoint, - set_model_dict_for_partial_init, -) - torch.cuda.set_device(int(os.environ["LOCAL_RANK"])) @@ -50,7 +45,7 @@ def train(cfg): checkpoint_state = load_checkpoint(model.cfg.pretrained_tts_model) checkpoint_state = set_model_dict_for_partial_init(checkpoint_state, model.tts_model.state_dict()) model.tts_model.load_state_dict(checkpoint_state, strict=True) - + # load pretrained checkpoint and rescale the weights if needed if model.cfg.get("pretrained_model", None): model.restore_from_pretrained_checkpoint(model.cfg.pretrained_model) diff --git a/nemo/collections/speechlm2/models/duplex_ear_tts.py b/nemo/collections/speechlm2/models/duplex_ear_tts.py index 5bff5668c64e..a3a6fc2022fd 100644 --- a/nemo/collections/speechlm2/models/duplex_ear_tts.py +++ b/nemo/collections/speechlm2/models/duplex_ear_tts.py @@ -339,7 +339,6 @@ def __init__(self, cfg: dict) -> None: self._use_fsdp = False self._use_tp = False - def get_codec_silence_frame_last_one(self): audio = torch.zeros(1, 10 * self.target_sample_rate).float().to(self.device) audio_len = torch.tensor([audio.size(-1)]).long() @@ -653,7 +652,9 @@ def training_step(self, batch: dict, batch_idx: int): return ans def on_train_epoch_start(self) -> None: - ensures_codec_target_dtype(self) # potentially reloads the audio codec to make sure it's in target codec precision + ensures_codec_target_dtype( + self + ) # potentially reloads the audio codec to make sure it's in target codec precision def on_train_epoch_end(self) -> None: # log model stats to debug gradient weights issues @@ -701,7 +702,9 @@ def log_model_stats(self): self.log("weights/mean", weight_mean, on_epoch=True, sync_dist=True) def on_validation_epoch_start(self) -> None: - ensures_codec_target_dtype(self) # potentially reloads the audio codec to make sure it's in target codec precision + ensures_codec_target_dtype( + self + ) # potentially reloads the audio codec to make sure it's in target codec precision self.results_logger = ResultsLogger(self.validation_save_path).reset() self.asr_bleu = ASRBLEU(self.cfg.scoring_asr).reset() From 7e7e77bab3548758665c534deca4a52749135e81 Mon Sep 17 00:00:00 2001 From: Edresson Casanova Date: Fri, 12 Dec 2025 08:48:11 -0800 Subject: [PATCH 082/102] Add EOS dropout and duplication Signed-off-by: Edresson Casanova --- .../speechlm2/models/duplex_ear_tts.py | 209 ++++++++++++++++++ 1 file changed, 209 insertions(+) diff --git a/nemo/collections/speechlm2/models/duplex_ear_tts.py b/nemo/collections/speechlm2/models/duplex_ear_tts.py index a3a6fc2022fd..acede20c4d20 100644 --- a/nemo/collections/speechlm2/models/duplex_ear_tts.py +++ b/nemo/collections/speechlm2/models/duplex_ear_tts.py @@ -520,6 +520,25 @@ def prepare_inputs(self, batch: dict): aligned_attention_mask = batch["aligned_attention_mask"] aligned_position_ids = batch["aligned_position_ids"] + if self.training and (self.cfg.get("empty_turn_probability", 0.0) > 0): + # Randomly decide whether this batch gets emptied + if torch.rand(1).item() < self.cfg.empty_turn_probability: + # Zero out audio + target_audio = torch.zeros_like(target_audio) + + # Create mask for tokens we want to drop + # Keep BOS and EOS, drop the rest. + keep_mask = (input_text_tokens == self.text_bos_id) | (input_text_tokens == self.text_eos_id) + full_dropout_mask = ~keep_mask # True = positions to replace with PAD + + # Replace all non-BOS/EOS with PAD + input_text_tokens = torch.where( + full_dropout_mask, + torch.full_like(input_text_tokens, self.text_pad_id), + input_text_tokens + ) + + # extract target audio codes target_audio, target_audio_lens = self.pad_audio_to_factor( target_audio, target_audio_lens, self.target_samples_per_frame, 1 @@ -581,6 +600,56 @@ def pad_or_truncate(x, pad_value=0): full_dropout_mask, torch.full_like(input_text_tokens, self.text_pad_id), input_text_tokens ) + if self.training and self.cfg.get("text_eos_duplicate_prob", 0.0) > 0: + p = self.cfg.text_eos_duplicate_prob + + # [B, T] mask of EOS positions + eos_mask = input_text_tokens == self.text_eos_id + + # Flatten EOS positions: tensor of shape [N, 2] where each row = (batch_idx, time_idx) + eos_positions = eos_mask.nonzero(as_tuple=False) # [N, 2] + + if eos_positions.numel() > 0: + N = eos_positions.shape[0] + + # One random decision per EOS occurrence + duplicate_decision = ( + torch.rand(N, device=input_text_tokens.device) < p + ) # [N] + + # Filter only EOS tokens that will be duplicated and are not at position t=0 + valid = (eos_positions[:, 1] > 0) & duplicate_decision # [N] + + if valid.any(): + # Select only valid EOS positions + valid_positions = eos_positions[valid] # [M, 2] + + # Indices for the token BEFORE the EOS (t-1) + b_idx = valid_positions[:, 0] + t_idx = valid_positions[:, 1] - 1 + + # Replace token before EOS with an EOS + input_text_tokens[b_idx, t_idx] = self.text_eos_id + + # BOS dropout to make the model more robust + if self.training and self.cfg.get("text_bos_dropout_prob", 0.0) > 0: + # Mask BOS positions + bos_mask = input_text_tokens == self.text_bos_id + + # Random dropout only on BOS positions + dropout_mask = torch.rand(bos_mask.sum(), device=input_text_tokens.device) < self.cfg.text_bos_dropout_prob + + # Scatter dropout decisions into [B, T] + full_dropout_mask = torch.zeros_like(input_text_tokens, dtype=torch.bool) + full_dropout_mask[bos_mask] = dropout_mask + + # Replace dropped BOS with PAD + input_text_tokens = torch.where( + full_dropout_mask, + torch.full_like(input_text_tokens, self.text_pad_id), + input_text_tokens, + ) + # shift text tokens subword_ids = F.pad(input_text_tokens[:, 1:], [0, 1]) # note that we are using a text mask where we are ignoring the desc + audio prompt but we are keeping 1 until the audio ends to support duplex @@ -601,6 +670,146 @@ def pad_or_truncate(x, pad_value=0): subword_ids = subword_ids[:, :-remainder] subword_mask = subword_mask[:, :-remainder] + # debug samples: + if ( + self.cfg.get("debug_dataloader_audios_path", None) + and self.training and batch["formatter"][0] == "s2s_duplex_silence_augmented" + ): + from nemo.collections.speechlm2.models.duplex_s2s_model import tokens_to_str + def count_leading_silence_tokens(tensor: torch.Tensor, silence_token: int = 0) -> int: + """ + Count the number of consecutive silence tokens at the beginning of a 1D tensor. + + Args: + tensor (torch.Tensor): 1D tensor of tokens. + silence_token (int): The token considered as silence (default: 0). + + Returns: + int: Number of consecutive silence tokens at the beginning. + """ + if tensor.ndim != 1: + raise ValueError("Input tensor must be 1D.") + + count = 0 + for token in tensor: + if token.item() == silence_token: + count += 1 + else: + break + return count + + def write_wave(one_audio_signal, file_name, sr=None): + import numpy as np + import soundfile as sf + + one_audio_signal = one_audio_signal.cpu().numpy() + one_audio_signal = one_audio_signal.astype(np.float32) + if sr is None: + sr = self.target_sample_rate + # one_audio_signal = np.clip(one_audio_signal, -1.0, 1.0) + sf.write(file_name, one_audio_signal, sr) + + # encode and decode the audio + with fp32_precision(), torch.no_grad(): + print(batch["target_audio"].shape) + lengths = torch.tensor([batch["target_audio"].shape[1]] * batch["target_audio"].shape[0]).to( + self.device + ) + # reconstruct wav + print("target_codes_aligned:", target_codes_aligned.shape) + target_codes_aligned_ = replace_control_speech_codes(target_codes_aligned, self._control_codes, self.codec_silence_tokens) + print(self._control_codes, target_codes_aligned_.shape) + with fp32_precision(), torch.no_grad(): + lengths = torch.tensor([target_codes_aligned_.shape[1]] * target_codes_aligned_.shape[0]).to( + self.device + ) + print(target_codes_aligned_.max(), target_codes.max()) + reconstructed_audio_from_tokens, _ = self.audio_codec.decode( + target_codes_aligned_, lengths + ) + reconstructed_audio_from_tokens = reconstructed_audio_from_tokens.squeeze(1) + print(reconstructed_audio_from_tokens.shape, batch["target_audio"].shape) + + + # Uses batch["input_text_tokens"] instead of subword_ids, because in subword_ids the first prompt BOS is replace, so it will breaks the generate_multiturn_speaking_mask + eou_labels = generate_multiturn_speaking_mask( + batch["input_text_tokens"], bos_token_id=self.text_bos_id, eos_token_id=self.text_eos_id + ) + + for i in range(target_codes_aligned_.shape[0]): + write_wave( + batch["target_audio"][i], + os.path.join(self.cfg.get("debug_dataloader_audios_path"), f"target_audio_{i}.wav"), + sr=self.target_sample_rate, + ) + write_wave( + batch["audio_prompt"][i], + os.path.join(self.cfg.get("debug_dataloader_audios_path"), f"speaker_ref_{i}.wav"), + sr=self.target_sample_rate, + ) + write_wave( + batch["source_audio"][i], + os.path.join(self.cfg.get("debug_dataloader_audios_path"), f"source_audio_{i}.wav"), + sr=self.source_sample_rate, + ) + + write_wave( + reconstructed_audio_from_tokens[i], + os.path.join( + self.cfg.get("debug_dataloader_audios_path"), f"target_audio_reconstructed_from_tokens_{i}.wav" + ), + sr=self.target_sample_rate, + ) + + repeat_factor = int(self.target_sample_rate / self.target_fps) + eou_wav = ( + eou_labels[i].unsqueeze(0).unsqueeze(-1).repeat(1, 1, repeat_factor) + ) # (B, T, repeat_factor) + eou_wav = eou_wav.view(1, -1) # (B, T * repeat_factor) + eou_wav = eou_wav.float() * 0.8 # make 1 audible and keep 0 as total silence + write_wave( + eou_wav.squeeze(), + os.path.join(self.cfg.get("debug_dataloader_audios_path"), f"eou_{i}.wav"), + sr=self.target_sample_rate, + ) + + print( + "target labels from dataloader decoded:", + tokens_to_str( + batch["input_text_tokens"][-1:], + target_codes_lens, + tokenizer=self.tokenizer, + pad_id=self.text_pad_id, + ), + ) + text_labels = batch["input_text_tokens"] + num_bos_tokens = (text_labels.unsqueeze(-1) == self.text_bos_id).flatten(1, 2).sum(-1) + # Count how many EOS tokens are present per sequence + # Shape: [B] + num_eos_tokens = (text_labels.unsqueeze(-1) == self.text_eos_id).flatten(1, 2).sum(-1) + print("Num eos:", num_eos_tokens, "num bos:", num_bos_tokens) + + batch_idx = -1 + positions = (text_labels[batch_idx] == self.text_bos_id).nonzero(as_tuple=True)[0] + first_pos = positions[0].item() if len(positions) > 0 else None + + print( + "First BOS is in:", first_pos, "for batch idx:", batch_idx + ) + audio_mask = non_prompt_mask + batch_idx = -1 + positions = (audio_mask[batch_idx] == 1).nonzero(as_tuple=True)[0] + first_audio_mask = positions[0].item() if len(positions) > 0 else None + + print("Last EOS in text (for TTS data it is the end of prompt):", (text_labels[batch_idx] == self.text_eos_id).nonzero(as_tuple=True)[0][-1].item()) + print("First one in audio mask:", first_audio_mask) + print("First one in subword_mask:", (subword_mask[batch_idx] == 1).nonzero(as_tuple=True)[0][0].item()) + + + print(batch["formatter"]) + if target_codes_aligned_.shape[0] > 1: + exit() + return { "code": target_codes_aligned, "audio_mask": non_prompt_mask, # set audio_mask as non_prompt_mask to avoid the audio prompt in loss computation From 8beaf8bd44b2341ef8c206930370b9788bc447c8 Mon Sep 17 00:00:00 2001 From: Edresson Date: Fri, 12 Dec 2025 16:48:57 +0000 Subject: [PATCH 083/102] Apply isort and black reformatting Signed-off-by: Edresson --- .../speechlm2/models/duplex_ear_tts.py | 36 ++++++++----------- 1 file changed, 15 insertions(+), 21 deletions(-) diff --git a/nemo/collections/speechlm2/models/duplex_ear_tts.py b/nemo/collections/speechlm2/models/duplex_ear_tts.py index acede20c4d20..3e0fdd4620b4 100644 --- a/nemo/collections/speechlm2/models/duplex_ear_tts.py +++ b/nemo/collections/speechlm2/models/duplex_ear_tts.py @@ -533,12 +533,9 @@ def prepare_inputs(self, batch: dict): # Replace all non-BOS/EOS with PAD input_text_tokens = torch.where( - full_dropout_mask, - torch.full_like(input_text_tokens, self.text_pad_id), - input_text_tokens + full_dropout_mask, torch.full_like(input_text_tokens, self.text_pad_id), input_text_tokens ) - # extract target audio codes target_audio, target_audio_lens = self.pad_audio_to_factor( target_audio, target_audio_lens, self.target_samples_per_frame, 1 @@ -613,9 +610,7 @@ def pad_or_truncate(x, pad_value=0): N = eos_positions.shape[0] # One random decision per EOS occurrence - duplicate_decision = ( - torch.rand(N, device=input_text_tokens.device) < p - ) # [N] + duplicate_decision = torch.rand(N, device=input_text_tokens.device) < p # [N] # Filter only EOS tokens that will be duplicated and are not at position t=0 valid = (eos_positions[:, 1] > 0) & duplicate_decision # [N] @@ -673,9 +668,11 @@ def pad_or_truncate(x, pad_value=0): # debug samples: if ( self.cfg.get("debug_dataloader_audios_path", None) - and self.training and batch["formatter"][0] == "s2s_duplex_silence_augmented" + and self.training + and batch["formatter"][0] == "s2s_duplex_silence_augmented" ): from nemo.collections.speechlm2.models.duplex_s2s_model import tokens_to_str + def count_leading_silence_tokens(tensor: torch.Tensor, silence_token: int = 0) -> int: """ Count the number of consecutive silence tokens at the beginning of a 1D tensor. @@ -717,20 +714,19 @@ def write_wave(one_audio_signal, file_name, sr=None): ) # reconstruct wav print("target_codes_aligned:", target_codes_aligned.shape) - target_codes_aligned_ = replace_control_speech_codes(target_codes_aligned, self._control_codes, self.codec_silence_tokens) + target_codes_aligned_ = replace_control_speech_codes( + target_codes_aligned, self._control_codes, self.codec_silence_tokens + ) print(self._control_codes, target_codes_aligned_.shape) with fp32_precision(), torch.no_grad(): lengths = torch.tensor([target_codes_aligned_.shape[1]] * target_codes_aligned_.shape[0]).to( self.device ) print(target_codes_aligned_.max(), target_codes.max()) - reconstructed_audio_from_tokens, _ = self.audio_codec.decode( - target_codes_aligned_, lengths - ) + reconstructed_audio_from_tokens, _ = self.audio_codec.decode(target_codes_aligned_, lengths) reconstructed_audio_from_tokens = reconstructed_audio_from_tokens.squeeze(1) print(reconstructed_audio_from_tokens.shape, batch["target_audio"].shape) - # Uses batch["input_text_tokens"] instead of subword_ids, because in subword_ids the first prompt BOS is replace, so it will breaks the generate_multiturn_speaking_mask eou_labels = generate_multiturn_speaking_mask( batch["input_text_tokens"], bos_token_id=self.text_bos_id, eos_token_id=self.text_eos_id @@ -762,9 +758,7 @@ def write_wave(one_audio_signal, file_name, sr=None): ) repeat_factor = int(self.target_sample_rate / self.target_fps) - eou_wav = ( - eou_labels[i].unsqueeze(0).unsqueeze(-1).repeat(1, 1, repeat_factor) - ) # (B, T, repeat_factor) + eou_wav = eou_labels[i].unsqueeze(0).unsqueeze(-1).repeat(1, 1, repeat_factor) # (B, T, repeat_factor) eou_wav = eou_wav.view(1, -1) # (B, T * repeat_factor) eou_wav = eou_wav.float() * 0.8 # make 1 audible and keep 0 as total silence write_wave( @@ -793,18 +787,18 @@ def write_wave(one_audio_signal, file_name, sr=None): positions = (text_labels[batch_idx] == self.text_bos_id).nonzero(as_tuple=True)[0] first_pos = positions[0].item() if len(positions) > 0 else None - print( - "First BOS is in:", first_pos, "for batch idx:", batch_idx - ) + print("First BOS is in:", first_pos, "for batch idx:", batch_idx) audio_mask = non_prompt_mask batch_idx = -1 positions = (audio_mask[batch_idx] == 1).nonzero(as_tuple=True)[0] first_audio_mask = positions[0].item() if len(positions) > 0 else None - print("Last EOS in text (for TTS data it is the end of prompt):", (text_labels[batch_idx] == self.text_eos_id).nonzero(as_tuple=True)[0][-1].item()) + print( + "Last EOS in text (for TTS data it is the end of prompt):", + (text_labels[batch_idx] == self.text_eos_id).nonzero(as_tuple=True)[0][-1].item(), + ) print("First one in audio mask:", first_audio_mask) print("First one in subword_mask:", (subword_mask[batch_idx] == 1).nonzero(as_tuple=True)[0][0].item()) - print(batch["formatter"]) if target_codes_aligned_.shape[0] > 1: From bc3eb8cd6c5025b1c0fc76876767ec905ff17162 Mon Sep 17 00:00:00 2001 From: Edresson Casanova Date: Fri, 12 Dec 2025 08:50:50 -0800 Subject: [PATCH 084/102] Remove debug dataloader code Signed-off-by: Edresson Casanova --- .../speechlm2/models/duplex_ear_tts.py | 139 ------------------ 1 file changed, 139 deletions(-) diff --git a/nemo/collections/speechlm2/models/duplex_ear_tts.py b/nemo/collections/speechlm2/models/duplex_ear_tts.py index 3e0fdd4620b4..43a8535b20b4 100644 --- a/nemo/collections/speechlm2/models/duplex_ear_tts.py +++ b/nemo/collections/speechlm2/models/duplex_ear_tts.py @@ -665,145 +665,6 @@ def pad_or_truncate(x, pad_value=0): subword_ids = subword_ids[:, :-remainder] subword_mask = subword_mask[:, :-remainder] - # debug samples: - if ( - self.cfg.get("debug_dataloader_audios_path", None) - and self.training - and batch["formatter"][0] == "s2s_duplex_silence_augmented" - ): - from nemo.collections.speechlm2.models.duplex_s2s_model import tokens_to_str - - def count_leading_silence_tokens(tensor: torch.Tensor, silence_token: int = 0) -> int: - """ - Count the number of consecutive silence tokens at the beginning of a 1D tensor. - - Args: - tensor (torch.Tensor): 1D tensor of tokens. - silence_token (int): The token considered as silence (default: 0). - - Returns: - int: Number of consecutive silence tokens at the beginning. - """ - if tensor.ndim != 1: - raise ValueError("Input tensor must be 1D.") - - count = 0 - for token in tensor: - if token.item() == silence_token: - count += 1 - else: - break - return count - - def write_wave(one_audio_signal, file_name, sr=None): - import numpy as np - import soundfile as sf - - one_audio_signal = one_audio_signal.cpu().numpy() - one_audio_signal = one_audio_signal.astype(np.float32) - if sr is None: - sr = self.target_sample_rate - # one_audio_signal = np.clip(one_audio_signal, -1.0, 1.0) - sf.write(file_name, one_audio_signal, sr) - - # encode and decode the audio - with fp32_precision(), torch.no_grad(): - print(batch["target_audio"].shape) - lengths = torch.tensor([batch["target_audio"].shape[1]] * batch["target_audio"].shape[0]).to( - self.device - ) - # reconstruct wav - print("target_codes_aligned:", target_codes_aligned.shape) - target_codes_aligned_ = replace_control_speech_codes( - target_codes_aligned, self._control_codes, self.codec_silence_tokens - ) - print(self._control_codes, target_codes_aligned_.shape) - with fp32_precision(), torch.no_grad(): - lengths = torch.tensor([target_codes_aligned_.shape[1]] * target_codes_aligned_.shape[0]).to( - self.device - ) - print(target_codes_aligned_.max(), target_codes.max()) - reconstructed_audio_from_tokens, _ = self.audio_codec.decode(target_codes_aligned_, lengths) - reconstructed_audio_from_tokens = reconstructed_audio_from_tokens.squeeze(1) - print(reconstructed_audio_from_tokens.shape, batch["target_audio"].shape) - - # Uses batch["input_text_tokens"] instead of subword_ids, because in subword_ids the first prompt BOS is replace, so it will breaks the generate_multiturn_speaking_mask - eou_labels = generate_multiturn_speaking_mask( - batch["input_text_tokens"], bos_token_id=self.text_bos_id, eos_token_id=self.text_eos_id - ) - - for i in range(target_codes_aligned_.shape[0]): - write_wave( - batch["target_audio"][i], - os.path.join(self.cfg.get("debug_dataloader_audios_path"), f"target_audio_{i}.wav"), - sr=self.target_sample_rate, - ) - write_wave( - batch["audio_prompt"][i], - os.path.join(self.cfg.get("debug_dataloader_audios_path"), f"speaker_ref_{i}.wav"), - sr=self.target_sample_rate, - ) - write_wave( - batch["source_audio"][i], - os.path.join(self.cfg.get("debug_dataloader_audios_path"), f"source_audio_{i}.wav"), - sr=self.source_sample_rate, - ) - - write_wave( - reconstructed_audio_from_tokens[i], - os.path.join( - self.cfg.get("debug_dataloader_audios_path"), f"target_audio_reconstructed_from_tokens_{i}.wav" - ), - sr=self.target_sample_rate, - ) - - repeat_factor = int(self.target_sample_rate / self.target_fps) - eou_wav = eou_labels[i].unsqueeze(0).unsqueeze(-1).repeat(1, 1, repeat_factor) # (B, T, repeat_factor) - eou_wav = eou_wav.view(1, -1) # (B, T * repeat_factor) - eou_wav = eou_wav.float() * 0.8 # make 1 audible and keep 0 as total silence - write_wave( - eou_wav.squeeze(), - os.path.join(self.cfg.get("debug_dataloader_audios_path"), f"eou_{i}.wav"), - sr=self.target_sample_rate, - ) - - print( - "target labels from dataloader decoded:", - tokens_to_str( - batch["input_text_tokens"][-1:], - target_codes_lens, - tokenizer=self.tokenizer, - pad_id=self.text_pad_id, - ), - ) - text_labels = batch["input_text_tokens"] - num_bos_tokens = (text_labels.unsqueeze(-1) == self.text_bos_id).flatten(1, 2).sum(-1) - # Count how many EOS tokens are present per sequence - # Shape: [B] - num_eos_tokens = (text_labels.unsqueeze(-1) == self.text_eos_id).flatten(1, 2).sum(-1) - print("Num eos:", num_eos_tokens, "num bos:", num_bos_tokens) - - batch_idx = -1 - positions = (text_labels[batch_idx] == self.text_bos_id).nonzero(as_tuple=True)[0] - first_pos = positions[0].item() if len(positions) > 0 else None - - print("First BOS is in:", first_pos, "for batch idx:", batch_idx) - audio_mask = non_prompt_mask - batch_idx = -1 - positions = (audio_mask[batch_idx] == 1).nonzero(as_tuple=True)[0] - first_audio_mask = positions[0].item() if len(positions) > 0 else None - - print( - "Last EOS in text (for TTS data it is the end of prompt):", - (text_labels[batch_idx] == self.text_eos_id).nonzero(as_tuple=True)[0][-1].item(), - ) - print("First one in audio mask:", first_audio_mask) - print("First one in subword_mask:", (subword_mask[batch_idx] == 1).nonzero(as_tuple=True)[0][0].item()) - - print(batch["formatter"]) - if target_codes_aligned_.shape[0] > 1: - exit() - return { "code": target_codes_aligned, "audio_mask": non_prompt_mask, # set audio_mask as non_prompt_mask to avoid the audio prompt in loss computation From b71490a50875fa31f53651ce8fe36fed5446deb1 Mon Sep 17 00:00:00 2001 From: Edresson Casanova Date: Mon, 15 Dec 2025 08:42:29 -0800 Subject: [PATCH 085/102] Remove unecessary operations on n samples_per_frame computation Signed-off-by: Edresson Casanova --- nemo/collections/speechlm2/data/duplex_ear_tts_dataset.py | 6 ++---- nemo/collections/speechlm2/models/duplex_ear_tts.py | 7 ++----- 2 files changed, 4 insertions(+), 9 deletions(-) diff --git a/nemo/collections/speechlm2/data/duplex_ear_tts_dataset.py b/nemo/collections/speechlm2/data/duplex_ear_tts_dataset.py index ffc51b2a5edf..ac9f9c7be5d9 100644 --- a/nemo/collections/speechlm2/data/duplex_ear_tts_dataset.py +++ b/nemo/collections/speechlm2/data/duplex_ear_tts_dataset.py @@ -204,10 +204,8 @@ def __init__( self.num_delay_speech_tokens = num_delay_speech_tokens # compute source and target samples_per_frame - source_fps = self.source_sample_rate / (self.source_sample_rate * self.frame_length) - self.source_samples_per_frame = int(self.source_sample_rate // source_fps) - target_fps = self.target_sample_rate / (self.target_sample_rate * self.frame_length) - self.target_samples_per_frame = int(self.target_sample_rate // target_fps) + self.source_samples_per_frame = int(self.source_sample_rate * self.frame_length) + self.target_samples_per_frame = int(self.target_sample_rate * self.frame_length) assert tokenizer.bos is not None, "BOS support in the tokenizer is required for S2S models." assert tokenizer.eos is not None, "EOS support in the tokenizer is required for S2S models." diff --git a/nemo/collections/speechlm2/models/duplex_ear_tts.py b/nemo/collections/speechlm2/models/duplex_ear_tts.py index 43a8535b20b4..0944283be2f6 100644 --- a/nemo/collections/speechlm2/models/duplex_ear_tts.py +++ b/nemo/collections/speechlm2/models/duplex_ear_tts.py @@ -310,11 +310,8 @@ def __init__(self, cfg: dict) -> None: self._codebook_size = self.tts_model.config.codebook_size - # compute source fps - self.source_fps = self.source_sample_rate / ( - self.source_sample_rate * cfg.data.frame_length - ) # conver frame rate in fps - self.source_samples_per_frame = int(self.source_sample_rate // self.source_fps) + # compute samples per frame + self.source_samples_per_frame = int(self.source_sample_rate * cfg.data.frame_length) # get codec silence tokens codec_silence_tokens = self.get_codec_silence_frame() From ca277ab79ac2a873440cfb0d0d25d5ee440b0e67 Mon Sep 17 00:00:00 2001 From: Edresson Casanova Date: Tue, 16 Dec 2025 10:25:28 -0800 Subject: [PATCH 086/102] Ignore sample on sample_audio_segments_repeat when audio is shorter Signed-off-by: Edresson Casanova --- nemo/collections/speechlm2/data/duplex_ear_tts_dataset.py | 6 +----- 1 file changed, 1 insertion(+), 5 deletions(-) diff --git a/nemo/collections/speechlm2/data/duplex_ear_tts_dataset.py b/nemo/collections/speechlm2/data/duplex_ear_tts_dataset.py index ac9f9c7be5d9..66b6b4b3eed8 100644 --- a/nemo/collections/speechlm2/data/duplex_ear_tts_dataset.py +++ b/nemo/collections/speechlm2/data/duplex_ear_tts_dataset.py @@ -71,11 +71,7 @@ def sample_audio_segments_repeat( else: # Audio shorter than target → repeat - if sample: - # Random start position - start = torch.randint(0, length, (1,), device=device).item() - else: - start = 0 + start = 0 segment = prompt_audio[b, start:length] repeat_times = (n_sample + (length - start) - 1) // (length - start) From bbdf250bdaec7a958c452818334356b41b99f361 Mon Sep 17 00:00:00 2001 From: Edresson Casanova Date: Wed, 17 Dec 2025 03:01:13 -0800 Subject: [PATCH 087/102] Move data utils to the end of the file Signed-off-by: Edresson Casanova --- .../speechlm2/data/duplex_ear_tts_dataset.py | 109 +++++++++--------- .../speechlm2/models/duplex_ear_tts.py | 18 +-- 2 files changed, 63 insertions(+), 64 deletions(-) diff --git a/nemo/collections/speechlm2/data/duplex_ear_tts_dataset.py b/nemo/collections/speechlm2/data/duplex_ear_tts_dataset.py index 66b6b4b3eed8..59db2d6b2d15 100644 --- a/nemo/collections/speechlm2/data/duplex_ear_tts_dataset.py +++ b/nemo/collections/speechlm2/data/duplex_ear_tts_dataset.py @@ -28,59 +28,6 @@ from nemo.utils import logging -def sample_audio_segments_repeat( - prompt_audio: torch.Tensor, - prompt_audio_lens: torch.Tensor, - n_sample: int, - sample: bool = True, -) -> torch.Tensor: - """ - Extract audio segments of length n_sample. - If sample=True: randomly sample segments (repeating if shorter). - If sample=False: always take from the beginning (repeating if shorter). - - Args: - prompt_audio: Tensor [B, T] - prompt_audio_lens: Tensor [B] with valid lengths - n_sample: int, target length per segment - sample: bool, whether to randomly sample (True) or take first seconds (False) - - Returns: - Tensor [B, n_sample] - """ - B, T = prompt_audio.shape - device = prompt_audio.device - out = torch.zeros(B, n_sample, device=device, dtype=prompt_audio.dtype) - - for b in range(B): - length = min(prompt_audio_lens[b].item(), T) - - # Case: empty audio - if length <= 0: - continue - - if length >= n_sample: - if sample: - # Random start (safe bounds) - max_start = max(1, length - n_sample + 1) - start = torch.randint(0, max_start, (1,), device=device).item() - else: - # Deterministic: take from start - start = 0 - out[b] = prompt_audio[b, start : start + n_sample] - - else: - # Audio shorter than target → repeat - start = 0 - segment = prompt_audio[b, start:length] - - repeat_times = (n_sample + (length - start) - 1) // (length - start) - repeated = segment.repeat(repeat_times)[:n_sample] - out[b] = repeated - - return out - - class DuplexEARTTSDataset(torch.utils.data.Dataset): """ A dataset for duplex speech-to-speech models that handles bidirectional conversations. @@ -231,9 +178,6 @@ def __getitem__(self, cuts: CutSet) -> dict: cuts, self.target_sample_rate, roles=self.output_roles, recording_field="target_audio" ) - # ensures that input_text_tokens is not longer than its duration - input_text_tokens = input_text_tokens[:, : target_token_lens.max()] - # add speech channel delay if needed if self.num_delay_speech_tokens: source_audio, source_audio_lens, target_audio, target_audio_lens = add_speech_delay( @@ -867,3 +811,56 @@ def _strip_timestamps( # Regexp pattern args are cached compiled patterns (micro-optimization). text = _TIMESTAMP_PATTERN.sub("", text) # strip timestamp tokens if present return _SPACE_PATTERN.sub(" ", text).strip() # strip multi-whitespaces + + +def sample_audio_segments_repeat( + prompt_audio: torch.Tensor, + prompt_audio_lens: torch.Tensor, + n_sample: int, + sample: bool = True, +) -> torch.Tensor: + """ + Extract audio segments of length n_sample. + If sample=True: randomly sample segments (repeating if shorter). + If sample=False: always take from the beginning (repeating if shorter). + + Args: + prompt_audio: Tensor [B, T] + prompt_audio_lens: Tensor [B] with valid lengths + n_sample: int, target length per segment + sample: bool, whether to randomly sample (True) or take first seconds (False) + + Returns: + Tensor [B, n_sample] + """ + B, T = prompt_audio.shape + device = prompt_audio.device + out = torch.zeros(B, n_sample, device=device, dtype=prompt_audio.dtype) + + for b in range(B): + length = min(prompt_audio_lens[b].item(), T) + + # Case: empty audio + if length <= 0: + continue + + if length >= n_sample: + if sample: + # Random start (safe bounds) + max_start = max(1, length - n_sample + 1) + start = torch.randint(0, max_start, (1,), device=device).item() + else: + # Deterministic: take from start + start = 0 + out[b] = prompt_audio[b, start : start + n_sample] + + else: + # Audio shorter than target → repeat + start = 0 + segment = prompt_audio[b, start:length] + + repeat_times = (n_sample + (length - start) - 1) // (length - start) + repeated = segment.repeat(repeat_times)[:n_sample] + out[b] = repeated + + return out diff --git a/nemo/collections/speechlm2/models/duplex_ear_tts.py b/nemo/collections/speechlm2/models/duplex_ear_tts.py index 0944283be2f6..1867ad3382bb 100644 --- a/nemo/collections/speechlm2/models/duplex_ear_tts.py +++ b/nemo/collections/speechlm2/models/duplex_ear_tts.py @@ -735,7 +735,7 @@ def log_model_stats(self): continue # ----- weights ----- - w = p.detach().cpu().float() # ✅ safe offline copy + w = p.detach().cpu().float() # safe offline copy total_w_sq += (w * w).sum().item() total_w_params += w.numel() max_abs_w = max(max_abs_w, w.abs().max().item()) @@ -1553,13 +1553,15 @@ def configure_model(self) -> None: self.tts_model.local_transformer_in_projection = fully_shard( self.tts_model.local_transformer_in_projection, **fsdp_config ) - else: - self.embed_text_tokens = fully_shard(self.embed_text_tokens, **fsdp_config) - self.tts_model.mog_head = fully_shard(self.tts_model.mog_head, **fsdp_config) - self.tts_model.embed_subword = fully_shard(self.tts_model.embed_subword, **fsdp_config) - self.tts_model.embed_context = fully_shard(self.tts_model.embed_context, **fsdp_config) - self.tts_model.embed_code = fully_shard(self.tts_model.embed_code, **fsdp_config) - self.tts_model.lm_head = fully_shard(self.tts_model.lm_head, **fsdp_config) + + self.embed_text_tokens = fully_shard(self.embed_text_tokens, **fsdp_config) + self.tts_model.mog_head = fully_shard(self.tts_model.mog_head, **fsdp_config) + self.tts_model.embed_subword = fully_shard(self.tts_model.embed_subword, **fsdp_config) + self.tts_model.embed_context = fully_shard(self.tts_model.embed_context, **fsdp_config) + self.tts_model.embed_code = fully_shard(self.tts_model.embed_code, **fsdp_config) + self.tts_model.lm_head = fully_shard(self.tts_model.lm_head, **fsdp_config) + self.llm = fully_shard(self.llm, **fsdp_config) + self.tts_model = fully_shard(self.tts_model, **fsdp_config) def load_state_dict(self, state_dict, strict: bool = True): try: From 6b9966cf4d0818e0cab91e52cd6b59fffe90e1f0 Mon Sep 17 00:00:00 2001 From: Edresson Casanova Date: Wed, 17 Dec 2025 04:49:20 -0800 Subject: [PATCH 088/102] Remove duplicated code Signed-off-by: Edresson Casanova --- .../speechlm2/models/duplex_ear_tts.py | 432 +++++++++--------- .../speechlm2/modules/ear_tts_model.py | 107 +---- 2 files changed, 217 insertions(+), 322 deletions(-) diff --git a/nemo/collections/speechlm2/models/duplex_ear_tts.py b/nemo/collections/speechlm2/models/duplex_ear_tts.py index 1867ad3382bb..0e1ffa7528d1 100644 --- a/nemo/collections/speechlm2/models/duplex_ear_tts.py +++ b/nemo/collections/speechlm2/models/duplex_ear_tts.py @@ -55,222 +55,6 @@ from nemo.utils import logging -def load_audio_librosa(path, sr=None): - """ - Load audio using librosa with torchaudio-like behavior. - - Returns: - audio_tensor: torch.FloatTensor of shape [channels, time] - sr: sampling rate - """ - # Load with librosa (preserve original sampling rate) - audio, sr = librosa.load(path, sr=sr, mono=False) - - # Ensure shape is [channels, time] - if audio.ndim == 1: - # Mono: (time,) -> (1, time) - audio = audio[None, :] - - # Convert to torch float32 (torchaudio behavior) - audio_tensor = torch.from_numpy(audio).float() - return audio_tensor, sr - - -def maybe_to(x, dtype): - if x is None: - return None - if isinstance(x, torch.Tensor) and torch.is_floating_point(x): - return x.to(dtype) - return x - - -@contextmanager -def ensures_target_precision(target_dtype): - """ - Workaround for precision related issues when training with bf16-true PyTorch Lightning precision setting. - In bf16-true, PTL changes PyTorch's default dtype, which may break implicit assumptions for some models. - This context manager restores default float32 precision and runs the computation in float32 autocast context. - """ - default_dtype = torch.get_default_dtype() - torch.set_default_dtype(target_dtype) - try: - with torch.amp.autocast(device_type="cuda" if torch.cuda.is_available() else "cpu", dtype=target_dtype): - yield - finally: - torch.set_default_dtype(default_dtype) - - -def generate_multiturn_speaking_mask(input_ids: torch.Tensor, bos_token_id: int = 0, eos_token_id: int = 1): - """ - Efficient, batched speaking mask generator that marks 1 between and pairs. - If is missing after a , mask continues to end. Handles multiple turns. - - Args: - input_ids (torch.Tensor): LongTensor of shape (B, T) - bos_token_id (int): Token ID for - eos_token_id (int): Token ID for - - Returns: - torch.Tensor: FloatTensor of shape (B, T), with 1.0 for speaking, 0.0 for silence. - - Note BOS is considered as speaking (1) and EOS as non speaking 0 - """ - device = input_ids.device - bos_mask = (input_ids == bos_token_id).to(torch.int32).to(device) - eos_mask = (input_ids == eos_token_id).to(torch.int32).to(device) - bos_cumsum = torch.cumsum(bos_mask, dim=1) - eos_cumsum = torch.cumsum(eos_mask, dim=1) - speaking_mask = (bos_cumsum > eos_cumsum).to(torch.float32) - return speaking_mask.long() - - -def replace_control_speech_codes( - speech_codes: torch.Tensor, control_codes: torch.Tensor, silence_tokens: torch.Tensor = None -) -> torch.Tensor: - """ - Replaces control codes (speech BOS, EOS, etc) in `speech_codes` with the first frame which is - assumed to consist of 'valid' codes representing silence. - """ - if silence_tokens is not None: - # Expand to [B, 1, 74] - silence_tokens_expanded = silence_tokens.unsqueeze(0).unsqueeze(1).expand(speech_codes.shape[0], 1, -1) - return torch.where(torch.isin(speech_codes, control_codes), silence_tokens_expanded, speech_codes) - - if torch.isin(speech_codes[:, :1], control_codes).any(): - return torch.where( - torch.isin(speech_codes, control_codes), torch.zeros_like(speech_codes[:, :1]), speech_codes - ) - else: - return torch.where(torch.isin(speech_codes, control_codes), speech_codes[:, :1], speech_codes) - - -def ensures_codec_target_dtype(model): - """ - Ensures the audio codec is instantiated with the target dtype. - - This function checks whether `model.audio_codec` exists and whether its - parameters match `model.audio_codec_run_dtype`. If the codec is missing - or is running with the wrong dtype (e.g., due to PTL auto-downcasting), - the codec is reloaded by calling `setup_audio_codec()`. - - Intended to be called at runtime boundaries such as: - - `on_train_epoch_start` - - `on_validation_epoch_start` - - Args: - model: Model instance of DuplexEARTTS - - """ - if hasattr(model, "audio_codec") and next(model.audio_codec.parameters()).dtype == model.audio_codec_run_dtype: - return # already correct precision → no-op - - setup_audio_codec(model) - - -def setup_audio_codec(model): - """ - Instantiates the RVQ audio codec and injects codec embeddings into the TTS model. - - This function is responsible only for: - - Instantiating the codec model (`RVQVAEModel`). - - Loading pretrained codec weights (if configured). - - Freezing codec parameters. - - Registering RVQ embeddings inside the TTS model via `set_rvq_embs`. - - Args: - model: Model instance of DuplexEARTTS - """ - with ensures_target_precision(model.audio_codec_run_dtype): - model.audio_codec = RVQVAEModel(DictConfig(model.cfg.codec_config)) - # load pretrained codec checkpoint - if model.cfg.get("pretrained_codec_model", None): - checkpoint_state = load_checkpoint(model.cfg.pretrained_codec_model) - checkpoint_state = set_model_dict_for_partial_init(checkpoint_state, model.audio_codec.state_dict()) - model.audio_codec.load_state_dict(checkpoint_state, strict=True) - - for p in model.audio_codec.parameters(): - p.requires_grad = False - - assert callable(model.tts_model.set_rvq_embs) - - model.tts_model.set_rvq_embs(torch.stack([x.detach() for x in model.audio_codec.prvq.mus_list], 0)) - model.tts_model.rvq_embs = model.tts_model.rvq_embs.to(next(model.tts_model.parameters()).dtype) - # compute target fps - model.target_fps = model.target_sample_rate / model.audio_codec.config.wav_to_token_ratio - model.target_samples_per_frame = model.audio_codec.config.wav_to_token_ratio - - -def rescale_state_dict(state_dict, target_std=0.02, first_n_layers=None, layer_prefix="tts_model.backbone.layers."): - """ - Rescale trainable weights in a state_dict for BF16/FP16 stability. - - Args: - state_dict: PyTorch state_dict - target_std: desired target std for weights - first_n_layers: if not None, rescale only the first N transformer blocks - layer_prefix: prefix for layer names (default: "tts_model.backbone.layers.") - Returns: - new_state_dict - """ - weight_tensors = [] - - # Compute which prefixes to match if first_n_layers is set - prefixes_to_match = [] - if first_n_layers is not None: - prefixes_to_match = [f"{layer_prefix}{i}" for i in range(first_n_layers)] - - for name, param in state_dict.items(): - if not torch.is_tensor(param): - continue - - if "rvq_embs" in name: - continue - - # Skip biases & 1-dim params (norm weights/gates) - if param.ndim <= 1: - continue - - # Skip layers not in the first N - if first_n_layers is not None and not any(name.startswith(pfx) for pfx in prefixes_to_match): - continue - - weight_tensors.append(param.float()) - - if not weight_tensors: - if first_n_layers is not None: - logging.info(f"No weights found for first {first_n_layers} layers with prefix '{layer_prefix}'.") - else: - logging.info("No weights found to rescale in state_dict.") - return state_dict - - # Compute global std across selected weights (on CPU) - cpu_weights = [p.detach().cpu() for p in weight_tensors] - flat = torch.cat([p.flatten() for p in cpu_weights]) - current_std = float(torch.std(flat)) - scale = target_std / (current_std + 1e-8) - - logging.info( - f"Rescaling state_dict " - f"{'(first N layers)' if first_n_layers else '(all layers)'}: " - f"current std = {current_std:.6f}, target = {target_std}, scale = {scale:.6f}" - ) - - # Apply scaling - new_state_dict = {} - for name, param in state_dict.items(): - if ( - torch.is_tensor(param) - and param.ndim > 1 - and (first_n_layers is None or any(name.startswith(pfx) for pfx in prefixes_to_match)) - ): - new_state_dict[name] = param * scale - else: - new_state_dict[name] = param - - logging.info("Done: weights rescaled.") - return new_state_dict - - class DuplexEARTTS(LightningModule, HFHubMixin): def __init__(self, cfg: dict) -> None: assert isinstance(cfg, dict), ( @@ -1570,3 +1354,219 @@ def load_state_dict(self, state_dict, strict: bool = True): logging.info("Error loading model state_dict !! Retrying with partial initialization!") model_dict = set_model_dict_for_partial_init(state_dict, self.state_dict()) return super().load_state_dict(model_dict, strict=False) + + +def load_audio_librosa(path, sr=None): + """ + Load audio using librosa with torchaudio-like behavior. + + Returns: + audio_tensor: torch.FloatTensor of shape [channels, time] + sr: sampling rate + """ + # Load with librosa (preserve original sampling rate) + audio, sr = librosa.load(path, sr=sr, mono=False) + + # Ensure shape is [channels, time] + if audio.ndim == 1: + # Mono: (time,) -> (1, time) + audio = audio[None, :] + + # Convert to torch float32 (torchaudio behavior) + audio_tensor = torch.from_numpy(audio).float() + return audio_tensor, sr + + +def maybe_to(x, dtype): + if x is None: + return None + if isinstance(x, torch.Tensor) and torch.is_floating_point(x): + return x.to(dtype) + return x + + +@contextmanager +def ensures_target_precision(target_dtype): + """ + Workaround for precision related issues when training with bf16-true PyTorch Lightning precision setting. + In bf16-true, PTL changes PyTorch's default dtype, which may break implicit assumptions for some models. + This context manager restores default float32 precision and runs the computation in float32 autocast context. + """ + default_dtype = torch.get_default_dtype() + torch.set_default_dtype(target_dtype) + try: + with torch.amp.autocast(device_type="cuda" if torch.cuda.is_available() else "cpu", dtype=target_dtype): + yield + finally: + torch.set_default_dtype(default_dtype) + + +def generate_multiturn_speaking_mask(input_ids: torch.Tensor, bos_token_id: int = 0, eos_token_id: int = 1): + """ + Efficient, batched speaking mask generator that marks 1 between and pairs. + If is missing after a , mask continues to end. Handles multiple turns. + + Args: + input_ids (torch.Tensor): LongTensor of shape (B, T) + bos_token_id (int): Token ID for + eos_token_id (int): Token ID for + + Returns: + torch.Tensor: FloatTensor of shape (B, T), with 1.0 for speaking, 0.0 for silence. + + Note BOS is considered as speaking (1) and EOS as non speaking 0 + """ + device = input_ids.device + bos_mask = (input_ids == bos_token_id).to(torch.int32).to(device) + eos_mask = (input_ids == eos_token_id).to(torch.int32).to(device) + bos_cumsum = torch.cumsum(bos_mask, dim=1) + eos_cumsum = torch.cumsum(eos_mask, dim=1) + speaking_mask = (bos_cumsum > eos_cumsum).to(torch.float32) + return speaking_mask.long() + + +def replace_control_speech_codes( + speech_codes: torch.Tensor, control_codes: torch.Tensor, silence_tokens: torch.Tensor = None +) -> torch.Tensor: + """ + Replaces control codes (speech BOS, EOS, etc) in `speech_codes` with the first frame which is + assumed to consist of 'valid' codes representing silence. + """ + if silence_tokens is not None: + # Expand to [B, 1, 74] + silence_tokens_expanded = silence_tokens.unsqueeze(0).unsqueeze(1).expand(speech_codes.shape[0], 1, -1) + return torch.where(torch.isin(speech_codes, control_codes), silence_tokens_expanded, speech_codes) + + if torch.isin(speech_codes[:, :1], control_codes).any(): + return torch.where( + torch.isin(speech_codes, control_codes), torch.zeros_like(speech_codes[:, :1]), speech_codes + ) + else: + return torch.where(torch.isin(speech_codes, control_codes), speech_codes[:, :1], speech_codes) + + +def ensures_codec_target_dtype(model): + """ + Ensures the audio codec is instantiated with the target dtype. + + This function checks whether `model.audio_codec` exists and whether its + parameters match `model.audio_codec_run_dtype`. If the codec is missing + or is running with the wrong dtype (e.g., due to PTL auto-downcasting), + the codec is reloaded by calling `setup_audio_codec()`. + + Intended to be called at runtime boundaries such as: + - `on_train_epoch_start` + - `on_validation_epoch_start` + + Args: + model: Model instance of DuplexEARTTS + + """ + if hasattr(model, "audio_codec") and next(model.audio_codec.parameters()).dtype == model.audio_codec_run_dtype: + return # already correct precision → no-op + + setup_audio_codec(model) + + +def setup_audio_codec(model): + """ + Instantiates the RVQ audio codec and injects codec embeddings into the TTS model. + + This function is responsible only for: + - Instantiating the codec model (`RVQVAEModel`). + - Loading pretrained codec weights (if configured). + - Freezing codec parameters. + - Registering RVQ embeddings inside the TTS model via `set_rvq_embs`. + + Args: + model: Model instance of DuplexEARTTS + """ + with ensures_target_precision(model.audio_codec_run_dtype): + model.audio_codec = RVQVAEModel(DictConfig(model.cfg.codec_config)) + # load pretrained codec checkpoint + if model.cfg.get("pretrained_codec_model", None): + checkpoint_state = load_checkpoint(model.cfg.pretrained_codec_model) + checkpoint_state = set_model_dict_for_partial_init(checkpoint_state, model.audio_codec.state_dict()) + model.audio_codec.load_state_dict(checkpoint_state, strict=True) + + for p in model.audio_codec.parameters(): + p.requires_grad = False + + assert callable(model.tts_model.set_rvq_embs) + + model.tts_model.set_rvq_embs(torch.stack([x.detach() for x in model.audio_codec.prvq.mus_list], 0)) + model.tts_model.rvq_embs = model.tts_model.rvq_embs.to(next(model.tts_model.parameters()).dtype) + # compute target fps + model.target_fps = model.target_sample_rate / model.audio_codec.config.wav_to_token_ratio + model.target_samples_per_frame = model.audio_codec.config.wav_to_token_ratio + + +def rescale_state_dict(state_dict, target_std=0.02, first_n_layers=None, layer_prefix="tts_model.backbone.layers."): + """ + Rescale trainable weights in a state_dict for BF16/FP16 stability. + + Args: + state_dict: PyTorch state_dict + target_std: desired target std for weights + first_n_layers: if not None, rescale only the first N transformer blocks + layer_prefix: prefix for layer names (default: "tts_model.backbone.layers.") + Returns: + new_state_dict + """ + weight_tensors = [] + + # Compute which prefixes to match if first_n_layers is set + prefixes_to_match = [] + if first_n_layers is not None: + prefixes_to_match = [f"{layer_prefix}{i}" for i in range(first_n_layers)] + + for name, param in state_dict.items(): + if not torch.is_tensor(param): + continue + + if "rvq_embs" in name: + continue + + # Skip biases & 1-dim params (norm weights/gates) + if param.ndim <= 1: + continue + + # Skip layers not in the first N + if first_n_layers is not None and not any(name.startswith(pfx) for pfx in prefixes_to_match): + continue + + weight_tensors.append(param.float()) + + if not weight_tensors: + if first_n_layers is not None: + logging.info(f"No weights found for first {first_n_layers} layers with prefix '{layer_prefix}'.") + else: + logging.info("No weights found to rescale in state_dict.") + return state_dict + + # Compute global std across selected weights (on CPU) + cpu_weights = [p.detach().cpu() for p in weight_tensors] + flat = torch.cat([p.flatten() for p in cpu_weights]) + current_std = float(torch.std(flat)) + scale = target_std / (current_std + 1e-8) + + logging.info( + f"Rescaling state_dict " + f"{'(first N layers)' if first_n_layers else '(all layers)'}: " + f"current std = {current_std:.6f}, target = {target_std}, scale = {scale:.6f}" + ) + + # Apply scaling + new_state_dict = {} + for name, param in state_dict.items(): + if ( + torch.is_tensor(param) + and param.ndim > 1 + and (first_n_layers is None or any(name.startswith(pfx) for pfx in prefixes_to_match)) + ): + new_state_dict[name] = param * scale + else: + new_state_dict[name] = param + + logging.info("Done: weights rescaled.") + return new_state_dict diff --git a/nemo/collections/speechlm2/modules/ear_tts_model.py b/nemo/collections/speechlm2/modules/ear_tts_model.py index 1e00b9a77cff..67644dda76be 100644 --- a/nemo/collections/speechlm2/modules/ear_tts_model.py +++ b/nemo/collections/speechlm2/modules/ear_tts_model.py @@ -105,113 +105,9 @@ def forward(self, x: Tensor) -> Tensor: TRITON_IMPORTED = False USE_TRITON = TRITON_IMPORTED and torch.cuda.is_available() -logging.info("Triton available & CUDA detected. Using Triton kernel for batch_matmul.") if USE_TRITON: - - @triton.jit - def batch_matmul_kernel( - x_ptr, # Pointer to input tensor x: [batch_size, d_in] - w_ptr, # Pointer to weight tensor w: [num_weights, d_out, d_in] - y_ptr, # Pointer to index tensor y: [batch_size] - result_ptr, # Pointer to output tensor result: [batch_size, d_out] - b, - d_in, - d_out, - n, # Dimensions - BLOCK_SIZE_DIN: tl.constexpr, - BLOCK_SIZE_DOUT: tl.constexpr, - ): - """ - Triton kernel for performing a batched matrix multiplication where each row - of the input `x` is multiplied by a different weight matrix selected from `w` - by an index in `y`. - """ - # Get the program IDs for the batch and output dimensions - batch_id = tl.program_id(axis=0) - dout_block_id = tl.program_id(axis=1) - - # Early exit for out-of-bounds batch IDs - if batch_id >= b: - return - - # Load the index for the current batch item - idx = tl.load(y_ptr + batch_id) - - # Compute base offsets for the current batch item - x_offset = x_ptr + batch_id * d_in - w_offset = w_ptr + idx * d_out * d_in - - # Define the block of output dimensions to compute - dout_offsets = dout_block_id * BLOCK_SIZE_DOUT + tl.arange(0, BLOCK_SIZE_DOUT) - dout_mask = dout_offsets < d_out - - # Initialize accumulator for the result block - result_block = tl.zeros([BLOCK_SIZE_DOUT], dtype=tl.float32) - - # Loop over the input dimension in blocks - for din_start in range(0, d_in, BLOCK_SIZE_DIN): - din_offsets = din_start + tl.arange(0, BLOCK_SIZE_DIN) - din_mask = din_offsets < d_in - - # Load a block of the input vector x - x_i = tl.load(x_offset + din_offsets, mask=din_mask, other=0.0) - - # Load a block of the selected weight matrix w - w_i_block = tl.load( - w_offset + dout_offsets[:, None] * d_in + din_offsets[None, :], - mask=(dout_mask[:, None] & din_mask[None, :]), - other=0.0, - ) - - # Compute the partial dot product and accumulate - partial = tl.sum(w_i_block * x_i[None, :], axis=1) - result_block += partial - - # Store the final result block - result_offset = result_ptr + batch_id * d_out + dout_offsets - tl.store(result_offset, result_block, mask=dout_mask) - - def batch_matmul_triton(x, w, y, BLOCK_SIZE_DIN: int = 16, BLOCK_SIZE_DOUT: int = 64): - """Wrapper function to launch the Triton kernel for batch_matmul.""" - assert x.is_contiguous() and w.is_contiguous() and y.is_contiguous() - assert math.log2(BLOCK_SIZE_DIN).is_integer() and math.log2(BLOCK_SIZE_DOUT).is_integer() - - b, d_in = x.shape - n, d_out, _ = w.shape - result = torch.empty(b, d_out, device=x.device, dtype=torch.float32) - - batch_matmul_kernel[lambda meta: (b, triton.cdiv(d_out, meta["BLOCK_SIZE_DOUT"]))]( - x.float(), - w.float(), - y, - result, - b, - d_in, - d_out, - n, - BLOCK_SIZE_DIN=BLOCK_SIZE_DIN, - BLOCK_SIZE_DOUT=BLOCK_SIZE_DOUT, - ) - return result.to(dtype=x.dtype) - - # Set batch_matmul to the optimized Triton version - batch_matmul = batch_matmul_triton - logging.info("Triton is available. Using optimized Triton kernel for batch_matmul.") - - -try: - import triton - import triton.language as tl - - TRITON_IMPORTED = True -except ImportError: - TRITON_IMPORTED = False - -USE_TRITON = TRITON_IMPORTED and torch.cuda.is_available() - -if USE_TRITON: - logging.info("Triton+CUDA detected. Using Triton kernel for batch_matmul.") + logging.info("Triton available & CUDA detected. Using Triton kernel for batch_matmul.") @triton.jit def batch_matmul_kernel( @@ -587,7 +483,6 @@ def depthsum_encoding_step( emb_i = F.embedding(idx_sel, embs[i]) r = r - emb_i - # FIX: assign correctly without shape mismatch code[..., i] = idx_sel return code From 5efbadfeaf21e487b6eef35235911e47bf98e5cf Mon Sep 17 00:00:00 2001 From: Edresson Casanova Date: Wed, 17 Dec 2025 05:07:26 -0800 Subject: [PATCH 089/102] Rename input_text_tokens to target_text_tokens Signed-off-by: Edresson Casanova --- .../speechlm2/data/duplex_ear_tts_dataset.py | 42 ++++++------- .../speechlm2/models/duplex_ear_tts.py | 62 +++++++++---------- .../speechlm2/test_duplex_eartts.py | 6 +- 3 files changed, 55 insertions(+), 55 deletions(-) diff --git a/nemo/collections/speechlm2/data/duplex_ear_tts_dataset.py b/nemo/collections/speechlm2/data/duplex_ear_tts_dataset.py index 59db2d6b2d15..530d2df27563 100644 --- a/nemo/collections/speechlm2/data/duplex_ear_tts_dataset.py +++ b/nemo/collections/speechlm2/data/duplex_ear_tts_dataset.py @@ -94,7 +94,7 @@ class DuplexEARTTSDataset(torch.utils.data.Dataset): - target_audio: Tensor of target waveform samples [B, T] - target_audio_lens: Tensor of target audio lengths [B] - - input_text_tokens: Tensor of frame-aligned input text tokens [B, T], + - target_text_tokens: Tensor of frame-aligned input text tokens [B, T], including BOS/EOS/PAD when enabled - target_token_lens: Tensor of target token sequence lengths [B] @@ -159,7 +159,7 @@ def __getitem__(self, cuts: CutSet) -> dict: target_audio, target_audio_lens = collate_audio( cuts.resample(self.target_sample_rate), recording_field="target_audio" ) - input_text_tokens, target_token_lens = collate_token_channel( + target_text_tokens, target_token_lens = collate_token_channel( cuts, self.tokenizer, self.frame_length, @@ -192,7 +192,7 @@ def __getitem__(self, cuts: CutSet) -> dict: # add audio prompt if needed ( - input_text_tokens, + target_text_tokens, target_token_lens, source_tokens, source_token_lens, @@ -202,7 +202,7 @@ def __getitem__(self, cuts: CutSet) -> dict: target_audio_lens, prompt_lens, ) = self.maybe_add_audio_prompt( - input_text_tokens, + target_text_tokens, target_token_lens, source_tokens, source_token_lens, @@ -260,7 +260,7 @@ def __getitem__(self, cuts: CutSet) -> dict: "source_audio_lens": source_audio_lens, "target_audio": target_audio, "target_audio_lens": target_audio_lens, - "input_text_tokens": input_text_tokens, + "target_text_tokens": target_text_tokens, "target_token_lens": target_token_lens, "source_tokens": source_tokens, "source_token_lens": source_token_lens, @@ -274,7 +274,7 @@ def __getitem__(self, cuts: CutSet) -> dict: def maybe_add_audio_prompt( self, - input_text_tokens: torch.Tensor, + target_text_tokens: torch.Tensor, target_token_lens: torch.Tensor, source_tokens: torch.Tensor, source_token_lens: torch.Tensor, @@ -293,12 +293,12 @@ def maybe_add_audio_prompt( padding is inserted into the text-token streams (input text tokens and source tokens). Args: - input_text_tokens (torch.Tensor): + target_text_tokens (torch.Tensor): Tensor of input text tokens with shape [B, T_text]. dtype: torch.long. target_token_lens (torch.Tensor): - Lengths of input_text_tokens per batch element (before padding). shape [B]. + Lengths of target_text_tokens per batch element (before padding). shape [B]. source_tokens (torch.Tensor): Source-side text tokens, shape [B, T_src_text], dtype torch.long. @@ -326,7 +326,7 @@ def maybe_add_audio_prompt( Returns: Tuple containing: - input_text_tokens (torch.Tensor): + target_text_tokens (torch.Tensor): Updated text tokens with prepended prompt-aligned tokens. Shape [B, T']. target_token_lens (torch.Tensor): @@ -356,17 +356,17 @@ def maybe_add_audio_prompt( text_pad_id = get_pad_id(self.tokenizer) - input_text_tokens_ = [] + target_text_tokens_ = [] source_tokens_ = [] source_audio_ = [] target_audio_ = [] prompt_lens = [] - for i in range(input_text_tokens.size(0)): + for i in range(target_text_tokens.size(0)): first_text_frame = torch.tensor( [self.tokenizer.eos], dtype=torch.long, - device=input_text_tokens.device, + device=target_text_tokens.device, ) if self.add_audio_prompt: @@ -389,21 +389,21 @@ def maybe_add_audio_prompt( prompt_audio_text_pad = ( torch.ones( prompt_audio_text_pad_size, - device=input_text_tokens.device, - dtype=input_text_tokens.dtype, + device=target_text_tokens.device, + dtype=target_text_tokens.dtype, ) * text_pad_id ) prompt_audio_text_pad[-1] = self.tokenizer.eos - new_input_text_tokens = torch.cat( + new_target_text_tokens = torch.cat( [ - first_text_frame.to(input_text_tokens.dtype), + first_text_frame.to(target_text_tokens.dtype), prompt_audio_text_pad, - input_text_tokens[i], + target_text_tokens[i], ] ) - input_text_tokens_.append(new_input_text_tokens) + target_text_tokens_.append(new_target_text_tokens) target_token_lens[i] += len(first_text_frame) + prompt_audio_text_pad_size new_source_tokens = torch.cat([first_text_frame, prompt_audio_text_pad, source_tokens[i]]) @@ -434,7 +434,7 @@ def maybe_add_audio_prompt( else: # Add only a single text-frame (EOS) as prompt - input_text_tokens_.append(torch.cat([first_text_frame, input_text_tokens[i]])) + target_text_tokens_.append(torch.cat([first_text_frame, target_text_tokens[i]])) target_token_lens[i] += len(first_text_frame) source_tokens_.append(torch.cat([first_text_frame, source_tokens[i]])) @@ -460,13 +460,13 @@ def maybe_add_audio_prompt( prompt_lens.append(len(first_text_frame)) - input_text_tokens = collate_vectors(input_text_tokens_, padding_value=text_pad_id) + target_text_tokens = collate_vectors(target_text_tokens_, padding_value=text_pad_id) source_tokens = collate_vectors(source_tokens_, padding_value=text_pad_id) source_audio = collate_vectors(source_audio_, padding_value=0) target_audio = collate_vectors(target_audio_, padding_value=0) return ( - input_text_tokens, + target_text_tokens, target_token_lens, source_tokens, source_token_lens, diff --git a/nemo/collections/speechlm2/models/duplex_ear_tts.py b/nemo/collections/speechlm2/models/duplex_ear_tts.py index 0e1ffa7528d1..0520b59557ba 100644 --- a/nemo/collections/speechlm2/models/duplex_ear_tts.py +++ b/nemo/collections/speechlm2/models/duplex_ear_tts.py @@ -296,7 +296,7 @@ def prepare_inputs(self, batch: dict): target_audio = batch["target_audio"] target_audio_lens = batch["target_audio_lens"] - input_text_tokens = batch["input_text_tokens"] + target_text_tokens = batch["target_text_tokens"] non_prompt_mask = batch["non_prompt_mask"] aligned_attention_mask = batch["aligned_attention_mask"] aligned_position_ids = batch["aligned_position_ids"] @@ -309,12 +309,12 @@ def prepare_inputs(self, batch: dict): # Create mask for tokens we want to drop # Keep BOS and EOS, drop the rest. - keep_mask = (input_text_tokens == self.text_bos_id) | (input_text_tokens == self.text_eos_id) + keep_mask = (target_text_tokens == self.text_bos_id) | (target_text_tokens == self.text_eos_id) full_dropout_mask = ~keep_mask # True = positions to replace with PAD # Replace all non-BOS/EOS with PAD - input_text_tokens = torch.where( - full_dropout_mask, torch.full_like(input_text_tokens, self.text_pad_id), input_text_tokens + target_text_tokens = torch.where( + full_dropout_mask, torch.full_like(target_text_tokens, self.text_pad_id), target_text_tokens ) # extract target audio codes @@ -337,7 +337,7 @@ def pad_or_truncate(x, pad_value=0): return x[:, :target_len] return x # leave others for now - input_text_tokens = pad_or_truncate(input_text_tokens, pad_value=self.text_pad_id) + target_text_tokens = pad_or_truncate(target_text_tokens, pad_value=self.text_pad_id) non_prompt_mask = pad_or_truncate(non_prompt_mask, pad_value=0) aligned_position_ids = pad_or_truncate(aligned_position_ids, pad_value=0) @@ -364,25 +364,25 @@ def pad_or_truncate(x, pad_value=0): # EOS dropout to make the model more robust if self.training and self.cfg.get("text_eos_dropout_prob", 0.0) > 0: # Mask EOS positions - eos_mask = input_text_tokens == self.text_eos_id + eos_mask = target_text_tokens == self.text_eos_id # Random dropout only on EOS positions - dropout_mask = torch.rand(eos_mask.sum(), device=input_text_tokens.device) < self.cfg.text_eos_dropout_prob + dropout_mask = torch.rand(eos_mask.sum(), device=target_text_tokens.device) < self.cfg.text_eos_dropout_prob # Scatter dropout decisions into [B, T] - full_dropout_mask = torch.zeros_like(input_text_tokens, dtype=torch.bool) + full_dropout_mask = torch.zeros_like(target_text_tokens, dtype=torch.bool) full_dropout_mask[eos_mask] = dropout_mask # Replace dropped EOS with PAD - input_text_tokens = torch.where( - full_dropout_mask, torch.full_like(input_text_tokens, self.text_pad_id), input_text_tokens + target_text_tokens = torch.where( + full_dropout_mask, torch.full_like(target_text_tokens, self.text_pad_id), target_text_tokens ) if self.training and self.cfg.get("text_eos_duplicate_prob", 0.0) > 0: p = self.cfg.text_eos_duplicate_prob # [B, T] mask of EOS positions - eos_mask = input_text_tokens == self.text_eos_id + eos_mask = target_text_tokens == self.text_eos_id # Flatten EOS positions: tensor of shape [N, 2] where each row = (batch_idx, time_idx) eos_positions = eos_mask.nonzero(as_tuple=False) # [N, 2] @@ -391,7 +391,7 @@ def pad_or_truncate(x, pad_value=0): N = eos_positions.shape[0] # One random decision per EOS occurrence - duplicate_decision = torch.rand(N, device=input_text_tokens.device) < p # [N] + duplicate_decision = torch.rand(N, device=target_text_tokens.device) < p # [N] # Filter only EOS tokens that will be duplicated and are not at position t=0 valid = (eos_positions[:, 1] > 0) & duplicate_decision # [N] @@ -405,42 +405,42 @@ def pad_or_truncate(x, pad_value=0): t_idx = valid_positions[:, 1] - 1 # Replace token before EOS with an EOS - input_text_tokens[b_idx, t_idx] = self.text_eos_id + target_text_tokens[b_idx, t_idx] = self.text_eos_id # BOS dropout to make the model more robust if self.training and self.cfg.get("text_bos_dropout_prob", 0.0) > 0: # Mask BOS positions - bos_mask = input_text_tokens == self.text_bos_id + bos_mask = target_text_tokens == self.text_bos_id # Random dropout only on BOS positions - dropout_mask = torch.rand(bos_mask.sum(), device=input_text_tokens.device) < self.cfg.text_bos_dropout_prob + dropout_mask = torch.rand(bos_mask.sum(), device=target_text_tokens.device) < self.cfg.text_bos_dropout_prob # Scatter dropout decisions into [B, T] - full_dropout_mask = torch.zeros_like(input_text_tokens, dtype=torch.bool) + full_dropout_mask = torch.zeros_like(target_text_tokens, dtype=torch.bool) full_dropout_mask[bos_mask] = dropout_mask # Replace dropped BOS with PAD - input_text_tokens = torch.where( + target_text_tokens = torch.where( full_dropout_mask, - torch.full_like(input_text_tokens, self.text_pad_id), - input_text_tokens, + torch.full_like(target_text_tokens, self.text_pad_id), + target_text_tokens, ) # shift text tokens - subword_ids = F.pad(input_text_tokens[:, 1:], [0, 1]) + subword_ids = F.pad(target_text_tokens[:, 1:], [0, 1]) # note that we are using a text mask where we are ignoring the desc + audio prompt but we are keeping 1 until the audio ends to support duplex subword_mask = F.pad(non_prompt_mask[:, 1:], [0, 1]) # detach embedding as in eartts if self.cfg.tts_config.context_hidden_size is not None: - context_hidden_state = self.embed_tokens(input_text_tokens).detach() + context_hidden_state = self.embed_tokens(target_text_tokens).detach() else: context_hidden_state = None if self._use_tp: tp_world_size = self.device_mesh["tensor_parallel"].size() - if (remainder := (input_text_tokens.shape[1] - 1) % tp_world_size) != 0: - input_text_tokens = input_text_tokens[:, :-remainder] + if (remainder := (target_text_tokens.shape[1] - 1) % tp_world_size) != 0: + target_text_tokens = target_text_tokens[:, :-remainder] target_codes_aligned = target_codes_aligned[:, :-remainder] target_codes_aligned = target_codes_aligned[:, :-remainder] subword_ids = subword_ids[:, :-remainder] @@ -456,7 +456,7 @@ def pad_or_truncate(x, pad_value=0): "context_hidden_state": context_hidden_state, "output_lens": target_codes_lens, "non_prompt_mask": non_prompt_mask, - "input_text_tokens": input_text_tokens, + "target_text_tokens": target_text_tokens, } def training_step(self, batch: dict, batch_idx: int): @@ -644,7 +644,7 @@ def run_evaluation_one_batch(self, name, dataset_batch, use_dataloader_init=Fals ) init_inputs = self.get_init_inputs(B=inputs["subword_ids"].size(0)) - # remove the prompt from the input_text_tokens to emulate S2S connected inference + # remove the prompt from the target_text_tokens to emulate S2S connected inference next_subword_ids = torch.stack( [ inputs["subword_ids"][i, plen:] # slice each element @@ -862,7 +862,7 @@ def set_init_inputs(self, speaker_audio, speaker_audio_lens, system_prompt=None, # Prepend an initial text EOS token followed by padding tokens that match # the number of audio-prompt frames (in text-token units). - input_text_tokens = torch.cat([first_text_frame, prompt_audio_text_pad.to(first_text_frame.dtype)]) + target_text_tokens = torch.cat([first_text_frame, prompt_audio_text_pad.to(first_text_frame.dtype)]) # create pad audio for the description pad_size = first_text_frame.size(-1) * self.target_samples_per_frame @@ -873,7 +873,7 @@ def set_init_inputs(self, speaker_audio, speaker_audio_lens, system_prompt=None, ) # repeat to reaches the batch size - input_text_tokens = input_text_tokens.unsqueeze(0).repeat(prompt_audio.size(0), 1) + target_text_tokens = target_text_tokens.unsqueeze(0).repeat(prompt_audio.size(0), 1) target_audio = torch.cat([pad_audio, prompt_audio], dim=1) # extract code codes @@ -885,16 +885,16 @@ def set_init_inputs(self, speaker_audio, speaker_audio_lens, system_prompt=None, # get context hidden if self.cfg.tts_config.context_hidden_size is not None: - context_hidden_state = self.embed_tokens(input_text_tokens) + context_hidden_state = self.embed_tokens(target_text_tokens) else: context_hidden_state = None # create masks # non_prompt_mask is all zeros, because all processed is prompt - non_prompt_mask = torch.zeros_like(input_text_tokens) + non_prompt_mask = torch.zeros_like(target_text_tokens) non_prompt_mask[:, -2:] = 1 # set last valid prompt frame as 1 to allow the addition of BOS in the right place subword_mask = torch.zeros_like( - input_text_tokens + target_text_tokens ) # subword_mask is almost all zeros because on the warmup there is only the prompt subword_mask[:, -3:] = ( 1 # -3 because of the it start right after the first valid prompt token and it is shifted by 1 @@ -904,7 +904,7 @@ def set_init_inputs(self, speaker_audio, speaker_audio_lens, system_prompt=None, code[:, 0] = self.speech_pad_id # shift subword_ids - subword_ids = F.pad(input_text_tokens[:, 1:], [0, 1], value=0.0) + subword_ids = F.pad(target_text_tokens[:, 1:], [0, 1], value=0.0) # set special token in the last audio prompt (it will works as a BOS token) pos = non_prompt_mask.float().argmax(dim=1) # shape: [B] diff --git a/tests/collections/speechlm2/test_duplex_eartts.py b/tests/collections/speechlm2/test_duplex_eartts.py index 8381ce89de40..e27c450f3e33 100644 --- a/tests/collections/speechlm2/test_duplex_eartts.py +++ b/tests/collections/speechlm2/test_duplex_eartts.py @@ -259,7 +259,7 @@ def test_eartts_dataset(dataset, training_cutset_batch): "source_audio_lens", "target_audio", "target_audio_lens", - "input_text_tokens", + "target_text_tokens", "target_token_lens", "source_tokens", "source_token_lens", @@ -280,7 +280,7 @@ def test_eartts_dataset(dataset, training_cutset_batch): "source_audio_lens", "target_audio", "target_audio_lens", - "input_text_tokens", + "target_text_tokens", "target_token_lens", "source_tokens", "source_token_lens", @@ -352,7 +352,7 @@ def test_eartts_dataset(dataset, training_cutset_batch): ] ] - assert batch["input_text_tokens"].tolist() == [ + assert batch["target_text_tokens"].tolist() == [ [ 2, 12, From d316cc95e55451cd41886557896fd9b6251cd66b Mon Sep 17 00:00:00 2001 From: Edresson Casanova Date: Wed, 17 Dec 2025 06:42:55 -0800 Subject: [PATCH 090/102] Move the bos/eos/pad token definition to AutoTokenizer and reuse the tokenizer instance Signed-off-by: Edresson Casanova --- .../speechlm2/models/duplex_ear_tts.py | 16 ++--- .../speechlm2/modules/ear_tts_model.py | 64 +++++++++---------- .../speechlm2/test_duplex_eartts.py | 1 - 3 files changed, 35 insertions(+), 46 deletions(-) diff --git a/nemo/collections/speechlm2/models/duplex_ear_tts.py b/nemo/collections/speechlm2/models/duplex_ear_tts.py index 0520b59557ba..3faae6cc36b6 100644 --- a/nemo/collections/speechlm2/models/duplex_ear_tts.py +++ b/nemo/collections/speechlm2/models/duplex_ear_tts.py @@ -87,8 +87,13 @@ def __init__(self, cfg: dict) -> None: # get codec run precision self.audio_codec_run_dtype = getattr(torch, self.cfg.get("audio_codec_run_dtype", "float32"), torch.float32) + # Load tokenizer + self.tokenizer = AutoTokenizer( + self.cfg.pretrained_lm_name, use_fast=True, trust_remote_code=True, bos_token=self.cfg.get("bos_token", ''), eos_token=self.cfg.get("eos_token", ''), pad_token=self.cfg.get("pad_token", '') + ) # Note that we are using fast tokenizer + # Instantiate TTS model - self.tts_model = RVQEARTTSModel(DictConfig(self.cfg.tts_config)) + self.tts_model = RVQEARTTSModel(DictConfig(self.cfg.tts_config), tokenizer=self.tokenizer) # Load and initialize audio codec, and bind RVQ embeddings to the TTS model setup_audio_codec(self) @@ -101,15 +106,6 @@ def __init__(self, cfg: dict) -> None: codec_silence_tokens = self.get_codec_silence_frame() self.register_buffer("codec_silence_tokens", codec_silence_tokens) - # Load tokenizer - self.tokenizer = AutoTokenizer( - self.cfg.pretrained_lm_name, use_fast=True, trust_remote_code=True - ) # Note that we are using fast tokenizer - - # set tokenizer special tokens - self.tokenizer.bos_token = self.cfg.get("bos_token", '') - self.tokenizer.eos_token = self.cfg.get("eos_token", '') - self.tokenizer.pad_token = self.cfg.get("pad_token", '') # cached for quicker audio decoding self.register_buffer( diff --git a/nemo/collections/speechlm2/modules/ear_tts_model.py b/nemo/collections/speechlm2/modules/ear_tts_model.py index 67644dda76be..6fd62d423f1e 100644 --- a/nemo/collections/speechlm2/modules/ear_tts_model.py +++ b/nemo/collections/speechlm2/modules/ear_tts_model.py @@ -23,9 +23,10 @@ from omegaconf import DictConfig, OmegaConf from torch import Tensor, nn from torch.nn import functional as F -from transformers import AutoConfig, AutoModel, AutoModelForTextEncoding, AutoTokenizer, Cache +from transformers import AutoConfig, AutoModel, AutoModelForTextEncoding, Cache from transformers.generation.logits_process import TopKLogitsWarper, TopPLogitsWarper +from nemo.collections.common.tokenizers import AutoTokenizer from nemo.collections.speechlm2.parts.precision import fp32_precision from nemo.collections.speechlm2.parts.pretrained import set_model_dict_for_partial_init from nemo.utils import logging @@ -394,7 +395,7 @@ def find_and_delete_module(parent_module: nn.Module, target_module: nn.Module, p def build_vocabs( - pretrained_tokenizer_name: str, vocab_dir: str | None = None + tokenizer: AutoTokenizer, vocab_dir: str | None = None ) -> tuple[dict[int, tuple[int, ...]], dict[str, int], int]: """ Builds or loads a character-level vocabulary derived from a subword tokenizer. @@ -407,7 +408,7 @@ def build_vocabs( loaded. Otherwise, it's created from the pretrained tokenizer and saved. Args: - pretrained_tokenizer_name (str): The name or path of the pretrained Hugging Face tokenizer. + tokenizer (AutoTokenizer): The pretrained Hugging Face tokenizer class. vocab_dir (str | None, optional): The directory to save or load the character vocabulary from. Defaults to None. @@ -417,11 +418,10 @@ def build_vocabs( - The character-to-ID vocabulary dictionary. - The ID for the subword padding token. """ - tokenizer = AutoTokenizer.from_pretrained(pretrained_tokenizer_name, trust_remote_code=True) def _build_char_vocab() -> dict[str, int]: # Find all single-character tokens in the original tokenizer's vocabulary - single_chars = {subword: subword_id for subword, subword_id in tokenizer.vocab.items() if len(subword) == 1} + single_chars = {subword: subword_id for subword, subword_id in tokenizer.tokenizer.vocab.items() if len(subword) == 1} # Create a new, dense character vocabulary sorted by the original token ID sorted_chars = sorted(single_chars.keys(), key=lambda k: single_chars[k]) char_vocab = {char: i for i, char in enumerate(sorted_chars)} @@ -447,13 +447,13 @@ def _build_char_vocab() -> dict[str, int]: char_vocab = json.load(f) else: # No cache directory provided, build in memory. - logging.info(f"Building character vocabulary from tokenizer '{pretrained_tokenizer_name}'.") + logging.info("Building character vocabulary from tokenizer.") char_vocab = _build_char_vocab() # 2. Reconstruct the subword-to-character mapping on the fly subword_id_to_char_ids = { subword_id: tuple(char_vocab[char] for char in subword if char in char_vocab) - for subword, subword_id in tokenizer.vocab.items() + for subword, subword_id in tokenizer.tokenizer.vocab.items() } # Filter out subwords that contain characters not in our character vocabulary subword_id_to_char_ids = {k: v for k, v in subword_id_to_char_ids.items() if v} @@ -670,13 +670,10 @@ class NeMoSubwordFlagEmbedding(nn.Module): Compatible with NeMo AutoTokenizer. """ - def __init__(self, model_name: str, d_model: int): + def __init__(self, tokenizer: AutoTokenizer, d_model: int): super().__init__() - # Load tokenizer from NeMo - # self.tokenizer_hf = AutoTokenizer.from_pretrained(model_name) - from nemo.collections.common.tokenizers import AutoTokenizer as NeMoAutoTokenizer - self.tokenizer = NeMoAutoTokenizer(model_name, use_fast=True, trust_remote_code=True) + self.tokenizer = tokenizer self.vocab_size = self.tokenizer.vocab_size self.d_model = d_model @@ -713,9 +710,10 @@ class SubwordFlagEmbedding(nn.Module): Ignores special tokens (starting with '<') when computing continuation flags. """ - def __init__(self, model_name: str, d_model: int): + def __init__(self, tokenizer: AutoTokenizer, d_model: int): super().__init__() - self.tokenizer = AutoTokenizer.from_pretrained(model_name) + + self.tokenizer = tokenizer self.vocab_size = self.tokenizer.vocab_size self.d_model = d_model @@ -725,7 +723,7 @@ def __init__(self, model_name: str, d_model: int): self.register_buffer("pad_tensor", torch.tensor(self.pad_id, dtype=torch.long)) # Precompute continuation flags - tokens = [self.tokenizer.convert_ids_to_tokens(i) for i in range(self.vocab_size)] + tokens = [self.tokenizer.ids_to_tokens(i) for i in range(self.vocab_size)] cont_flags = [ 1 if not (tok.startswith("Ġ") or tok.startswith("▁") or tok.startswith("<")) else 0 for tok in tokens ] @@ -755,11 +753,12 @@ class BOSEOSEmbedding(nn.Module): Compatible with Hugging Face tokenizers that may or may not have BOS/EOS. """ - def __init__(self, model_name: str, d_model: int): + def __init__(self, tokenizer: AutoTokenizer, d_model: int): super().__init__() - self.tokenizer = AutoTokenizer.from_pretrained(model_name) + + self.tokenizer = tokenizer # vocab size that includes special tokens - vocab_dict = self.tokenizer.get_vocab() + vocab_dict = self.tokenizer.tokenizer.get_vocab() self.vocab_size = max(vocab_dict.values()) self.d_model = d_model @@ -768,14 +767,7 @@ def __init__(self, model_name: str, d_model: int): self.register_buffer("pad_tensor", torch.tensor(self.pad_id, dtype=torch.long)) # Identify BOS and EOS tokens (may be None) - tokens = [self.tokenizer.convert_ids_to_tokens(i) for i in range(self.vocab_size)] - - if 'Qwen2.5' in model_name: - # For Qwen, '<|im_start|>' is a common choice for a BOS token. - # You can check your tokenizer's vocabulary for the best candidate. - logging.warning("Tokenizer does not have a `bos_token`. Setting it to '<|im_start|>'.") - self.tokenizer.bos_token = '<|im_start|>' - self.tokenizer.eos_token = '<|im_end|>' + tokens = [self.tokenizer.ids_to_tokens(i) for i in range(self.vocab_size)] special_flags = [] for tok in tokens: @@ -812,12 +804,12 @@ class SubwordEmbedding(nn.Module): No special handling for OOVs or padding — assumes token_ids are valid. """ - def __init__(self, model_name: str, d_model: int): + def __init__(self, tokenizer: AutoTokenizer, d_model: int): super().__init__() - self.tokenizer = AutoTokenizer.from_pretrained(model_name) + self.tokenizer = tokenizer # Get vocab size from tokenizer - vocab_dict = self.tokenizer.get_vocab() + vocab_dict = self.tokenizer.tokenizer.get_vocab() self.vocab_size = max(vocab_dict.values()) + 1 # +1 for safety self.d_model = d_model @@ -848,7 +840,7 @@ class CharAwareSubwordEncoder(nn.Module): Args: out_size (int): The dimensionality of the output embedding vectors. - pretrained_tokenizer_name (str): The name of the base Hugging Face tokenizer. + tokenizer (AutoTokenizer): The Hugging Face tokenizer class. vocab_dir (str | None): Directory to save/load the character vocabulary. backbone_type (str | None): The type of backbone model from Hugging Face (e.g., "t5gemma"). backbone_model_class (str | None): The class name of the backbone model if not using AutoModel. @@ -859,7 +851,7 @@ class CharAwareSubwordEncoder(nn.Module): def __init__( self, out_size: int, - pretrained_tokenizer_name: str, + tokenizer: AutoTokenizer, vocab_dir: str | None = None, backbone_type: str | None = "t5gemma", backbone_model_class: str | None = None, @@ -868,12 +860,13 @@ def __init__( use_subword_flag_emb: bool = True, use_bos_eos_emb: bool = True, use_cumulative_word_emb: bool = False, + ): super().__init__() # 1. Build or load the character vocabulary self.subword_id_to_char_ids, self.char_vocab, self.subword_padding_idx = build_vocabs( - pretrained_tokenizer_name, + tokenizer, vocab_dir, ) @@ -902,10 +895,10 @@ def __init__( self.proj_embedding = nn.Linear(self.hidden_size, out_size, bias=False) if self.use_subword_flag_emb: - self.subword_flag_emb = SubwordFlagEmbedding(pretrained_tokenizer_name, self.hidden_size) + self.subword_flag_emb = SubwordFlagEmbedding(tokenizer, self.hidden_size) if self.use_bos_eos_emb: - self.bos_eos_emb = BOSEOSEmbedding(pretrained_tokenizer_name, self.hidden_size) + self.bos_eos_emb = BOSEOSEmbedding(tokenizer, self.hidden_size) def prepare_inputs(self, subword_ids: Tensor, padding_mask: Tensor) -> tuple[Tensor, Tensor]: """ @@ -1050,7 +1043,7 @@ class RVQEARTTSModel(nn.Module): config_class = DictConfig rvq_embs: Tensor - def __init__(self, config: DictConfig | dict[str, Any]): + def __init__(self, config: DictConfig | dict[str, Any], tokenizer: AutoTokenizer = None): super().__init__() self.config = config @@ -1105,6 +1098,7 @@ def __init__(self, config: DictConfig | dict[str, Any]): self.embed_subword = ( CharAwareSubwordEncoder( + tokenizer=tokenizer, out_size=self.hidden_size, use_subword_flag_emb=self.config.use_subword_flag_emb, use_bos_eos_emb=self.config.use_bos_eos_emb, diff --git a/tests/collections/speechlm2/test_duplex_eartts.py b/tests/collections/speechlm2/test_duplex_eartts.py index e27c450f3e33..03098e842741 100644 --- a/tests/collections/speechlm2/test_duplex_eartts.py +++ b/tests/collections/speechlm2/test_duplex_eartts.py @@ -96,7 +96,6 @@ "num_quantizers": 31, "context_hidden_size": None, "cas_config": { - "pretrained_tokenizer_name": "nvidia/NVIDIA-Nemotron-Nano-9B-v2", "backbone_type": "t5gemma", "backbone_model_class": None, "backbone_config_class": None, From b14ef663a4d272ab85a19cf3c33b1ee359face3c Mon Sep 17 00:00:00 2001 From: chtruong814 Date: Wed, 17 Dec 2025 14:48:44 +0000 Subject: [PATCH 091/102] Apply isort and black reformatting Signed-off-by: chtruong814 --- .../speechlm2/models/duplex_ear_tts.py | 16 ++++++++++++---- .../speechlm2/modules/ear_tts_model.py | 7 ++++--- 2 files changed, 16 insertions(+), 7 deletions(-) diff --git a/nemo/collections/speechlm2/models/duplex_ear_tts.py b/nemo/collections/speechlm2/models/duplex_ear_tts.py index 3faae6cc36b6..5f660953baec 100644 --- a/nemo/collections/speechlm2/models/duplex_ear_tts.py +++ b/nemo/collections/speechlm2/models/duplex_ear_tts.py @@ -89,7 +89,12 @@ def __init__(self, cfg: dict) -> None: # Load tokenizer self.tokenizer = AutoTokenizer( - self.cfg.pretrained_lm_name, use_fast=True, trust_remote_code=True, bos_token=self.cfg.get("bos_token", ''), eos_token=self.cfg.get("eos_token", ''), pad_token=self.cfg.get("pad_token", '') + self.cfg.pretrained_lm_name, + use_fast=True, + trust_remote_code=True, + bos_token=self.cfg.get("bos_token", ''), + eos_token=self.cfg.get("eos_token", ''), + pad_token=self.cfg.get("pad_token", ''), ) # Note that we are using fast tokenizer # Instantiate TTS model @@ -106,7 +111,6 @@ def __init__(self, cfg: dict) -> None: codec_silence_tokens = self.get_codec_silence_frame() self.register_buffer("codec_silence_tokens", codec_silence_tokens) - # cached for quicker audio decoding self.register_buffer( "_control_codes", @@ -363,7 +367,9 @@ def pad_or_truncate(x, pad_value=0): eos_mask = target_text_tokens == self.text_eos_id # Random dropout only on EOS positions - dropout_mask = torch.rand(eos_mask.sum(), device=target_text_tokens.device) < self.cfg.text_eos_dropout_prob + dropout_mask = ( + torch.rand(eos_mask.sum(), device=target_text_tokens.device) < self.cfg.text_eos_dropout_prob + ) # Scatter dropout decisions into [B, T] full_dropout_mask = torch.zeros_like(target_text_tokens, dtype=torch.bool) @@ -409,7 +415,9 @@ def pad_or_truncate(x, pad_value=0): bos_mask = target_text_tokens == self.text_bos_id # Random dropout only on BOS positions - dropout_mask = torch.rand(bos_mask.sum(), device=target_text_tokens.device) < self.cfg.text_bos_dropout_prob + dropout_mask = ( + torch.rand(bos_mask.sum(), device=target_text_tokens.device) < self.cfg.text_bos_dropout_prob + ) # Scatter dropout decisions into [B, T] full_dropout_mask = torch.zeros_like(target_text_tokens, dtype=torch.bool) diff --git a/nemo/collections/speechlm2/modules/ear_tts_model.py b/nemo/collections/speechlm2/modules/ear_tts_model.py index 6fd62d423f1e..f38e77fcb832 100644 --- a/nemo/collections/speechlm2/modules/ear_tts_model.py +++ b/nemo/collections/speechlm2/modules/ear_tts_model.py @@ -421,7 +421,9 @@ def build_vocabs( def _build_char_vocab() -> dict[str, int]: # Find all single-character tokens in the original tokenizer's vocabulary - single_chars = {subword: subword_id for subword, subword_id in tokenizer.tokenizer.vocab.items() if len(subword) == 1} + single_chars = { + subword: subword_id for subword, subword_id in tokenizer.tokenizer.vocab.items() if len(subword) == 1 + } # Create a new, dense character vocabulary sorted by the original token ID sorted_chars = sorted(single_chars.keys(), key=lambda k: single_chars[k]) char_vocab = {char: i for i, char in enumerate(sorted_chars)} @@ -712,7 +714,7 @@ class SubwordFlagEmbedding(nn.Module): def __init__(self, tokenizer: AutoTokenizer, d_model: int): super().__init__() - + self.tokenizer = tokenizer self.vocab_size = self.tokenizer.vocab_size self.d_model = d_model @@ -860,7 +862,6 @@ def __init__( use_subword_flag_emb: bool = True, use_bos_eos_emb: bool = True, use_cumulative_word_emb: bool = False, - ): super().__init__() From 8aa65efa05c38f72c355d12159b28f51542784de Mon Sep 17 00:00:00 2001 From: Edresson Casanova Date: Wed, 17 Dec 2025 08:48:11 -0800 Subject: [PATCH 092/102] Add docs and config for duplex EARTTS evaluation and set bos eos pad custom tokens to None Signed-off-by: Edresson Casanova --- docs/source/speechlm2/intro.rst | 2 +- examples/speechlm2/conf/duplex_eartts.yaml | 200 ++++++++++++++++++ examples/speechlm2/duplex_eartts_eval.py | 34 ++- examples/speechlm2/duplex_eartts_train.py | 2 +- .../speechlm2/models/duplex_ear_tts.py | 6 +- 5 files changed, 238 insertions(+), 6 deletions(-) create mode 100644 examples/speechlm2/conf/duplex_eartts.yaml diff --git a/docs/source/speechlm2/intro.rst b/docs/source/speechlm2/intro.rst index bb0f898ebefd..921db362d871 100644 --- a/docs/source/speechlm2/intro.rst +++ b/docs/source/speechlm2/intro.rst @@ -100,7 +100,7 @@ You can run inference using the loaded pretrained DuplexS2SModel: # Prepare audio for model audio_signal = audio_signal.to(model.device) audio_len = torch.tensor([audio_signal.shape[1]], device=model.device) - + # Run offline inference results = model.offline_inference( input_signal=audio_signal, diff --git a/examples/speechlm2/conf/duplex_eartts.yaml b/examples/speechlm2/conf/duplex_eartts.yaml new file mode 100644 index 000000000000..a2765fa36364 --- /dev/null +++ b/examples/speechlm2/conf/duplex_eartts.yaml @@ -0,0 +1,200 @@ +model: + pretrained_lm_name: "nvidia/NVIDIA-Nemotron-Nano-9B-v2" + pretrained_audio_codec: ??? # to be released + pretrained_tts_model: null + scoring_asr: stt_en_fastconformer_transducer_large # used only in validation/evaluation + + # Regexp (re.compile) patterns matching parameters to be frozen. + freeze_params: + - "^audio_codec\\..+$" # Keep audio codec frozen as it only provides supervision for training. + - "^embed_tokens\\..+$" # Keep embed_tokens frozen as done in eartts + + prevent_freeze_params: [] # Use to make specific submodules trainable; overrides freeze_params + + # set custom text eos/bos/pad tokens + bos_token: "" + eos_token: "" + pad_token: "" + + # inference params + inference_guidance_scale: 0.5 + inference_noise_scale: 0.8 + inference_top_p_or_k: 0.8 + inference_guidance_enabled: true + + + optimizer: + _target_: torch.optim.AdamW + lr: 4e-05 + betas: [0.9, 0.98] + weight_decay: 0 + foreach: true # set to false if having issues with tensor-parallelism + + lr_scheduler: + _target_: nemo.core.optim.lr_scheduler.InverseSquareRootAnnealing + warmup_steps: 2500 + min_lr: 1e-6 + max_steps: ${trainer.max_steps} + + codec_config: + latent_size: 512 + n_fft: 16 + hop_length: 4 + base_hidden_size: 384 + channel_mult: + - 1 + - 2 + - 4 + rates: + - 7 + - 7 + - 9 + num_blocks: 3 + kernel_size: 7 + groups: 1 + codebook_size: 1024 + num_quantizers: 31 + wav_to_token_ratio: 1764 + + tts_config: + # extra configs added + use_gated_fusion_for_text_audio: true + disable_eos_prediction: true # disable eos prediction + use_bos_eos_emb: true + use_subword_flag_emb: true + num_delay_speech_tokens: 2 + # EAR-TTS configs + backbone_type: gemma3_text + backbone_model_class: null + backbone_config_class: null + backbone_config: + hidden_size: 1152 + intermediate_size: 4608 + num_hidden_layers: 28 + num_attention_heads: 16 + num_key_value_heads: 16 + head_dim: 72 + attention_dropout: 0.1 + use_cache: false + latent_size: 512 + codebook_size: 1024 + num_quantizers: 31 + context_hidden_size: null + cas_config: + backbone_type: t5gemma + backbone_model_class: null + backbone_config_class: null + backbone_config: + is_encoder_decoder: false + encoder: + hidden_size: 1152 + intermediate_size: 4608 + num_hidden_layers: 1 + num_attention_heads: 16 + num_key_value_heads: 16 + head_dim: 72 + use_cache: false + attention_dropout: 0.1 + mog_head_config: + intermediate_size: 4608 + num_layers: 3 + low_rank: 64 + num_predictions: 1024 + min_log_std: -4.0 + eps: 1e-06 + p_uncond: 0.1 + label_smoothing: 0.01 + max_training_rate: 0.8 + quantizer_dropout: 0.5 + random_target_masking: false + exponent: 3.0 +trainer: + devices: -1 + accelerator: gpu + num_nodes: 1 + precision: 32 + logger: False # logger provided by exp_manager + enable_checkpointing: False + use_distributed_sampler: False + max_steps: 1000000 + val_check_interval: 2000 + limit_train_batches: ${trainer.val_check_interval} # an "epoch" + limit_val_batches: 2 + log_every_n_steps: 20 + num_sanity_val_steps: 0 + gradient_clip_val: 1.0 + accumulate_grad_batches: 1 + strategy: + _target_: lightning.pytorch.strategies.DDPStrategy + gradient_as_bucket_view: true + find_unused_parameters: true + +data: + # data loader configs + add_text_bos_and_eos_in_each_turn: true + add_audio_prompt_after_description: true + audio_prompt_duration: 3.0 + frame_length: 0.08 + source_sample_rate: 22050 + target_sample_rate: 22050 + input_roles: ["user", "User"] + output_roles: ["agent", "Assistant", "assistant","Agent"] + + train_ds: + sample_rate: ${data.target_sample_rate} + input_cfg: + - type: lhotse_shar + shar_path: ??? + seed: 42 + shard_seed: "randomized" + num_workers: 2 + batch_size: 4 + # Optional bucketing: + # batch_size: null + # batch_duration: 100 + # bucket_duration_bins: [8.94766,10.1551,11.64118,19.30376,42.85] + # use_bucketing: true + # num_buckets: 5 + # bucket_buffer_size: 5000 + + validation_ds: + # The entries under 'datasets' are a list of separate dataloaders. + # The structure is : {} + # They inherit all settings from validation_ds, but can individually override them. + datasets: + val_set_0: # rename to your dataset name, add more as needed + shar_path: ??? + sample_rate: ${data.target_sample_rate} + batch_size: 1 + seed: 42 + shard_seed: "randomized" + +exp_manager: + exp_dir: null + explicit_log_dir: duplex_eartts_results/ + name: eartts + create_tensorboard_logger: false + create_checkpoint_callback: true + use_datetime_version: true + max_time_per_run: 00:03:50:00 + + resume_from_checkpoint: null # The path to a checkpoint file to continue the training, restores the whole state including the epoch, step, LR schedulers, apex, etc. + # you need to set these two to True to continue the training + resume_if_exists: true + resume_ignore_no_checkpoint: true + + # You may use this section to create a W&B logger + create_wandb_logger: false + wandb_logger_kwargs: + name: development-run + project: duplex_eartts + resume: true + + checkpoint_callback_params: + filename: "{step}" + monitor: val_asr_bleu + mode: max + every_n_train_steps: null + every_n_epochs: 1 + save_top_k: 1 + always_save_nemo: false diff --git a/examples/speechlm2/duplex_eartts_eval.py b/examples/speechlm2/duplex_eartts_eval.py index b55ad5219f26..f28fdd65413b 100644 --- a/examples/speechlm2/duplex_eartts_eval.py +++ b/examples/speechlm2/duplex_eartts_eval.py @@ -11,6 +11,38 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. + +""" +Evaluation script for Duplex EARTTS models. + +This script computes standard speech evaluation metrics for a given Duplex +EARTTS checkpoint, including Word Error Rate (WER), Character Error Rate (CER), +speaker encoder cosine similarity (SECS), and ASR BLEU score. + +The configuration file must define a valid ``validation_ds`` based on a Lhotse +dataset using one of the following dataset formats: +- Duplex S2S standard format +- ``s2s_duplex_overlap_as_s2s_duplex`` +- ``lhotse_magpietts_data_as_continuation`` + +During evaluation, the script saves generated audio samples to +``exp_manager.explicit_log_dir`` as specified in the configuration. For each +utterance, the following audio files may be produced: + +- Autoregressive inference output (``*.wav``) +- Teacher-forced output (``*_tf.wav``) +- Ground-truth reference audio (``*_gt.wav``) + +Args: + config-path (str): Path to the directory containing the YAML configuration file. + config-name (str): Name of the YAML configuration file. + +Usage: + python duplex_eartts_eval.py \ + --config-path=conf/ \ + --config-name=duplex_eartts.yaml +""" + import os import torch @@ -27,7 +59,7 @@ torch.cuda.set_device(int(os.environ["LOCAL_RANK"])) -@hydra_runner(config_path="conf", config_name="s2s_duplex_speech_decoder") +@hydra_runner(config_path="conf", config_name="duplex_eartts") def inference(cfg): OmegaConf.resolve(cfg) torch.distributed.init_process_group(backend="nccl") diff --git a/examples/speechlm2/duplex_eartts_train.py b/examples/speechlm2/duplex_eartts_train.py index 2229be46e144..33774cc3ac48 100644 --- a/examples/speechlm2/duplex_eartts_train.py +++ b/examples/speechlm2/duplex_eartts_train.py @@ -27,7 +27,7 @@ torch.cuda.set_device(int(os.environ["LOCAL_RANK"])) -@hydra_runner(config_path="conf", config_name="s2s_duplex_speech_decoder") +@hydra_runner(config_path="conf", config_name="duplex_eartts") def train(cfg): OmegaConf.resolve(cfg) torch.distributed.init_process_group(backend="nccl") diff --git a/nemo/collections/speechlm2/models/duplex_ear_tts.py b/nemo/collections/speechlm2/models/duplex_ear_tts.py index 5f660953baec..a7523baccc83 100644 --- a/nemo/collections/speechlm2/models/duplex_ear_tts.py +++ b/nemo/collections/speechlm2/models/duplex_ear_tts.py @@ -92,9 +92,9 @@ def __init__(self, cfg: dict) -> None: self.cfg.pretrained_lm_name, use_fast=True, trust_remote_code=True, - bos_token=self.cfg.get("bos_token", ''), - eos_token=self.cfg.get("eos_token", ''), - pad_token=self.cfg.get("pad_token", ''), + bos_token=self.cfg.get("bos_token", None), + eos_token=self.cfg.get("eos_token", None), + pad_token=self.cfg.get("pad_token", None), ) # Note that we are using fast tokenizer # Instantiate TTS model From 3ca5e287f9c0ea464348924184a8673205d83c49 Mon Sep 17 00:00:00 2001 From: Edresson Casanova Date: Wed, 17 Dec 2025 08:55:45 -0800 Subject: [PATCH 093/102] Update docs Signed-off-by: Edresson Casanova --- docs/source/speechlm2/intro.rst | 2 +- examples/speechlm2/duplex_eartts_eval.py | 4 +++- 2 files changed, 4 insertions(+), 2 deletions(-) diff --git a/docs/source/speechlm2/intro.rst b/docs/source/speechlm2/intro.rst index 921db362d871..bb0f898ebefd 100644 --- a/docs/source/speechlm2/intro.rst +++ b/docs/source/speechlm2/intro.rst @@ -100,7 +100,7 @@ You can run inference using the loaded pretrained DuplexS2SModel: # Prepare audio for model audio_signal = audio_signal.to(model.device) audio_len = torch.tensor([audio_signal.shape[1]], device=model.device) - + # Run offline inference results = model.offline_inference( input_signal=audio_signal, diff --git a/examples/speechlm2/duplex_eartts_eval.py b/examples/speechlm2/duplex_eartts_eval.py index f28fdd65413b..972b47c15d20 100644 --- a/examples/speechlm2/duplex_eartts_eval.py +++ b/examples/speechlm2/duplex_eartts_eval.py @@ -36,11 +36,13 @@ Args: config-path (str): Path to the directory containing the YAML configuration file. config-name (str): Name of the YAML configuration file. + checkpoint_path (str): Path to the Duplex EARTTS checkpoint file. Usage: python duplex_eartts_eval.py \ --config-path=conf/ \ - --config-name=duplex_eartts.yaml + --config-name=duplex_eartts.yaml \ + ++checkpoint_path=duplex_eartts_results/duplex_eartts/model.ckpt """ import os From dfa7f1694f5968515c167a92c653cc7bc433e716 Mon Sep 17 00:00:00 2001 From: Edresson Casanova Date: Wed, 17 Dec 2025 09:34:15 -0800 Subject: [PATCH 094/102] Add extra unit tests Signed-off-by: Edresson Casanova --- .../speechlm2/test_duplex_eartts.py | 102 +++++++++++++++++- 1 file changed, 101 insertions(+), 1 deletion(-) diff --git a/tests/collections/speechlm2/test_duplex_eartts.py b/tests/collections/speechlm2/test_duplex_eartts.py index 03098e842741..0da05d9d4f11 100644 --- a/tests/collections/speechlm2/test_duplex_eartts.py +++ b/tests/collections/speechlm2/test_duplex_eartts.py @@ -18,7 +18,7 @@ from lhotse.testing.dummies import dummy_cut, dummy_recording from nemo.collections.common.data.utils import move_data_to_device -from nemo.collections.speechlm2.data import DuplexEARTTSDataset +from nemo.collections.speechlm2.data.duplex_ear_tts_dataset import DuplexEARTTSDataset, add_speech_delay, sample_audio_segments_repeat from nemo.collections.speechlm2.models import DuplexEARTTS @@ -36,6 +36,9 @@ r"^audio_codec\..+$", # Keep audio codec frozen as it only provides supervision for training. r"^embed_tokens\..+$", # Keep embed_tokens frozen as done in eartts ], + "bos_token": "", + "eos_token": "", + "pad_token": "", "audio_codec_run_dtype": "float32", "prevent_freeze_params": [], "audio_save_path": "", @@ -410,6 +413,103 @@ def test_eartts_dataset(dataset, training_cutset_batch): # Check formatter assert batch["formatter"] == ["s2s_duplex"] +# test extra functions inside of eartts dataset +def test_add_speech_delay(): + source_audio = torch.ones(1, 16000) + target_audio = torch.ones(1, 22050) + + source_lens = torch.tensor([16000]) + target_lens = torch.tensor([22050]) + + num_delays = 2 + + # samples per frame (float → int handled explicitly) + target_samples_per_frame = source_audio.size(1) / 12.5 + source_samples_per_frame = target_audio.size(1) / 12.5 + + expected_extra_src_size = int(source_samples_per_frame * num_delays) + expected_extra_tgt_size = int(target_samples_per_frame * num_delays) + + out_src, out_src_lens, out_tgt, out_tgt_lens = add_speech_delay( + source_audio=source_audio, + source_audio_lens=source_lens, + target_audio=target_audio, + target_audio_lens=target_lens, + num_delay_speech_tokens=num_delays, + target_samples_per_frame=target_samples_per_frame, + source_samples_per_frame=source_samples_per_frame, + ) + + # -------------------------------------------------- + # Shape & length bookkeeping + # -------------------------------------------------- + assert out_src.shape == (1, source_audio.size(1) + expected_extra_src_size) + assert out_tgt.shape == (1, target_audio.size(1) + expected_extra_tgt_size) + assert out_src_lens.item() == source_lens.item() + expected_extra_src_size + assert out_tgt_lens.item() == target_lens.item() + expected_extra_tgt_size + + + # -------------------------------------------------- + # Padding direction & content + # -------------------------------------------------- + # Target audio is left-padded + assert torch.all(out_tgt[:, :expected_extra_tgt_size] == 0) + assert torch.all(out_tgt[:, expected_extra_tgt_size:] == 1) + + # Source audio is right-padded + assert torch.all(out_src[:, :source_audio.size(1)] == 1) + assert torch.all(out_src[:, source_audio.size(1):] == 0) + + +def test_sample_audio_segments_repeat(): + cases = [ + # (audio, lens, n_sample, expected_when_sample_false) + ( + torch.tensor([[1., 2., 3., 4., 5.]]), + torch.tensor([5]), + 3, + torch.tensor([[1., 2., 3.]]), + ), + ( + torch.tensor([[1., 2.]]), + torch.tensor([2]), + 5, + torch.tensor([[1., 2., 1., 2., 1.]]), + ), + ( + torch.zeros(1, 10), + torch.tensor([0]), + 4, + torch.zeros(1, 4), + ), + ] + + for prompt_audio, prompt_audio_lens, n_sample, expected in cases: + # -------------------------------------------------- + # sample=False → deterministic + sequence check + # -------------------------------------------------- + out = sample_audio_segments_repeat( + prompt_audio, + prompt_audio_lens, + n_sample=n_sample, + sample=False, + ) + + assert out.shape == expected.shape + assert torch.equal(out, expected) + + # -------------------------------------------------- + # sample=True → stochastic, shape only + # -------------------------------------------------- + out = sample_audio_segments_repeat( + prompt_audio, + prompt_audio_lens, + n_sample=n_sample, + sample=True, + ) + + assert out.shape == expected.shape + def test_eartts_training_step(model, dataset, training_cutset_batch): model.train() From 6460486b442a2c02a6707274a138bebc34703b72 Mon Sep 17 00:00:00 2001 From: Edresson Date: Wed, 17 Dec 2025 17:35:03 +0000 Subject: [PATCH 095/102] Apply isort and black reformatting Signed-off-by: Edresson --- .../speechlm2/test_duplex_eartts.py | 20 +++++++++++-------- 1 file changed, 12 insertions(+), 8 deletions(-) diff --git a/tests/collections/speechlm2/test_duplex_eartts.py b/tests/collections/speechlm2/test_duplex_eartts.py index 0da05d9d4f11..dfc0a3b2c317 100644 --- a/tests/collections/speechlm2/test_duplex_eartts.py +++ b/tests/collections/speechlm2/test_duplex_eartts.py @@ -18,7 +18,11 @@ from lhotse.testing.dummies import dummy_cut, dummy_recording from nemo.collections.common.data.utils import move_data_to_device -from nemo.collections.speechlm2.data.duplex_ear_tts_dataset import DuplexEARTTSDataset, add_speech_delay, sample_audio_segments_repeat +from nemo.collections.speechlm2.data.duplex_ear_tts_dataset import ( + DuplexEARTTSDataset, + add_speech_delay, + sample_audio_segments_repeat, +) from nemo.collections.speechlm2.models import DuplexEARTTS @@ -413,6 +417,7 @@ def test_eartts_dataset(dataset, training_cutset_batch): # Check formatter assert batch["formatter"] == ["s2s_duplex"] + # test extra functions inside of eartts dataset def test_add_speech_delay(): source_audio = torch.ones(1, 16000) @@ -448,7 +453,6 @@ def test_add_speech_delay(): assert out_src_lens.item() == source_lens.item() + expected_extra_src_size assert out_tgt_lens.item() == target_lens.item() + expected_extra_tgt_size - # -------------------------------------------------- # Padding direction & content # -------------------------------------------------- @@ -457,24 +461,24 @@ def test_add_speech_delay(): assert torch.all(out_tgt[:, expected_extra_tgt_size:] == 1) # Source audio is right-padded - assert torch.all(out_src[:, :source_audio.size(1)] == 1) - assert torch.all(out_src[:, source_audio.size(1):] == 0) + assert torch.all(out_src[:, : source_audio.size(1)] == 1) + assert torch.all(out_src[:, source_audio.size(1) :] == 0) def test_sample_audio_segments_repeat(): cases = [ # (audio, lens, n_sample, expected_when_sample_false) ( - torch.tensor([[1., 2., 3., 4., 5.]]), + torch.tensor([[1.0, 2.0, 3.0, 4.0, 5.0]]), torch.tensor([5]), 3, - torch.tensor([[1., 2., 3.]]), + torch.tensor([[1.0, 2.0, 3.0]]), ), ( - torch.tensor([[1., 2.]]), + torch.tensor([[1.0, 2.0]]), torch.tensor([2]), 5, - torch.tensor([[1., 2., 1., 2., 1.]]), + torch.tensor([[1.0, 2.0, 1.0, 2.0, 1.0]]), ), ( torch.zeros(1, 10), From 19705d1e4135541d92dfe67fb593bf3b07fc9fff Mon Sep 17 00:00:00 2001 From: Edresson Casanova Date: Sat, 20 Dec 2025 09:31:56 -0800 Subject: [PATCH 096/102] Use CI cached path for Duplex EARTTS tests Signed-off-by: Edresson Casanova --- tests/collections/speechlm2/test_duplex_eartts.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/tests/collections/speechlm2/test_duplex_eartts.py b/tests/collections/speechlm2/test_duplex_eartts.py index dfc0a3b2c317..1bf615b02abc 100644 --- a/tests/collections/speechlm2/test_duplex_eartts.py +++ b/tests/collections/speechlm2/test_duplex_eartts.py @@ -188,6 +188,9 @@ }, } +# set CI cached path +if os.path.exists("/home/TestData/"): + test_eartts_config["model"]["pretrained_lm_name"] = "/home/TestData/HF_HOME/hub/models--nvidia--NVIDIA-Nemotron-Nano-9B-v2/" @pytest.fixture(scope="session") def model(): From edf8e143f38d1426162b2b7a087c33fa9c30962b Mon Sep 17 00:00:00 2001 From: Edresson Date: Sat, 20 Dec 2025 17:33:06 +0000 Subject: [PATCH 097/102] Apply isort and black reformatting Signed-off-by: Edresson --- tests/collections/speechlm2/test_duplex_eartts.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/tests/collections/speechlm2/test_duplex_eartts.py b/tests/collections/speechlm2/test_duplex_eartts.py index 1bf615b02abc..a13a27256b71 100644 --- a/tests/collections/speechlm2/test_duplex_eartts.py +++ b/tests/collections/speechlm2/test_duplex_eartts.py @@ -190,7 +190,10 @@ # set CI cached path if os.path.exists("/home/TestData/"): - test_eartts_config["model"]["pretrained_lm_name"] = "/home/TestData/HF_HOME/hub/models--nvidia--NVIDIA-Nemotron-Nano-9B-v2/" + test_eartts_config["model"][ + "pretrained_lm_name" + ] = "/home/TestData/HF_HOME/hub/models--nvidia--NVIDIA-Nemotron-Nano-9B-v2/" + @pytest.fixture(scope="session") def model(): From 421d15e85da808af461d617bb008c5ed79329126 Mon Sep 17 00:00:00 2001 From: Edresson Casanova Date: Mon, 5 Jan 2026 02:53:58 -0800 Subject: [PATCH 098/102] Fix eartts tests Signed-off-by: Edresson Casanova --- tests/collections/speechlm2/test_duplex_eartts.py | 1 + 1 file changed, 1 insertion(+) diff --git a/tests/collections/speechlm2/test_duplex_eartts.py b/tests/collections/speechlm2/test_duplex_eartts.py index a13a27256b71..c71cedce68bf 100644 --- a/tests/collections/speechlm2/test_duplex_eartts.py +++ b/tests/collections/speechlm2/test_duplex_eartts.py @@ -12,6 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. +import os import pytest import torch from lhotse import CutSet, SupervisionSegment From f55fad00d620f2a76de65ae50d819cf5d684f1b9 Mon Sep 17 00:00:00 2001 From: Edresson Casanova Date: Mon, 5 Jan 2026 12:56:29 -0800 Subject: [PATCH 099/102] Update CI nanov2 path Signed-off-by: Edresson Casanova --- tests/collections/speechlm2/test_duplex_eartts.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/collections/speechlm2/test_duplex_eartts.py b/tests/collections/speechlm2/test_duplex_eartts.py index c71cedce68bf..a6544cadb04e 100644 --- a/tests/collections/speechlm2/test_duplex_eartts.py +++ b/tests/collections/speechlm2/test_duplex_eartts.py @@ -193,7 +193,7 @@ if os.path.exists("/home/TestData/"): test_eartts_config["model"][ "pretrained_lm_name" - ] = "/home/TestData/HF_HOME/hub/models--nvidia--NVIDIA-Nemotron-Nano-9B-v2/" + ] = "/home/TestData/nvidia--NVIDIA-Nemotron-Nano-9B-v2/" @pytest.fixture(scope="session") From e204d7c3190115ad17e50ed21e8ba75857f568fb Mon Sep 17 00:00:00 2001 From: Edresson Date: Mon, 5 Jan 2026 20:57:12 +0000 Subject: [PATCH 100/102] Apply isort and black reformatting Signed-off-by: Edresson --- tests/collections/speechlm2/test_duplex_eartts.py | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/tests/collections/speechlm2/test_duplex_eartts.py b/tests/collections/speechlm2/test_duplex_eartts.py index a6544cadb04e..3817b581330f 100644 --- a/tests/collections/speechlm2/test_duplex_eartts.py +++ b/tests/collections/speechlm2/test_duplex_eartts.py @@ -191,9 +191,7 @@ # set CI cached path if os.path.exists("/home/TestData/"): - test_eartts_config["model"][ - "pretrained_lm_name" - ] = "/home/TestData/nvidia--NVIDIA-Nemotron-Nano-9B-v2/" + test_eartts_config["model"]["pretrained_lm_name"] = "/home/TestData/nvidia--NVIDIA-Nemotron-Nano-9B-v2/" @pytest.fixture(scope="session") From b51de02cd8c5dc3d9180b43af4468378faa08ea9 Mon Sep 17 00:00:00 2001 From: Edresson Casanova Date: Tue, 6 Jan 2026 05:32:22 -0800 Subject: [PATCH 101/102] Fix eartts dataset unittest Signed-off-by: Edresson Casanova --- tests/collections/speechlm2/test_duplex_eartts.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tests/collections/speechlm2/test_duplex_eartts.py b/tests/collections/speechlm2/test_duplex_eartts.py index 3817b581330f..4950c4400121 100644 --- a/tests/collections/speechlm2/test_duplex_eartts.py +++ b/tests/collections/speechlm2/test_duplex_eartts.py @@ -250,7 +250,7 @@ def training_cutset_batch(): id=cut.id, recording_id=cut.recording_id, start=0.6, - duration=0.4, + duration=0.1, text='okay', speaker="assistant", ), @@ -412,7 +412,7 @@ def test_eartts_dataset(dataset, training_cutset_batch): 12, 12, 1, - 1662, + 2, 1417, 12, 12, From befea3257055ae4f6d51c80f60ae5cd515cbfc85 Mon Sep 17 00:00:00 2001 From: Edresson Casanova Date: Wed, 7 Jan 2026 04:46:07 -0800 Subject: [PATCH 102/102] Fix unit test Signed-off-by: Edresson Casanova --- tests/collections/speechlm2/test_duplex_eartts.py | 7 ++----- 1 file changed, 2 insertions(+), 5 deletions(-) diff --git a/tests/collections/speechlm2/test_duplex_eartts.py b/tests/collections/speechlm2/test_duplex_eartts.py index 4950c4400121..2a56d2afe996 100644 --- a/tests/collections/speechlm2/test_duplex_eartts.py +++ b/tests/collections/speechlm2/test_duplex_eartts.py @@ -219,8 +219,8 @@ def dataset(model): @pytest.fixture(scope="session") def training_cutset_batch(): - cut = dummy_cut(0, recording=dummy_recording(0, with_data=True)) - cut.target_audio = dummy_recording(1, with_data=True) + cut = dummy_cut(0, recording=dummy_recording(0, with_data=True, duration=1.0, sampling_rate=22050)) + cut.target_audio = dummy_recording(1, with_data=True, duration=1.0, sampling_rate=22050) cut.supervisions = [ SupervisionSegment( id=cut.id, @@ -302,9 +302,6 @@ def test_eartts_dataset(dataset, training_cutset_batch): for key in tensor_keys: assert torch.is_tensor(batch[key]), f"{key} must be a tensor" - assert batch["source_audio"].shape == (1, 89082) - assert batch["target_audio"].shape == (1, 89082) - # Check target text consistency assert batch["target_texts"] == ["hello okay"] assert batch["source_tokens"].tolist() == [