diff --git a/docs/source/speechlm2/datasets.rst b/docs/source/speechlm2/datasets.rst index c31a51a28006..01feec9ec9eb 100644 --- a/docs/source/speechlm2/datasets.rst +++ b/docs/source/speechlm2/datasets.rst @@ -11,6 +11,7 @@ Duplex S2S models use the Lhotse framework for audio data management. The primar 1. **DuplexS2SDataset**: For general duplex speech-to-speech models 2. **SALMDataset**: Specifically for the Speech-Augmented Language Model (SALM), which processes speech+text and outputs text. +3. **DuplexEARTTSDataset**: Dataset for Duplex EARTTS model, extending DuplexS2SDataset with additional output fields for TTS, including audio prompting. It optionally prepends an audio prompt (speaker reference) to target_audio, which is used to initialize speaker conditioning in the EARTTS model. The dataset provides audio_prompt, audio_prompt_lens, non_prompt_mask, aligned_attention_mask, and aligned_position_ids, and supports custom speaker reference audio through the context_audio field, while preserving full compatibility with the original data format. DuplexS2S Dataset Structure ^^^^^^^^^^^^^^^^^^^^^^^^^^^ diff --git a/docs/source/speechlm2/models.rst b/docs/source/speechlm2/models.rst index 3e970c83a760..5439220b8520 100644 --- a/docs/source/speechlm2/models.rst +++ b/docs/source/speechlm2/models.rst @@ -8,6 +8,30 @@ Core Model Architectures The collection includes the following core model architectures: + +DuplexEARTTS +^^^^^^^^^^^^ + +DuplexEARTTS is a streaming text-to-speech model designed for duplex speech-to-speech systems. It focuses on low-latency, fully streamable speech generation by converting text tokens into audio representations in real time. + +The architecture is based on the Streaming TTS model proposed in `Audio Flamingo 3`_, with several extensions for duplex interaction: + +* **Gated fusion of text and audio representations**: (`GatedProjectedSumRMSNorm`), enabling better multimodal integration. +* **Subword-aware embeddings**: (`SubwordFlagEmbedding`) to improve pronunciation for words composed of multiple text tokens. +* **Custom BOS/EOS embeddings**: (`BOSEOSEmbedding`) for interruption-aware, multi-turn duplex generation. + + +Key components: + +* **RVQVAEModel**: An RVQ-based neural audio codec that compresses speech into discrete acoustic tokens using a convolutional encoder and reconstructs high-quality audio via a convolutional decoder. +* **RVQEARTTSModel**: A streaming speech generation model that predicts multiple RVQ codebooks in parallel using a Mixture-of-Gaussians (MoG) prediction head. It produces audio tokens autoregressively from text representations with minimal latency. + +DuplexEARTTS is particularly useful for: +* Duplex speech-to-speech systems requiring interruption-aware synthesis. +* Low-latency text-to-speech generation. +* Real-time conversational agents with streamed audio output. + + SALM (Speech-Augmented Language Model) ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ @@ -78,6 +102,7 @@ Speech generation components convert text or token representations back into spe 1. **TransformerARSpeechDecoder**: An autoregressive transformer-based speech decoder 2. **Audio Codec Integration**: Works with audio codecs to generate natural speech from discrete tokens +3. **DuplexEARTTS**: A ready-to-use duplex text-to-speech model that supports user interruption via a special text interruption token. The model integrates an RVQ-based audio codec with a streaming speech generation module to enable low-latency, real-time synthesis. Implementation Details -------------------- @@ -200,6 +225,9 @@ All models in the speechlm2 collection can be instantiated from pretrained check # Load DuplexS2SSpeechDecoderModel decoder_model = slm.models.DuplexS2SSpeechDecoderModel.from_pretrained("path/to/checkpoint") + # Load DuplexEARTTS + decoder_model = slm.models.DuplexEARTTS.from_pretrained("path/to/checkpoint") + Model Configuration ----------------- diff --git a/examples/speechlm2/conf/duplex_eartts.yaml b/examples/speechlm2/conf/duplex_eartts.yaml new file mode 100644 index 000000000000..a2765fa36364 --- /dev/null +++ b/examples/speechlm2/conf/duplex_eartts.yaml @@ -0,0 +1,200 @@ +model: + pretrained_lm_name: "nvidia/NVIDIA-Nemotron-Nano-9B-v2" + pretrained_audio_codec: ??? # to be released + pretrained_tts_model: null + scoring_asr: stt_en_fastconformer_transducer_large # used only in validation/evaluation + + # Regexp (re.compile) patterns matching parameters to be frozen. + freeze_params: + - "^audio_codec\\..+$" # Keep audio codec frozen as it only provides supervision for training. + - "^embed_tokens\\..+$" # Keep embed_tokens frozen as done in eartts + + prevent_freeze_params: [] # Use to make specific submodules trainable; overrides freeze_params + + # set custom text eos/bos/pad tokens + bos_token: "" + eos_token: "" + pad_token: "" + + # inference params + inference_guidance_scale: 0.5 + inference_noise_scale: 0.8 + inference_top_p_or_k: 0.8 + inference_guidance_enabled: true + + + optimizer: + _target_: torch.optim.AdamW + lr: 4e-05 + betas: [0.9, 0.98] + weight_decay: 0 + foreach: true # set to false if having issues with tensor-parallelism + + lr_scheduler: + _target_: nemo.core.optim.lr_scheduler.InverseSquareRootAnnealing + warmup_steps: 2500 + min_lr: 1e-6 + max_steps: ${trainer.max_steps} + + codec_config: + latent_size: 512 + n_fft: 16 + hop_length: 4 + base_hidden_size: 384 + channel_mult: + - 1 + - 2 + - 4 + rates: + - 7 + - 7 + - 9 + num_blocks: 3 + kernel_size: 7 + groups: 1 + codebook_size: 1024 + num_quantizers: 31 + wav_to_token_ratio: 1764 + + tts_config: + # extra configs added + use_gated_fusion_for_text_audio: true + disable_eos_prediction: true # disable eos prediction + use_bos_eos_emb: true + use_subword_flag_emb: true + num_delay_speech_tokens: 2 + # EAR-TTS configs + backbone_type: gemma3_text + backbone_model_class: null + backbone_config_class: null + backbone_config: + hidden_size: 1152 + intermediate_size: 4608 + num_hidden_layers: 28 + num_attention_heads: 16 + num_key_value_heads: 16 + head_dim: 72 + attention_dropout: 0.1 + use_cache: false + latent_size: 512 + codebook_size: 1024 + num_quantizers: 31 + context_hidden_size: null + cas_config: + backbone_type: t5gemma + backbone_model_class: null + backbone_config_class: null + backbone_config: + is_encoder_decoder: false + encoder: + hidden_size: 1152 + intermediate_size: 4608 + num_hidden_layers: 1 + num_attention_heads: 16 + num_key_value_heads: 16 + head_dim: 72 + use_cache: false + attention_dropout: 0.1 + mog_head_config: + intermediate_size: 4608 + num_layers: 3 + low_rank: 64 + num_predictions: 1024 + min_log_std: -4.0 + eps: 1e-06 + p_uncond: 0.1 + label_smoothing: 0.01 + max_training_rate: 0.8 + quantizer_dropout: 0.5 + random_target_masking: false + exponent: 3.0 +trainer: + devices: -1 + accelerator: gpu + num_nodes: 1 + precision: 32 + logger: False # logger provided by exp_manager + enable_checkpointing: False + use_distributed_sampler: False + max_steps: 1000000 + val_check_interval: 2000 + limit_train_batches: ${trainer.val_check_interval} # an "epoch" + limit_val_batches: 2 + log_every_n_steps: 20 + num_sanity_val_steps: 0 + gradient_clip_val: 1.0 + accumulate_grad_batches: 1 + strategy: + _target_: lightning.pytorch.strategies.DDPStrategy + gradient_as_bucket_view: true + find_unused_parameters: true + +data: + # data loader configs + add_text_bos_and_eos_in_each_turn: true + add_audio_prompt_after_description: true + audio_prompt_duration: 3.0 + frame_length: 0.08 + source_sample_rate: 22050 + target_sample_rate: 22050 + input_roles: ["user", "User"] + output_roles: ["agent", "Assistant", "assistant","Agent"] + + train_ds: + sample_rate: ${data.target_sample_rate} + input_cfg: + - type: lhotse_shar + shar_path: ??? + seed: 42 + shard_seed: "randomized" + num_workers: 2 + batch_size: 4 + # Optional bucketing: + # batch_size: null + # batch_duration: 100 + # bucket_duration_bins: [8.94766,10.1551,11.64118,19.30376,42.85] + # use_bucketing: true + # num_buckets: 5 + # bucket_buffer_size: 5000 + + validation_ds: + # The entries under 'datasets' are a list of separate dataloaders. + # The structure is : {} + # They inherit all settings from validation_ds, but can individually override them. + datasets: + val_set_0: # rename to your dataset name, add more as needed + shar_path: ??? + sample_rate: ${data.target_sample_rate} + batch_size: 1 + seed: 42 + shard_seed: "randomized" + +exp_manager: + exp_dir: null + explicit_log_dir: duplex_eartts_results/ + name: eartts + create_tensorboard_logger: false + create_checkpoint_callback: true + use_datetime_version: true + max_time_per_run: 00:03:50:00 + + resume_from_checkpoint: null # The path to a checkpoint file to continue the training, restores the whole state including the epoch, step, LR schedulers, apex, etc. + # you need to set these two to True to continue the training + resume_if_exists: true + resume_ignore_no_checkpoint: true + + # You may use this section to create a W&B logger + create_wandb_logger: false + wandb_logger_kwargs: + name: development-run + project: duplex_eartts + resume: true + + checkpoint_callback_params: + filename: "{step}" + monitor: val_asr_bleu + mode: max + every_n_train_steps: null + every_n_epochs: 1 + save_top_k: 1 + always_save_nemo: false diff --git a/examples/speechlm2/duplex_eartts_eval.py b/examples/speechlm2/duplex_eartts_eval.py new file mode 100644 index 000000000000..972b47c15d20 --- /dev/null +++ b/examples/speechlm2/duplex_eartts_eval.py @@ -0,0 +1,101 @@ +# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +Evaluation script for Duplex EARTTS models. + +This script computes standard speech evaluation metrics for a given Duplex +EARTTS checkpoint, including Word Error Rate (WER), Character Error Rate (CER), +speaker encoder cosine similarity (SECS), and ASR BLEU score. + +The configuration file must define a valid ``validation_ds`` based on a Lhotse +dataset using one of the following dataset formats: +- Duplex S2S standard format +- ``s2s_duplex_overlap_as_s2s_duplex`` +- ``lhotse_magpietts_data_as_continuation`` + +During evaluation, the script saves generated audio samples to +``exp_manager.explicit_log_dir`` as specified in the configuration. For each +utterance, the following audio files may be produced: + +- Autoregressive inference output (``*.wav``) +- Teacher-forced output (``*_tf.wav``) +- Ground-truth reference audio (``*_gt.wav``) + +Args: + config-path (str): Path to the directory containing the YAML configuration file. + config-name (str): Name of the YAML configuration file. + checkpoint_path (str): Path to the Duplex EARTTS checkpoint file. + +Usage: + python duplex_eartts_eval.py \ + --config-path=conf/ \ + --config-name=duplex_eartts.yaml \ + ++checkpoint_path=duplex_eartts_results/duplex_eartts/model.ckpt +""" + +import os + +import torch +from lightning.pytorch import Trainer +from omegaconf import OmegaConf + +from nemo.collections.speechlm2 import DataModule, DuplexEARTTSDataset + +from nemo.collections.speechlm2.models.duplex_ear_tts import DuplexEARTTS +from nemo.core.config import hydra_runner +from nemo.utils.exp_manager import exp_manager +from nemo.utils.trainer_utils import resolve_trainer_cfg + +torch.cuda.set_device(int(os.environ["LOCAL_RANK"])) + + +@hydra_runner(config_path="conf", config_name="duplex_eartts") +def inference(cfg): + OmegaConf.resolve(cfg) + torch.distributed.init_process_group(backend="nccl") + torch.set_float32_matmul_precision("medium") + torch.backends.cudnn.allow_tf32 = True + trainer = Trainer(**resolve_trainer_cfg(cfg.trainer)) + log_dir = exp_manager(trainer, cfg.get("exp_manager", None)) + OmegaConf.save(cfg, log_dir / "exp_config.yaml") + + with trainer.init_module(): + if cfg.get("checkpoint_path", None): + model = DuplexEARTTS.load_from_checkpoint( + cfg.checkpoint_path, + cfg=OmegaConf.to_container(cfg, resolve=True), + ) + else: + raise ValueError("For evaluation, you must provide `cfg.checkpoint_path`.") + + dataset = DuplexEARTTSDataset( + tokenizer=model.tokenizer, + frame_length=cfg.data.frame_length, + source_sample_rate=cfg.data.source_sample_rate, + target_sample_rate=cfg.data.target_sample_rate, + input_roles=cfg.data.input_roles, + output_roles=cfg.data.output_roles, + add_text_bos_and_eos_in_each_turn=cfg.data.get("add_text_bos_and_eos_in_each_turn", True), + add_audio_prompt=cfg.data.get("add_audio_prompt", True), + audio_prompt_duration=cfg.data.get("audio_prompt_duration", 3), + num_delay_speech_tokens=cfg.model.get("num_delay_speech_tokens", 2), + ) + datamodule = DataModule(cfg.data, tokenizer=model.tokenizer, dataset=dataset) + + trainer.validate(model, datamodule) + + +if __name__ == "__main__": + inference() diff --git a/examples/speechlm2/duplex_eartts_train.py b/examples/speechlm2/duplex_eartts_train.py new file mode 100644 index 000000000000..33774cc3ac48 --- /dev/null +++ b/examples/speechlm2/duplex_eartts_train.py @@ -0,0 +1,71 @@ +# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import os + +import torch +from lightning.pytorch import Trainer +from omegaconf import OmegaConf + +from nemo.collections.speechlm2 import DataModule, DuplexEARTTSDataset +from nemo.collections.speechlm2.models.duplex_ear_tts import DuplexEARTTS +from nemo.collections.speechlm2.parts.pretrained import load_checkpoint, set_model_dict_for_partial_init +from nemo.core.config import hydra_runner +from nemo.utils.exp_manager import exp_manager +from nemo.utils.trainer_utils import resolve_trainer_cfg + +torch.cuda.set_device(int(os.environ["LOCAL_RANK"])) + + +@hydra_runner(config_path="conf", config_name="duplex_eartts") +def train(cfg): + OmegaConf.resolve(cfg) + torch.distributed.init_process_group(backend="nccl") + torch.set_float32_matmul_precision("medium") + torch.backends.cudnn.allow_tf32 = True + trainer = Trainer(**resolve_trainer_cfg(cfg.trainer)) + log_dir = exp_manager(trainer, cfg.get("exp_manager", None)) + OmegaConf.save(cfg, log_dir / "exp_config.yaml") + + with trainer.init_module(): + model = DuplexEARTTS(OmegaConf.to_container(cfg, resolve=True)) + + # load pretrained tts checkpoint if available + if model.cfg.get("pretrained_tts_model", None): + checkpoint_state = load_checkpoint(model.cfg.pretrained_tts_model) + checkpoint_state = set_model_dict_for_partial_init(checkpoint_state, model.tts_model.state_dict()) + model.tts_model.load_state_dict(checkpoint_state, strict=True) + + # load pretrained checkpoint and rescale the weights if needed + if model.cfg.get("pretrained_model", None): + model.restore_from_pretrained_checkpoint(model.cfg.pretrained_model) + + dataset = DuplexEARTTSDataset( + tokenizer=model.tokenizer, + frame_length=cfg.data.frame_length, + source_sample_rate=cfg.data.source_sample_rate, + target_sample_rate=cfg.data.target_sample_rate, + input_roles=cfg.data.input_roles, + output_roles=cfg.data.output_roles, + add_text_bos_and_eos_in_each_turn=cfg.data.get("add_text_bos_and_eos_in_each_turn", True), + add_audio_prompt=cfg.data.get("add_audio_prompt", True), + audio_prompt_duration=cfg.data.get("audio_prompt_duration", 3), + num_delay_speech_tokens=cfg.model.get("num_delay_speech_tokens", 2), + ) + datamodule = DataModule(cfg.data, tokenizer=model.tokenizer, dataset=dataset) + + trainer.fit(model, datamodule) + + +if __name__ == "__main__": + train() diff --git a/nemo/collections/common/data/lhotse/cutset.py b/nemo/collections/common/data/lhotse/cutset.py index 4ba644c97002..b8929c82b366 100644 --- a/nemo/collections/common/data/lhotse/cutset.py +++ b/nemo/collections/common/data/lhotse/cutset.py @@ -12,7 +12,9 @@ # See the License for the specific language governing permissions and # limitations under the License. +import io import logging +import random import re import warnings from functools import partial @@ -20,11 +22,14 @@ from pathlib import Path from typing import KeysView, Mapping, Sequence, Tuple, Union +import numpy as np import omegaconf -from lhotse import CutSet, Features, Recording +import soundfile as sf +from lhotse import CutSet, Features, MonoCut, Recording, SupervisionSegment from lhotse.array import Array, TemporalArray from lhotse.cut import Cut, MixedCut, PaddingCut from lhotse.serialization import load_yaml +from lhotse.utils import fastcopy from omegaconf import DictConfig, ListConfig, OmegaConf from nemo.collections.common.data.lhotse.nemo_adapters import ( @@ -535,6 +540,201 @@ def cut_to_conversation( ) +@data_type_parser(["s2s_duplex_overlap_as_s2s_duplex"]) +def read_s2s_duplex_overlap_as_s2s_duplex(config) -> Tuple[CutSet, bool]: + """ + Convert a CutSet with overlapping agent/user segments into a standard S2S duplex format. + + Args: + config: Dictionary containing parser options: + - move_agent_text_back_by (float): Time offset to shift agent text back. + - filter_samples_starting_with_agent (bool): Whether to remove samples starting with agent. + - agent_roles (List[str]): Roles considered as agent. + + Returns: + Tuple[CutSet, bool]: Converted cuts and a flag indicating if the data was tarred. + """ + move_agent_text_back_by = config.get("move_agent_text_back_by", 0) + filter_samples_starting_with_agent = config.get("filter_samples_starting_with_agent", False) + agent_roles = config.get("agent_roles", ["agent", "Assistant", "assistant"]) + + cuts, is_tarred = read_cutset_from_config(config) + + def filter_cuts_starting_with_agent_fn(cuts: CutSet, agent_roles: Tuple[str, ...]) -> CutSet: + """Remove cuts where the first supervision belongs to an agent role.""" + + def _filter_fn(cut: Cut) -> bool: + if not cut.supervisions: + return False + cut.supervisions = sorted(cut.supervisions, key=lambda s: s.start) + return cut.supervisions[0].speaker not in agent_roles + + return cuts.filter(_filter_fn) + + def convert_overlap_cut_fn(cut: Cut) -> Cut: + """Convert agent/user overlapping segments into sequential SupervisionSegments.""" + agent_segments = [ + SupervisionSegment( + id=cut.id, + recording_id=cut.id, + start=seg["start"] - move_agent_text_back_by, + duration=seg["end"] - seg["start"] + move_agent_text_back_by, + text=seg["text"], + speaker="agent", + ) + for seg in cut.agent_segments + ] + + user_segments = [ + SupervisionSegment( + id=cut.id, + recording_id=cut.id, + start=seg["start"], + duration=seg["end"] - seg["start"], + text=seg["text"], + speaker="user", + ) + for seg in cut.user_segments + ] + + cut.supervisions = sorted(agent_segments + user_segments, key=lambda s: s.start) + cut.formatter = "s2s_duplex_overlap_as_s2s_duplex" + return cut + + cuts = cuts.map(convert_overlap_cut_fn) + if filter_samples_starting_with_agent: + cuts = filter_cuts_starting_with_agent_fn(cuts, tuple(agent_roles)) + + return cuts, is_tarred + + +@data_type_parser(["lhotse_magpietts_data_as_continuation"]) +def read_lhotse_magpietts_data_as_s2s_duplex(config) -> Tuple[CutSet, bool]: + """ + Convert MagpieTTS dataset cuts into the Duplex S2S format, with optional + `context_audio` that can be used as a speaker reference. + + Args: + config: Dictionary containing parser options: + - add_extra_end_silence (bool): Whether to add extra silence at the end. + - extra_end_silence_range (List[float]): Range of extra silence duration. + - max_cer (float): Maximum allowed character error rate. + - min_context_speaker_similarity (float): Minimum similarity score. + - target_speaker (str, optional): Target speaker filter. + - sample_rate (int): Audio sample rate for resampling. + + Returns: + Tuple[CutSet, bool]: Converted cuts and a flag indicating if data was tarred. + """ + cuts, is_tarred = read_cutset_from_config(config) + + add_extra_end_sil = config.get("add_extra_end_silence", False) + extra_end_silence_range = config.get("extra_end_silence_range", [0.5, 6.0]) + sample_rate = config.get("sample_rate", 22050) + + max_cer = config.get("max_cer", 0.03) + min_context_speaker_similarity = config.get("min_context_speaker_similarity", 0.6) + target_speaker = config.get("target_speaker", None) + keep_flag = "pass" + + def create_recording_from_array(samples: np.ndarray, sampling_rate: int, recording_id: str) -> Recording: + """Convert a numpy array into a Lhotse Recording object.""" + with io.BytesIO() as buffer: + sf.write(buffer, samples.T, samplerate=sampling_rate, format='WAV') + buffer.seek(0) + return Recording.from_bytes(buffer.read(), recording_id=recording_id) + + def convert_cut_fn(cut: Cut) -> Cut: + """Convert a single cut into the continuation format.""" + orig_agent_sup = fastcopy(cut.supervisions[0]) + target_audio_orig_dur = cut.target_audio.duration + + # Resample audios + cut.target_audio = cut.target_audio.resample(sample_rate) + cut.context_audio = cut.context_audio.resample(sample_rate) + total_duration = cut.target_audio.duration + + # Prepare MonoCuts + cut_target = MonoCut( + id=f"{cut.id}_target", + start=0.0, + duration=total_duration, + channel=0, + recording=cut.target_audio, + supervisions=[], + ) + + zero_audio = np.zeros((1, int(total_duration * sample_rate)), dtype=np.float32) + source_recording = create_recording_from_array(zero_audio, sample_rate, recording_id=f"{cut.id}_source") + + cut_source = MonoCut( + id=f"{cut.id}_source", + start=0.0, + duration=total_duration, + channel=0, + recording=source_recording, + supervisions=[], + ) + + # Save to memory + cut_source = cut_source.move_to_memory(audio_format='wav') + cut_target = cut_target.move_to_memory(audio_format='wav') + + # Create user and agent supervisions + user_sup = fastcopy(orig_agent_sup, start=0.0, duration=0.08, speaker="user", text="dummy text") + agent_sup = fastcopy(orig_agent_sup, start=0.0, duration=target_audio_orig_dur - 0.08, speaker="agent") + + # Optionally add extra silence + if add_extra_end_sil: + sil_duration = random.uniform(*extra_end_silence_range) + cut_target = cut_target.pad(duration=total_duration + sil_duration, direction="right") + cut_source = cut_source.pad(duration=total_duration + sil_duration, direction="right") + cut_source = cut_source.to_mono().move_to_memory(audio_format='wav') + cut_target = cut_target.to_mono().move_to_memory(audio_format='wav') + agent_sup.duration += sil_duration + 1.0 + user_sup.duration += sil_duration + + # Assemble final cut + cut_source.supervisions = [user_sup, agent_sup] + cut_source.target_audio = cut_target.recording + cut_source.duration = cut_target.duration + cut_source.context_audio = cut.context_audio + cut_source.formatter = "lhotse_magpietts_data_as_continuation" + + return cut_source + + # Filters + def filter_cer_fn(cut: Cut) -> bool: + return ( + len(cut.supervisions) == 0 + or not cut.supervisions[0].has_custom("cer") + or cut.supervisions[0].cer <= max_cer + ) + + def filter_val_flag_fn(cut: Cut) -> bool: + return not cut.has_custom("validation_status") or cut.validation_status == keep_flag + + def filter_secs_fn(cut: Cut) -> bool: + return ( + len(cut.supervisions) == 0 + or not cut.supervisions[0].has_custom("context_speaker_similarity") + or cut.supervisions[0].context_speaker_similarity >= min_context_speaker_similarity + ) + + def filter_target_speaker_fn(cut: Cut) -> bool: + return len(cut.supervisions) == 0 or target_speaker is None or target_speaker in cut.supervisions[0].speaker + + # Apply filters + cuts = ( + cuts.filter(filter_cer_fn).filter(filter_val_flag_fn).filter(filter_secs_fn).filter(filter_target_speaker_fn) + ) + + # Convert cuts + cuts = cuts.map(convert_cut_fn) + + return cuts, is_tarred + + @data_type_parser(["lhotse_as_conversation"]) def read_lhotse_as_conversation(config) -> tuple[CutSet, bool]: cuts, is_tarred = read_cutset_from_config(config) diff --git a/nemo/collections/speechlm2/__init__.py b/nemo/collections/speechlm2/__init__.py index b638dd13f08d..bd9267c08fe3 100644 --- a/nemo/collections/speechlm2/__init__.py +++ b/nemo/collections/speechlm2/__init__.py @@ -11,12 +11,13 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from .data import DataModule, DuplexS2SDataset, SALMDataset +from .data import DataModule, DuplexEARTTSDataset, DuplexS2SDataset, SALMDataset from .models import SALM, DuplexS2SModel, DuplexS2SSpeechDecoderModel __all__ = [ 'DataModule', 'DuplexS2SDataset', + 'DuplexEARTTSDataset', 'SALMDataset', 'DuplexS2SModel', 'DuplexS2SSpeechDecoderModel', diff --git a/nemo/collections/speechlm2/data/__init__.py b/nemo/collections/speechlm2/data/__init__.py index f698b6bcf12a..802c199462d7 100644 --- a/nemo/collections/speechlm2/data/__init__.py +++ b/nemo/collections/speechlm2/data/__init__.py @@ -12,11 +12,13 @@ # See the License for the specific language governing permissions and # limitations under the License. from .datamodule import DataModule +from .duplex_ear_tts_dataset import DuplexEARTTSDataset from .s2s_dataset import DuplexS2SDataset from .salm_dataset import SALMDataset __all__ = [ 'DataModule', 'DuplexS2SDataset', + 'DuplexEARTTSDataset', 'SALMDataset', ] diff --git a/nemo/collections/speechlm2/data/duplex_ear_tts_dataset.py b/nemo/collections/speechlm2/data/duplex_ear_tts_dataset.py new file mode 100644 index 000000000000..530d2df27563 --- /dev/null +++ b/nemo/collections/speechlm2/data/duplex_ear_tts_dataset.py @@ -0,0 +1,866 @@ +# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import random +import re + +import torch +import torch.nn.functional as F +import torch.utils.data +from lhotse import CutSet, Seconds, compute_num_frames +from lhotse.cut import Cut +from lhotse.dataset.collation import collate_audio, collate_vectors +from lhotse.utils import ifnone + +from nemo.collections.common.tokenizers import TokenizerSpec +from nemo.collections.speechlm2.data.utils import get_pad_id +from nemo.collections.tts.parts.utils.helpers import get_mask_from_lengths +from nemo.utils import logging + + +class DuplexEARTTSDataset(torch.utils.data.Dataset): + """ + A dataset for duplex speech-to-speech models that handles bidirectional conversations. + + This dataset processes Lhotse CutSet objects containing recordings with supervision segments + from different speakers (roles). It creates aligned representations of audio and text for + both source (input) and target (output) channels, preserving temporal alignment between + audio frames and text tokens. + + Args: + tokenizer (TokenizerSpec): + Tokenizer for converting text to token IDs and vice versa. Must support BOS and EOS tokens. + It's expected to support PAD token as well, otherwise we will use 0 as the pad token + and emit a warning. + + frame_length (Seconds): + Duration of a single frame in seconds. Used to calculate frame positions for token alignment. + + source_sample_rate (int): + Sample rate for source audio (e.g., 16000 Hz). + + target_sample_rate (int): + Sample rate for target audio (e.g., 22050 Hz). + + input_roles (list[str], optional): + List of speaker roles (cut.supervisions[:].speaker) to consider as inputs. Defaults to ["user"]. + + output_roles (list[str], optional): + List of speaker roles (cut.supervisions[:].speaker) to consider as outputs. Defaults to ["agent"]. + + p_drop_description (float, optional): + Probability of dropping text descriptions. Default: `0.0`. + + add_text_bos_and_eos_in_each_turn (bool, optional): + If True, each conversational turn from any speaker is explicitly delimited + with BOS and EOS tokens in the text stream. + Default: `True`. + + add_audio_prompt (bool, optional): + If True, an optional audio/speaker prompt is appended after the description. + Default: `True`. + + audio_prompt_duration (float, optional): + Duration (in seconds) of the audio prompt appended when + `add_audio_prompt=True`. Default: `3.0`. + + num_delay_speech_tokens (int, optional): + Number of PAD tokens to insert before speech tokens to artificially + delay the start of speech. Default: `0`. + + Returns: + A dictionary with the following keys: + - sample_id: List of sample IDs for each cut in the batch [B] + + - non_prompt_mask: Bool tensor [B, T] marking positions that are not part of the prompt + - prompt_lens: Tensor of description + audio prompt lengths [B] + + - aligned_attention_mask: Bool tensor [B, T] used by alignment-aware transformer models + - aligned_position_ids: Tensor of position indices aligned to audio frames [B, T] + + - source_audio: Tensor of source waveform samples [B, T] + - source_audio_lens: Tensor of source audio lengths [B] + + - target_audio: Tensor of target waveform samples [B, T] + - target_audio_lens: Tensor of target audio lengths [B] + + - target_text_tokens: Tensor of frame-aligned input text tokens [B, T], + including BOS/EOS/PAD when enabled + - target_token_lens: Tensor of target token sequence lengths [B] + + - source_tokens: Tensor of frame-aligned source text tokens [B, T], + including BOS/EOS/PAD + - source_token_lens: Tensor of source token sequence lengths [B] + + - target_texts: List of full target texts joined from output_roles supervisions [B] + + - audio_prompt: Tensor of optional speaker reference waveform samples [B, T] + - audio_prompt_lens: Tensor of speaker reference audio lengths [B] + + - formatter: List indicating the formatter to use for each cut (default "s2s_duplex") [B] + + Notes: + - The dataset ensures frame-level alignment between audio and text by inserting tokens at + specific frame positions based on the timing of supervision segments. + - PAD tokens (typically 0) are used to fill gaps where there's no text. + - BOS tokens mark the beginning of each speech segment. + - EOS tokens mark the end of each speech segment. + - Text tokens from each speaker are placed at frame positions corresponding to their + timestamp in the original recording, preserving the temporal relationship. + This is a segment-level alignment only, not word-level alignment. + """ + + def __init__( + self, + tokenizer, + frame_length: Seconds, + source_sample_rate: int, + target_sample_rate: int, + input_roles: list[str] = None, + output_roles: list[str] = None, + p_drop_description: float = 0.0, + add_text_bos_and_eos_in_each_turn: bool = True, + add_audio_prompt: bool = True, + audio_prompt_duration: float = 3.0, + num_delay_speech_tokens: int = 0, + ): + self.tokenizer = tokenizer + self.frame_length = frame_length + self.source_sample_rate = source_sample_rate + self.target_sample_rate = target_sample_rate + self.input_roles = set(ifnone(input_roles, ["user"])) + self.output_roles = set(ifnone(output_roles, ["agent"])) + self.p_drop_description = p_drop_description + self.add_text_bos_and_eos_in_each_turn = add_text_bos_and_eos_in_each_turn + self.add_audio_prompt = add_audio_prompt + self.audio_prompt_duration = audio_prompt_duration + self.num_delay_speech_tokens = num_delay_speech_tokens + + # compute source and target samples_per_frame + self.source_samples_per_frame = int(self.source_sample_rate * self.frame_length) + self.target_samples_per_frame = int(self.target_sample_rate * self.frame_length) + + assert tokenizer.bos is not None, "BOS support in the tokenizer is required for S2S models." + assert tokenizer.eos is not None, "EOS support in the tokenizer is required for S2S models." + + def __getitem__(self, cuts: CutSet) -> dict: + cuts = cuts.transform_text(_strip_timestamps) + source_audio, source_audio_lens = collate_audio(cuts.resample(self.source_sample_rate)) + target_audio, target_audio_lens = collate_audio( + cuts.resample(self.target_sample_rate), recording_field="target_audio" + ) + target_text_tokens, target_token_lens = collate_token_channel( + cuts, + self.tokenizer, + self.frame_length, + roles=self.output_roles, + add_text_bos_and_eos_in_each_turn=self.add_text_bos_and_eos_in_each_turn, + ) + source_tokens, source_token_lens = collate_token_channel( + cuts, + self.tokenizer, + self.frame_length, + roles=self.input_roles, + add_text_bos_and_eos_in_each_turn=self.add_text_bos_and_eos_in_each_turn, + ) + + audio_prompt, audio_prompt_lens = get_audio_prompt( + cuts, self.target_sample_rate, roles=self.output_roles, recording_field="target_audio" + ) + + # add speech channel delay if needed + if self.num_delay_speech_tokens: + source_audio, source_audio_lens, target_audio, target_audio_lens = add_speech_delay( + source_audio, + source_audio_lens, + target_audio, + target_audio_lens, + self.num_delay_speech_tokens, + self.target_samples_per_frame, + self.source_samples_per_frame, + ) + + # add audio prompt if needed + ( + target_text_tokens, + target_token_lens, + source_tokens, + source_token_lens, + source_audio, + source_audio_lens, + target_audio, + target_audio_lens, + prompt_lens, + ) = self.maybe_add_audio_prompt( + target_text_tokens, + target_token_lens, + source_tokens, + source_token_lens, + target_audio, + target_audio_lens, + source_audio, + source_audio_lens, + audio_prompt, + audio_prompt_lens, + ) + + # create non_prompt_mask that should mask desc plus audio prompt if used + non_prompt_mask = get_mask_from_lengths(target_token_lens) + for i, frame in enumerate(prompt_lens): + non_prompt_mask[i, : frame - 1] = 0.0 + + max_len = max(target_token_lens) + + # Segment IDs per sequence (padded) + aligned_segment_ids = torch.stack( + [ + torch.nn.functional.pad(torch.full((seq_len,), i), (0, max_len - seq_len), value=-1) # -1 for padding + for i, seq_len in enumerate(target_token_lens) + ], + dim=0, + ) # [B, max_len] + + # Attention mask: same-segment & causal + aligned_attention_mask = ( + aligned_segment_ids.unsqueeze(-2) == aligned_segment_ids.unsqueeze(-1) + ) & ( # [B, max_len, max_len] + torch.arange(max_len).unsqueeze(0).unsqueeze(1) <= torch.arange(max_len).unsqueeze(0).unsqueeze(-1) + ) # causal tril + + aligned_attention_mask = aligned_attention_mask.unsqueeze(1) # [B, 1, max_len, max_len] + + # create position IDs from the aligned length + aligned_position_ids = torch.stack( + [ + torch.nn.functional.pad( + torch.arange(seq_len), (0, max(target_token_lens) - seq_len), value=0 + ) # value=0 is safe for padding + for seq_len in target_token_lens + ], + dim=0, + ) + + return { + "sample_id": [str(cut.id) for cut in cuts], + "non_prompt_mask": non_prompt_mask.bool(), + "prompt_lens": prompt_lens, + "aligned_attention_mask": aligned_attention_mask.bool(), + "aligned_position_ids": aligned_position_ids, + "source_audio": source_audio, + "source_audio_lens": source_audio_lens, + "target_audio": target_audio, + "target_audio_lens": target_audio_lens, + "target_text_tokens": target_text_tokens, + "target_token_lens": target_token_lens, + "source_tokens": source_tokens, + "source_token_lens": source_token_lens, + "target_texts": [ + " ".join(s.text for s in cut.supervisions if s.speaker in self.output_roles) for cut in cuts + ], + "audio_prompt": audio_prompt, + "audio_prompt_lens": audio_prompt_lens, + "formatter": [getattr(cut, "formatter", "s2s_duplex") for cut in cuts], + } + + def maybe_add_audio_prompt( + self, + target_text_tokens: torch.Tensor, + target_token_lens: torch.Tensor, + source_tokens: torch.Tensor, + source_token_lens: torch.Tensor, + target_audio: torch.Tensor, + target_audio_lens: torch.Tensor, + source_audio: torch.Tensor, + source_audio_lens: torch.Tensor, + audio_prompt: torch.Tensor, + audio_prompt_lens: torch.Tensor, + ): + """ + Prepend an audio-based speaker prompt and aligned text tokens to the Duplex S2S inputs. + + This method optionally injects a speaker-reference audio prompt at the beginning of each + sample in the batch. The prompt is inserted in the target-audio channel and aligned text + padding is inserted into the text-token streams (input text tokens and source tokens). + + Args: + target_text_tokens (torch.Tensor): + Tensor of input text tokens with shape [B, T_text]. + dtype: torch.long. + + target_token_lens (torch.Tensor): + Lengths of target_text_tokens per batch element (before padding). shape [B]. + + source_tokens (torch.Tensor): + Source-side text tokens, shape [B, T_src_text], dtype torch.long. + + source_token_lens (torch.Tensor): + Source text token lengths per batch element, shape [B]. + + target_audio (torch.Tensor): + Target-side audio waveforms, shape [B, T_audio], dtype torch.float32/float16. + + target_audio_lens (torch.Tensor): + Target audio lengths per batch element, shape [B]. + + source_audio (torch.Tensor): + Source-side audio waveforms, shape [B, T_audio], dtype torch.float32/float16. + + source_audio_lens (torch.Tensor): + Source audio lengths per batch element, shape [B]. + + audio_prompt (torch.Tensor): + Audio prompt waveforms to sample from, shape [B, T_prompt_audio]. + + audio_prompt_lens (torch.Tensor): + Valid lengths for audio_prompt per batch element. + + Returns: + Tuple containing: + target_text_tokens (torch.Tensor): + Updated text tokens with prepended prompt-aligned tokens. Shape [B, T']. + + target_token_lens (torch.Tensor): + Updated token lengths per batch element. + + source_tokens (torch.Tensor): + Updated source text tokens with prompt padding included. Shape [B, T']. + + source_token_lens (torch.Tensor): + Updated source token lengths per batch element. + + source_audio (torch.Tensor): + Updated source audio with silence padding. Shape [B, T_audio']. + + source_audio_lens (torch.Tensor): + Updated source audio lengths. + + target_audio (torch.Tensor): + Updated target audio with prompt audio and silence padding. Shape [B, T_audio']. + + target_audio_lens (torch.Tensor): + Updated target audio lengths. + + prompt_lens (list[int]): + Length (in text-token units) of the prompt region per batch item. + """ + + text_pad_id = get_pad_id(self.tokenizer) + + target_text_tokens_ = [] + source_tokens_ = [] + source_audio_ = [] + target_audio_ = [] + prompt_lens = [] + + for i in range(target_text_tokens.size(0)): + first_text_frame = torch.tensor( + [self.tokenizer.eos], + dtype=torch.long, + device=target_text_tokens.device, + ) + + if self.add_audio_prompt: + # Compute audio prompt duration in samples (rounded to frame boundaries) + prompt_audio_size = int( + ((self.audio_prompt_duration * self.target_sample_rate) // self.target_samples_per_frame) + * self.target_samples_per_frame + ) + + prompt_audio = sample_audio_segments_repeat( + audio_prompt, audio_prompt_lens, prompt_audio_size, sample=True + ) + + # add silence at the end of the prompt + prompt_audio[:, -int(self.target_samples_per_frame * 2) :] = 0 + + # Number of text tokens to insert to align with prompt_audio frames + prompt_audio_text_pad_size = prompt_audio_size // self.target_samples_per_frame + + prompt_audio_text_pad = ( + torch.ones( + prompt_audio_text_pad_size, + device=target_text_tokens.device, + dtype=target_text_tokens.dtype, + ) + * text_pad_id + ) + prompt_audio_text_pad[-1] = self.tokenizer.eos + + new_target_text_tokens = torch.cat( + [ + first_text_frame.to(target_text_tokens.dtype), + prompt_audio_text_pad, + target_text_tokens[i], + ] + ) + target_text_tokens_.append(new_target_text_tokens) + target_token_lens[i] += len(first_text_frame) + prompt_audio_text_pad_size + + new_source_tokens = torch.cat([first_text_frame, prompt_audio_text_pad, source_tokens[i]]) + source_tokens_.append(new_source_tokens) + source_token_lens[i] += len(first_text_frame) + prompt_audio_text_pad_size + + # Silence in source audio during prompt processing + pad_size_src = (len(first_text_frame) * self.source_samples_per_frame) + prompt_audio.size(1) + pad_audio_src = torch.zeros( + pad_size_src, + device=source_audio.device, + dtype=source_audio.dtype, + ) + source_audio_.append(torch.cat([pad_audio_src, source_audio[i]])) + source_audio_lens[i] += pad_size_src + + # Insert prompt audio in the target channel + pad_size_tgt = len(first_text_frame) * self.target_samples_per_frame + pad_audio_tgt = torch.zeros( + pad_size_tgt, + device=target_audio.device, + dtype=target_audio.dtype, + ) + target_audio_.append(torch.cat([pad_audio_tgt, prompt_audio[i], target_audio[i]])) + target_audio_lens[i] += pad_size_tgt + prompt_audio.size(1) + + prompt_lens.append(len(first_text_frame) + prompt_audio_text_pad_size - 1) + + else: + # Add only a single text-frame (EOS) as prompt + target_text_tokens_.append(torch.cat([first_text_frame, target_text_tokens[i]])) + target_token_lens[i] += len(first_text_frame) + + source_tokens_.append(torch.cat([first_text_frame, source_tokens[i]])) + source_token_lens[i] += len(first_text_frame) + + pad_size_src = len(first_text_frame) * self.source_samples_per_frame + pad_audio_src = torch.zeros( + pad_size_src, + device=source_audio.device, + dtype=source_audio.dtype, + ) + source_audio_.append(torch.cat([pad_audio_src, source_audio[i]])) + source_audio_lens[i] += pad_size_src + + pad_size_tgt = len(first_text_frame) * self.target_samples_per_frame + pad_audio_tgt = torch.zeros( + pad_size_tgt, + device=target_audio.device, + dtype=target_audio.dtype, + ) + target_audio_.append(torch.cat([pad_audio_tgt, target_audio[i]])) + target_audio_lens[i] += pad_size_tgt + + prompt_lens.append(len(first_text_frame)) + + target_text_tokens = collate_vectors(target_text_tokens_, padding_value=text_pad_id) + source_tokens = collate_vectors(source_tokens_, padding_value=text_pad_id) + source_audio = collate_vectors(source_audio_, padding_value=0) + target_audio = collate_vectors(target_audio_, padding_value=0) + + return ( + target_text_tokens, + target_token_lens, + source_tokens, + source_token_lens, + source_audio, + source_audio_lens, + target_audio, + target_audio_lens, + prompt_lens, + ) + + +def add_speech_delay( + source_audio: torch.Tensor, + source_audio_lens: torch.Tensor, + target_audio: torch.Tensor, + target_audio_lens: torch.Tensor, + num_delay_speech_tokens: int, + target_samples_per_frame: int, + source_samples_per_frame: int, +): + """ + Apply a speech delay by padding audio waveforms based on the number of delay speech tokens. + + Behavior: + - Target audio is *left padded* to force the model to predict initial silence. + - Source audio is *right padded* to maintain size consistency for attention alignment. + + Args: + source_audio (FloatTensor [B, T_src]): + Source/input audio waveforms. + + source_audio_lens (LongTensor [B]): + Lengths of source audio in samples. + + target_audio (FloatTensor [B, T_tgt]): + Target/output audio waveforms. + + target_audio_lens (LongTensor [B]): + Lengths of target audio in samples. + + num_delay_speech_tokens (int): + Number of delay tokens inserted on the text side. + + target_samples_per_frame (int): + Number of audio samples per frame for target audio. + + source_samples_per_frame (int): + Number of audio samples per frame for source audio. + + Returns: + Tuple containing: + - source_audio (FloatTensor [B, T_src + pad]) + - source_audio_lens (LongTensor [B]) + - target_audio (FloatTensor [B, T_tgt + pad]) + - target_audio_lens (LongTensor [B]) + """ + # Compute target-side left padding for the delay + extra_target_samples = int(num_delay_speech_tokens * target_samples_per_frame) + + # Left-pad target audio: forces model to generate silence initially + target_audio = F.pad(target_audio, (extra_target_samples, 0)) + target_audio_lens = target_audio_lens + extra_target_samples + + # Compute source-side right padding to maintain alignment + extra_source_samples = int(num_delay_speech_tokens * source_samples_per_frame) + + # Right-pad source audio: avoids mismatch when aligning source/target frames + source_audio = F.pad(source_audio, (0, extra_source_samples)) + source_audio_lens = source_audio_lens + extra_source_samples + + return source_audio, source_audio_lens, target_audio, target_audio_lens + + +def get_audio_prompt( + cuts: CutSet, + target_sample_rate: int, + roles: set[str], + recording_field: str = "target_audio", +) -> tuple[torch.Tensor, torch.Tensor]: + """ + Retrieve an audio prompt for speaker conditioning. + + This function returns: + - context audio if available (per-cut), + - otherwise a random turn belonging to the desired speaker roles. + + Behavior: + 1. If `cut.context_audio` exists, use it as the reference. + 2. Otherwise, select a random audio turn from the same target-role speakers. + + Args: + cuts (CutSet): + Batch of cuts from which to extract reference audio. + + target_sample_rate (int): + Sample rate to which reference audio is resampled. + + roles (set[str]): + Set of speaker roles to sample from when selecting random turns. + + recording_field (str, optional): + Name of the audio field in the cut ("recording", "target_audio", etc.). + Used when sampling random reference turns. + + Returns: + Tuple containing: + - audio_prompt (FloatTensor [B, T]): + Padded batch of reference waveforms. + - audio_prompt_lens (LongTensor [B]): + Lengths of each reference waveform before padding. + """ + # use provided context audio directly + if hasattr(cuts[0], "context_audio"): + audio_prompt = [] + audio_prompt_lens = [] + + for cut in cuts: + ref_audio = cut.context_audio.resample(target_sample_rate).load_audio() + ref_audio = torch.tensor(ref_audio).float() # shape: [1, T] + ref_audio_len = ref_audio.shape[1] + + audio_prompt.append(ref_audio.squeeze(0)) # [T] + audio_prompt_lens.append(ref_audio_len) + + audio_prompt = collate_vectors(audio_prompt, padding_value=0).float() + audio_prompt_lens = torch.tensor(audio_prompt_lens).long() + + else: + # sample a reference turn from the target-role speakers + audio_prompt, audio_prompt_lens = collate_random_turn_audio( + cuts.resample(target_sample_rate), + roles=roles, + recording_field=recording_field, + ) + + return audio_prompt, audio_prompt_lens + + +def collate_random_turn_audio( + cuts: CutSet, + roles: set[str], + recording_field: str = "target_audio", +) -> tuple[torch.Tensor, torch.Tensor]: + """ + Sample and collate reference audio from random speaker turns. + + For each cut in the batch, this function: + - selects a random supervision belonging to one of the specified roles, + - extracts the corresponding audio segment, + - collates all segments into a padded batch. + + Args: + cuts (CutSet): + Batch of cuts to sample from. + + roles (set[str]): + Set of speaker roles to consider when selecting random turns. + + recording_field (str, optional): + Name of the audio field to load from the cut + (e.g., "recording", "target_audio"). + + Returns: + Tuple containing: + - audio (FloatTensor [B, T]): + Padded batch of sampled reference waveforms. + - audio_lens (LongTensor [B]): + Lengths of each waveform before padding. + """ + selected_turn_audios = [] + selected_turn_audios_lens = [] + for cut in cuts: + # Filter supervisions matching roles + matching_supervisions = [s for s in cut.supervisions if s.speaker in roles] + + # Randomly select one supervision + selected_supervision = random.choice(matching_supervisions) + + # Truncate audio according to supervision + truncated_audio = cut.truncate( + offset=max(0, selected_supervision.start), duration=selected_supervision.duration + ).load_custom(recording_field) + + selected_turn_audios.append(truncated_audio.squeeze(0)) + selected_turn_audios_lens.append(truncated_audio.shape[-1]) + + return collate_vectors(selected_turn_audios, padding_value=0), torch.tensor(selected_turn_audios_lens) + + +def collate_token_channel( + cuts: CutSet, + tokenizer: TokenizerSpec, + frame_length: Seconds, + roles: set[str], + add_text_bos_and_eos_in_each_turn: bool = True, +) -> tuple[torch.Tensor, torch.Tensor]: + """ + Build and collate token channels aligned to the audio frame grid. + + This function converts text supervisions into frame-level token + representations for each cut, pads them to a uniform length, and + returns both the padded tokens and their true lengths. + + Args: + cuts (CutSet): + Batch of cuts from which to extract token channels. + + tokenizer (TokenizerSpec): + Tokenizer used to convert text into token IDs. + + frame_length (Seconds): + Duration of a single audio frame, used to align text tokens + to audio frames. + + roles (set[str]): + Speaker roles whose text will be included in the token channel. + + add_text_bos_and_eos_in_each_turn (bool, optional): + Whether to insert BOS at the beginning and EOS at the end of + each speaking turn. + + Returns: + Tuple containing: + - tokens (LongTensor [B, T]): + Padded batch of frame-aligned token sequences. + - token_lens (LongTensor [B]): + Length of each sequence before padding. + """ + pad_id = get_pad_id(tokenizer) + tokens = [ + build_token_channel( + c, + tokenizer=tokenizer, + frame_length=frame_length, + roles=roles, + pad_id=pad_id, + add_text_bos_and_eos_in_each_turn=add_text_bos_and_eos_in_each_turn, + ) + for c in cuts + ] + token_lens = torch.tensor([len(tt) for tt in tokens]) + tokens = collate_vectors(tokens, padding_value=pad_id) + return tokens, token_lens + + +def build_token_channel( + cut: Cut, + tokenizer: TokenizerSpec, + frame_length: Seconds, + roles: set[str], + pad_id: int = -1, + add_text_bos_and_eos_in_each_turn: bool = True, +) -> torch.Tensor: + """ + Build a frame-aligned token sequence for a single cut. + + This function maps speaking turns into a token channel aligned to the + audio frame grid. Tokens are inserted at frame positions corresponding + to supervision start times, with optional BOS and EOS insertion. + + Any region not covered by text is filled with `pad_id`. + + Args: + cut (Cut): + Input cut containing audio and supervisions. + + tokenizer (TokenizerSpec): + Tokenizer used to encode text into tokens. + + frame_length (Seconds): + Duration of one frame, used to align text to the audio grid. + + roles (set[str]): + Speaker roles whose text should be included. + + pad_id (int, optional): + Token ID used for padding empty frames. + + add_text_bos_and_eos_in_each_turn (bool, optional): + Whether to insert BOS before and EOS after each supervision. + + Returns: + tokens (LongTensor [T]): + Frame-aligned token sequence for the cut. + """ + + diagnostic = f"Extra info: {cut.id=}" + if getattr(cut, "shard_origin", None) is not None: + diagnostic = f"{diagnostic} {cut.shard_origin=}" + + total = compute_num_frames(cut.duration, frame_length, cut.sampling_rate) + tokens = torch.ones(total, dtype=torch.long) * pad_id + for supervision in cut.supervisions: + if supervision.speaker in roles: + text = supervision.text + if add_text_bos_and_eos_in_each_turn: + text_ids = torch.as_tensor([tokenizer.bos] + tokenizer.text_to_ids(text)) + else: + text_ids = torch.as_tensor(tokenizer.text_to_ids(text)) + + # Determine the frame offset for the start of the supervision to insert the text tokens. + pos = compute_num_frames(supervision.start, frame_length, cut.sampling_rate) + if pos > len(tokens): + logging.warning( + f"Ill-constructed example: the beginning offset of a supervision {pos} is larger than the example's length {len(tokens)}. {diagnostic}" + ) + continue + + # Determine the frame offset for the last non-EOS text token to form a valid range for insertion; + # Note that EOS will be placed possibly much later, at the frame that coincides with end of speech, + # rather than end of text. The gap between last non-EOS token and EOS token will be filled with `pad_id`. + endpos = pos + len(text_ids) + if endpos > len(tokens): + trunc_len = len(tokens) - pos + logging.warning( + f"Truncating training example's text_ids of length {len(text_ids)} by {trunc_len} because {endpos=} > {len(tokens)=}. {diagnostic}" + ) + text_ids = text_ids[:trunc_len] + try: + tokens[pos:endpos] = text_ids + except Exception as e: + raise RuntimeError(f"{tokens.shape=} {pos=} {endpos=} {text_ids.shape=} {diagnostic}") from e + + # Insert EOS at the end of the supervision segment. + if add_text_bos_and_eos_in_each_turn: + eospos = compute_num_frames(supervision.end, frame_length, cut.sampling_rate) + if eospos < len(tokens): # skip otherwise - unfinished turn + tokens[eospos] = tokenizer.eos + + return tokens + + +def _strip_timestamps( + text: str, _TIMESTAMP_PATTERN=re.compile(r"<\|\d+\|>"), _SPACE_PATTERN=re.compile(r"\s+") +) -> str: + """ + Strips timestamp tokens from text, e.g. turns: + '<|0|> Hey <|3|> <|3|> how <|5|> <|7|> are <|8|> <|8|> <|10|> you? <|12|>' + into: + 'Hey how are you?' + """ + # Regexp pattern args are cached compiled patterns (micro-optimization). + text = _TIMESTAMP_PATTERN.sub("", text) # strip timestamp tokens if present + return _SPACE_PATTERN.sub(" ", text).strip() # strip multi-whitespaces + + +def sample_audio_segments_repeat( + prompt_audio: torch.Tensor, + prompt_audio_lens: torch.Tensor, + n_sample: int, + sample: bool = True, +) -> torch.Tensor: + """ + Extract audio segments of length n_sample. + If sample=True: randomly sample segments (repeating if shorter). + If sample=False: always take from the beginning (repeating if shorter). + + Args: + prompt_audio: Tensor [B, T] + prompt_audio_lens: Tensor [B] with valid lengths + n_sample: int, target length per segment + sample: bool, whether to randomly sample (True) or take first seconds (False) + + Returns: + Tensor [B, n_sample] + """ + B, T = prompt_audio.shape + device = prompt_audio.device + out = torch.zeros(B, n_sample, device=device, dtype=prompt_audio.dtype) + + for b in range(B): + length = min(prompt_audio_lens[b].item(), T) + + # Case: empty audio + if length <= 0: + continue + + if length >= n_sample: + if sample: + # Random start (safe bounds) + max_start = max(1, length - n_sample + 1) + start = torch.randint(0, max_start, (1,), device=device).item() + else: + # Deterministic: take from start + start = 0 + out[b] = prompt_audio[b, start : start + n_sample] + + else: + # Audio shorter than target → repeat + start = 0 + segment = prompt_audio[b, start:length] + + repeat_times = (n_sample + (length - start) - 1) // (length - start) + repeated = segment.repeat(repeat_times)[:n_sample] + out[b] = repeated + + return out diff --git a/nemo/collections/speechlm2/models/__init__.py b/nemo/collections/speechlm2/models/__init__.py index 144ea7774a6a..6fc06b6527ac 100644 --- a/nemo/collections/speechlm2/models/__init__.py +++ b/nemo/collections/speechlm2/models/__init__.py @@ -11,6 +11,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +from .duplex_ear_tts import DuplexEARTTS from .duplex_s2s_model import DuplexS2SModel from .duplex_s2s_speech_decoder_model import DuplexS2SSpeechDecoderModel from .salm import SALM @@ -18,5 +19,6 @@ __all__ = [ 'DuplexS2SModel', 'DuplexS2SSpeechDecoderModel', + 'DuplexEARTTS', 'SALM', ] diff --git a/nemo/collections/speechlm2/models/duplex_ear_tts.py b/nemo/collections/speechlm2/models/duplex_ear_tts.py new file mode 100644 index 000000000000..a7523baccc83 --- /dev/null +++ b/nemo/collections/speechlm2/models/duplex_ear_tts.py @@ -0,0 +1,1576 @@ +# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import copy +import os +import time +from collections import Counter +from contextlib import contextmanager + +import librosa +import torch +import torch.nn as nn +import torch.nn.functional as F +from lightning import LightningModule +from omegaconf import DictConfig +from peft import PeftModel +from torch.distributed.fsdp import fully_shard +from torch.distributed.tensor import Replicate, Shard +from torch.distributed.tensor.parallel import ( + ColwiseParallel, + PrepareModuleInput, + RowwiseParallel, + SequenceParallel, + loss_parallel, + parallelize_module, +) + +from nemo.collections.audio.parts.utils.resampling import resample +from nemo.collections.common.tokenizers import AutoTokenizer +from nemo.collections.speechlm2.data.utils import get_pad_id +from nemo.collections.speechlm2.modules.ear_tts_model import RVQEARTTSModel +from nemo.collections.speechlm2.modules.ear_tts_vae_codec import RVQVAEModel +from nemo.collections.speechlm2.parts.hf_hub import HFHubMixin +from nemo.collections.speechlm2.parts.metrics.asr_bleu import ASRBLEU +from nemo.collections.speechlm2.parts.metrics.asr_cer_wer import Intelligibility +from nemo.collections.speechlm2.parts.metrics.results_logger import ResultsLogger +from nemo.collections.speechlm2.parts.metrics.secs import SECS +from nemo.collections.speechlm2.parts.optim_setup import configure_optimizers, is_frozen +from nemo.collections.speechlm2.parts.precision import fp32_precision +from nemo.collections.speechlm2.parts.pretrained import ( + load_checkpoint, + load_pretrained_hf, + set_model_dict_for_partial_init, +) +from nemo.utils import logging + + +class DuplexEARTTS(LightningModule, HFHubMixin): + def __init__(self, cfg: dict) -> None: + assert isinstance(cfg, dict), ( + "You must pass the config to DuplexEARTTS as a Python dict to support hyperparameter serialization " + f"in PTL checkpoints (we got: '{type(cfg)=}')." + ) + super().__init__() + self.save_hyperparameters() + # convert dict to config + cfg = DictConfig(cfg) + self.trainer_config = cfg.get("trainer", None) + self.data_cfg = cfg.data + self.cfg = cfg.model + self.target_sample_rate = cfg.data.target_sample_rate + self.source_sample_rate = cfg.data.source_sample_rate + self.normalize_text = cfg.data.get("normalize_text", False) + + self.validation_save_path = os.path.join(cfg.exp_manager.explicit_log_dir, "validation_logs") + + # move back text channel by x, in inference it advance the text channel prediction by x frames + self.advance_text_channel_by = self.cfg.get("advance_text_channel_by", None) + + # Load ForCausalLM + if self.cfg.tts_config.context_hidden_size is not None: + self.language_model = self._load_language_model(self.cfg) + self.embed_tokens = self._load_embed_tokens(self.cfg) + # delete llm because we use it only to get the embbeding tokens + del self.language_model + + # get codec run precision + self.audio_codec_run_dtype = getattr(torch, self.cfg.get("audio_codec_run_dtype", "float32"), torch.float32) + + # Load tokenizer + self.tokenizer = AutoTokenizer( + self.cfg.pretrained_lm_name, + use_fast=True, + trust_remote_code=True, + bos_token=self.cfg.get("bos_token", None), + eos_token=self.cfg.get("eos_token", None), + pad_token=self.cfg.get("pad_token", None), + ) # Note that we are using fast tokenizer + + # Instantiate TTS model + self.tts_model = RVQEARTTSModel(DictConfig(self.cfg.tts_config), tokenizer=self.tokenizer) + # Load and initialize audio codec, and bind RVQ embeddings to the TTS model + setup_audio_codec(self) + + self._codebook_size = self.tts_model.config.codebook_size + + # compute samples per frame + self.source_samples_per_frame = int(self.source_sample_rate * cfg.data.frame_length) + + # get codec silence tokens + codec_silence_tokens = self.get_codec_silence_frame() + self.register_buffer("codec_silence_tokens", codec_silence_tokens) + + # cached for quicker audio decoding + self.register_buffer( + "_control_codes", + torch.tensor([self.speech_bos_id, self.speech_eos_id, self.speech_pad_id], device=self.device), + ) + + self._use_fsdp = False + self._use_tp = False + + def get_codec_silence_frame_last_one(self): + audio = torch.zeros(1, 10 * self.target_sample_rate).float().to(self.device) + audio_len = torch.tensor([audio.size(-1)]).long() + audio, audio_len = self.pad_audio_to_factor(audio, audio_len, self.target_samples_per_frame) + + with ensures_target_precision(self.audio_codec_run_dtype), torch.no_grad(): + sil_codes, sil_codes_lens = self.audio_codec.encode(audio.unsqueeze(1), audio_len) + return sil_codes[0, -1] + + def get_codec_silence_frame(self): + + # Generate long zero waveform (silence) + audio = torch.zeros(1, 10 * self.target_sample_rate).float().to(self.device) + audio_len = torch.tensor([audio.size(-1)]).long() + audio, audio_len = self.pad_audio_to_factor(audio, audio_len, self.target_samples_per_frame) + + with ensures_target_precision(self.audio_codec_run_dtype), torch.no_grad(): + sil_codes, _ = self.audio_codec.encode(audio.unsqueeze(1), audio_len) # [1, T, C] + sil_codes = sil_codes[0] # [T, C] + + # Convert each frame (C tokens) into a tuple + combos = [tuple(row.tolist()) for row in sil_codes] + + # Count frequencies + counter = Counter(combos) + + # Pick the most common combination + most_common_combo, freq = counter.most_common(1)[0] + + # Return as tensor [C] + return torch.tensor(most_common_combo, device=self.device, dtype=torch.long) + + def _load_embed_tokens(self, cfg) -> nn.Embedding: + """Load token embedding layer for RVQ-EAR-TTS.""" + if self.language_model: + assert callable(self.language_model.get_input_embeddings) + embed_tokens: nn.Embedding = self.language_model.get_input_embeddings() + else: + embed_tokens_state_dict = torch.load( + cfg.pretrained_lm_embedding_path, map_location="cpu", weights_only=True + ) + + # Create token embedding layer + vocab_size, hidden_size = embed_tokens_state_dict["weight"].size() + embed_tokens = nn.Embedding(vocab_size, hidden_size, dtype=torch.bfloat16) + embed_tokens.load_state_dict(embed_tokens_state_dict) + return embed_tokens + + def _load_language_model(self, cfg): + """Load language model for RVQ-EAR-TTS.""" + if cfg.pretrained_lm_name: + language_model = load_pretrained_hf( + self.cfg.pretrained_lm_name, pretrained_weights=True, trust_remote_code=True + ).eval() + else: + language_model = None + return language_model + + def restore_from_pretrained_checkpoint(self, checkpoint_path): + """ + Loads model weights a pretrained checkpoint file, supporting partial loading from safetensor and PyTorch formats. + + Args: + checkpoint_path (str): Path to checkpoint file. + + Returns: + None. The model is updated in-place. + """ + if checkpoint_path is not None: + checkpoint_state = load_checkpoint(checkpoint_path) + checkpoint_state = set_model_dict_for_partial_init(checkpoint_state, self.state_dict()) + + if self.cfg.get("rescale_pretrained_weights", None): + checkpoint_state = rescale_state_dict( + checkpoint_state, first_n_layers=self.cfg.get("rescale_first_n_layers", None) + ) + + self.load_state_dict(checkpoint_state, strict=True) + + @property + def device(self): + return next(self.parameters()).device + + @property + def speech_vocab_size(self): + """Return the size of the audio codec codebook including extra speech BOS and EOS tokens.""" + if self.use_local_transformer and self.local_transformer_type == "nar": # add extra token for mask + return self._codebook_size + 4 + return self._codebook_size + 3 + + @property + def speech_bos_id(self) -> int: + """Indicates start of utterance generation (not start of inference!).""" + if self.cfg.get("custom_speech_bos_id", None): + return self.cfg.get("custom_speech_bos_id") + return self._codebook_size + 2 + + @property + def speech_eos_id(self) -> int: + """Indicates end of utterance generation.""" + if self.cfg.get("custom_speech_eos_id", None): + return self.cfg.get("custom_speech_eos_id") + return self._codebook_size + 1 + + @property + def speech_pad_id(self) -> int: + """Indicates start of inference (the very first frame).""" + if self.cfg.get("custom_speech_pad_id", None): + return self.cfg.get("custom_speech_pad_id") + return self._codebook_size + + @property + def text_vocab_size(self): + """Return the size of the text tokenizer.""" + return self.tokenizer.vocab_size + + @property + def text_bos_id(self) -> int: + return self.tokenizer.bos_id + + @property + def text_eos_id(self) -> int: + return self.tokenizer.eos_id + + @property + def text_pad_id(self) -> int: + """ + Text pad ID is used as a 'blank' for frames when the model is not speaking + and for frames where the model is speaking but has already predicted the + entire text channel's content. + + Example: + + flow: |---user---||-------assistant--------||-user-| + text channel: 0000000000 1xxxxxxx0000000000000002 000000 + + Where 0 indicates PAD ID, 1 indicates BOS ID, 2 indacates EOS ID, + and x indicates tokens corresponding to actual text + + """ + return get_pad_id(self.tokenizer) + + def pad_audio_to_factor(self, audio, audio_len, samples_per_frame, downsampling_factor: int = 1): + """ + Zero pad the end of the audio so that we do not have a partial end frame. + The output will be zero-padded to have an integer number of frames of + length `samples_per_frame * downsampling_factor`. + + Args: + audio: input time-domain signal (B, T) + audio_len: valid length for each example in the batch (B,) + samples_per_frame: number of samples per frame + downsampling_factor: how much each frame is downsampled in later processing + + Returns: + padded_audio: Padded time-domain signal (B, T') + padded_len: Adjusted valid lengths (B,) + """ + with fp32_precision(): + total_factor = samples_per_frame * downsampling_factor + padded_len = total_factor * torch.ceil(audio_len / total_factor).int() + max_len = padded_len.max().int().item() + num_padding = max_len - audio.shape[1] + padded_audio = F.pad(audio, (0, num_padding)) + return padded_audio, padded_len + + def prepare_inputs(self, batch: dict): + """ + Prepare inputs, extracting audio tokens and padding if needed. + """ + # check if audios has the same batch size + assert batch["source_audio"].size(0) == batch["target_audio"].size(0) + assert batch["audio_prompt"].size(0) == batch["target_audio"].size(0) + + target_audio = batch["target_audio"] + target_audio_lens = batch["target_audio_lens"] + target_text_tokens = batch["target_text_tokens"] + non_prompt_mask = batch["non_prompt_mask"] + aligned_attention_mask = batch["aligned_attention_mask"] + aligned_position_ids = batch["aligned_position_ids"] + + if self.training and (self.cfg.get("empty_turn_probability", 0.0) > 0): + # Randomly decide whether this batch gets emptied + if torch.rand(1).item() < self.cfg.empty_turn_probability: + # Zero out audio + target_audio = torch.zeros_like(target_audio) + + # Create mask for tokens we want to drop + # Keep BOS and EOS, drop the rest. + keep_mask = (target_text_tokens == self.text_bos_id) | (target_text_tokens == self.text_eos_id) + full_dropout_mask = ~keep_mask # True = positions to replace with PAD + + # Replace all non-BOS/EOS with PAD + target_text_tokens = torch.where( + full_dropout_mask, torch.full_like(target_text_tokens, self.text_pad_id), target_text_tokens + ) + + # extract target audio codes + target_audio, target_audio_lens = self.pad_audio_to_factor( + target_audio, target_audio_lens, self.target_samples_per_frame, 1 + ) + with ensures_target_precision(self.audio_codec_run_dtype), torch.no_grad(): + target_codes, target_codes_lens = self.audio_codec.encode(target_audio.unsqueeze(1), target_audio_lens) + + with fp32_precision(): + target_len = target_codes.shape[1] + + # Pad or truncate sequence variables + def pad_or_truncate(x, pad_value=0): + if x.dim() == 2: # [B, T] + L = x.shape[1] + if L < target_len: + return F.pad(x, (0, target_len - L), value=pad_value) + else: + return x[:, :target_len] + return x # leave others for now + + target_text_tokens = pad_or_truncate(target_text_tokens, pad_value=self.text_pad_id) + non_prompt_mask = pad_or_truncate(non_prompt_mask, pad_value=0) + aligned_position_ids = pad_or_truncate(aligned_position_ids, pad_value=0) + + # Correct attention mask padding/truncation + B, H, L1, L2 = aligned_attention_mask.shape + new_len = target_len + if L1 < new_len or L2 < new_len: + pad_rows = new_len - L1 + pad_cols = new_len - L2 + aligned_attention_mask = F.pad(aligned_attention_mask, (0, pad_cols, 0, pad_rows)) + elif L1 > new_len or L2 > new_len: + aligned_attention_mask = aligned_attention_mask[:, :, :new_len, :new_len] + + # set the pad token for the first BOS frame + target_codes_aligned = target_codes.clone() + target_codes_aligned[:, 0] = self.speech_pad_id + + # set special token in the last audio prompt (it will works as a BOS token) + pos = non_prompt_mask.float().argmax(dim=1) # shape: [B] + row_idx = torch.arange(B, device=self.device) + # set the extra self.speech_pad_id at first 1 position in non_prompt_mask + target_codes_aligned[row_idx, pos] = self.speech_pad_id + + # EOS dropout to make the model more robust + if self.training and self.cfg.get("text_eos_dropout_prob", 0.0) > 0: + # Mask EOS positions + eos_mask = target_text_tokens == self.text_eos_id + + # Random dropout only on EOS positions + dropout_mask = ( + torch.rand(eos_mask.sum(), device=target_text_tokens.device) < self.cfg.text_eos_dropout_prob + ) + + # Scatter dropout decisions into [B, T] + full_dropout_mask = torch.zeros_like(target_text_tokens, dtype=torch.bool) + full_dropout_mask[eos_mask] = dropout_mask + + # Replace dropped EOS with PAD + target_text_tokens = torch.where( + full_dropout_mask, torch.full_like(target_text_tokens, self.text_pad_id), target_text_tokens + ) + + if self.training and self.cfg.get("text_eos_duplicate_prob", 0.0) > 0: + p = self.cfg.text_eos_duplicate_prob + + # [B, T] mask of EOS positions + eos_mask = target_text_tokens == self.text_eos_id + + # Flatten EOS positions: tensor of shape [N, 2] where each row = (batch_idx, time_idx) + eos_positions = eos_mask.nonzero(as_tuple=False) # [N, 2] + + if eos_positions.numel() > 0: + N = eos_positions.shape[0] + + # One random decision per EOS occurrence + duplicate_decision = torch.rand(N, device=target_text_tokens.device) < p # [N] + + # Filter only EOS tokens that will be duplicated and are not at position t=0 + valid = (eos_positions[:, 1] > 0) & duplicate_decision # [N] + + if valid.any(): + # Select only valid EOS positions + valid_positions = eos_positions[valid] # [M, 2] + + # Indices for the token BEFORE the EOS (t-1) + b_idx = valid_positions[:, 0] + t_idx = valid_positions[:, 1] - 1 + + # Replace token before EOS with an EOS + target_text_tokens[b_idx, t_idx] = self.text_eos_id + + # BOS dropout to make the model more robust + if self.training and self.cfg.get("text_bos_dropout_prob", 0.0) > 0: + # Mask BOS positions + bos_mask = target_text_tokens == self.text_bos_id + + # Random dropout only on BOS positions + dropout_mask = ( + torch.rand(bos_mask.sum(), device=target_text_tokens.device) < self.cfg.text_bos_dropout_prob + ) + + # Scatter dropout decisions into [B, T] + full_dropout_mask = torch.zeros_like(target_text_tokens, dtype=torch.bool) + full_dropout_mask[bos_mask] = dropout_mask + + # Replace dropped BOS with PAD + target_text_tokens = torch.where( + full_dropout_mask, + torch.full_like(target_text_tokens, self.text_pad_id), + target_text_tokens, + ) + + # shift text tokens + subword_ids = F.pad(target_text_tokens[:, 1:], [0, 1]) + # note that we are using a text mask where we are ignoring the desc + audio prompt but we are keeping 1 until the audio ends to support duplex + subword_mask = F.pad(non_prompt_mask[:, 1:], [0, 1]) + + # detach embedding as in eartts + if self.cfg.tts_config.context_hidden_size is not None: + context_hidden_state = self.embed_tokens(target_text_tokens).detach() + else: + context_hidden_state = None + + if self._use_tp: + tp_world_size = self.device_mesh["tensor_parallel"].size() + if (remainder := (target_text_tokens.shape[1] - 1) % tp_world_size) != 0: + target_text_tokens = target_text_tokens[:, :-remainder] + target_codes_aligned = target_codes_aligned[:, :-remainder] + target_codes_aligned = target_codes_aligned[:, :-remainder] + subword_ids = subword_ids[:, :-remainder] + subword_mask = subword_mask[:, :-remainder] + + return { + "code": target_codes_aligned, + "audio_mask": non_prompt_mask, # set audio_mask as non_prompt_mask to avoid the audio prompt in loss computation + "attention_mask": aligned_attention_mask, + "position_ids": aligned_position_ids, + "subword_ids": subword_ids, + "subword_mask": subword_mask, + "context_hidden_state": context_hidden_state, + "output_lens": target_codes_lens, + "non_prompt_mask": non_prompt_mask, + "target_text_tokens": target_text_tokens, + } + + def training_step(self, batch: dict, batch_idx: int): + for m in (self.tts_model,): + if is_frozen(m): + m.eval() + + inputs = self.prepare_inputs(batch) + + tts_output = self.tts_model( + code=inputs["code"], + audio_mask=inputs["audio_mask"], + attention_mask=inputs["attention_mask"], + position_ids=inputs["position_ids"], + context_hidden_state=inputs["context_hidden_state"], + subword_ids=inputs["subword_ids"], + subword_mask=inputs["subword_mask"], + non_prompt_mask=inputs["non_prompt_mask"], + ) + loss_dict = {"lm_loss": tts_output.lm_loss, "c_loss": tts_output.c_loss, "k_loss": tts_output.k_loss} + loss = sum(loss_dict.values()) + + num_frames = inputs["output_lens"].sum() + B, T = inputs["code"].shape[:2] + ans = { + "loss": loss, + "learning_rate": ( + torch.as_tensor(self.trainer.optimizers[0].param_groups[0]['lr'] if self._trainer is not None else 0) + ), + "batch_size": B, + "sequence_length": T, + "num_frames": num_frames.to(torch.float32), # avoid warning + "padding_ratio": num_frames / (B * T), + **loss_dict, + } + + self.log_dict(ans, on_step=True) + return ans + + def on_train_epoch_start(self) -> None: + ensures_codec_target_dtype( + self + ) # potentially reloads the audio codec to make sure it's in target codec precision + + def on_train_epoch_end(self) -> None: + # log model stats to debug gradient weights issues + self.log_model_stats() + + def log_model_stats(self): + total_w_sq = 0.0 + total_w_params = 0 + max_abs_w = 0.0 + sum_w = 0.0 + + total_g_sq = 0.0 + total_g_params = 0 + + for p in self.parameters(): + if not p.requires_grad: + continue + + # ----- weights ----- + w = p.detach().cpu().float() # safe offline copy + total_w_sq += (w * w).sum().item() + total_w_params += w.numel() + max_abs_w = max(max_abs_w, w.abs().max().item()) + sum_w += w.sum().item() + + # ----- grads (optional, disabled for speed) ----- + if p.grad is not None: + g = p.grad.detach().cpu().float() + total_g_sq += (g * g).sum().item() + total_g_params += g.numel() + + # L2 norms + weight_l2 = (total_w_sq**0.5) if total_w_sq > 0 else 0.0 + + # RMS (global) + weight_rms = ((total_w_sq / total_w_params) ** 0.5) if total_w_params > 0 else 0.0 + + # Mean + weight_mean = sum_w / total_w_params if total_w_params > 0 else 0.0 + + # direct float logging avoids device sync penalty + self.log("weights/L2", weight_l2, on_epoch=True, sync_dist=True) + self.log("weights/RMS", weight_rms, on_epoch=True, sync_dist=True) + self.log("weights/max_abs", max_abs_w, on_epoch=True, sync_dist=True) + self.log("weights/mean", weight_mean, on_epoch=True, sync_dist=True) + + def on_validation_epoch_start(self) -> None: + ensures_codec_target_dtype( + self + ) # potentially reloads the audio codec to make sure it's in target codec precision + + self.results_logger = ResultsLogger(self.validation_save_path).reset() + self.asr_bleu = ASRBLEU(self.cfg.scoring_asr).reset() + self.intelligibility = Intelligibility(self.cfg.scoring_asr, reuse_asr_hyps=True).reset() + self.secs = SECS(self.cfg.get("scoring_se", "titanet_large")).reset() + + def on_validation_epoch_end(self, prefix="val") -> None: + asr_bleu = self.asr_bleu.compute() + for k, m in asr_bleu.items(): + self.log(f"{prefix}_{k}", m.to(self.device), on_epoch=True, sync_dist=True) + cer_wer = self.intelligibility.compute() + for k, m in cer_wer.items(): + self.log(f"{prefix}_{k}", m.to(self.device), on_epoch=True, sync_dist=True) + secs = self.secs.compute() + for k, m in secs.items(): + self.log(f"{prefix}_{k}", m.to(self.device), on_epoch=True, sync_dist=True) + + def get_teacher_force_inference_audio(self, batch, guidance_enabled=True): + inputs = self.prepare_inputs(batch) + + tts_output = self.tts_model( + code=inputs["code"], + audio_mask=inputs["audio_mask"], + attention_mask=inputs["attention_mask"], + position_ids=inputs["position_ids"], + context_hidden_state=inputs["context_hidden_state"], + subword_ids=inputs["subword_ids"], + subword_mask=inputs["subword_mask"], + non_prompt_mask=inputs["non_prompt_mask"], + generation_config=self._get_generation_config(guidance_enabled=guidance_enabled), + teacher_forcing_inference=True, + guidance_enabled=guidance_enabled, + ) + tf_audio_codes_pred = tts_output["codes"].squeeze(2) + + # decode audio + tf_audio_codes_pred = replace_control_speech_codes( + tf_audio_codes_pred, self._control_codes, self.codec_silence_tokens + ) + with ensures_target_precision(self.audio_codec_run_dtype), torch.no_grad(): + audio_pred, audio_len = self.audio_codec.decode(tf_audio_codes_pred, inputs["output_lens"]) + + return audio_pred.squeeze(1), audio_len + + def _get_generation_config(self, guidance_enabled: bool = False): + """Get default generation config for EAR-TTS.""" + return { + "num_iter": 8, + "guidance_scale": self.cfg.get("inference_guidance_scale", 0.5) if guidance_enabled else None, + "top_p_or_k": self.cfg.get("inference_top_p_or_k", 0.8), + "noise_scale": self.cfg.get("inference_noise_scale", 0.8), + "eos_threshold": -3.0, + } + + def run_evaluation_one_batch(self, name, dataset_batch, use_dataloader_init=False): + """ + Runs evaluation and scoring for a single data batch, logging metrics and updating result buffers. + + Args: + name (str): Name/id for the batch (for logging). + dataset_batch (dict): Batch of data inputs, supports batched text/audio/etc. + use_dataloader_init (bool, optional): If True, use dataloader initialization for prompts. + + Returns: + None. Outputs are logged and stored in result buffers. + """ + results = {} + inputs = self.prepare_inputs(dataset_batch) + + results["audio_tf"], results["audio_tf_len"] = self.get_teacher_force_inference_audio(dataset_batch) + if use_dataloader_init: + # cut it on prompt + init_inputs = { + "code": inputs["code"], + "audio_mask": inputs["audio_mask"], + "non_prompt_mask": inputs["non_prompt_mask"], + "context_hidden_state": inputs["context_hidden_state"], + "subword_ids": inputs["subword_ids"], + "subword_mask": inputs["subword_mask"], + } + # cut init_inputs to consider only the prompt + for key in init_inputs: + if init_inputs[key] is not None: + init_inputs[key] = torch.stack( + [init_inputs[key][i, :plen] for i, plen in enumerate(dataset_batch["prompt_lens"])] + ) + else: + # set init inputs and get it + self.set_init_inputs( + speaker_audio=dataset_batch["audio_prompt"], + speaker_audio_lens=dataset_batch["audio_prompt_lens"], + ) + init_inputs = self.get_init_inputs(B=inputs["subword_ids"].size(0)) + + # remove the prompt from the target_text_tokens to emulate S2S connected inference + next_subword_ids = torch.stack( + [ + inputs["subword_ids"][i, plen:] # slice each element + for i, plen in enumerate(dataset_batch["prompt_lens"]) + ] + ) + + results["audio"], results["audio_len"] = self.offline_inference( + next_subword_ids=next_subword_ids, + formatter=dataset_batch["formatter"][0], + init_inputs=init_inputs, + ) + + # remove prompt padding from the user audio as autoregressive inference does not return the prompt + dataset_batch["source_audio"] = dataset_batch["source_audio"][ + :, -int(next_subword_ids.size(-1) * self.source_samples_per_frame) : + ] + + # clean prompt from the audio + results["audio_tf"] = results["audio_tf"][:, -int(next_subword_ids.size(-1) * self.target_samples_per_frame) :] + # remove prompt from target audio + target_audio_no_prompt = dataset_batch["target_audio"][ + :, -int(next_subword_ids.size(-1) * self.target_samples_per_frame) : + ] + target_audio_no_prompt_lens = dataset_batch["target_audio_lens"] - ( + torch.tensor( + dataset_batch["prompt_lens"], + dtype=torch.long, + device=dataset_batch["target_audio_lens"].device, + ) + * self.target_samples_per_frame + ) + + with fp32_precision(): # resample is fragile to bfloat16 default dtype + metric_audio_pred = results["audio"] + metric_audio_pred_lens = results["audio_len"] + + # resample audio to the asr sampling rate + metric_audio_pred = resample(metric_audio_pred, self.target_sample_rate, 16000) + metric_audio_pred_lens = (metric_audio_pred_lens / self.target_sample_rate * 16000).to(torch.long) + # reshape target audio without prompt + target_audio_no_prompt_16khz = resample(target_audio_no_prompt, self.target_sample_rate, 16000) + target_audio_no_prompt_lens_16khz = (target_audio_no_prompt_lens / self.target_sample_rate * 16000).to( + torch.long + ) + if self.cfg.get("use_GT_transcriptions_for_metrics", True): + # use target audio transcription for metrics + target_asr_texts = self.asr_bleu.asr.transcribe( + [ + audio[:alen] + for audio, alen in zip(target_audio_no_prompt_16khz, target_audio_no_prompt_lens_16khz) + ], + batch_size=target_audio_no_prompt_16khz.shape[0], + verbose=False, + ) + metric_text = [asr_hyp.text for asr_hyp in target_asr_texts] + else: + metric_text = dataset_batch["target_texts"] + + asr_hyps = self.asr_bleu.update( + name=name, + refs=metric_text, + pred_audio=metric_audio_pred, + pred_audio_lens=metric_audio_pred_lens, + ) + + self.intelligibility.update( + name=name, + refs=metric_text, + pred_audio=metric_audio_pred, + pred_audio_lens=metric_audio_pred_lens, + asr_hyps=asr_hyps, + ) + + # add ground truth intelligibility metrics + self.intelligibility.update( + name=name + "_gt", + refs=dataset_batch["target_texts"], + pred_audio=target_audio_no_prompt_16khz, + pred_audio_lens=target_audio_no_prompt_lens_16khz, + asr_hyps=( + metric_text if self.cfg.get("use_GT_transcriptions_for_metrics", True) else None + ), # reuse GT transcription + ) + + self.secs.update( + name=name, + target_audio=resample(dataset_batch["target_audio"], self.target_sample_rate, 16000), + target_audio_lens=(dataset_batch["target_audio_lens"] / self.target_sample_rate * 16000).to( + torch.long + ), + pred_audio=resample(results["audio"], self.target_sample_rate, 16000), + pred_audio_lens=(results["audio_len"] / self.target_sample_rate * 16000).to(torch.long), + ) + + eou_labels = generate_multiturn_speaking_mask( + next_subword_ids, bos_token_id=self.text_bos_id, eos_token_id=self.text_eos_id + ) + + self.results_logger.update( + name=name, + refs=dataset_batch["target_texts"], + hyps=metric_text, + asr_hyps=asr_hyps, + samples_id=dataset_batch['sample_id'], + pred_audio=results["audio"].float(), + pred_audio_tf=results["audio_tf"].float(), + pre_audio_trimmed=None, + reference_audio=dataset_batch["audio_prompt"].float(), + target_audio=target_audio_no_prompt.float(), + pred_audio_sr=self.target_sample_rate, + user_audio=dataset_batch["source_audio"].float(), + user_audio_sr=self.source_sample_rate, + eou_pred=eou_labels, + fps=self.target_fps, + results=results if self.cfg.get("dump_tokens_text", False) else None, + tokenizer=self.tokenizer, + ) + + def validation_step(self, batch: dict, batch_idx: int): + for name, dataset_batch in batch.items(): + if dataset_batch is None: + continue # some dataset is exhausted + + B = len(dataset_batch['sample_id']) + + # run inference for a custom speaker reference + if self.cfg.get("inference_speaker_reference", None): + new_dataset_batch = copy.deepcopy(dataset_batch) + speaker_audio, sr = load_audio_librosa(self.cfg.inference_speaker_reference) + speaker_audio = resample(speaker_audio, sr, self.target_sample_rate) + speaker_audio = speaker_audio.repeat(B, 1).to(self.device) + # lengths -> [B] + speaker_audio_lens = torch.tensor([speaker_audio.size(1)], device=self.device).long().repeat(B) + new_dataset_batch["audio_prompt"] = speaker_audio + new_dataset_batch["audio_prompt_lens"] = speaker_audio_lens + self.run_evaluation_one_batch(name, new_dataset_batch) + + # run inference using dataloader speaker references + else: + self.run_evaluation_one_batch(name, dataset_batch, use_dataloader_init=False) + + def on_test_epoch_start(self) -> None: + return self.on_validation_epoch_start() + + def on_test_epoch_end(self) -> None: + return self.on_validation_epoch_end(prefix="test") + + def test_step(self, *args, **kwargs): + return self.validation_step(*args, **kwargs) + + def set_init_inputs(self, speaker_audio, speaker_audio_lens, system_prompt=None, user_prompt=None): + """ + Registers and prepares initial input buffers for text/audio prompt and context, to warm up AR inference. + + Args: + speaker_audio (torch.Tensor): Batch of prompt audio, (B, T). + speaker_audio_lens (torch.Tensor): Lengths for each sample in speaker_audio, (B,). + system_prompt (str, optional): System prompt for context. + user_prompt (str, optional): User message for context. + + Returns: + dict: Dictionary of input tensors to be passed to inference, with registered buffers. + """ + # compute prompt audio size and slice it + with fp32_precision(): + # compute the exact number of samples for the prompt duration + prompt_audio_size = int( + ((self.data_cfg.audio_prompt_duration * self.target_sample_rate) // self.target_samples_per_frame) + * self.target_samples_per_frame + ) + + B, T = speaker_audio.shape + device = speaker_audio.device + dtype = speaker_audio.dtype + + # allocate result + prompt_audio = torch.zeros(B, prompt_audio_size, device=device, dtype=dtype) + + # process each example independently + for b in range(B): + valid_len = min(speaker_audio_lens[b].item(), T) + + # handle empty + if valid_len <= 0: + continue + + # valid (non-padded) segment + valid_segment = speaker_audio[b, :valid_len] + + if valid_len >= prompt_audio_size: + # enough valid audio → crop from start (no silence) + prompt_audio[b] = valid_segment[:prompt_audio_size] + else: + # too short → repeat and crop + repeat_factor = (prompt_audio_size + valid_len - 1) // valid_len # ceil division + expanded = valid_segment.repeat(repeat_factor) + prompt_audio[b] = expanded[:prompt_audio_size] + + # add a silence in the end to smooth the transition between prompt and audio tokens + prompt_audio[:, -int(self.target_samples_per_frame * 2) :] = 0 + + # get prompt audio size + with fp32_precision(): + prompt_audio_text_pad_size = int(prompt_audio_size // self.target_samples_per_frame) + + # create a eos token id + first_text_frame = torch.tensor([self.tokenizer.eos], dtype=torch.long, device=self.device) + + # create a padding tensor + prompt_audio_text_pad = ( + torch.ones(prompt_audio_text_pad_size, device=self.device, dtype=first_text_frame.dtype) * self.text_pad_id + ) + prompt_audio_text_pad[-1] = self.tokenizer.eos + + # Prepend an initial text EOS token followed by padding tokens that match + # the number of audio-prompt frames (in text-token units). + target_text_tokens = torch.cat([first_text_frame, prompt_audio_text_pad.to(first_text_frame.dtype)]) + + # create pad audio for the description + pad_size = first_text_frame.size(-1) * self.target_samples_per_frame + pad_audio = ( + torch.zeros(pad_size, device=prompt_audio.device, dtype=prompt_audio.dtype) + .unsqueeze(0) + .repeat(prompt_audio.size(0), 1) + ) + + # repeat to reaches the batch size + target_text_tokens = target_text_tokens.unsqueeze(0).repeat(prompt_audio.size(0), 1) + target_audio = torch.cat([pad_audio, prompt_audio], dim=1) + + # extract code codes + target_audio_len = torch.tensor( + [target_audio.size(-1)] * target_audio.size(0), dtype=torch.long, device=self.device + ) + with ensures_target_precision(self.audio_codec_run_dtype), torch.no_grad(): + code, _ = self.audio_codec.encode(target_audio.unsqueeze(1), target_audio_len) + + # get context hidden + if self.cfg.tts_config.context_hidden_size is not None: + context_hidden_state = self.embed_tokens(target_text_tokens) + else: + context_hidden_state = None + + # create masks + # non_prompt_mask is all zeros, because all processed is prompt + non_prompt_mask = torch.zeros_like(target_text_tokens) + non_prompt_mask[:, -2:] = 1 # set last valid prompt frame as 1 to allow the addition of BOS in the right place + subword_mask = torch.zeros_like( + target_text_tokens + ) # subword_mask is almost all zeros because on the warmup there is only the prompt + subword_mask[:, -3:] = ( + 1 # -3 because of the it start right after the first valid prompt token and it is shifted by 1 + ) + + # set the pad token for the first BOS frame + code[:, 0] = self.speech_pad_id + + # shift subword_ids + subword_ids = F.pad(target_text_tokens[:, 1:], [0, 1], value=0.0) + + # set special token in the last audio prompt (it will works as a BOS token) + pos = non_prompt_mask.float().argmax(dim=1) # shape: [B] + row_idx = torch.arange(B, device=self.device) + # set the extra self.speech_pad_id at first 1 position in non_prompt_mask + code[row_idx, pos] = self.speech_pad_id + + init_inputs = { + "code": code[:, :-1], + "audio_mask": non_prompt_mask.bool()[ + :, :-1 + ], # set audio_mask as non_prompt_mask to avoid the audio prompt in loss computation + "context_hidden_state": context_hidden_state[:, :-1] if context_hidden_state is not None else None, + "subword_ids": subword_ids[:, :-1], + "subword_mask": subword_mask.bool()[:, :-1], + "non_prompt_mask": non_prompt_mask.bool()[:, :-1], + } + # register to acess later + for k, v in init_inputs.items(): + name = f"init_input_{k}" + if v is not None: + self.register_buffer(name, v) + + return init_inputs + + def get_init_inputs( + self, + B: int, + init_inputs_names=[ + "code", + "audio_mask", + "context_hidden_state", + "subword_ids", + "subword_mask", + "non_prompt_mask", + ], + ): + """ + Returns a dictionary of initial inputs for inference, using registered buffers. + + Args: + B (int): Required batch size. + init_inputs_names (List[str], optional): Names of input buffers to fetch. + + Returns: + dict: Each key is name from init_inputs_names, and value is tensor of appropriate shape (B, ...). + + Notes: + Expands batch-1 buffers to B if necessary. + """ + if init_inputs_names is None: + init_inputs_names = [ + "code", + "audio_mask", + "context_hidden_state", + "subword_ids", + "subword_mask", + "non_prompt_mask", + ] + + init_inputs = {} + for name in init_inputs_names: + buf_name = f"init_input_{name}" + buf = getattr(self, buf_name, None) + + if buf is None: + init_inputs[name] = None + continue + + # Use as-is if batch matches + if buf.shape[0] == B: + init_inputs[name] = buf + else: + # Otherwise, assume batch=1 and expand to target B + init_inputs[name] = buf[:1].expand(B, *buf.shape[1:]) + + return init_inputs + + @torch.no_grad() + def infer_codes_one_step( + self, + current_subword_id, + prev_subword_id, + current_subword_mask, + prev_audio_tokens, + past_key_values, + guidance_enabled=True, + generation_config=None, + ignore_eos_flag_stop=True, + ): + """ + Runs a single autoregressive prediction step to infer audio codec codes. + + Args: + current_subword_id (torch.Tensor): Current text token IDs, shape (B, 1). + prev_subword_id (torch.Tensor): Previous text token IDs, shape (B, 1). + current_subword_mask (torch.Tensor): Current mask, shape (B, 1). + prev_audio_tokens (torch.Tensor): Previously generated audio tokens, shape (B, 1, C). + past_key_values: Key-value cache for transformer decoder state. + guidance_enabled (bool, optional): Enables classifier-free guidance. + generation_config (dict, optional): Generation hyperparameters. + ignore_eos_flag_stop (bool): If True, ignore EOS flag for stopping. + + Returns: + Tuple[torch.Tensor, Any]: + - Predicted audio codec token(s), shape (B, 1, C) + - Updated past_key_values for the next step. + """ + + if self.cfg.tts_config.context_hidden_size is not None: + # get context_hidden_state it is always one step behind current_subword_id + # for the first step uses the last step from warmup + context_hidden_state = self.embed_tokens(prev_subword_id) + else: + context_hidden_state = None + + # force silence as next token + if self.cfg.get('inference_force_speech_silence_on_eos', True): + silence_codes = self.codec_silence_tokens.view(1, 1, -1).expand(prev_audio_tokens.shape) + prev_audio_tokens = torch.where( + current_subword_id.unsqueeze(-1) == self.text_eos_id, + silence_codes, # silence + prev_audio_tokens, # keep original + ) + + # get subword_ids + inputs = { + "code": prev_audio_tokens, + "context_hidden_state": context_hidden_state, + "subword_ids": current_subword_id, + "subword_mask": current_subword_mask, + "past_key_values": past_key_values, + "use_cache": True, + "guidance_enabled": guidance_enabled, + "generation_config": generation_config, + "ignore_eos_flag_stop": ignore_eos_flag_stop, + } + + outputs = self.tts_model(**inputs) + + return outputs["codes"], outputs["past_key_values"] + + @torch.no_grad() + def decode_one_audio_step(self, gen_audio_codes_history, number_prev_tokens=None): + """ + Decodes one step of generated audio codec tokens to raw waveform. + + Args: + gen_audio_codes_history (torch.Tensor): Audio tokens history, shape (B, T, C). + number_prev_tokens (int, optional): Number of previous tokens to decode, for incremental decoding. + + Returns: + Tuple[torch.Tensor, torch.Tensor]: + - audio_pred_cur_step: Latest decoded waveform chunk, shape (B, wav_to_token_ratio). + - audio_len: Lengths (number of samples), shape (B,). + """ + with fp32_precision(), torch.no_grad(): + if number_prev_tokens: + gen_audio_codes_history = gen_audio_codes_history[:, -number_prev_tokens:] + + gen_audio_codes_history = replace_control_speech_codes( + gen_audio_codes_history, self._control_codes, self.codec_silence_tokens + ) + gen_audio_codes_lens = torch.tensor( + [gen_audio_codes_history.size(1)] * gen_audio_codes_history.size(0), device=self.device + ) + audio_pred, audio_len = self.audio_codec.decode(gen_audio_codes_history, gen_audio_codes_lens) + + # return only the current/lastest audio chunk + audio_pred_cur_step = audio_pred.squeeze(1)[:, -self.audio_codec.config.wav_to_token_ratio :] + audio_len[:] = self.audio_codec.config.wav_to_token_ratio + return audio_pred_cur_step, audio_len + + @torch.no_grad() + def offline_inference( + self, + next_subword_ids: torch.Tensor, + init_inputs: dict, + formatter: str = "", + guidance_enabled: bool = True, + generation_config: dict = None, + incremental_audio_decoding: bool = False, + ) -> dict[str, torch.Tensor]: + """ + Runs offline autoregressive inference for the Duplex EAR-TTS speech decoder. + + This method performs **text-to-speech (TTS)** generation: given subword/text + tokens and prompt-initialization states, it autoregressively generates + audio codec tokens and decodes them into a waveform. + + Args: + next_subword_ids (torch.Tensor): + Conditioning subword/text token IDs for the speech decoder. + Shape: (B, T_text). + + init_inputs (dict): + Dictionary of prompt-dependent initial states produced by + ``get_init_inputs()``. May include: + + • "code" — initial audio tokens (e.g., prompt audio) + • "audio_mask" — mask for prompt audio positions + • "context_hidden_state" — decoder hidden state at t = 0 + • "subword_ids" — prompt text tokens + • "subword_mask" — mask for prompt text + • "non_prompt_mask" — mask marking positions to be generated + + ``get_init_inputs()`` automatically expands batch-1 buffers to + batch size B. + + formatter (str, optional): + Optional formatter identifier used to customize the prompt structure. + + guidance_enabled (bool, optional): + Whether classifier-free guidance (CFG) is enabled. + If enabled and ``generation_config`` is ``None``, guidance parameters + are taken from ``_get_generation_config()``. + + generation_config (dict, optional): + Settings controlling autoregressive generation, including sampling + strategy, noise scale, refinement iterations, and EOS rules. + If ``None``, defaults are taken from + ``_get_generation_config(guidance_enabled)``. + + incremental_audio_decoding (bool, optional): + If True, codec-to-waveform decoding is performed incrementally during + autoregressive generation. + If False, waveform decoding occurs only after all audio tokens are produced. + + Returns: + dict[str, torch.Tensor]: + Contains: + + • **"audio"**: + Generated waveform of shape ``(B, T_audio)``, obtained via + ``audio_pred.squeeze(1)``. + + • **"audio_len"**: + Length of each generated waveform in samples, shape ``(B,)``. + """ + B = next_subword_ids.size(0) + + if generation_config is None: + generation_config = self._get_generation_config(guidance_enabled) + logging.info(f"Doing inference using the following config: {generation_config} !") + + init_inputs.update({"use_cache": True, "past_key_values": None, "guidance_enabled": guidance_enabled}) + + # warmup the model and generate the very first audio token + outputs = self.tts_model(**init_inputs) + + if self.cfg.get("inference_skip_first_code_prediction_on_init", True): + # use the last token on init, because we are shifthing it in the model forward, so we dont really need to compute it + code = init_inputs["code"][:, -1:] + else: + code, _, _ = self.tts_model.generate_step(outputs.hidden_states[:, -1:], **generation_config) + + past_key_values = outputs["past_key_values"] + + # use the text tokens to stop generation + max_steps = next_subword_ids.size(-1) + # create variable to store the audios + gen_audio_codes = torch.zeros( + B, max_steps, self.tts_model.config.num_quantizers, device=self.device, dtype=torch.long + ) + + # init subwork as all ones + subword_mask = torch.ones(B, max_steps, device=self.device, dtype=torch.bool) + # get first context subword_id, that is the last subword_ids from the warmup + first_context_subword_id = init_inputs["subword_ids"][:, -1].unsqueeze(-1) + + # initialize variables used to save the output audio + audio_pred = None + audio_pred_len = torch.zeros(B, device=self.device, dtype=torch.long) + + for i in range(max_steps): + step_start = time.time() + # current subword id is always seem + current_subword_id = next_subword_ids[:, i].unsqueeze(-1) + + if i == 0: + prev_subword_id = first_context_subword_id + else: + prev_subword_id = next_subword_ids[:, i - 1].unsqueeze(-1) + + # create subword_mask + current_subword_mask = subword_mask[:, i].unsqueeze(-1) + + code, past_key_values = self.infer_codes_one_step( + current_subword_id=current_subword_id, + prev_subword_id=prev_subword_id, + current_subword_mask=current_subword_mask, + prev_audio_tokens=code, + past_key_values=past_key_values, + guidance_enabled=guidance_enabled, + generation_config=generation_config, + ignore_eos_flag_stop=True, + ) + + # cache audio tokens + gen_audio_codes[:, i] = code.squeeze(1) + + if incremental_audio_decoding: + audio_pred_i, audio_pred_i_len = self.decode_one_audio_step( + gen_audio_codes[:, : i + 1], + number_prev_tokens=self.cfg.get("inference_codec_decoding_prev_tokens_number", None), + ) + if audio_pred is None: + audio_pred = audio_pred_i + else: + audio_pred = torch.cat([audio_pred, audio_pred_i], dim=1) + audio_pred_len += audio_pred_i_len + + step_time = time.time() - step_start + logging.info(f"Autoregressive inference step: {i} of {max_steps} take around {step_time}s") + + if not incremental_audio_decoding: + gen_audio_codes_lens = torch.tensor([gen_audio_codes.shape[1]] * gen_audio_codes.shape[0]).to(self.device) + # decode audio. Note that it is not necessary because the prompt is removed, so no special token should be on the output, but lets do it for safety + gen_audio_codes = replace_control_speech_codes( + gen_audio_codes, self._control_codes, self.codec_silence_tokens + ) + with ensures_target_precision(self.audio_codec_run_dtype), torch.no_grad(): + audio_pred, audio_pred_len = self.audio_codec.decode(gen_audio_codes, gen_audio_codes_lens) + + return audio_pred.squeeze(1), audio_pred_len + + def backward(self, *args, **kwargs): + with loss_parallel(): + super().backward(*args, **kwargs) + + def configure_optimizers(self): + return configure_optimizers(self) + + @property + def oomptimizer_schema(self) -> dict: + """ + Return a typing schema for optimal batch size calibration for various + sequence lengths using OOMptimizer. + """ + raise NotImplementedError + + def configure_model(self) -> None: + # TODO(pzelasko): refactor into separate module re-usable across models + device_mesh = self.device_mesh + if device_mesh is None: + return + + llm = self.tts_model.backbone + if isinstance(llm, PeftModel): + llm = llm.base_model.model + + if (tp_mesh := device_mesh["tensor_parallel"]).size() > 1: + self._use_tp = True + + plan = { + "layers.0": PrepareModuleInput( + input_layouts=(Replicate(),), # , None) + desired_input_layouts=(Shard(1),), # , None) + use_local_output=True, + ), + "norm": SequenceParallel(), + } + parallelize_module(llm, tp_mesh, plan) + + for transformer_block in llm.layers: + plan = { + "input_layernorm": SequenceParallel(), + "self_attn.q_proj": ColwiseParallel(), + "self_attn.k_proj": ColwiseParallel(), + "self_attn.v_proj": ColwiseParallel(), + "self_attn.o_proj": RowwiseParallel(output_layouts=Shard(1)), + "post_attention_layernorm": SequenceParallel(), + "mlp": PrepareModuleInput( + input_layouts=(Shard(1),), + desired_input_layouts=(Replicate(),), + ), + "mlp.gate_proj": ColwiseParallel(), + "mlp.up_proj": ColwiseParallel(), + "mlp.down_proj": RowwiseParallel(output_layouts=Shard(1)), + # "pre_feedforward_layernorm": SequenceParallel(), + # "post_feedforward_layernorm": SequenceParallel(), + } + + # Adjust attention module to use the local number of heads + attn_layer = transformer_block.self_attn + for attr in ("num_heads", "num_key_value_heads", "hidden_size"): + val = getattr(attn_layer, attr) + if val % tp_mesh.size() != 0: + logging.warning( + f"attn_layer.{attr}={val} is not divisible by {tp_mesh.size()=}: " + f"set a different tensor parallelism size to avoid errors." + ) + setattr(attn_layer, attr, val // tp_mesh.size()) + + parallelize_module(transformer_block, tp_mesh, plan) + + for m in ( + self.tts_model.mog_head, + self.tts_model.embed_subword, + self.tts_model.embed_context, + self.tts_model.embed_code, + self.tts_model.null_emb, + self.tts_model.bos_emb, + self.tts_model.lm_head, + ): + parallelize_module( + m, + tp_mesh, + ColwiseParallel( + input_layouts=Shard(1), + output_layouts=Shard(-1), + use_local_output=False, + ), + ) + + if (dp_mesh := device_mesh["data_parallel"]).size() > 1: + assert dp_mesh.ndim == 1 + self._use_fsdp = True + + fsdp_config = {"mesh": dp_mesh} + + for idx, layer in enumerate(llm.layers): + llm.layers[idx] = fully_shard(layer, **fsdp_config) + + for idx in range(self.tts_model._num_codebooks): + self.tts_model.audio_embeddings[idx] = fully_shard(self.tts_model.audio_embeddings[idx], **fsdp_config) + + if self.tts_model.use_local_transformer: + self.tts_model.local_transformer = fully_shard(self.tts_model.local_transformer, **fsdp_config) + self.tts_model.local_transformer_in_projection = fully_shard( + self.tts_model.local_transformer_in_projection, **fsdp_config + ) + + self.embed_text_tokens = fully_shard(self.embed_text_tokens, **fsdp_config) + self.tts_model.mog_head = fully_shard(self.tts_model.mog_head, **fsdp_config) + self.tts_model.embed_subword = fully_shard(self.tts_model.embed_subword, **fsdp_config) + self.tts_model.embed_context = fully_shard(self.tts_model.embed_context, **fsdp_config) + self.tts_model.embed_code = fully_shard(self.tts_model.embed_code, **fsdp_config) + self.tts_model.lm_head = fully_shard(self.tts_model.lm_head, **fsdp_config) + self.llm = fully_shard(self.llm, **fsdp_config) + self.tts_model = fully_shard(self.tts_model, **fsdp_config) + + def load_state_dict(self, state_dict, strict: bool = True): + try: + return super().load_state_dict(state_dict, strict=strict) + except RuntimeError: + logging.info("Error loading model state_dict !! Retrying with partial initialization!") + model_dict = set_model_dict_for_partial_init(state_dict, self.state_dict()) + return super().load_state_dict(model_dict, strict=False) + + +def load_audio_librosa(path, sr=None): + """ + Load audio using librosa with torchaudio-like behavior. + + Returns: + audio_tensor: torch.FloatTensor of shape [channels, time] + sr: sampling rate + """ + # Load with librosa (preserve original sampling rate) + audio, sr = librosa.load(path, sr=sr, mono=False) + + # Ensure shape is [channels, time] + if audio.ndim == 1: + # Mono: (time,) -> (1, time) + audio = audio[None, :] + + # Convert to torch float32 (torchaudio behavior) + audio_tensor = torch.from_numpy(audio).float() + return audio_tensor, sr + + +def maybe_to(x, dtype): + if x is None: + return None + if isinstance(x, torch.Tensor) and torch.is_floating_point(x): + return x.to(dtype) + return x + + +@contextmanager +def ensures_target_precision(target_dtype): + """ + Workaround for precision related issues when training with bf16-true PyTorch Lightning precision setting. + In bf16-true, PTL changes PyTorch's default dtype, which may break implicit assumptions for some models. + This context manager restores default float32 precision and runs the computation in float32 autocast context. + """ + default_dtype = torch.get_default_dtype() + torch.set_default_dtype(target_dtype) + try: + with torch.amp.autocast(device_type="cuda" if torch.cuda.is_available() else "cpu", dtype=target_dtype): + yield + finally: + torch.set_default_dtype(default_dtype) + + +def generate_multiturn_speaking_mask(input_ids: torch.Tensor, bos_token_id: int = 0, eos_token_id: int = 1): + """ + Efficient, batched speaking mask generator that marks 1 between and pairs. + If is missing after a , mask continues to end. Handles multiple turns. + + Args: + input_ids (torch.Tensor): LongTensor of shape (B, T) + bos_token_id (int): Token ID for + eos_token_id (int): Token ID for + + Returns: + torch.Tensor: FloatTensor of shape (B, T), with 1.0 for speaking, 0.0 for silence. + + Note BOS is considered as speaking (1) and EOS as non speaking 0 + """ + device = input_ids.device + bos_mask = (input_ids == bos_token_id).to(torch.int32).to(device) + eos_mask = (input_ids == eos_token_id).to(torch.int32).to(device) + bos_cumsum = torch.cumsum(bos_mask, dim=1) + eos_cumsum = torch.cumsum(eos_mask, dim=1) + speaking_mask = (bos_cumsum > eos_cumsum).to(torch.float32) + return speaking_mask.long() + + +def replace_control_speech_codes( + speech_codes: torch.Tensor, control_codes: torch.Tensor, silence_tokens: torch.Tensor = None +) -> torch.Tensor: + """ + Replaces control codes (speech BOS, EOS, etc) in `speech_codes` with the first frame which is + assumed to consist of 'valid' codes representing silence. + """ + if silence_tokens is not None: + # Expand to [B, 1, 74] + silence_tokens_expanded = silence_tokens.unsqueeze(0).unsqueeze(1).expand(speech_codes.shape[0], 1, -1) + return torch.where(torch.isin(speech_codes, control_codes), silence_tokens_expanded, speech_codes) + + if torch.isin(speech_codes[:, :1], control_codes).any(): + return torch.where( + torch.isin(speech_codes, control_codes), torch.zeros_like(speech_codes[:, :1]), speech_codes + ) + else: + return torch.where(torch.isin(speech_codes, control_codes), speech_codes[:, :1], speech_codes) + + +def ensures_codec_target_dtype(model): + """ + Ensures the audio codec is instantiated with the target dtype. + + This function checks whether `model.audio_codec` exists and whether its + parameters match `model.audio_codec_run_dtype`. If the codec is missing + or is running with the wrong dtype (e.g., due to PTL auto-downcasting), + the codec is reloaded by calling `setup_audio_codec()`. + + Intended to be called at runtime boundaries such as: + - `on_train_epoch_start` + - `on_validation_epoch_start` + + Args: + model: Model instance of DuplexEARTTS + + """ + if hasattr(model, "audio_codec") and next(model.audio_codec.parameters()).dtype == model.audio_codec_run_dtype: + return # already correct precision → no-op + + setup_audio_codec(model) + + +def setup_audio_codec(model): + """ + Instantiates the RVQ audio codec and injects codec embeddings into the TTS model. + + This function is responsible only for: + - Instantiating the codec model (`RVQVAEModel`). + - Loading pretrained codec weights (if configured). + - Freezing codec parameters. + - Registering RVQ embeddings inside the TTS model via `set_rvq_embs`. + + Args: + model: Model instance of DuplexEARTTS + """ + with ensures_target_precision(model.audio_codec_run_dtype): + model.audio_codec = RVQVAEModel(DictConfig(model.cfg.codec_config)) + # load pretrained codec checkpoint + if model.cfg.get("pretrained_codec_model", None): + checkpoint_state = load_checkpoint(model.cfg.pretrained_codec_model) + checkpoint_state = set_model_dict_for_partial_init(checkpoint_state, model.audio_codec.state_dict()) + model.audio_codec.load_state_dict(checkpoint_state, strict=True) + + for p in model.audio_codec.parameters(): + p.requires_grad = False + + assert callable(model.tts_model.set_rvq_embs) + + model.tts_model.set_rvq_embs(torch.stack([x.detach() for x in model.audio_codec.prvq.mus_list], 0)) + model.tts_model.rvq_embs = model.tts_model.rvq_embs.to(next(model.tts_model.parameters()).dtype) + # compute target fps + model.target_fps = model.target_sample_rate / model.audio_codec.config.wav_to_token_ratio + model.target_samples_per_frame = model.audio_codec.config.wav_to_token_ratio + + +def rescale_state_dict(state_dict, target_std=0.02, first_n_layers=None, layer_prefix="tts_model.backbone.layers."): + """ + Rescale trainable weights in a state_dict for BF16/FP16 stability. + + Args: + state_dict: PyTorch state_dict + target_std: desired target std for weights + first_n_layers: if not None, rescale only the first N transformer blocks + layer_prefix: prefix for layer names (default: "tts_model.backbone.layers.") + Returns: + new_state_dict + """ + weight_tensors = [] + + # Compute which prefixes to match if first_n_layers is set + prefixes_to_match = [] + if first_n_layers is not None: + prefixes_to_match = [f"{layer_prefix}{i}" for i in range(first_n_layers)] + + for name, param in state_dict.items(): + if not torch.is_tensor(param): + continue + + if "rvq_embs" in name: + continue + + # Skip biases & 1-dim params (norm weights/gates) + if param.ndim <= 1: + continue + + # Skip layers not in the first N + if first_n_layers is not None and not any(name.startswith(pfx) for pfx in prefixes_to_match): + continue + + weight_tensors.append(param.float()) + + if not weight_tensors: + if first_n_layers is not None: + logging.info(f"No weights found for first {first_n_layers} layers with prefix '{layer_prefix}'.") + else: + logging.info("No weights found to rescale in state_dict.") + return state_dict + + # Compute global std across selected weights (on CPU) + cpu_weights = [p.detach().cpu() for p in weight_tensors] + flat = torch.cat([p.flatten() for p in cpu_weights]) + current_std = float(torch.std(flat)) + scale = target_std / (current_std + 1e-8) + + logging.info( + f"Rescaling state_dict " + f"{'(first N layers)' if first_n_layers else '(all layers)'}: " + f"current std = {current_std:.6f}, target = {target_std}, scale = {scale:.6f}" + ) + + # Apply scaling + new_state_dict = {} + for name, param in state_dict.items(): + if ( + torch.is_tensor(param) + and param.ndim > 1 + and (first_n_layers is None or any(name.startswith(pfx) for pfx in prefixes_to_match)) + ): + new_state_dict[name] = param * scale + else: + new_state_dict[name] = param + + logging.info("Done: weights rescaled.") + return new_state_dict diff --git a/nemo/collections/speechlm2/modules/ear_tts_model.py b/nemo/collections/speechlm2/modules/ear_tts_model.py new file mode 100644 index 000000000000..f38e77fcb832 --- /dev/null +++ b/nemo/collections/speechlm2/modules/ear_tts_model.py @@ -0,0 +1,1615 @@ +# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import json +import math +import os +from dataclasses import dataclass, fields +from typing import Any + +import torch +import transformers +from omegaconf import DictConfig, OmegaConf +from torch import Tensor, nn +from torch.nn import functional as F +from transformers import AutoConfig, AutoModel, AutoModelForTextEncoding, Cache +from transformers.generation.logits_process import TopKLogitsWarper, TopPLogitsWarper + +from nemo.collections.common.tokenizers import AutoTokenizer +from nemo.collections.speechlm2.parts.precision import fp32_precision +from nemo.collections.speechlm2.parts.pretrained import set_model_dict_for_partial_init +from nemo.utils import logging + +# ============================================================================== +# MLP module and Norm +# ============================================================================== + + +class RMSNorm(nn.Module): + def __init__(self, dim: int, eps: float = 1e-6): + super().__init__() + self.eps = eps + self.weight = nn.Parameter(torch.zeros(dim)) + + def _norm(self, x): + return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps) + + def forward(self, x): + output = self._norm(x.float()) + # Llama does x.to(float16) * w whilst Gemma3 is (x * w).to(float16) + output = output * (1.0 + self.weight.float()) + return output.type_as(x) + + def extra_repr(self): + return f"{tuple(self.weight.shape)}, eps={self.eps}" + + +class MLP(nn.Module): + def __init__( + self, + hidden_size: int, + intermediate_size: int, + ): + super().__init__() + self.gate_proj = nn.Linear(hidden_size, intermediate_size, bias=False) + self.up_proj = nn.Linear(hidden_size, intermediate_size, bias=False) + self.down_proj = nn.Linear(intermediate_size, hidden_size, bias=False) + self.act_fn = nn.GELU(approximate="tanh") + + def forward(self, x: Tensor) -> Tensor: + down_proj = self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x)) + return down_proj + + +class MLPLayer(nn.Module): + def __init__( + self, + hidden_size: int, + intermediate_size: int, + eps: float = 1e-6, + ): + super().__init__() + self.pre_norm = RMSNorm(hidden_size, eps=eps) + self.mlp = MLP(hidden_size, intermediate_size) + self.post_norm = RMSNorm(hidden_size, eps=eps) + + def forward(self, x: Tensor) -> Tensor: + y = self.pre_norm(x) + y = self.mlp(y) + y = self.post_norm(y) + x = x + y + return x + + +# ============================================================================== +# Triton-accelerated and Fallback Functions +# ============================================================================== + +TRITON_IMPORTED = False +try: + import triton + import triton.language as tl + + TRITON_IMPORTED = True +except ImportError: + TRITON_IMPORTED = False + +USE_TRITON = TRITON_IMPORTED and torch.cuda.is_available() + +if USE_TRITON: + logging.info("Triton available & CUDA detected. Using Triton kernel for batch_matmul.") + + @triton.jit + def batch_matmul_kernel( + x_ptr, + w_ptr, + y_ptr, + result_ptr, + b, + d_in, + d_out, + n, + BLOCK_SIZE_DIN: tl.constexpr, + BLOCK_SIZE_DOUT: tl.constexpr, + ): + batch_id = tl.program_id(axis=0) + dout_block_id = tl.program_id(axis=1) + + if batch_id >= b: + return + + idx = tl.load(y_ptr + batch_id) + + x_offset = x_ptr + batch_id * d_in + w_offset = w_ptr + idx * d_out * d_in + + dout_offsets = dout_block_id * BLOCK_SIZE_DOUT + tl.arange(0, BLOCK_SIZE_DOUT) + dout_mask = dout_offsets < d_out + + result_block = tl.zeros([BLOCK_SIZE_DOUT], dtype=tl.float32) + + for din_start in range(0, d_in, BLOCK_SIZE_DIN): + din_offsets = din_start + tl.arange(0, BLOCK_SIZE_DIN) + din_mask = din_offsets < d_in + + x_i = tl.load(x_offset + din_offsets, mask=din_mask, other=0.0) + + w_i_block = tl.load( + w_offset + dout_offsets[:, None] * d_in + din_offsets[None, :], + mask=(dout_mask[:, None] & din_mask[None, :]), + other=0.0, + ) + + result_block += tl.sum(w_i_block * x_i[None, :], axis=1) + + result_offset = result_ptr + batch_id * d_out + dout_offsets + tl.store(result_offset, result_block, mask=dout_mask) + + def batch_matmul_triton(x, w, y, BLOCK_SIZE_DIN: int = 16, BLOCK_SIZE_DOUT: int = 64): + assert x.is_contiguous() and w.is_contiguous() and y.is_contiguous() + + b, d_in = x.shape + n, d_out, _ = w.shape + result = torch.empty(b, d_out, device=x.device, dtype=torch.float32) + + batch_matmul_kernel[lambda meta: (b, triton.cdiv(d_out, meta["BLOCK_SIZE_DOUT"]))]( + x.float(), + w.float(), + y, + result, + b, + d_in, + d_out, + n, + BLOCK_SIZE_DIN=BLOCK_SIZE_DIN, + BLOCK_SIZE_DOUT=BLOCK_SIZE_DOUT, + ) + + return result.to(dtype=x.dtype) + + batch_matmul = batch_matmul_triton + +else: + logging.info("Using PyTorch fallback (Triton unavailable or no CUDA).") + + # Fallback to PyTorch implementation if Triton is not available + def batch_matmul_pytorch(x: Tensor, w: Tensor, y: Tensor, *args, **kwargs) -> Tensor: + """ + Performs a batched matrix multiplication using PyTorch's native functions. + + This function serves as a fallback when Triton is not available. It achieves + the same result by gathering the appropriate weight matrices and using `torch.bmm`. + + Args: + x (Tensor): The input tensor of shape `[batch_size, d_in]`. + w (Tensor): The weight tensor of shape `[num_weights, d_out, d_in]`. + y (Tensor): The index tensor of shape `[batch_size]`. + + Returns: + Tensor: The result of the multiplication, shape `[batch_size, d_out]`. + """ + # w[y] gathers the weight matrices for each item in the batch. + # x.unsqueeze(2) reshapes x to [batch_size, d_in, 1] for bmm. + # The result is squeezed to remove the trailing dimension of size 1. + return torch.bmm(w[y], x.unsqueeze(2)).squeeze(2) + + batch_matmul = batch_matmul_pytorch + + +# ============================================================================== +# Core Mathematical and Masking Functions +# ============================================================================== + + +def gumbel_like(tensor: Tensor, eps: float = 1e-8) -> Tensor: + """ + Generates a tensor of Gumbel noise with the same shape as the input tensor. + + This is used for the Gumbel-Max trick, a technique to sample from a categorical + distribution in a differentiable way (using a straight-through estimator). + + Args: + tensor (torch.Tensor): The input tensor to match the shape of. + eps (float): A small epsilon value for numerical stability. + + Returns: + torch.Tensor: A tensor containing Gumbel noise. + """ + # Sample from a uniform distribution + u = torch.rand_like(tensor) + # Apply the inverse CDF of the Gumbel distribution + return -torch.log(-torch.log(u + eps) + eps) + + +def sequence_mask(lengths: Tensor, max_length: Tensor | int | None = None) -> Tensor: + """ + Creates a boolean mask from a 1D tensor of sequence lengths. + + This function is useful for masking out padding in sequences. Given a tensor + of lengths, it produces a 2D boolean tensor where `mask[i, j]` is `True` if + `j < lengths[i]` and `False` otherwise. + + Args: + lengths (Long Tensor): A 1D tensor of integer lengths. Shape: `[batch_size]`. + max_length (Long Tensor | int | None, optional): The maximum length of the mask. If None, + it is inferred from the maximum value + in `lengths`. Defaults to None. + + Returns: + Tensor: The boolean mask. Shape: `[batch_size, max_length]`. + """ + if max_length is None: + max_length = lengths.max() + + # Create a range tensor from 0 to max_length - 1 + x = torch.arange(max_length, dtype=lengths.dtype, device=lengths.device) # type: ignore[arg-type] + + # Compare each length with the range tensor to create the mask via broadcasting + return x.unsqueeze(0) < lengths.unsqueeze(1) + + +def get_masking_rate(rate: Tensor, exponent: float = 3.0) -> Tensor: + """ + Converts a desired token keep rate to a masking rate using a power function. + + This function is part of a scheduling strategy for masking, where the effective + masking rate changes non-linearly with the desired keep rate. This function is + its own inverse. + + Args: + rate (Tensor): The desired rate of tokens to keep (0 to 1). + exponent (float, optional): The exponent for the transformation. Defaults to 3.0. + + Returns: + Tensor: The corresponding masking rate. + """ + return (1 - rate.pow(exponent)).pow(1 / exponent) + + +# Alias the function for clarity in the inverse context +get_rate = get_masking_rate + + +def get_mask( + code_mask: Tensor, + num_masking: Tensor, + unmasking: bool = False, + validate: bool = False, +) -> Tensor: + """ + Adjusts a boolean mask by masking or unmasking tokens from the end. + + This function operates on a `code_mask` where `True` values represent valid + tokens and are assumed to be contiguous at the start of the sequence. It + calculates a new mask by decreasing (masking) or increasing (unmasking) + the number of `True` values. + + Args: + code_mask (Tensor): The input boolean mask. Shape: `[..., depth]`. + num_masking (Tensor): The number of tokens to mask or unmask. + Shape matching `code_mask`'s batch dimensions. + unmasking (bool, optional): If `True`, increases the number of valid + tokens (unmasking). Defaults to `False`. + validate (bool, optional): If `True`, asserts that the input `code_mask` + is contiguous. This adds a slight overhead and + is mainly for debugging. Defaults to `False`. + + Returns: + Tensor: A new boolean mask with the adjusted length of valid tokens. + Shape is identical to `code_mask`. + """ + depth = code_mask.size(-1) + num_valid = code_mask.sum(dim=-1, dtype=torch.long) + + if validate: + # Reconstruct the expected contiguous mask and assert equality. + expected_mask = sequence_mask(num_valid.view(-1), depth).view_as(code_mask) + assert torch.equal( + code_mask, expected_mask + ), "Input `code_mask` must have contiguous `True` values at the beginning." + + # Calculate the target number of valid tokens. + if not unmasking: + # Masking: reduce the number of valid tokens, ensuring it's not negative. + num_to_keep = (num_valid - num_masking).clamp_min(0) + else: + # Unmasking: increase the number of valid tokens, capped by total depth. + num_to_keep = (num_valid + num_masking).clamp_max(depth) + + # Generate the new mask using the final number of tokens to keep. + return sequence_mask(num_to_keep.view(-1), depth).view_as(code_mask) + + +# ============================================================================== +# Model and Vocabulary Utilities +# ============================================================================== + + +@dataclass +class RVQEARTTSOutput: + """ + Output type for the RVQEARTTSModel, providing a structured way to return model outputs. + This class allows accessing outputs by attribute, key, or index. + """ + + loss: Tensor | None = None + lm_loss: Tensor | None = None + c_loss: Tensor | None = None + k_loss: Tensor | None = None + + hidden_states: Tensor | None = None + past_key_values: Tensor | None = None + + codes: Tensor | None = None + lm_logits: Tensor | None = None + eos_flag: Tensor | None = None + + def __getitem__(self, item: str | int): + """Allows for accessing attributes by key or index.""" + if isinstance(item, str): + return getattr(self, item) + else: + # Access fields in the order they are defined in the dataclass + return getattr(self, fields(self)[item].name) + + +def find_and_delete_module(parent_module: nn.Module, target_module: nn.Module, parent_name: str) -> str | None: + """ + Recursively searches for a specific module instance and deletes it from its parent. + + This is useful for dynamically modifying a model's architecture, such as replacing + an existing embedding layer with a custom one. + + Args: + parent_module (nn.Module): The module to search within. + target_module (nn.Module): The exact module instance to find and delete. + parent_name (str): The initial name of the parent module for constructing the path. + + Returns: + str | None: The full dotted name of the deleted attribute if found, otherwise None. + """ + # Iterate over all direct children of the parent module + for name, module in parent_module.named_children(): + # Use the 'is' operator to check for object identity, not just value equality + if module is target_module: + # If found, delete the attribute from the parent and return its name + delattr(parent_module, name) + return f"{parent_name}.{name}" + + # If not found, recurse into the child module + found_path = find_and_delete_module(module, target_module, parent_name=f"{parent_name}.{name}") + if found_path: + return found_path + return None + + +def build_vocabs( + tokenizer: AutoTokenizer, vocab_dir: str | None = None +) -> tuple[dict[int, tuple[int, ...]], dict[str, int], int]: + """ + Builds or loads a character-level vocabulary derived from a subword tokenizer. + + This function creates a mapping from each subword in a pretrained tokenizer to a + sequence of character IDs. It follows a modern practice of using a directory + to save and load vocabulary files, making the process more robust and extensible. + + The primary source of truth is the `char_vocab.json` file. If it exists, it's + loaded. Otherwise, it's created from the pretrained tokenizer and saved. + + Args: + tokenizer (AutoTokenizer): The pretrained Hugging Face tokenizer class. + vocab_dir (str | None, optional): The directory to save or load the character + vocabulary from. Defaults to None. + + Returns: + tuple[dict[int, tuple[int, ...]], dict[str, int], int]: A tuple containing: + - A mapping from subword IDs to tuples of character IDs. + - The character-to-ID vocabulary dictionary. + - The ID for the subword padding token. + """ + + def _build_char_vocab() -> dict[str, int]: + # Find all single-character tokens in the original tokenizer's vocabulary + single_chars = { + subword: subword_id for subword, subword_id in tokenizer.tokenizer.vocab.items() if len(subword) == 1 + } + # Create a new, dense character vocabulary sorted by the original token ID + sorted_chars = sorted(single_chars.keys(), key=lambda k: single_chars[k]) + char_vocab = {char: i for i, char in enumerate(sorted_chars)} + return char_vocab + + # 1. Load or build the character vocabulary + if vocab_dir: + from filelock import FileLock + + char_vocab_file = os.path.join(vocab_dir, "char_vocab.json") + os.makedirs(vocab_dir, exist_ok=True) + with FileLock(char_vocab_file + ".lock", timeout=60): + if not os.path.exists(char_vocab_file): + char_vocab = _build_char_vocab() + + logging.info(f"Saving character vocabulary to {char_vocab_file}") + with open(char_vocab_file, "w", encoding="utf-8") as f: + json.dump(char_vocab, f, ensure_ascii=False, indent=2) + + # All processes can now safely load the file. + logging.info(f"Loading character vocabulary from {char_vocab_file}") + with open(char_vocab_file, encoding="utf-8") as f: + char_vocab = json.load(f) + else: + # No cache directory provided, build in memory. + logging.info("Building character vocabulary from tokenizer.") + char_vocab = _build_char_vocab() + + # 2. Reconstruct the subword-to-character mapping on the fly + subword_id_to_char_ids = { + subword_id: tuple(char_vocab[char] for char in subword if char in char_vocab) + for subword, subword_id in tokenizer.tokenizer.vocab.items() + } + # Filter out subwords that contain characters not in our character vocabulary + subword_id_to_char_ids = {k: v for k, v in subword_id_to_char_ids.items() if v} + + # 3. Define a padding index for subwords + subword_padding_idx = len(tokenizer.vocab) + # The padding subword maps to a new character padding ID + subword_id_to_char_ids[subword_padding_idx] = (len(char_vocab),) + return subword_id_to_char_ids, char_vocab, subword_padding_idx + + +@torch.compile +def depthsum_encoding_step( + embs: Tensor, + r: Tensor, + code: Tensor, + depth_str: int = 0, + k: int = 72, +) -> Tensor: + for i in range(depth_str, depth_str + k): + idx_sel = ( + embs[i].pow(2).sum(-1) # [g?, v] + - 2 + * (r.unsqueeze(-2) @ embs[i].transpose(-1, -2)).squeeze(-2) # [b, ?, g?, h] , [g?, h, v] -> [b, ?, g?, v] + ).argmin(-1) + + emb_i = F.embedding(idx_sel, embs[i]) + r = r - emb_i + + code[..., i] = idx_sel + + return code + + +class MoGHead(nn.Module): + """ + A Mixture of Gaussians (MoG) prediction head. + + This module takes a hidden state and predicts the parameters for a mixture of + Gaussian distributions. It's suitable for modeling continuous, multi-modal data. + + Args: + hidden_size (int): The dimensionality of the input hidden state. + intermediate_size (int): The dimensionality of the MLP layers. + out_size (int): The dimensionality of the output vectors (the mean of each Gaussian). + num_layers (int): The number of MLP layers in the stack. + num_predictions (int): The number of Gaussian components in the mixture. + low_rank (int | None): The dimensionality used for compressing the hidden states. + min_log_std (float): The minimum value for the logarithm of the standard deviation. + eps (float): A small epsilon value for the RMSNorm layers. + """ + + def __init__( + self, + hidden_size: int, + intermediate_size: int, + out_size: int, + num_layers: int, + num_predictions: int, + low_rank: int | None = 64, + min_log_std: float = -4.0, + eps: float = 1e-6, + ): + super().__init__() + self.out_size = out_size + self.low_rank = low_rank + self.num_predictions = num_predictions + self.min_log_std = min_log_std + + self.mlp_stack = nn.Sequential( + *[MLPLayer(hidden_size, intermediate_size, eps=eps) for _ in range(num_layers)], + RMSNorm(hidden_size, eps=eps), + ) + + if low_rank is None: + self.proj_logits = nn.Linear(hidden_size, num_predictions, bias=False) # Predicts mixture weights + self.proj_mus = nn.Linear(hidden_size, num_predictions * out_size, bias=False) # Predicts means + self.proj_logs = nn.Linear(hidden_size, 1, bias=False) # Predicts log standard deviations + else: + assert low_rank < out_size + self.proj_logits = nn.Linear(hidden_size, num_predictions, bias=False) # Predicts mixture weights + self.proj_mus = nn.Linear(hidden_size, num_predictions * low_rank, bias=False) # Predicts means + self.proj_logs = nn.Linear(hidden_size, 1, bias=False) # Predicts log standard deviations + self.proj_else = nn.Linear(hidden_size, out_size, bias=False) + self.low_mat = nn.Parameter(torch.randn(num_predictions, out_size, low_rank) * (low_rank**-0.5)) + + def infer(self, x: Tensor, guidance_scale: float = 0.0, top_p_or_k: float | int = 1.0) -> tuple[Tensor, Tensor]: + """ + Performs inference by sampling from the predicted mixture distribution. + + Args: + x (Tensor): The input hidden state. + guidance_scale (float): The weight for classifier-free guidance. + top_p_or_k (float | int): The value for top-p (nucleus) or top-k sampling of the mixture components. + + Returns: + tuple[Tensor, Tensor]: A tuple containing the mean of the chosen component, + and the log standard deviations. + """ + b, t, _ = x.size() + n, d = self.num_predictions, self.low_rank or self.out_size + x = self.mlp_stack(x) + if guidance_scale > 0: + b //= 2 + x_cond, x_uncond = x.chunk(2, dim=0) + x = x_cond + guidance_scale * (x_cond - x_uncond) + + logits = self.proj_logits(x) + + # Apply top-p or top-k filtering to the mixture logits + if top_p_or_k is not None: + + logits = ( + TopPLogitsWarper(top_p_or_k)( + None, + logits.view(-1, n), + ).view_as(logits) + if isinstance(top_p_or_k, float) + else TopKLogitsWarper(top_p_or_k)( + None, + logits.view(-1, n), + ).view_as(logits) + ) + + # Sample a mixture component using the Gumbel-Max trick + mixture_indices = (F.log_softmax(logits, dim=-1) + gumbel_like(logits)).argmax(-1) + + # Select the mean corresponding to the sampled component + mu = batch_matmul( + x.view(b * t, -1), + self.proj_mus.weight.detach().view(n, d, -1), + mixture_indices.view(b * t), + ).view(b, t, d) + if self.proj_mus.bias is not None: + mu += self.proj_mus.bias.detach().view(n, d)[mixture_indices] + + if self.low_rank: + assert math.log2(d).is_integer() and math.log2(self.out_size).is_integer() + mu = batch_matmul( + mu.view(b * t, -1), + self.low_mat.detach().view(n, self.out_size, -1), + mixture_indices.view(b * t), + BLOCK_SIZE_DIN=d, + BLOCK_SIZE_DOUT=self.out_size, + ).view(b, t, self.out_size) + + mu_res = self.proj_else(x) + else: + mu_res = torch.zeros((b, t, d), device=x.device) + + logs = self.proj_logs(x).clamp_min(self.min_log_std) + return mu * torch.exp(logs) + mu_res, logs + + def forward(self, x: Tensor) -> tuple[Tensor, Tensor, Tensor, Tensor]: + """ + Performs a forward pass for training. + + Args: + x (Tensor): The input hidden state. + + Returns: + tuple[Tensor, Tensor, Tensor]: A tuple containing the mixture logits, + the means for all components, and the + log standard deviations. + """ + b, t, _ = x.size() + d = self.low_rank or self.out_size + x = self.mlp_stack(x) + logits = self.proj_logits(x) + mus = self.proj_mus(x).view(b, t, self.num_predictions, d) + logs = self.proj_logs(x).clamp_min(self.min_log_std) + + if self.low_rank: + mu_res = self.proj_else(x) + else: + mu_res = torch.zeros((b, t, d), device=x.device) + return logits, mus, mu_res, logs + + def dist(self, mus: Tensor, mu: Tensor) -> Tensor: + """ + mus: [b, t, n, d] + mu: [b, t, d] + + return: [b, t, n] + """ + if self.low_rank is None: + return (mus - mu.unsqueeze(-2)).pow(2).sum(-1) + else: + low_mat_sq = self.low_mat.transpose(-1, -2) @ self.low_mat + x, y = mus, mu + wx_sq = ( + x + * torch.einsum( + "btni,nij->btnj", + x, + low_mat_sq.to(x), + ) + ).sum( + -1 + ) # [b, t, n] + y_sq = y.pow(2).sum(-1, keepdim=True) # [b, t, 1] + xwy = (x * torch.einsum("bti,nij->btnj", y, self.low_mat.to(y))).sum( + -1 + ) # [b, t, n, d_l], [n, d_i, d_l], [b, t, d_i] -> [b, t, n] + + dist = wx_sq + y_sq - 2 * xwy + return torch.abs(dist) + + +class NeMoSubwordFlagEmbedding(nn.Module): + """ + Adds a tiny embedding table for continuation tokens + (subwords that do NOT start with Ġ or the word-boundary marker). + Compatible with NeMo AutoTokenizer. + """ + + def __init__(self, tokenizer: AutoTokenizer, d_model: int): + super().__init__() + + self.tokenizer = tokenizer + self.vocab_size = self.tokenizer.vocab_size + self.d_model = d_model + + # Precompute continuation flags + tokens = [self.tokenizer.ids_to_tokens(i) for i in range(self.vocab_size)] + self.register_buffer( + 'is_continuation', + torch.tensor( + [1 if not (tok.startswith("Ġ") or tok.startswith("▁")) else 0 for tok in tokens], dtype=torch.long + ), + ) + + # Tiny embedding table: 0 = word-start, 1 = continuation + init_std = self.d_model**-0.5 + self.cont_emb = nn.Embedding(2, self.d_model) + nn.init.normal_(self.cont_emb.weight, mean=0.0, std=init_std) + + # Force word-start embedding to zero so only continuation tokens get shifted + self.cont_emb.weight.data[0].zero_() + + def forward(self, subword_embeds: torch.Tensor, token_ids: torch.LongTensor): + # Continuation flags + cont_flags = self.is_continuation[token_ids] + + # Add continuation embedding + cont_emb = self.cont_emb(cont_flags) + return subword_embeds + cont_emb + + +class SubwordFlagEmbedding(nn.Module): + """ + Adds a small continuation embedding for subwords (tokens without word-boundary marker). + Automatically adds a custom padding token at index vocab_size. + Ignores special tokens (starting with '<') when computing continuation flags. + """ + + def __init__(self, tokenizer: AutoTokenizer, d_model: int): + super().__init__() + + self.tokenizer = tokenizer + self.vocab_size = self.tokenizer.vocab_size + self.d_model = d_model + + # Custom pad token at vocab_size + self.pad_id = self.vocab_size + # register pad_id as a tensor buffer to avoid device issues + self.register_buffer("pad_tensor", torch.tensor(self.pad_id, dtype=torch.long)) + + # Precompute continuation flags + tokens = [self.tokenizer.ids_to_tokens(i) for i in range(self.vocab_size)] + cont_flags = [ + 1 if not (tok.startswith("Ġ") or tok.startswith("▁") or tok.startswith("<")) else 0 for tok in tokens + ] + cont_flags.append(0) # for the custom pad token + self.register_buffer("is_continuation", torch.tensor(cont_flags, dtype=torch.long)) + + # Continuation embedding + init_std = self.d_model**-0.5 + self.cont_emb = nn.Embedding(2, self.d_model) + nn.init.normal_(self.cont_emb.weight, mean=0.0, std=init_std) + self.cont_emb.weight.data[0].zero_() + + def forward(self, subword_embeds: torch.Tensor, token_ids: torch.LongTensor): + # Replace OOV token IDs with pad_id safely + token_ids_clamped = torch.where(token_ids >= self.vocab_size, self.pad_tensor, token_ids) + # Continuation flags + cont_flags = self.is_continuation[token_ids_clamped] + # Add continuation embedding + cont_emb = self.cont_emb(cont_flags) + return subword_embeds + cont_emb + + +class BOSEOSEmbedding(nn.Module): + """ + Adds independent embeddings for BOS and EOS tokens using a single embedding table. + Index 0 = regular token (ignored), 1 = BOS, 2 = EOS. + Compatible with Hugging Face tokenizers that may or may not have BOS/EOS. + """ + + def __init__(self, tokenizer: AutoTokenizer, d_model: int): + super().__init__() + + self.tokenizer = tokenizer + # vocab size that includes special tokens + vocab_dict = self.tokenizer.tokenizer.get_vocab() + self.vocab_size = max(vocab_dict.values()) + self.d_model = d_model + + # Custom pad token for OOVs + self.pad_id = self.vocab_size + self.register_buffer("pad_tensor", torch.tensor(self.pad_id, dtype=torch.long)) + + # Identify BOS and EOS tokens (may be None) + tokens = [self.tokenizer.ids_to_tokens(i) for i in range(self.vocab_size)] + + special_flags = [] + for tok in tokens: + if self.tokenizer.bos_token is not None and tok == self.tokenizer.bos_token: + special_flags.append(1) + elif self.tokenizer.eos_token is not None and tok == self.tokenizer.eos_token: + special_flags.append(2) + else: + special_flags.append(0) + special_flags.append(0) # for custom pad token + self.register_buffer("special_flags", torch.tensor(special_flags, dtype=torch.long)) + # Embedding table: 0 = regular, 1 = BOS, 2 = EOS + init_std = self.d_model**-0.5 + self.special_emb = nn.Embedding(3, d_model) + nn.init.normal_(self.special_emb.weight, mean=0.0, std=init_std) + self.special_emb.weight.data[0].zero_() # regular tokens ignored + + def forward(self, token_embeds: torch.Tensor, token_ids: torch.LongTensor): + """ + token_embeds: (B, T, d_model) + token_ids: (B, T) + """ + # Clamp OOVs to custom pad token + safe_ids = torch.where(token_ids >= self.vocab_size, self.pad_tensor, token_ids) + + # Lookup flags (0=regular, 1=BOS, 2=EOS) + flags = self.special_flags[safe_ids] + return token_embeds + self.special_emb(flags) + + +class SubwordEmbedding(nn.Module): + """ + Produces subword embeddings from a Hugging Face tokenizer vocabulary. + No special handling for OOVs or padding — assumes token_ids are valid. + """ + + def __init__(self, tokenizer: AutoTokenizer, d_model: int): + super().__init__() + self.tokenizer = tokenizer + + # Get vocab size from tokenizer + vocab_dict = self.tokenizer.tokenizer.get_vocab() + self.vocab_size = max(vocab_dict.values()) + 1 # +1 for safety + self.d_model = d_model + + # Subword embedding table + init_std = d_model**-0.5 + self.subword_emb = nn.Embedding(self.vocab_size, d_model) + nn.init.normal_(self.subword_emb.weight, mean=0.0, std=init_std) + + def forward(self, token_ids: torch.LongTensor, subword_mask: torch.tensor = None): + """ + token_ids: (B, T) + subword_mask: (B, T) + Returns: + subword_embeds: (B, T, d_model) + """ + return self.subword_emb(token_ids) + + +class CharAwareSubwordEncoder(nn.Module): + """ + An encoder that creates subword embeddings from character-level embeddings. + + This module replaces a standard subword embedding layer. It breaks down each + subword into its constituent characters, embeds the characters, and then + aggregates these character embeddings (e.g., via mean pooling) to form the + final subword representation. This allows the model to handle rare or out-of-vocabulary + subwords more gracefully. + + Args: + out_size (int): The dimensionality of the output embedding vectors. + tokenizer (AutoTokenizer): The Hugging Face tokenizer class. + vocab_dir (str | None): Directory to save/load the character vocabulary. + backbone_type (str | None): The type of backbone model from Hugging Face (e.g., "t5gemma"). + backbone_model_class (str | None): The class name of the backbone model if not using AutoModel. + backbone_config_class (str | None): The class name of the backbone config. + backbone_config (DictConfig | None): A configuration for the backbone model. + """ + + def __init__( + self, + out_size: int, + tokenizer: AutoTokenizer, + vocab_dir: str | None = None, + backbone_type: str | None = "t5gemma", + backbone_model_class: str | None = None, + backbone_config_class: str | None = None, + backbone_config: DictConfig | None = None, + use_subword_flag_emb: bool = True, + use_bos_eos_emb: bool = True, + use_cumulative_word_emb: bool = False, + ): + super().__init__() + + # 1. Build or load the character vocabulary + self.subword_id_to_char_ids, self.char_vocab, self.subword_padding_idx = build_vocabs( + tokenizer, + vocab_dir, + ) + + self.char_padding_idx = len(self.char_vocab) + self.use_subword_flag_emb = use_subword_flag_emb + self.use_bos_eos_emb = use_bos_eos_emb + + # 2. Initialize the backbone model + if backbone_type: + config = AutoConfig.for_model( + backbone_type, **(OmegaConf.to_container(backbone_config, resolve=True) if backbone_config else {}) + ) + self.backbone = AutoModelForTextEncoding.from_config(config) + else: + assert backbone_model_class and backbone_config_class + config_class = getattr(transformers, backbone_config_class) + model_class = getattr(transformers, backbone_model_class) + config = config_class(**(OmegaConf.to_container(backbone_config, resolve=True) if backbone_config else {})) + self.backbone = model_class(config) + + self.hidden_size = self.backbone.get_input_embeddings().weight.size(-1) + + # 3. Delete the original subword embedding layer and replace it with our character embedding layer + find_and_delete_module(self.backbone, self.backbone.get_input_embeddings(), "backbone") + self.embed_tokens = nn.Embedding(len(self.char_vocab) + 1, self.hidden_size, padding_idx=self.char_padding_idx) + self.proj_embedding = nn.Linear(self.hidden_size, out_size, bias=False) + + if self.use_subword_flag_emb: + self.subword_flag_emb = SubwordFlagEmbedding(tokenizer, self.hidden_size) + + if self.use_bos_eos_emb: + self.bos_eos_emb = BOSEOSEmbedding(tokenizer, self.hidden_size) + + def prepare_inputs(self, subword_ids: Tensor, padding_mask: Tensor) -> tuple[Tensor, Tensor]: + """ + Converts a batch of subword IDs into a padded batch of character IDs. + + Args: + subword_ids (Tensor): A tensor of subword IDs. Shape: `[batch, seq_len]`. + padding_mask (Tensor): A boolean mask indicating valid (non-padding) subwords. + + Returns: + tuple[Tensor, Tensor]: A tuple containing: + - Padded character IDs. Shape: `[num_valid_subwords, max_char_len]`. + - Lengths of each character sequence. Shape: `[num_valid_subwords]`. + """ + device = subword_ids.device + # Select only the valid subword IDs + subword_id_list = torch.masked_select(subword_ids, padding_mask).cpu().tolist() + # Map each subword ID to its sequence of character IDs + char_id_list = [list(self.subword_id_to_char_ids.get(x, ())) for x in subword_id_list] + + char_lengths = torch.tensor([len(x) for x in char_id_list], dtype=torch.long, device=device) + batch_size = char_lengths.size(0) + max_len = int(char_lengths.max().item()) if batch_size > 0 else 0 + + # Create a padded tensor for the character IDs + char_ids = torch.full((batch_size, max_len), self.char_padding_idx, dtype=torch.long, device=device) + for i, char_seq in enumerate(char_id_list): + char_ids[i, : len(char_seq)] = torch.tensor(char_seq, dtype=torch.long, device=device) + + return char_ids, char_lengths + + def forward(self, subword_ids: Tensor, subword_mask: Tensor | None = None) -> Tensor: + """ + Performs the forward pass to get character-aware subword embeddings. + + Args: + subword_ids (Tensor): A tensor of subword IDs. Shape: `[batch, seq_len]`. + subword_mask (Tensor | None): A boolean mask for padding. Defaults to None. + + Returns: + Tensor: The final subword embeddings. Shape: `[batch, seq_len, hidden_size]`. + """ + if subword_mask is None: + subword_mask = torch.ones_like(subword_ids, dtype=torch.bool) + + # 1. Convert subword IDs to character IDs + char_ids, char_lengths = self.prepare_inputs(subword_ids, subword_mask) + + # char_mask = sequence_mask(char_lengths).float() + char_mask = sequence_mask(char_lengths) + + # 2. Get character embeddings and pass them through the backbone + char_embeds = self.embed_tokens(char_ids) + + # The backbone model should be able to accept `inputs_embeds` + char_hidden_states = self.backbone(inputs_embeds=char_embeds, attention_mask=char_mask).last_hidden_state + + # 3. Aggregate character embeddings to form subword embeddings (mean pooling) + # We mask the padding characters before summing to get a correct mean. + masked_sum = (char_hidden_states * char_mask.unsqueeze(-1)).sum(dim=1) + # Avoid division by zero for empty sequences + mean_emb = masked_sum / (char_lengths.unsqueeze(-1).clamp(min=1)) + + # 4. Scatter the aggregated embeddings back to the original subword sequence shape + out_emb = self.proj_embedding(mean_emb) + subword_embeds = torch.zeros( + subword_ids.shape + (out_emb.size(-1),), device=subword_ids.device, dtype=out_emb.dtype + ) + subword_embeds[subword_mask] = out_emb + + if self.use_subword_flag_emb: + subword_embeds = self.subword_flag_emb(subword_embeds, subword_ids) + + if self.use_bos_eos_emb: + subword_embeds = self.bos_eos_emb(subword_embeds, subword_ids) + + return subword_embeds + + +class GatedProjectedSumRMSNorm(nn.Module): + def __init__(self, audio_dim, text_dim, hidden_dim, final_norm=True, num_codebooks=31, init_residual_scale=0.5): + super().__init__() + self.num_codebooks = num_codebooks + + self.audio_proj = nn.Linear(audio_dim, hidden_dim) + self.text_proj = nn.Linear(text_dim, hidden_dim) + + nn.init.normal_(self.audio_proj.weight, mean=0.0, std=0.015) + nn.init.zeros_(self.audio_proj.bias) + nn.init.normal_(self.text_proj.weight, mean=0.0, std=0.015) + nn.init.zeros_(self.text_proj.bias) + + # FP32 gate params + self.gate = nn.Parameter(torch.zeros(hidden_dim, dtype=torch.float32)) + self.residual_scale = nn.Parameter(torch.tensor(init_residual_scale, dtype=torch.float32)) + + self.final_norm = RMSNorm(hidden_dim) if final_norm else nn.Identity() + + def forward(self, audio_emb, text_emb): + audio_emb = audio_emb / self.num_codebooks + + # projections run in model dtype (BF16) + audio_h = self.audio_proj(audio_emb) + text_h = self.text_proj(text_emb) + + dtype = audio_h.dtype + + with fp32_precision(): + gate = torch.sigmoid(self.gate) # FP32 + res = torch.sigmoid(self.residual_scale) # FP32 + + h = gate.to(dtype) * audio_h + (1 - gate).to(dtype) * text_h + h = res.to(dtype) * h + h = self.final_norm(h.float()).to(dtype) + + return h + + +class RVQEARTTSModel(nn.Module): + """ + Main RVQEARTTS model for training and inference. + + This model integrates a character-aware text encoder with a transformer backbone + and a Mixture-of-Gaussians (MoG) prediction head. + + The architecture is based on the Streaming TTS model proposed in + "Audio Flamingo 3" (https://arxiv.org/abs/2507.08128), with several improvements: + + 1. Gated fusion of text and audio representations + (`GatedProjectedSumRMSNorm`). + + 2. Subword-aware embeddings for improved pronunciation of multi-token words + (`SubwordFlagEmbedding`). + + 3. Custom BOS and EOS embeddings for duplex interaction support, + enabling interruption-aware generation (`BOSEOSEmbedding`). + + Args: + config (DictConfig | dict[str, Any]): The configuration object for the model. + """ + + config_class = DictConfig + rvq_embs: Tensor + + def __init__(self, config: DictConfig | dict[str, Any], tokenizer: AutoTokenizer = None): + super().__init__() + self.config = config + + # Backbone module + if self.config.get("pretrained_text_name", None): + # Load pretrained backbone from huggingface + from nemo.collections.speechlm2.parts.pretrained import load_pretrained_hf + + llm = load_pretrained_hf(self.config.pretrained_text_name, pretrained_weights=True).train() + self.backbone = llm.model # fetch PretrainedBaseModel from model "ForCausalLM" + else: + if self.config.get("backbone_type", None) is None: + assert ( + self.config.get("backbone_model_class", None) is not None + and self.config.get("backbone_config_class", None) is not None + ) + backbone_config = getattr(transformers, self.config.backbone_config_class)( + **( + OmegaConf.to_container(self.config.backbone_config, resolve=True) + if self.config.backbone_config + else {} + ), + ) + self.backbone = getattr(transformers, self.config.backbone_model_class)(backbone_config) + else: + backbone_config = AutoConfig.for_model( + self.config.backbone_type, + **( + OmegaConf.to_container(self.config.backbone_config, resolve=True) + if self.config.backbone_config + else {} + ), + ) + self.backbone = AutoModel.from_config(backbone_config) + + self.hidden_size = self.backbone.get_input_embeddings().weight.size(-1) + find_and_delete_module(self.backbone, self.backbone.get_input_embeddings(), "backbone") + + # Embedding and projection layers + self.bos_emb = nn.Parameter(torch.randn(self.hidden_size)) + self.null_emb = nn.Parameter(torch.randn(self.hidden_size)) + if self.config.random_target_masking: + self.embed_target_mask = nn.Embedding(self.config.num_quantizers, self.hidden_size) + + self.embed_code = nn.Linear(self.config.latent_size, self.hidden_size, bias=False) + + self.embed_context = ( + nn.Linear(self.config.context_hidden_size, self.hidden_size, bias=False) + if self.config.context_hidden_size + else None + ) + + self.embed_subword = ( + CharAwareSubwordEncoder( + tokenizer=tokenizer, + out_size=self.hidden_size, + use_subword_flag_emb=self.config.use_subword_flag_emb, + use_bos_eos_emb=self.config.use_bos_eos_emb, + **self.config.cas_config, + ) + if self.config.cas_config + else None + ) + + if self.config.use_gated_fusion_for_text_audio: + self.gated_fusion_audio_text = GatedProjectedSumRMSNorm( + self.hidden_size, self.hidden_size, self.hidden_size, self.config.num_quantizers + ) + + # Prediction Heads + if not self.config.disable_eos_prediction: + self.lm_head = nn.Linear(self.hidden_size, 2, bias=False) + + self.mog_head = MoGHead( + hidden_size=self.hidden_size, + out_size=self.config.latent_size, + **self.config.mog_head_config, + ) + + def set_rvq_embs(self, rvq_embs: Tensor): + self.register_buffer("rvq_embs", rvq_embs.detach().clone()) + + def depthsum_embedding(self, code: Tensor) -> Tensor: + """ + code: [b, t, d] + rvq_embs: [d, v, h] + + ret: [b, t, h] + """ + b, t, d = code.size() + _, v, h = self.rvq_embs.size() + device = code.device + + ret = torch.zeros((b, t, h), device=device) + embs = F.pad(self.rvq_embs, [0, 0, 0, 1]) + for i in range(d): + emb = embs[i] + ret = ret + F.embedding(code[..., i], emb) + return ret + + def prepare_training_inputs(self, code: Tensor) -> tuple[Tensor, Tensor, Tensor, Tensor, Tensor]: + """Prepares masked and dropped-out versions of the code for training.""" + b, t, d = code.size() + device = code.device + + src_rate = torch.rand((b, t), device=device) * self.config.max_training_rate + src_masking_rate = get_masking_rate(src_rate, self.config.exponent) + src_num_masking = torch.ceil(src_masking_rate * self.config.num_quantizers).long() + + src_code_mask = torch.ones((b, t, d), dtype=torch.bool, device=device) + src_code_mask = get_mask(src_code_mask, src_num_masking) + src_masked_code = code * src_code_mask + (torch.zeros_like(code) + self.config.codebook_size) * ( + ~src_code_mask + ) + + if self.config.random_target_masking: + tgt_rate = src_rate + (1.0 - src_rate) * torch.rand((b, t), device=device) + tgt_masking_rate = get_masking_rate(tgt_rate, self.config.exponent) + tgt_num_masking = torch.floor(tgt_masking_rate * self.config.num_quantizers).long() + + tgt_code_mask = torch.ones((b, t, d), dtype=torch.bool, device=device) + tgt_code_mask = get_mask(tgt_code_mask, tgt_num_masking) + tgt_masked_code = code * tgt_code_mask + (torch.zeros_like(code) + self.config.codebook_size) * ( + ~tgt_code_mask + ) + else: + tgt_code_mask = torch.ones((b, t, d), dtype=torch.bool, device=device) + tgt_masked_code = code + + dropout_mask = torch.where( + torch.rand((b, t, 1), device=device) < self.config.quantizer_dropout, + (torch.randint(0, self.config.num_quantizers + 1, (b, t, 1), device=device)), + self.config.num_quantizers, + ) > torch.arange(d, dtype=torch.long, device=device) + dropped_code = code * dropout_mask + (torch.zeros_like(code) + self.config.codebook_size) * (~dropout_mask) + + return src_masked_code, src_code_mask, tgt_masked_code, tgt_code_mask, dropped_code + + def _prepare_conditioning( + self, + context_hidden_state: Tensor | None, + subword_ids: Tensor | None, + subword_mask: Tensor | None, + uncond_dec_flag: Tensor, + asr_speech_tokens_emb: Tensor | None, + ) -> Tensor: + """Computes the final conditioning tensor by combining all sources.""" + cond = torch.zeros((1, 1, self.hidden_size), device=uncond_dec_flag.device) + + if self.embed_context is not None and context_hidden_state is not None: + cond = cond + self.embed_context(context_hidden_state) + + if self.embed_subword is not None and subword_ids is not None: + # Infer subword mask from context if not provided + if subword_mask is None and context_hidden_state is not None: + subword_mask = torch.any(context_hidden_state != 0, dim=-1) + # at least one value should be true, otherwise we can completly skip it to avoid errors + if subword_mask is not None and subword_mask.any(): + cond = cond + self.embed_subword(subword_ids, subword_mask) + + if asr_speech_tokens_emb is not None: + cond = cond + asr_speech_tokens_emb + + # Replace with null embedding for unconditional generation + cond = torch.where(uncond_dec_flag, self.null_emb, cond) + return cond + + def _compute_losses( + self, + code: Tensor, + lm_logits: Tensor, + mog_logits: Tensor, + mog_mus: Tensor, + mog_mu_res: Tensor, + mog_logs: Tensor, + src_code_mask: Tensor, + tgt_code_mask: Tensor, + audio_mask: Tensor, + ) -> tuple[Tensor, Tensor, Tensor]: + """Helper to compute all losses for the training step.""" + with torch.autocast(code.device.type, enabled=False): + # 1. LM Loss (predicting discrete tokens) + if not self.config.disable_eos_prediction: + eos_mask = (~audio_mask) & F.pad(audio_mask[:, :-1], [1, 0]) + lm_mask = eos_mask | audio_mask + lm_target = torch.where(eos_mask, 1, 0) + lm_loss = ( + F.cross_entropy(lm_logits.transpose(1, 2), lm_target, reduction="none") * lm_mask + ).sum() / lm_mask.sum().clamp_min(1) + else: + lm_loss = 0.0 + + # 2. Continuous & KL Losses (for the MoG head) + target_mask = (~src_code_mask & tgt_code_mask) & audio_mask.unsqueeze(-1) + reduced_target_mask = target_mask.any(dim=-1) + + cont_code_target = self.depthsum_embedding( + code * target_mask + (torch.zeros_like(code) + self.config.codebook_size) * (~target_mask) + ) + mog_logits = mog_logits.float() + mog_mus = mog_mus.float() + mog_mu_res = mog_mu_res.float() + mog_logs = mog_logs.float() + + # Log probability of the true code under each Gaussian component + logp_code = (-0.5 * math.log(2 * math.pi) - mog_logs) * self.config.latent_size - 0.5 * self.mog_head.dist( + mog_mus, (cont_code_target - mog_mu_res) * torch.exp(-mog_logs) + ) + + # Compute posterior q(k|c) + q_kc = ( + torch.softmax( + logp_code, + -1, + ) + * (1 - self.config.label_smoothing) + + self.config.label_smoothing / self.mog_head.num_predictions + ).detach() + log_q_kc = torch.log(q_kc + 1e-8).detach() + + # Continuous Loss (negative log-likelihood) + c_loss = (-(q_kc * logp_code).sum(-1) * reduced_target_mask).sum() / target_mask.sum().clamp_min(1) + + # KL Divergence Loss + k_loss = ( + (q_kc * (log_q_kc - F.log_softmax(mog_logits, -1))).sum(-1) * reduced_target_mask + ).sum() / target_mask.sum().clamp_min(1) + + return lm_loss, c_loss, k_loss + + def forward( + self, + code: Tensor, + attention_mask: Tensor | None = None, + position_ids: Tensor | None = None, + context_hidden_state: Tensor | None = None, + subword_ids: Tensor | None = None, + subword_mask: Tensor | None = None, + audio_mask: Tensor | None = None, + non_prompt_mask: Tensor | None = None, + past_key_values: Cache | None = None, + use_cache: bool = False, + training: bool | None = None, + guidance_enabled: bool = False, + generation_config: dict[str, Any] | None = None, + teacher_forcing_inference: bool = False, + ignore_eos_flag_stop: bool = False, + asr_speech_tokens_emb: Tensor | None = None, + ) -> RVQEARTTSOutput: + """ + Performs a forward pass handling training, generation, or single-step inference. + + Args: + code (Tensor): Input audio codes. For training, this is the ground truth. + For generation, this is the previously generated code token. + attention_mask (Tensor | None): Attention mask for the backbone transformer. + position_ids (Tensor | None): Position ids for the backbone transformer. + context_hidden_state (Tensor | None): Conditioning from a language model. + subword_ids (Tensor | None): Subword token IDs for conditioning. + subword_mask (Tensor | None): Mask for subword IDs. + audio_mask (Tensor | None): Mask for valid audio positions (for training and inference initialization). + past_key_values (Cache | None): Cache for past key-values for fast decoding. + use_cache (bool): If True, returns the updated `past_key_values`. + training (bool | None): Explicitly set training mode. If `None`, uses `self.training`. + guidance_enabled (bool): If True, duplicates inputs internally to run both + conditional and unconditional passes + generation_config (dict[str, Any] | None): If provided, triggers an iterative code generation. + + Returns: + RVQEARTTSOutput: A dataclass containing losses (for training) or generated outputs + and the cache (for inference). + """ + + # Determine operating mode. + if training is None: + training = self.training + + if audio_mask is not None: + if training: + (src_masked_code, src_code_mask, tgt_masked_code, tgt_code_mask, dropped_code) = ( + self.prepare_training_inputs(code) + ) + uncond_dec_flag = torch.rand(code.size(0), 1, 1, device=code.device) < self.config.p_uncond + else: + dropped_code = code + uncond_dec_flag = torch.zeros(code.size(0), 1, 1, device=code.device, dtype=torch.bool) + + code_embeds = ( + self.embed_code(self.depthsum_embedding(F.pad(dropped_code[:, :-1], [0, 0, 1, 0]))) + + (audio_mask & (~F.pad(audio_mask[:, :-1], [1, 0]))).unsqueeze(-1) * self.bos_emb + ) + + else: # Inference + code_embeds = self.embed_code(self.depthsum_embedding(code)) + uncond_dec_flag = torch.zeros(code.size(0), 1, 1, device=code.device, dtype=torch.bool) + + if guidance_enabled: + assert not training, "Classifier-free guidance can only be used when `training` is False." + code_embeds = torch.cat([code_embeds] * 2, 0) + if attention_mask is not None: + attention_mask = torch.cat([attention_mask] * 2, 0) + if position_ids is not None: + position_ids = torch.cat([position_ids] * 2, 0) + if context_hidden_state is not None: + context_hidden_state = torch.cat([context_hidden_state] * 2, 0) + if subword_ids is not None: + subword_ids = torch.cat([subword_ids] * 2, 0) + if subword_mask is not None: + subword_mask = torch.cat([subword_mask] * 2, 0) + if asr_speech_tokens_emb is not None: + asr_speech_tokens_emb = torch.cat([asr_speech_tokens_emb] * 2, 0) + + uncond_dec_flag = torch.cat([uncond_dec_flag, torch.ones_like(uncond_dec_flag)], 0) + + # Prepare conditioning + cond = self._prepare_conditioning( + context_hidden_state, + subword_ids, + subword_mask, + uncond_dec_flag, + asr_speech_tokens_emb=asr_speech_tokens_emb, + ) + + if self.config.use_gated_fusion_for_text_audio: + inputs_embeds = self.gated_fusion_audio_text(code_embeds, cond) + else: + inputs_embeds = code_embeds + cond + + # Main backbone pass + backbone_outputs = self.backbone( + inputs_embeds=inputs_embeds, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_values=past_key_values, + use_cache=use_cache, + ) + hidden_states = backbone_outputs.last_hidden_state + + if audio_mask is not None and training: + # --- Training-specific loss computation --- + if not self.config.disable_eos_prediction: + lm_logits = self.lm_head(hidden_states) + else: + lm_logits = None + + mog_input_embeds = self.embed_code(self.depthsum_embedding(src_masked_code)) + if self.config.random_target_masking: + mog_input_embeds = mog_input_embeds + self.embed_target_mask((tgt_code_mask.sum(-1) - 1).clamp_min(0)) + mog_input_embeds = mog_input_embeds + hidden_states + mog_logits, mog_mus, mog_mu_res, mog_logs = self.mog_head(mog_input_embeds) + + lm_loss, c_loss, k_loss = self._compute_losses( + code, lm_logits, mog_logits, mog_mus, mog_mu_res, mog_logs, src_code_mask, tgt_code_mask, audio_mask + ) + total_loss = lm_loss + c_loss + k_loss + + return RVQEARTTSOutput( + loss=total_loss, lm_loss=lm_loss, c_loss=c_loss, k_loss=k_loss, hidden_states=hidden_states + ) + else: # Inference + if not generation_config: + return RVQEARTTSOutput( + hidden_states=hidden_states, + past_key_values=backbone_outputs.past_key_values, + ) + else: + if teacher_forcing_inference: + generated_codes, lm_logits, eos_flag = self.generate_teacher_forcing( + hidden_states, generation_config + ) + else: + generated_codes, lm_logits, eos_flag = self.generate_step( + hidden_states, ignore_eos_flag_stop=ignore_eos_flag_stop, **generation_config + ) + return RVQEARTTSOutput( + past_key_values=backbone_outputs.past_key_values, + codes=generated_codes, + lm_logits=lm_logits, + eos_flag=eos_flag, + hidden_states=hidden_states, + ) + + @torch.no_grad() + def generate_teacher_forcing(self, hidden_states: Tensor, generation_config: dict): + """ + Teacher-forcing wrapper for generate_step, processing all frames in parallel + using a per-frame loop internally. + + Args: + hidden_states: [B, T, H] hidden states + generation_config: kwargs for self.generate_step() + + Returns: + generated_codes: [B, T, ...] generated codes per frame + lm_logits: [B, T, vocab_size] language model logits + eos_flag: [B, T] boolean tensor indicating EOS + """ + B, T, H = hidden_states.shape + + # Preallocate caches + generated_codes_cache = [] + lm_logits_cache = [] + eos_flag_cache = [] + + # Iterate over time steps (frames) + for t in range(T): + # extract one frame (as the original generate_step expects) + frame_hidden = hidden_states[:, t, :] # [B, H] + + # call original generate_step + generated_codes, lm_logits, eos_flag = self.generate_step( + frame_hidden.unsqueeze(1), **generation_config # keep batch dim + frame dim + ) + if generated_codes is not None: + # store in cache + generated_codes_cache.append(generated_codes) + lm_logits_cache.append(lm_logits) + eos_flag_cache.append(eos_flag) + + # Stack results along time dimension + generated_codes = torch.stack(generated_codes_cache, dim=1) # [B, T, ...] + if not self.config.disable_eos_prediction: + lm_logits = torch.stack(lm_logits_cache, dim=1) # [B, T, vocab_size] + eos_flag = torch.stack(eos_flag_cache, dim=1) # [B, T] + else: + lm_logits = None + eos_flag = None + + return generated_codes, lm_logits, eos_flag + + @torch.no_grad() + def generate_step( + self, + hidden_states: Tensor, + num_iter: int, + guidance_scale: list[float] | float | None = None, + top_p_or_k: list[float | int] | float | int | None = None, + noise_scale: list[float] | float | None = None, + exponent: float | None = None, + eos_threshold: float | None = None, + ignore_eos_flag_stop: bool = False, + ) -> tuple[Tensor | None, Tensor, Tensor]: + """ + Performs the iterative unmasking process for a single generation step. + + This function takes the hidden state from the backbone transformer and generates + codes through an iterative unmasking process. + + Args: + hidden_states (Tensor): The hidden states from the backbone. If using CFG, + this should be the combined [uncond, cond] tensor. + num_iter (int): The number of unmasking iterations. + guidance_scale (list[float] | float | None): The scale for Classifier-Free Guidance. + top_p_or_k (ist[float | int] | float | int | None): The value for top-p or top-k sampling. + noise_scale (list[float] | float | None): The scale of noise to add during MoG sampling. + exponent (float | None): The exponent for the masking schedule. + eos_threshold (float | None): The threshold for EOS prediction. + + Returns: + tuple[Tensor | None, Tensor, Tensor]: A tuple containing: + - the generated codes. + - The logits from `lm_head`. + - The EOS flag. + """ + # 1. Preparation + if guidance_scale is not None: + if not isinstance(guidance_scale, list): + guidance_scale = [guidance_scale] * (1 + num_iter) # includes one step for `lm_head` + assert len(guidance_scale) == 1 + num_iter + if top_p_or_k is not None: + if not isinstance(top_p_or_k, list): + top_p_or_k = [top_p_or_k] * (1 + num_iter) # includes one step for `lm_head` + assert len(top_p_or_k) == 1 + num_iter + if noise_scale is not None: + if not isinstance(noise_scale, list): + noise_scale = [noise_scale] * num_iter + assert len(noise_scale) == num_iter + if exponent is None: + exponent = self.config.exponent + + if guidance_scale is not None: + # The effective batch size is halved + hidden_states, uncond_hidden_states = hidden_states.chunk(2, dim=0) + else: + uncond_hidden_states = hidden_states[:0, :0, :0] + + b, t, _ = hidden_states.size() + d = self.config.num_quantizers + device = hidden_states.device + + # 2. Predict the discrete part of the code + if not self.config.disable_eos_prediction: + if guidance_scale is not None: + lm_logits = self.lm_head(hidden_states + guidance_scale[0] * (hidden_states - uncond_hidden_states)) + else: + lm_logits = self.lm_head(hidden_states) + if top_p_or_k is not None: + lm_logits = ( + TopPLogitsWarper(top_p_or_k[0])( + None, + lm_logits.view(-1, lm_logits.size(-1)), + ).view_as(lm_logits) + if isinstance(top_p_or_k[0], float) + else TopKLogitsWarper(top_p_or_k[0])( + None, + lm_logits.view(-1, lm_logits.size(-1)), + ).view_as(lm_logits) + ) + lm_logits = F.log_softmax(lm_logits, -1) + if eos_threshold is not None: + eos_flag = lm_logits[..., -1] > eos_threshold + else: + eos_flag = lm_logits.argmax(-1) == 1 + + if torch.all(eos_flag) and not ignore_eos_flag_stop: + return None, lm_logits, eos_flag + else: + lm_logits = None + eos_flag = None + + # Initialize the full code tensor + code = torch.zeros((b, t, d), dtype=torch.long, device=device) + self.config.codebook_size + + # 3. Set up the iterative denoising schedule for the continuous part + rates = torch.linspace(0.0, 1.0, num_iter + 1, device=device)[:-1].unsqueeze(-1) + masking_rates = get_masking_rate(rates, exponent=exponent) + num_maskings = torch.ceil(masking_rates * self.config.num_quantizers).long() + + ks = num_maskings - F.pad(num_maskings[1:], [0, 0, 0, 1]) + + # 4. Iteratively unmask the continuous part of the code + cnt = 0 + for i, k in enumerate(ks): + if torch.all(k == 0): + continue + + # Prepare input for the MoG head + guidance_scale_i = guidance_scale[i] if guidance_scale is not None else 0.0 + top_p_or_k_i = top_p_or_k[i] if top_p_or_k is not None else 1.0 + noise_scale_i = noise_scale[i] if noise_scale is not None else 1.0 + + mog_input_embeds = self.embed_code(self.depthsum_embedding(code)) + if self.config.random_target_masking: + mog_input_embeds += self.embed_target_mask(cnt + k - 1) + if guidance_scale_i > 0.0: + mog_input_embeds = torch.cat( + [mog_input_embeds + hidden_states, mog_input_embeds + uncond_hidden_states], 0 + ) + else: + mog_input_embeds += hidden_states + + mog_mu, mog_logs = self.mog_head.infer( + mog_input_embeds, + guidance_scale=guidance_scale_i, + top_p_or_k=top_p_or_k_i, + ) + z = mog_mu + torch.exp(mog_logs) * torch.randn_like(mog_mu) * noise_scale_i + code = depthsum_encoding_step(self.rvq_embs, z, code, cnt, k[0].item()) + cnt += k[0].item() + return code, lm_logits, eos_flag + + def load_state_dict(self, state_dict, strict: bool = True): + try: + super().load_state_dict(state_dict, strict=strict) + except RuntimeError: + logging.info("Error loading model state_dict !! Retrying with partial initialization!") + model_dict = set_model_dict_for_partial_init(state_dict, self.state_dict()) + super().load_state_dict(model_dict, strict=False) diff --git a/nemo/collections/speechlm2/modules/ear_tts_vae_codec.py b/nemo/collections/speechlm2/modules/ear_tts_vae_codec.py new file mode 100644 index 000000000000..1d2aaa8ad44c --- /dev/null +++ b/nemo/collections/speechlm2/modules/ear_tts_vae_codec.py @@ -0,0 +1,1063 @@ +# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import functools +import math +from collections.abc import Callable +from contextlib import contextmanager +from typing import Any, Concatenate + +import librosa +import torch +from omegaconf import DictConfig +from torch import Tensor, nn +from torch.nn import functional as F + + +@contextmanager +def disable_tf32(): + prev = torch.backends.cudnn.allow_tf32 + torch.backends.cudnn.allow_tf32 = False + try: + yield + finally: + torch.backends.cudnn.allow_tf32 = prev + + +# ============================================================================== +# Utility Functions +# ============================================================================== + + +def zero_module(module: nn.Module) -> nn.Module: + """ + Zeros out the parameters of a PyTorch module in-place. + + This is a utility function that iterates through all parameters of a given + `nn.Module` and sets their values to zero. This is often used for specific + initialization strategies, for example in diffusion models where some layers + are initialized to zero. + + From: https://github.com/openai/guided-diffusion/blob/22e0df8183507e13a7813f8d38d51b072ca1e67c/guided_diffusion/nn.py#L68 + + Args: + module (nn.Module): The PyTorch module to be zeroed. + + Returns: + nn.Module: The same module with its parameters zeroed. + """ + for p in module.parameters(): + # p.detach().zero_() performs the operation in-place without tracking it in autograd + p.detach().zero_() + return module + + +def sequence_mask(lengths: Tensor, max_length: int | None = None) -> Tensor: + """ + Creates a boolean mask from a 1D tensor of sequence lengths. + + This function is useful for masking out padding in sequences. Given a tensor + of lengths, it produces a 2D boolean tensor where `mask[i, j]` is `True` if + `j < lengths[i]` and `False` otherwise. + + Example: + >>> lengths = torch.tensor([1, 3, 2]) + >>> sequence_mask(lengths) + tensor([[ True, False, False], + [ True, True, True], + [ True, True, False]]) + + Args: + lengths (Tensor): A 1D tensor of integer lengths. Shape: `[batch_size]`. + max_length (int | None, optional): The maximum length of the mask. If None, + it is inferred from the maximum value + in `lengths`. Defaults to None. + + Returns: + Tensor: The boolean mask. Shape: `[batch_size, max_length]`. + """ + if max_length is None: + # If max_length is not provided, use the longest sequence length in the batch. + max_length = int(lengths.max().item()) + + # Create a range tensor from 0 to max_length - 1 + x = torch.arange(max_length, dtype=lengths.dtype, device=lengths.device) + + # Compare each length with the range tensor to create the mask. + # `x.unsqueeze(0)` is `[1, max_length]` + # `lengths.unsqueeze(1)` is `[batch_size, 1]` + # Broadcasting takes care of the comparison. + return x.unsqueeze(0) < lengths.unsqueeze(1) + + +# ============================================================================== +# Signal Processing Functions +# ============================================================================== + + +def spectrogram( + wav: Tensor, + n_fft: int, + hop_length: int, + win_length: int, + window_fn: Callable[Concatenate[int, ...], Tensor] = torch.hann_window, +) -> Tensor: + """ + Computes the Short-Time Fourier Transform (STFT) of a waveform with manual padding. + + This implementation manually applies zero padding before computing the STFT. + This is done to center the analysis window at the beginning of the signal + without using the `center=True` argument in `torch.stft`, giving more control. + + Args: + wav (Tensor): The input audio waveform. + Shape: [batch_size?, time_steps], where batch_size? is an + optional batch dimension. + n_fft (int): The size of the FFT. + hop_length (int): The number of samples between adjacent STFT columns. + win_length (int): The size of the window function. + window_fn (function, optional): The window function to apply. + Defaults to torch.hann_window. + + Returns: + Tensor: The complex-valued spectrogram. + Shape: [batch_size?, n_fft // 2 + 1, num_frames] + """ + # Calculate the padding required on the left and right sides to center the frames. + pad_size_l = (n_fft - hop_length) // 2 + pad_size_r = (n_fft - hop_length) - pad_size_l + + # Use a torch.autocast context to perform STFT in float32 for precision. + with torch.autocast(device_type=wav.device.type, enabled=False): + # Apply reflection padding to the waveform. + wav = F.pad(wav.float(), (pad_size_l, pad_size_r)) + + # Create the window tensor on the same device as the waveform. + window = window_fn(win_length, dtype=torch.float, device=wav.device) + + # Compute the STFT. + # `center=False` because we have already manually padded the signal. + spec = torch.stft( + wav, + n_fft, + hop_length=hop_length, + win_length=win_length, + window=window, + center=False, + normalized=False, + onesided=True, + return_complex=True, + ) + return spec + + +def spec_to_wav( + spec: Tensor, + n_fft: int, + hop_length: int, + win_length: int, + window_fn: Callable[Concatenate[int, ...], Tensor] = torch.hann_window, + constrain_value_range: bool = False, +) -> Tensor: + """ + Converts a spectrogram back into a waveform using the overlap-add method. + This function is an approximate inverse of the `spectrogram` function. + + Args: + spec (Tensor): The input complex-valued spectrogram. + Shape: [batch_size?, dim, time_steps], where batch_size? + is an optional batch dimension. + n_fft (int): The size of the FFT used to create the spectrogram. + hop_length (int): The number of samples between frames in the original signal. + win_length (int): The size of the window function used in the original signal. + window_fn (function, optional): The window function used. Currently only + `torch.hann_window` is supported. + constrain_value_range (bool, optional): If True, constrains the IFFT values + to be within the range of the window. + This ensures that the output values + remain within the range of -1.0 to 1.0. + Defaults to False. + + Returns: + Tensor: The reconstructed waveform. + Shape: [batch_size?, time_steps] + + Raises: + ValueError: If a window function other than `torch.hann_window` is provided. + """ + with torch.autocast(device_type=spec.device.type, enabled=False): + if window_fn != torch.hann_window: + raise ValueError(f"`window_fn` should be 'torch.hann_window', but got '{window_fn}'.") + + # Calculate padding and number of frames + pad = (win_length - hop_length) // 2 + T = spec.size(-1) + window = window_fn(win_length, device=spec.device) + + # 1. Inverse FFT + # Convert from frequency domain back to time domain for each frame. + ifft = torch.fft.irfft(spec, n=n_fft, dim=-2, norm="backward") + window_unsqz = window.unsqueeze(-1) + + # 2. Optionally constrain values + if constrain_value_range: + ifft = torch.where( + ifft >= 0, + torch.minimum(ifft, window_unsqz), + torch.maximum(ifft, -window_unsqz), + ) + + # 3. Apply window to the IFFT result + ifft = ifft * window_unsqz + + # 4. Overlap and Add + # Use `torch.nn.functional.fold` to perform the overlap-add operation efficiently. + # This reconstructs the continuous signal from the windowed frames. + output_size = (T - 1) * hop_length + win_length + wav = F.fold( + ifft, + output_size=(1, output_size), + kernel_size=(1, win_length), + stride=(1, hop_length), + )[..., 0, 0, pad:-pad] + + # 5. Calculate the window envelope for normalization + # This is necessary to correct for the energy added by overlapping windows. + window_sq = window.square().expand(T, -1).transpose(0, 1) + window_envelope = F.fold( + window_sq, + output_size=(1, output_size), + kernel_size=(1, win_length), + stride=(1, hop_length), + ).squeeze()[pad:-pad] + + # 6. Normalize the waveform + # Divide by the window envelope to get the final reconstructed signal. + assert (window_envelope > 1e-11).all(), "Window envelope has zero values, cannot normalize." + wav = wav / window_envelope + + return wav + + +def spectrogram_mag( + wav: Tensor, + n_fft: int, + hop_length: int, + win_length: int, + window_fn: Callable[Concatenate[int, ...], Tensor] = torch.hann_window, + power: float = 1.0, +) -> Tensor: + """ + Computes the magnitude spectrogram from an audio waveform. + + This function first calculates the complex-valued spectrogram using the + Short-Time Fourier Transform (STFT), then computes the magnitude of the + resulting complex numbers. An optional power can be applied to the + magnitude spectrogram. + + Args: + wav (Tensor): The input audio waveform. + Shape: [batch_size?, time_steps], where batch_size? is + an optional batch dimension. + n_fft (int): The size of the Fast Fourier Transform (FFT) to use. + hop_length (int): The number of audio samples between adjacent STFT columns. + win_length (int): The size of the window function for each frame. + window_fn (function, optional): The windowing function to apply to each + frame. Defaults to torch.hann_window. + power (float, optional): The exponent to apply to the magnitude spectrogram. + A value of 2.0 yields a power spectrogram. + Defaults to 1.0 (magnitude). + + Returns: + Tensor: The resulting magnitude spectrogram. + Shape: [batch_size?, n_fft // 2 + 1, num_frames] + """ + # Calculate the complex spectrogram + spec = spectrogram(wav, n_fft, hop_length, win_length, window_fn) + + # Compute the magnitude by taking the absolute value + spec = spec.abs() + + # Apply power if it's not the default value of 1.0 + if power != 1.0: + spec = spec.pow(power) + + return spec + + +@functools.cache +def get_fbanks( + sample_rate: int, + n_fft: int, + n_mels: int, + f_min: float, + f_max: float, + norm: str = "slaney", + mel_scale: str = "slaney", +) -> Tensor: + """ + Creates and caches Mel filterbanks. + + This function generates a set of triangular filters on the Mel scale. + The `@functools.cache` decorator memoizes the result, so the filterbanks + are only computed once for a given set of parameters, improving efficiency + when the function is called multiple times with the same arguments. + + Note: This implementation only supports Mel filterbanks via librosa. + + Args: + sample_rate (int): The sample rate of the audio. + n_fft (int): The size of the FFT used to compute the spectrogram. + n_mels (int): The number of Mel bands to generate. + f_min (float): The lowest frequency (in Hz) for the filterbanks. + f_max (float): The highest frequency (in Hz) for the filterbanks. + norm (str, optional): The normalization method to use for the triangles. + 'slaney' normalizes to unit area. None applies no norm. + Defaults to "slaney". + mel_scale (str, optional): The Mel scale to use, "htk" or "slaney". + Defaults to "slaney". + + Returns: + Tensor: The Mel filterbank matrix. + Shape: [n_mels, n_fft // 2 + 1] + """ + # Generate Mel filterbanks using librosa's functional API + fb = librosa.filters.mel( + sr=sample_rate, + n_fft=n_fft, + n_mels=n_mels, + fmin=f_min, + fmax=f_max, + norm=norm, + htk=(mel_scale == "htk"), + ) # [n_mels, n_freqs] + fb = torch.from_numpy(fb).float() + return fb + + +def mel_spectrogram( + wav: Tensor, + n_fft: int, + hop_length: int, + win_length: int, + sample_rate: int, + n_mels: int, + f_min: float, + f_max: float | None = None, + window_fn: Callable[Concatenate[int, ...], Tensor] = torch.hann_window, + power: float = 1.0, + log_scale: str | None = "natural", +) -> Tensor: + """ + Computes a Mel-scaled spectrogram from an audio waveform. + + This function transforms a standard spectrogram into a Mel spectrogram by + applying Mel-scaled filterbanks. It can optionally return the result on a + logarithmic scale. + + Args: + wav (Tensor): The input audio waveform. + Shape: [batch_size?, time_steps], where batch_size? is an + optional batch dimension. + n_fft (int): The size of the FFT. + hop_length (int): The number of samples between adjacent frames. + win_length (int): The size of the window function. + sample_rate (int): The sample rate of the audio. + n_mels (int): The number of Mel bands to generate. + f_min (float): The lowest frequency (in Hz) for the Mel scale. + f_max (float | None, optional): The highest frequency (in Hz). If None, + it defaults to sample_rate / 2 (Nyquist). + window_fn (function, optional): The windowing function. Defaults to torch.hann_window. + power (float, optional): The exponent for the magnitude spectrogram before + Mel conversion. Defaults to 1.0. + log_scale (str | None, optional): The type of logarithmic scaling to apply. + Can be "natural" (for `log`), "log10", or `None` + to return the linear-amplitude Mel spectrogram. + Defaults to "natural". + + Returns: + Tensor: The resulting Mel spectrogram. + Shape: [batch_size?, n_mels, num_frames] + + Raises: + ValueError: If an unsupported string is provided for `log_scale`. + """ + # If f_max is not provided, use the Nyquist frequency. + f_max = f_max or sample_rate / 2 + + # 1. Compute the magnitude spectrogram. + spec = spectrogram_mag(wav, n_fft, hop_length, win_length, window_fn=window_fn, power=power) + + # Use a torch.autocast context to ensure the following operations + # are performed in float32 precision for numerical stability, especially + # when the input `spec` might be in a lower precision format like float16. + with torch.autocast(device_type=spec.device.type, enabled=False): + # 2. Get the Mel filterbanks (cached for efficiency). + fb = ( + get_fbanks( + sample_rate, + n_fft, + n_mels, + f_min, + f_max, + ) + .float() + .to(device=spec.device) + ) # Ensure filterbank is float32 and on the correct device. + + # 3. Apply the filterbanks to the spectrogram via matrix multiplication. + # This maps the linear frequency scale to the Mel scale. + # (n_mels, n_freqs) @ (..., n_freqs, time) -> (..., n_mels, time) + mel = torch.matmul(fb, spec.float()) + + # 4. Optionally, apply a logarithmic function. + # A small value (epsilon) is added to prevent taking the log of zero. + if log_scale == "natural": + mel = torch.log(torch.clamp(mel, min=1e-6)) + elif log_scale == "log10": + mel = torch.log10(torch.clamp(mel, min=1e-6)) + elif log_scale is not None: + raise ValueError(f"Unsupported log_scale: '{log_scale}'. Choose from 'natural', 'log10', or None.") + + return mel + + +# ============================================================================== +# Basic Modules +# ============================================================================== + + +class CausalConv1dCache: + """ + A cache for managing states in causal 1D convolutions. + + This class is used during autoregressive inference to store and update the + tail of the input to a causal convolution, which is used as padding for the + next time step. This avoids re-computing the entire sequence at each step. + """ + + def __init__(self) -> None: + self.cache: dict[int | str, Tensor] = {} + + def __getitem__(self, layer_id: int | str) -> Tensor: + """Retrieves the cached tensor for a given layer.""" + return self.cache[layer_id] + + def update( + self, + states: Tensor, + layer_id: int | str, + padding: int, + padding_value: int = 0, + flush: bool = False, + ) -> Tensor: + """ + Updates the cache for a specific layer and returns the padded input. + + Args: + states (Tensor): The new input tensor for the current time step. + layer_id (int | str): An identifier for the convolutional layer. + padding (int): The amount of left padding required by the convolution. + padding_value (int, optional): The value to use for initial padding. Defaults to 0. + flush (bool, optional): If True, the cache for this layer is deleted + after use. Defaults to False. + + Returns: + Tensor: The input states concatenated with the cached padding. + """ + device = states.device + dtype = states.dtype + b, c, t = states.size() + + if layer_id not in self.cache: + # Initialize cache with zero padding if it's the first time step + padding_tensor = torch.zeros((b, c, padding), dtype=dtype, device=device) + padding_value + else: + padding_tensor = self.cache[layer_id] + assert padding_tensor.size(2) == padding + + # Concatenate the cached padding with the new states + padded_states = torch.cat([padding_tensor, states], dim=2) + # Update the cache with the tail of the new padded states + self.cache[layer_id] = padded_states[:, :, -padding:] + + if flush: + del self.cache[layer_id] + + return padded_states + + +class LayerNormNd(nn.Module): + """ + A LayerNorm module that works for N-dimensional inputs. + + This implementation normalizes over the channel dimension (dim=1), which is + a common setup for convolutional networks. + + Args: + channels (int): The number of channels of the input tensor. + eps (float, optional): A value added to the denominator for numerical + stability. Defaults to 1e-6. + elementwise_affine (bool, optional): If True, this module has learnable + affine parameters (weight and bias). + Defaults to True. + bias (bool, optional): If True, this module has a learnable bias. + Defaults to True. + """ + + def __init__(self, channels: int, eps=1e-6, elementwise_affine: bool = True, bias: bool = True): + super().__init__() + self.channels = channels + self.eps = eps + + self.weight = nn.Parameter(torch.ones((channels,)), requires_grad=elementwise_affine) + self.bias = nn.Parameter(torch.zeros((channels,)), requires_grad=elementwise_affine and bias) + + def forward(self, x: Tensor) -> Tensor: + # Calculate mean and reciprocal standard deviation over the channel dimension + mean = x.mean(1, keepdim=True) + x_shift = x - mean + # Using rsqrt for potentially better performance + x_rstd = torch.rsqrt(x_shift.pow(2).mean(1, keepdim=True) + self.eps) + + # Reshape weight and bias to be broadcastable with the input tensor + shape = [-1 if i == 1 else 1 for i in range(x.ndim)] + + # Apply normalization and affine transformation + return (x_shift * x_rstd) * self.weight.view(shape) + self.bias.view(shape) + + +class ConvNeXt1d(nn.Module): + """ + A 1D ConvNeXt block adapted for causal convolutions on audio signals. + + This block is a core component of modern convolutional architectures, featuring + a depthwise convolution, layer normalization, and pointwise convolutions to + expand and contract the channel dimension, similar to an inverted bottleneck. + + Implementation adapted from: https://github.com/charactr-platform/vocos + + Args: + dim (int): Number of input and output channels. + intermediate_dim (int): Dimensionality of the intermediate (expanded) layer. + kernel_size (int): The kernel size for the causal depthwise convolution. + identity_init (bool, optional): If True, the final pointwise convolution + is initialized to zero, making the block + an identity function at the start of training. + Defaults to False. + layer_idx (int, optional): An index for this layer, used for caching. Defaults to 0. + """ + + def __init__( + self, + dim: int, + intermediate_dim: int, + kernel_size: int, + identity_init: bool = False, + layer_idx: int = 0, + ): + super().__init__() + self.layer_idx = layer_idx + self.kernel_size = kernel_size + + # Depthwise convolution + self.dwconv = nn.Conv1d(dim, dim, kernel_size=kernel_size, groups=dim) + + self.norm = LayerNormNd(dim) + self.pwconv1 = nn.Conv1d(dim, intermediate_dim, 1) # Pointwise/1x1 conv for expansion + self.act = nn.GELU() + + # Pointwise/1x1 conv for projection + if identity_init: + self.pwconv2 = zero_module(nn.Conv1d(intermediate_dim, dim, 1)) + else: + self.pwconv2 = nn.Conv1d(intermediate_dim, dim, 1) + + def forward(self, x: Tensor, cache: CausalConv1dCache | None = None, flush: bool = False) -> Tensor: + residual = x + + # Apply causal padding, either through a cache or manually + if cache is not None: + x = cache.update(x, self.layer_idx, self.kernel_size - 1, flush=flush) + else: + x = F.pad(x, [self.kernel_size - 1, 0]) # Left padding for causality + + # Main ConvNeXt path + x = self.dwconv(x) + x = self.norm(x) + x = self.pwconv1(x) + x = self.act(x) + x = self.pwconv2(x) + + # Add residual connection + x = residual + x + return x + + +class PreTrainedEMAVariance(nn.Module): + """ + Exponential Moving Average of Variance + """ + + def __init__(self, initial_value: float = 1.0): + super().__init__() + self.variance = nn.Parameter( + torch.tensor(initial_value), + requires_grad=False, + ) + + def forward(self) -> Tensor: + return self.variance + + +class PreTrainedProbabilisticVQ(nn.Module): + def __init__( + self, + channels: int, + num_mixtures: int, + depth: int = 1, + ): + super().__init__() + self.channels = channels + self.num_mixtures = num_mixtures + self.depth = depth + + self.mus_list = nn.ParameterList( + [ + nn.Parameter( + F.normalize(torch.randn(num_mixtures, channels), p=2.0, dim=1) * ((depth - i) / depth), + requires_grad=False, + ) + for i in range(depth) + ] + ) + self._variance_list = nn.ModuleList([PreTrainedEMAVariance() for _ in range(depth)]) + + @property + def log_std(self) -> Tensor: + return torch.log(self._variance_list[-1]()) * 0.5 + + def encode(self, z: Tensor, return_z_q: bool = False) -> list[Tensor] | tuple[list[Tensor], Tensor]: + r = z + ids_sel = [] + for i in range(self.depth): + mus = self.mus_list[i] + idx_sel = self._dist_sq(r, mus).argmin(-1) # [b, ?, h], [v, h] -> [b, ?] + r = r - F.embedding(idx_sel, mus) + ids_sel.append(idx_sel) + if return_z_q: + return ids_sel, z - r + return ids_sel + + def decode(self, ids_sel: list[Tensor]) -> Tensor: + z = torch.zeros((*ids_sel[0].size(), self.channels), device=ids_sel[0].device) + for i in range(len(ids_sel)): + mus = self.mus_list[i] + z = z + F.embedding(ids_sel[i], mus) + return z # [b, ?, h] + + def _dist_sq(self, z: Tensor, mus: Tensor) -> Tensor: + """ + z: [b, ?, d?, h] + mus: [d?, v, h] + """ + return ( + z.pow(2).sum(-1, keepdim=True) # [b, ?, d?, 1] + + mus.pow(2).sum(-1) # [d?, v] + - 2 * (z.unsqueeze(-2) @ mus.transpose(-1, -2)).squeeze(-2) # [b, ?, d?, h] , [d?, h, v] -> [b, ?, d?, v] + ) + + +class Wav2Latent(nn.Module): + """ + An encoder model that transforms a raw waveform into a latent representation. + + This model first converts the waveform to a spectrogram, then processes it + through a series of ConvNeXt blocks and downsampling convolutional layers + to produce a compressed latent tensor. + + Args: + latent_size (int): The number of channels in the final latent representation. + n_fft (int): The FFT size for the initial spectrogram transformation. + hop_length (int): The hop length for the STFT. + base_hidden_size (int): The base number of channels for the hidden layers. + channel_mult (tuple[int, ...]): A tuple of multipliers for the hidden + size at each stage of downsampling. + rates (tuple[int, ...]): A tuple of downsampling factors (strides) for + the convolutional layers. + num_blocks (int): The number of ConvNeXt blocks per stage. + kernel_size (int): The kernel size for the ConvNeXt blocks. + groups (int): The number of groups for the downsampling convolutions. + """ + + def __init__( + self, + latent_size: int = 1024, + n_fft: int = 32, + hop_length: int = 8, + base_hidden_size: int = 384, + channel_mult: tuple[int, ...] = (1, 2, 4), + rates: tuple[int, ...] = (8, 8, 8), + num_blocks: int = 3, + kernel_size: int = 7, + groups: int = 1, + ): + super().__init__() + self.n_fft = n_fft + self.hop_length = hop_length + + # Initial projection from spectrogram to hidden size + layers: list[nn.Module] = [nn.Conv1d(n_fft + 2, base_hidden_size * channel_mult[0], 1, bias=False)] + + # Downsampling stages + for i in range(len(channel_mult)): + ch_mult, rate = channel_mult[i], rates[i] + hidden_size = base_hidden_size * ch_mult + # Add ConvNeXt blocks for this stage + for j in range(num_blocks): + layers.append( + ConvNeXt1d(hidden_size, hidden_size * 4, kernel_size, True, layer_idx=i * num_blocks + j) + ) + # Add downsampling convolution + next_hidden_size = base_hidden_size * channel_mult[i + 1] if i < len(channel_mult) - 1 else latent_size + layers.append( + nn.Conv1d(hidden_size, next_hidden_size, kernel_size=rate, stride=rate, bias=False, groups=groups) + ) + + self.layers = nn.ModuleList(layers) + + def forward(self, x: Tensor, cache=None, flush: bool = False) -> Tensor: + if cache is not None: + raise NotImplementedError("Caching is not implemented for the encoder.") + + # Convert waveform to spectrogram (magnitude and phase) + with torch.autocast(device_type=x.device.type, enabled=False): + spec = spectrogram(x.squeeze(1), n_fft=self.n_fft, hop_length=self.hop_length, win_length=self.n_fft) + # Split complex spectrogram into real and imaginary, then treat as magnitude and phase + mag, ph = torch.view_as_real(spec).chunk(2, dim=-1) + x = torch.cat([mag, ph], 1).squeeze(-1) + + # Pass through the network + for layer in self.layers: + if isinstance(layer, ConvNeXt1d): + x = layer(x, cache=cache, flush=flush) + else: + x = layer(x) + + # Transpose to [batch, time, channels] for compatibility with transformers + x = x.transpose(-1, -2) + return x + + +class Latent2Wav(nn.Module): + """ + A decoder (vocoder) model that transforms a latent representation back into a raw waveform. + + This model processes a latent tensor through a series of ConvNeXt blocks and + upsampling transposed convolutional layers to produce a spectrogram, which is + then converted back to a waveform using an inverse STFT. + + Args: + latent_size (int): The number of channels in the input latent representation. + n_fft (int): The FFT size for the final spectrogram reconstruction. + hop_length (int): The hop length for the ISTFT. + base_hidden_size (int): The base number of channels for the hidden layers. + channel_mult (tuple[int, ...]): A tuple of multipliers for the hidden + size at each stage of upsampling. + rates (tuple[int, ...]): A tuple of upsampling factors (strides) for + the transposed convolutional layers. + num_blocks (int): The number of ConvNeXt blocks per stage. + kernel_size (int): The kernel size for the ConvNeXt blocks. + groups (int): The number of groups for the upsampling convolutions. + """ + + def __init__( + self, + latent_size: int = 1024, + n_fft: int = 32, + hop_length: int = 8, + base_hidden_size: int = 384, + channel_mult: tuple[int, ...] = (4, 2, 1), + rates: tuple[int, ...] = (8, 8, 8), + num_blocks: int = 3, + kernel_size: int = 7, + groups=1, + ): + super().__init__() + self.n_fft = n_fft + self.hop_length = hop_length + self.spec_cache_idx = (len(channel_mult)) * num_blocks + + layers: list[nn.Module] = [] + + # Upsampling stages + for i in range(len(channel_mult)): + ch_mult, rate = channel_mult[i], rates[i] + hidden_size = base_hidden_size * ch_mult + # Add upsampling transposed convolution + in_size = base_hidden_size * channel_mult[i - 1] if i != 0 else latent_size + layers.append( + nn.ConvTranspose1d(in_size, hidden_size, kernel_size=rate, stride=rate, bias=False, groups=groups) + ) + # Add ConvNeXt blocks for this stage + for j in range(num_blocks): + layers.append( + ConvNeXt1d(hidden_size, hidden_size * 4, kernel_size, True, layer_idx=i * num_blocks + j) + ) + + # Final projection to spectrogram dimensions (magnitude + phase) + layers.append(nn.Conv1d(hidden_size, n_fft + 2, 1, bias=False)) + self.layers = nn.ModuleList(layers) + + def forward(self, x: Tensor, cache=None, flush: bool = False, constrain_value_range: bool = True) -> Tensor: + # Transpose input from [batch, time, channels] to [batch, channels, time] + x = x.transpose(-1, -2) + + # Pass through the network + for layer in self.layers: + if isinstance(layer, ConvNeXt1d): + x = layer(x, cache=cache, flush=flush) + else: + x = layer(x) + + # Convert network output to a complex spectrogram and then to a waveform + with torch.autocast(device_type=x.device.type, enabled=False): + max_mag = 100.0 + # Split output into magnitude and phase components + mag, ph = x.float().chunk(2, dim=1) + # Safeguard to prevent excessively large magnitudes + mag = max_mag * torch.exp(-F.softplus(-mag + math.log(max_mag))) + + # Reconstruct the complex spectrogram from magnitude and phase + # The DC and Nyquist components are real, so their phase is applied via cosine. + mag_dc, mag_mid, mag_nyquist = mag.split([1, mag.size(1) - 2, 1], dim=1) + ph_dc, ph_mid, ph_nyquist = torch.cos(ph).split([1, ph.size(1) - 2, 1], dim=1) + ph_imag = torch.sin(ph[:, 1:-1, :]) + + spec_real = mag_mid * ph_mid + spec_imag = mag_mid * ph_imag + + spec = torch.cat([mag_dc * ph_dc, spec_real + 1j * spec_imag, mag_nyquist * ph_nyquist], 1) + + # Handle caching for autoregressive generation of the spectrogram + if cache is not None: + half_spec_padding = math.ceil(((self.n_fft - self.hop_length) // 2) / self.hop_length) + spec = cache.update(spec, self.spec_cache_idx, padding=half_spec_padding * 2, flush=flush) + if flush: + spec = F.pad(spec, [0, half_spec_padding]) + + # Convert spectrogram to waveform + x = spec_to_wav( + spec, self.n_fft, self.hop_length, self.n_fft, constrain_value_range=constrain_value_range + ).unsqueeze(1) + + if cache is not None: + # Trim the output to remove the padded region from the start + half_wav_padding = half_spec_padding * self.hop_length + x = x[:, :, half_wav_padding:-half_wav_padding] + + return x + + +class RVQVAEModel(nn.Module): + """ + Residual Vector-Quantized Variational Autoencoder (RVQ-VAE) model. + + This model learns a discrete representation of audio by encoding a waveform + into a latent space and then quantizing the latents into discrete codes. + It consists of an encoder, a quantizer, and a decoder. + + Args: + config (DictConfig | dict[str, Any]): A configuration object with model hyperparameters. + """ + + config_class: type[DictConfig] = DictConfig + + def __init__(self, config: DictConfig | dict[str, Any]): + super().__init__() + self.config = config + + self.encoder = Wav2Latent( + latent_size=self.config.latent_size, + n_fft=self.config.n_fft, + hop_length=self.config.hop_length, + base_hidden_size=self.config.base_hidden_size, + channel_mult=self.config.channel_mult, + rates=self.config.rates, + num_blocks=self.config.num_blocks, + kernel_size=self.config.kernel_size, + groups=self.config.groups, + ) + + # Layers for quantization + self.prvq = PreTrainedProbabilisticVQ( + channels=self.config.latent_size, + num_mixtures=self.config.codebook_size, + depth=self.config.num_quantizers, + ) + + self.decoder = Latent2Wav( + latent_size=self.config.latent_size, + n_fft=self.config.n_fft, + hop_length=self.config.hop_length, + base_hidden_size=self.config.base_hidden_size, + channel_mult=tuple(reversed(self.config.channel_mult)), + rates=tuple(reversed(self.config.rates)), + num_blocks=self.config.num_blocks, + kernel_size=self.config.kernel_size, + groups=self.config.groups, + ) + + for p in self.parameters(): + p.requires_grad = False + + def ae_encode(self, x: Tensor, cache: CausalConv1dCache | None = None, flush: bool = False) -> Tensor: + """ + Runs the encoder part of the autoencoder. + + Args: + x (Tensor): Input waveform. Shape: `[batch, 1, time]`. + cache (CausalConv1dCache | None): Not implemented for the encoder. + flush (bool): Not implemented for the encoder. + + Returns: + Tensor: The continuous latent representation. Shape: `[batch, time', channels]`. + """ + assert x.size(1) == 1 and x.dim() == 3, "Input must be a batch of mono audio." + assert x.size(2) % self.config.wav_to_token_ratio == 0, ( + f"Input audio length ({x.size(2)}) must be divisible by the model's " + f"wav_to_token_ratio ({self.config.wav_to_token_ratio}). " + f"Please pad the input to a compatible length." + ) + + if cache is not None: + raise NotImplementedError("Caching is not supported for the encoder.") + + return self.encoder(x, cache=cache, flush=flush) + + def ae_decode( + self, + x: Tensor, + constrain_value_range: bool = True, + cache: CausalConv1dCache | None = None, + flush: bool = False, + ) -> Tensor: + """ + Runs the decoder part of the autoencoder. + + Args: + x (Tensor): The (de-quantized) latent representation. Shape: `[batch, time', channels]`. + constrain_value_range (bool): If True, constrains the output of the ISTFT. + cache (CausalConv1dCache | None): Cache for autoregressive generation. + flush (bool): If True, flushes the cache. + + Returns: + Tensor: The reconstructed waveform. Shape: `[batch, 1, time]`. + """ + return self.decoder(x, constrain_value_range=constrain_value_range, cache=cache, flush=flush) + + def encode(self, x: Tensor, x_len: Tensor) -> tuple[Tensor, Tensor]: + """ + Encodes a waveform into discrete codes. + + Args: + x (Tensor): Input waveform. Shape: `[batch, 1, time]`. + x_len (Tensor): The original lengths of the waveforms in the batch. + + Returns: + tuple[Tensor, Tensor]: A tuple containing: + - The discrete codes. Shape: `[batch, time', n_quantizers]`. + - The lengths of the code sequences. + """ + with disable_tf32(): + z_e = self.ae_encode(x) + code_len = x_len // self.config.wav_to_token_ratio + return self.quantize(z_e), code_len + + def decode( + self, + code: Tensor, + code_len: Tensor | None = None, + constrain_value_range: bool = True, + cache: CausalConv1dCache | None = None, + flush: bool = False, + ) -> tuple[Tensor, Tensor | None]: + """ + Decodes discrete codes back into a waveform. + + Args: + code (Tensor): The discrete codes. Shape: `[batch, time', n_quantizers]`. + code_len (Tensor | None): The lengths of the code sequences. + constrain_value_range (bool): If True, constrains the output of the ISTFT. + cache (CausalConv1dCache | None): Cache for autoregressive generation. + flush (bool): If True, flushes the cache. + + Returns: + tuple[Tensor, Tensor | None]: A tuple containing: + - The reconstructed waveform. Shape: `[batch, 1, time]`. + - The lengths of the reconstructed waveforms. + """ + with disable_tf32(): + z_q = self.dequantize(code) + x_hat = self.ae_decode(z_q, constrain_value_range=constrain_value_range, cache=cache, flush=flush) + wav_len = code_len * self.config.wav_to_token_ratio if code_len is not None else None + return x_hat, wav_len + + def quantize(self, z: Tensor) -> Tensor: + """ + Quantizes a continuous latent tensor into discrete codes. + + Args: + z (Tensor): The continuous latent tensor from the encoder. + Shape: `[batch, time, channels]`. + + Returns: + Tensor: The quantized codes. Shape: `[batch, time, n_quantizers]`. + """ + with disable_tf32(): + ids_sel = self.prvq.encode(z, return_z_q=False) + return torch.stack(ids_sel, -1) + + def dequantize(self, code: Tensor) -> Tensor: + """ + De-quantizes discrete codes back into a continuous latent tensor. + + Args: + code (Tensor): The quantized codes. Shape: `[batch, time, n_quantizers]`. + + Returns: + Tensor: The de-quantized continuous latent tensor. + Shape: `[batch, time, latent_size]`. + """ + ids_sel = [x.squeeze(-1) for x in torch.split(code, 1, -1)] + return self.prvq.decode(ids_sel) + + def forward(self, x: Tensor, constrain_value_range: bool = False) -> Tensor: + """ + Performs a full autoencoding pass: encode, quantize, dequantize, and decode. + + Args: + x (Tensor): The input waveform. Shape: `[batch, 1, time]`. + constrain_value_range (bool): If True, constrains the output of the ISTFT. + + Returns: + Tensor: The reconstructed waveform. Shape: `[batch, 1, time]`. + """ + + with torch.no_grad(): + z_e = self.ae_encode(x) + code = self.quantize(z_e) + z_d = self.dequantize(code) + x_hat = self.ae_decode(z_d, constrain_value_range=constrain_value_range) + return x_hat diff --git a/nemo/collections/speechlm2/parts/metrics/__init__.py b/nemo/collections/speechlm2/parts/metrics/__init__.py index aa628a2f76b3..07cff31b5dd7 100644 --- a/nemo/collections/speechlm2/parts/metrics/__init__.py +++ b/nemo/collections/speechlm2/parts/metrics/__init__.py @@ -12,11 +12,17 @@ # See the License for the specific language governing permissions and # limitations under the License. from .asr_bleu import ASRBLEU +from .asr_cer_wer import Intelligibility from .bleu import BLEU +from .results_logger import ResultsLogger +from .token_accuracy import TokenAccuracy from .wer import WER __all__ = [ 'ASRBLEU', 'BLEU', 'WER', + 'TokenAccuracy', + 'ResultsLogger', + 'Intelligibility', ] diff --git a/nemo/collections/speechlm2/parts/metrics/asr_bleu.py b/nemo/collections/speechlm2/parts/metrics/asr_bleu.py index d7124b49ff01..3ffd489312dc 100644 --- a/nemo/collections/speechlm2/parts/metrics/asr_bleu.py +++ b/nemo/collections/speechlm2/parts/metrics/asr_bleu.py @@ -57,7 +57,7 @@ def reset(self): def update( self, name: str, refs: list[str], pred_audio: torch.Tensor, pred_audio_lens: torch.Tensor = None - ) -> None: + ) -> list[str]: if self.asr is None: self.reset() @@ -70,7 +70,7 @@ def update( batch_size=pred_audio.shape[0], verbose=False, ) - + asr_hyps_texts = [] for ref, asr_hyp in zip(refs, asr_hyps): asr_hyp = asr_hyp.text self._refs[name].append(self.normalizer(ref)) @@ -78,6 +78,9 @@ def update( if self.verbose: asrb = sacrebleu.sentence_bleu(asr_hyp, [ref]).score logging.info(f"[REF]\t{ref}\n[ASR]\t{asr_hyp} [{asrb:.2f}]") + asr_hyps_texts.append(asr_hyp) + + return asr_hyps_texts def compute(self) -> dict[str, torch.Tensor]: """Computes the final score and deallocates ASR and partial results.""" diff --git a/nemo/collections/speechlm2/parts/metrics/asr_cer_wer.py b/nemo/collections/speechlm2/parts/metrics/asr_cer_wer.py new file mode 100644 index 000000000000..15823dc86647 --- /dev/null +++ b/nemo/collections/speechlm2/parts/metrics/asr_cer_wer.py @@ -0,0 +1,120 @@ +# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from collections import defaultdict + +import torch +from whisper_normalizer.english import EnglishTextNormalizer + +from nemo.collections.asr.metrics.wer import word_error_rate +from nemo.collections.asr.models import ASRModel +from nemo.collections.common.parts.optional_cuda_graphs import WithOptionalCudaGraphs +from nemo.collections.speechlm2.parts.precision import fp32_precision +from nemo.collections.speechlm2.parts.pretrained import load_pretrained_nemo + + +class Intelligibility: + """ + Computes CER on ASR predictions on generated audio with pretrained NeMo ASR. + By default, uses Whisper's EnglishTextNormalizer on hypotheses and references. + """ + + def __init__( + self, + pretrained_asr: str, + normalize: bool = True, + normalizer=None, + verbose: bool = True, + reuse_asr_hyps: bool = False, + ) -> None: + self.asr = None # load into memory on reset() + self.pretrained_asr_name = pretrained_asr + self.verbose = verbose + self.reuse_asr_hyps = reuse_asr_hyps + if normalize: + if normalizer is None: + self.normalizer = EnglishTextNormalizer() + else: + self.normalizer = normalizer + else: + self.normalizer = _identity + + self._refs = defaultdict(list) + self._hyps = defaultdict(list) + + def reset(self): + # Cleaning up GPU memory before we load ASRModel, because it may already + # be quite fragmented and close to the limit after observing many + # dynamic shapes during the training epoch. + torch.cuda.memory.empty_cache() + with fp32_precision(): # Some NeMo ASR models weren't trained with bfloat16. + if not self.reuse_asr_hyps: + self.asr = load_pretrained_nemo(ASRModel, self.pretrained_asr_name).eval() + WithOptionalCudaGraphs.disable_cuda_graphs_recursive(self.asr, attribute_path="decoding.decoding") + return self + + def update( + self, + name: str, + refs: list[str], + pred_audio: torch.Tensor, + pred_audio_lens: torch.Tensor = None, + asr_hyps: list[str] = None, + ) -> list[str]: + if self.asr is None and not self.reuse_asr_hyps: + self.reset() + + if pred_audio_lens is None and pred_audio is not None: + pred_audio_lens = [pred_audio.shape[1]] * pred_audio.shape[0] + + with fp32_precision(): + if not self.reuse_asr_hyps: + asr_hyps = self.asr.transcribe( + [audio[:alen] for audio, alen in zip(pred_audio, pred_audio_lens)], + batch_size=pred_audio.shape[0], + verbose=False, + ) + asr_hyps = [asr_hyp.text for asr_hyp in asr_hyps] + + asr_hyps_texts = [] + for ref, asr_hyp in zip(refs, asr_hyps): + self._refs[name].append(self.normalizer(ref)) + self._hyps[name].append(self.normalizer(asr_hyp)) + asr_hyps_texts.append(asr_hyp) + + return asr_hyps_texts + + def compute(self) -> dict[str, torch.Tensor]: + """Computes the final score and deallocates ASR and partial results.""" + corpus_metric = {} + corpus_metric["cer"] = [] + corpus_metric["wer"] = [] + for name in self._refs.keys(): + cer = torch.tensor(word_error_rate(self._hyps[name], self._refs[name], use_cer=True)) + wer = torch.tensor(word_error_rate(self._hyps[name], self._refs[name], use_cer=False)) + corpus_metric[f"cer_{name}"] = cer + corpus_metric[f"wer_{name}"] = wer + corpus_metric["cer"].append(cer) + corpus_metric["wer"].append(wer) + + corpus_metric["cer"] = torch.stack(corpus_metric["cer"]).mean() + corpus_metric["wer"] = torch.stack(corpus_metric["wer"]).mean() + self._refs.clear() + self._hyps.clear() + self.asr = None # free up GPU memory + torch.cuda.memory.empty_cache() + return corpus_metric + + +def _identity(x): + return x diff --git a/nemo/collections/speechlm2/parts/metrics/results_logger.py b/nemo/collections/speechlm2/parts/metrics/results_logger.py new file mode 100644 index 000000000000..c9f66212c137 --- /dev/null +++ b/nemo/collections/speechlm2/parts/metrics/results_logger.py @@ -0,0 +1,223 @@ +# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import json +import os +import shutil + +import soundfile as sf +import torch + +from nemo.collections.audio.parts.utils.resampling import resample +from nemo.utils import logging + +""" +Utilities for logging evaluation results of SpeechLM2 collection models. + +This file provides helper functionality for saving audio outputs and structured +metadata during evaluation or inference of Duplex speech-to-speech / TTS models. +It is primarily responsible for: + + - Writing predicted waveforms to disk. + - Merging user and model audio into multi-channel WAV files for analysis. + - Exporting metadata (reference text, predictions, ASR output) into JSONL format. + - Saving auxiliary debug artifacts such as: + * teacher-forced predictions, + * reference audio, + * trimmed outputs, + * end-of-utterance (EOU) probability signals. + +Unlike other files in this directory, which focus on metric evaluation, this module +is dedicated to persisting model outputs — including predicted audio samples and +their associated metadata — for later inspection and analysis. + +Key abstraction: + - `ResultsLogger`: A lightweight utility class that manages audio dumping + and metadata bookkeeping across inference batches. +""" + + +def safe_remove_path(path): + shutil.rmtree(path, ignore_errors=True) + + +class ResultsLogger: + """ + Saves audios and a json file with the model outputs. + """ + + def __init__(self, save_path): + self.save_path = save_path + self.audio_save_path = os.path.join(save_path, "pred_wavs") + os.makedirs(self.audio_save_path, exist_ok=True) + self.matadata_save_path = os.path.join(save_path, "metadatas") + os.makedirs(self.matadata_save_path, exist_ok=True) + + def reset(self): + # ensures that we are cleaning the metadata files + metadata_files = os.listdir(self.matadata_save_path) + for f in metadata_files: + open(os.path.join(self.matadata_save_path, f), 'w').close() + + # clean out any existing .wav predictions safely + try: + audio_files = os.listdir(self.audio_save_path) + for f in audio_files: + if f.lower().endswith(".wav"): + try: + os.remove(os.path.join(self.audio_save_path, f)) + except FileNotFoundError: + pass # already gone + except Exception: + logging.warning(f"Failed to remove audio file {f} during reset.", stack_info=False) + except FileNotFoundError: + # directory somehow missing: recreate it + os.makedirs(self.audio_save_path, exist_ok=True) + + return self + + @staticmethod + def merge_and_save_audio( + out_audio_path: str, pred_audio: torch.Tensor, pred_audio_sr: int, user_audio: torch.Tensor, user_audio_sr: int + ) -> None: + # if user_audio is None ignore it + if user_audio is not None: + user_audio = resample(user_audio.float(), user_audio_sr, pred_audio_sr) + T1, T2 = pred_audio.shape[0], user_audio.shape[0] + max_len = max(T1, T2) + pred_audio_padded = torch.nn.functional.pad(pred_audio, (0, max_len - T1), mode='constant', value=0) + user_audio_padded = torch.nn.functional.pad(user_audio, (0, max_len - T2), mode='constant', value=0) + + # combine audio in a multichannel audio + combined_wav = torch.cat( + [ + user_audio_padded.squeeze().unsqueeze(0).detach().cpu(), + pred_audio_padded.squeeze().unsqueeze(0).detach().cpu(), + ], + dim=0, + ).squeeze() + + else: + combined_wav = pred_audio.unsqueeze(0).detach().cpu() + + # save audio + os.makedirs(os.path.dirname(out_audio_path), exist_ok=True) + sf.write(out_audio_path, combined_wav.numpy().astype('float32').T, pred_audio_sr) + logging.info(f"Audio saved at: {out_audio_path}") + + def update( + self, + name: str, + refs: list[str], + hyps: list[str], + asr_hyps: list[str], + samples_id: list[str], + pred_audio: torch.Tensor, + pred_audio_sr: int, + user_audio: torch.Tensor, + user_audio_sr: int, + target_audio: torch.Tensor = None, + pred_audio_tf: torch.Tensor = None, + pre_audio_trimmed: torch.Tensor = None, + eou_pred: torch.Tensor = None, + fps: float = None, + results=None, + tokenizer=None, + reference_audio: torch.Tensor = None, + ) -> None: + + out_json_path = os.path.join(self.matadata_save_path, f"{name}.json") + out_dicts = [] + for i in range(len(refs)): + # save audio + sample_id = samples_id[i][:150] # make sure that sample id is not too big + out_audio_path = os.path.join(self.audio_save_path, f"{name}_{sample_id}.wav") + self.merge_and_save_audio( + out_audio_path, + pred_audio[i], + pred_audio_sr, + user_audio[i] if user_audio is not None else None, + user_audio_sr, + ) + + if pred_audio_tf is not None: + out_audio_path_tf = out_audio_path.replace(".wav", "_tf.wav") + self.merge_and_save_audio( + out_audio_path_tf, + pred_audio_tf[i], + pred_audio_sr, + user_audio[i] if user_audio is not None else None, + user_audio_sr, + ) + + if target_audio is not None: + out_audio_path_gt = out_audio_path.replace(".wav", "_GT.wav") + self.merge_and_save_audio( + out_audio_path_gt, + target_audio[i], + pred_audio_sr, + user_audio[i] if user_audio is not None else None, + user_audio_sr, + ) + + # create a wav with eou prediction for debug purposes + if eou_pred is not None: + out_audio_path_eou = os.path.join(self.audio_save_path, f"{name}_{sample_id}_eou.wav") + repeat_factor = int(pred_audio_sr / fps) + eou_pred_wav = ( + eou_pred[i].unsqueeze(0).unsqueeze(-1).repeat(1, 1, repeat_factor) + ) # (B, T, repeat_factor) + eou_pred_wav = eou_pred_wav.view(1, -1) # (B, T * repeat_factor) + eou_pred_wav = eou_pred_wav.float() * 0.8 # make 1 audible and keep 0 as total silence + sf.write( + out_audio_path_eou, + eou_pred_wav.squeeze().unsqueeze(0).detach().cpu().numpy().astype('float32').T, + pred_audio_sr, + ) + + if pre_audio_trimmed is not None: + out_audio_path_trimmed = os.path.join(self.audio_save_path, f"{name}_{sample_id}_pred_trimmed.wav") + sf.write( + out_audio_path_trimmed, + pre_audio_trimmed[i].squeeze().unsqueeze(0).detach().cpu().numpy().astype('float32').T, + pred_audio_sr, + ) + + if reference_audio is not None: + out_audio_path_ref = os.path.join(self.audio_save_path, f"{name}_{sample_id}_spk_reference.wav") + sf.write( + out_audio_path_ref, + reference_audio[i].squeeze().unsqueeze(0).detach().cpu().numpy().astype('float32').T, + pred_audio_sr, + ) + + # cache metadata + out_dict = { + "target_text": refs[i], + "pred_text": hyps[i], + "speech_pred_transcribed": asr_hyps[i], + "audio_path": os.path.relpath(out_audio_path, self.save_path), + } + if results is not None: + if tokenizer is not None: + out_dict['tokens_text'] = " ".join(tokenizer.ids_to_tokens(results['tokens_text'][i])) + else: + out_dict['tokens_text'] = results['tokens_text'][i].tolist() + out_dicts.append(out_dict) + # uses append here to avoid needs to cache + with open(out_json_path, 'a+', encoding='utf-8') as fout: + for out_dict in out_dicts: + fout.write(json.dumps(out_dict, ensure_ascii=False, indent=4) + '\n') + # json.dump(out_dict, fout) + + logging.info(f"Metadata file for {name} dataset updated at: {out_json_path}") diff --git a/nemo/collections/speechlm2/parts/metrics/secs.py b/nemo/collections/speechlm2/parts/metrics/secs.py new file mode 100644 index 000000000000..dfb9379ae2a2 --- /dev/null +++ b/nemo/collections/speechlm2/parts/metrics/secs.py @@ -0,0 +1,74 @@ +# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from collections import defaultdict + +import torch + +from nemo.collections.asr.models import EncDecSpeakerLabelModel +from nemo.collections.speechlm2.parts.precision import fp32_precision + + +class SECS: + """ + Computes Speacker encoder cossine similarity (SECS) on generated audio with pretrained speaker encoder model. + """ + + def __init__(self, pretrained_se_name: str) -> None: + self.speaker_encoder = None # load into memory on reset() + self.pretrained_se_name = pretrained_se_name + self._secs = defaultdict(list) + + def reset(self): + # Cleaning up GPU memory before we load ASRModel, because it may already + # be quite fragmented and close to the limit after observing many + # dynamic shapes during the training epoch. + torch.cuda.memory.empty_cache() + with fp32_precision(): + self.speaker_encoder = EncDecSpeakerLabelModel.from_pretrained(model_name=self.pretrained_se_name).eval() + + return self + + def update( + self, + name: str, + target_audio: torch.Tensor, + target_audio_lens: torch.Tensor, + pred_audio: torch.Tensor, + pred_audio_lens: torch.Tensor, + ) -> None: + if self.speaker_encoder is None: + self.reset() + + with fp32_precision(): + with torch.no_grad(): + _, t_g = self.speaker_encoder(input_signal=target_audio, input_signal_length=target_audio_lens.long()) + _, s_g = self.speaker_encoder(input_signal=pred_audio, input_signal_length=pred_audio_lens.long()) + secs = torch.nn.functional.cosine_similarity(t_g, s_g, dim=-1).mean() + + self._secs[name].append(secs) + + def compute(self) -> dict[str, torch.Tensor]: + """Computes the final score and deallocates ASR and partial results.""" + corpus_metric = {} + avg_secs = [] + for name in self._secs.keys(): + secs = torch.stack(self._secs[name]).mean() + corpus_metric[f"secs_{name}"] = secs + avg_secs.append(secs) + + corpus_metric["secs"] = torch.stack(avg_secs).mean() + self._secs.clear() + self.speaker_encoder = None # free up GPU memory + torch.cuda.memory.empty_cache() + return corpus_metric diff --git a/nemo/collections/speechlm2/parts/metrics/token_accuracy.py b/nemo/collections/speechlm2/parts/metrics/token_accuracy.py new file mode 100644 index 000000000000..a10300dd5d14 --- /dev/null +++ b/nemo/collections/speechlm2/parts/metrics/token_accuracy.py @@ -0,0 +1,86 @@ +# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from collections import defaultdict +import torch + + +def compute_token_accuracy_with_tolerance(target, pred, token, tolerance=1): + """ + Computes the accuracy of `token` in `target` vs `pred` within a ±`tolerance` window. + + Args: + target (torch.Tensor): Batch of target sequences (batch_size, seq_len) + pred (torch.Tensor): Batch of predicted sequences (batch_size, seq_len) + token (int): The token to compute accuracy for + tolerance (int): Allowed index difference (window) for correct predictions + + Returns: + float: Accuracy as correct / total occurrences of `token` in target + """ + batch_size, seq_len = target.shape + + # Mask of positions where token appears + target_mask = target == token + pred_mask = pred == token + + correct = 0 + total = 0 + + # For each sequence in the batch + for b in range(batch_size): + # Get indices of token in target and pred + target_indices = target_mask[b].nonzero(as_tuple=True)[0] + pred_indices = pred_mask[b].nonzero(as_tuple=True)[0] + + total += len(target_indices) + if len(pred_indices) == 0: + continue # No token in pred, none can be correct + + # Compute pairwise distances + distances = torch.abs(target_indices[:, None] - pred_indices[None, :]) # shape (n_target, n_pred) + min_distances, _ = torch.min(distances, dim=1) # min distance for each target occurrence + + # Count how many target tokens have a pred within tolerance + correct += torch.sum(min_distances <= tolerance).item() + + accuracy = correct / total if total > 0 else 0.0 + return accuracy + + +class TokenAccuracy: + """ + Computes Token Accuracy scores. + """ + + def __init__(self, token_name: str, token_id: int, tolerance: int = 1, verbose: bool = True): + self.token_name = token_name + self.token_id = token_id + self.tolerance = tolerance + self.verbose = verbose + self.scores = defaultdict(list) + + def reset(self): + return self + + def update(self, name: str, refs: torch.Tensor, hyps: torch.Tensor) -> None: + token_acc = compute_token_accuracy_with_tolerance(refs, hyps, token=self.token_id, tolerance=self.tolerance) + self.scores[name].append(token_acc) + + def compute(self) -> dict[str, torch.Tensor]: + corpus_metric = {} + for name in self.scores.keys(): + metric = torch.tensor(self.scores[name]).mean() + corpus_metric[f"token_acc_{self.token_name}_{name}"] = metric + self.scores.clear() + return corpus_metric diff --git a/nemo/collections/speechlm2/parts/pretrained.py b/nemo/collections/speechlm2/parts/pretrained.py index 7c5696457c65..c43da2a9747a 100644 --- a/nemo/collections/speechlm2/parts/pretrained.py +++ b/nemo/collections/speechlm2/parts/pretrained.py @@ -13,17 +13,19 @@ # limitations under the License. from contextlib import contextmanager from pathlib import Path +from typing import Dict import torch from omegaconf import open_dict from peft import PeftModel +from safetensors.torch import load_file from transformers import AutoConfig, AutoModelForCausalLM from nemo.collections.asr.models import ASRModel from nemo.collections.speechlm2.modules import AudioPerceptionModule - from nemo.collections.speechlm2.parts.precision import fp32_precision from nemo.collections.tts.models import AudioCodecModel +from nemo.utils import logging def load_pretrained_nemo(cls, model_path_or_name: str): @@ -99,3 +101,73 @@ def setup_speech_encoder(model: torch.nn.Module, pretrained_weights: bool = True model.perception.load_state_dict(asr.state_dict(), strict=False) else: model.perception = AudioPerceptionModule(model.cfg.perception).train() + + +def set_model_dict_for_partial_init( + pretrained_dict: Dict[str, torch.Tensor], model_dict: Dict[str, torch.Tensor] +) -> Dict[str, torch.Tensor]: + """ + Partially initialize a model's state dictionary with a pretrained state dictionary. + This function safely copies compatible layers from a pretrained model into a new model, + ignoring layers with mismatched shapes or missing keys. + + Steps: + 1. Remove layers from the pretrained dictionary if their shape does not match the target model. + 2. Keep only keys that exist in the target model. + 3. Update the model dictionary with the filtered pretrained weights. + + Args: + pretrained_dict (Dict[str, torch.Tensor]): + The state dictionary of the pretrained model. + model_dict (Dict[str, torch.Tensor]): + The state dictionary of the target model to be partially initialized. + + Returns: + Dict[str, torch.Tensor]: + The updated model state dictionary with compatible layers loaded from the pretrained dictionary. + + Example: + >>> model_dict = model.state_dict() + >>> pretrained_dict = load_checkpoint("pretrained_model.ckpt") + >>> model_dict = set_model_dict_for_partial_init(pretrained_dict, model_dict) + >>> model.load_state_dict(model_dict) + """ + # 1. Remove layers where pretrained shape differs from model shape + for k, v in list(pretrained_dict.items()): + if k in model_dict and hasattr(model_dict[k], "numel") and v.numel() != model_dict[k].numel(): + del pretrained_dict[k] + logging.info(f" | > Layer with shape mismatch in the model definition: {k}") + + # 2. Keep only keys that exist in the target model + pretrained_dict = {k: v for k, v in pretrained_dict.items() if k in model_dict} + + # 3. Update model dictionary with filtered pretrained layers + model_dict.update(pretrained_dict) + logging.info(f" | > {len(pretrained_dict)} / {len(model_dict)} layers are restored.") + + return model_dict + + +def load_checkpoint(checkpoint_path): + """ + Load a model checkpoint from disk. + + Supports loading checkpoints stored in either PyTorch (`.ckpt`, `.pt`) or + SafeTensors (`.safetensors`) formats. All parameters are loaded onto CPU + regardless of the original device. + + Args: + checkpoint_path (str): + Path to the checkpoint file. If the filename contains `.safetensors`, + it is loaded using the SafeTensors backend; otherwise, it is assumed + to be a PyTorch checkpoint containing a `state_dict` field. + + Returns: + dict: + A state dictionary mapping parameter names to tensors. + """ + if ".safetensors" in checkpoint_path: + checkpoint_state = load_file(checkpoint_path, device="cpu") + else: + checkpoint_state = torch.load(checkpoint_path, weights_only=False, map_location="cpu")["state_dict"] + return checkpoint_state diff --git a/tests/collections/speechlm2/test_duplex_eartts.py b/tests/collections/speechlm2/test_duplex_eartts.py new file mode 100644 index 000000000000..2a56d2afe996 --- /dev/null +++ b/tests/collections/speechlm2/test_duplex_eartts.py @@ -0,0 +1,581 @@ +# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os +import pytest +import torch +from lhotse import CutSet, SupervisionSegment +from lhotse.testing.dummies import dummy_cut, dummy_recording + +from nemo.collections.common.data.utils import move_data_to_device +from nemo.collections.speechlm2.data.duplex_ear_tts_dataset import ( + DuplexEARTTSDataset, + add_speech_delay, + sample_audio_segments_repeat, +) +from nemo.collections.speechlm2.models import DuplexEARTTS + + +if torch.cuda.is_available(): + torch.set_default_device('cuda') + + +test_eartts_config = { + "model": { + "pretrained_lm_name": "nvidia/NVIDIA-Nemotron-Nano-9B-v2", + "pretrained_ae_dir": None, + "pretrained_tts_model": None, + "scoring_asr": "stt_en_fastconformer_transducer_large", + "freeze_params": [ + r"^audio_codec\..+$", # Keep audio codec frozen as it only provides supervision for training. + r"^embed_tokens\..+$", # Keep embed_tokens frozen as done in eartts + ], + "bos_token": "", + "eos_token": "", + "pad_token": "", + "audio_codec_run_dtype": "float32", + "prevent_freeze_params": [], + "audio_save_path": "", + "inference_guidance_scale": 0.5, + "inference_noise_scale": 0.8, + "inference_top_p_or_k": 0.8, + "inference_guidance_enabled": False, + "subword_mask_exactly_as_eartts": False, + "context_hidden_mask_exactly_as_eartts": False, + "optimizer": { + "_target_": "torch.optim.AdamW", + "lr": 4e-5, + "betas": [0.9, 0.98], + "weight_decay": 0, + "foreach": True, + }, + "lr_scheduler": { + "_target_": "nemo.core.optim.lr_scheduler.InverseSquareRootAnnealing", + "warmup_steps": 2500, + "min_lr": 1e-6, + "max_steps": 100_000_000, + }, + "codec_config": { + "latent_size": 512, + "n_fft": 16, + "hop_length": 4, + "base_hidden_size": 384, + "channel_mult": [1, 2, 4], + "rates": [7, 7, 9], + "num_blocks": 3, + "kernel_size": 7, + "groups": 1, + "codebook_size": 1024, + "num_quantizers": 31, + "wav_to_token_ratio": 1764, + }, + "tts_config": { + "use_gated_fusion_for_text_audio": True, + "disable_eos_prediction": True, + "use_bos_eos_emb": True, + "use_subword_flag_emb": True, + "num_delay_speech_tokens": 2, + "backbone_type": "gemma3_text", + "backbone_model_class": None, + "backbone_config_class": None, + "backbone_config": { + "hidden_size": 1152, + "intermediate_size": 4608, + "num_hidden_layers": 1, + "num_attention_heads": 16, + "num_key_value_heads": 16, + "head_dim": 72, + "attention_dropout": 0.1, + "use_cache": False, + }, + "latent_size": 512, + "codebook_size": 1024, + "num_quantizers": 31, + "context_hidden_size": None, + "cas_config": { + "backbone_type": "t5gemma", + "backbone_model_class": None, + "backbone_config_class": None, + "backbone_config": { + "is_encoder_decoder": False, + "encoder": { + "hidden_size": 1152, + "intermediate_size": 4608, + "num_hidden_layers": 1, + "num_attention_heads": 16, + "num_key_value_heads": 16, + "head_dim": 72, + "use_cache": False, + "attention_dropout": 0.1, + }, + }, + }, + "mog_head_config": { + "intermediate_size": 4608, + "num_layers": 3, + "low_rank": 64, + "num_predictions": 1024, + "min_log_std": -4.0, + "eps": 1e-6, + }, + "p_uncond": 0.1, + "label_smoothing": 0.01, + "max_training_rate": 0.8, + "quantizer_dropout": 0.5, + "random_target_masking": False, + "exponent": 3.0, + }, + }, + "trainer": { + "devices": -1, + "accelerator": "gpu", + "num_nodes": 1, + "precision": 32, + "logger": False, + "enable_checkpointing": False, + "use_distributed_sampler": False, + "max_steps": 100_000_000, + "val_check_interval": 1000, + "limit_train_batches": "${trainer.val_check_interval}", + "limit_val_batches": 2, + "log_every_n_steps": 20, + "num_sanity_val_steps": 0, + "gradient_clip_val": 1.0, + "accumulate_grad_batches": 1, + "strategy": { + "_target_": "lightning.pytorch.strategies.DDPStrategy", + "gradient_as_bucket_view": True, + "find_unused_parameters": True, + }, + }, + "data": { + "add_text_bos_and_eos_in_each_turn": True, + "add_audio_prompt": True, + "audio_prompt_duration": 3.0, + "frame_length": 0.08, + "source_sample_rate": 22050, + "target_sample_rate": 22050, + "input_roles": ["user", "User"], + "output_roles": ["agent", "Assistant", "assistant", "Agent"], + }, + "exp_manager": { + "exp_dir": None, + "explicit_log_dir": "", + "name": "eartts", + "create_tensorboard_logger": False, + "create_checkpoint_callback": True, + "use_datetime_version": True, + "max_time_per_run": "00:03:50:00", + "resume_from_checkpoint": None, + "resume_if_exists": True, + "resume_ignore_no_checkpoint": True, + "create_wandb_logger": True, + "wandb_logger_kwargs": { + "name": "duplex_eartts_test", + "project": "duplex_eartts", + "resume": True, + }, + }, +} + +# set CI cached path +if os.path.exists("/home/TestData/"): + test_eartts_config["model"]["pretrained_lm_name"] = "/home/TestData/nvidia--NVIDIA-Nemotron-Nano-9B-v2/" + + +@pytest.fixture(scope="session") +def model(): + model = DuplexEARTTS(test_eartts_config) + if torch.cuda.is_available(): + model.to("cuda") + return model + + +@pytest.fixture(scope="session") +def dataset(model): + return DuplexEARTTSDataset( + model.tokenizer, + add_text_bos_and_eos_in_each_turn=True, + add_audio_prompt=True, + audio_prompt_duration=3.0, + frame_length=0.08, + source_sample_rate=22050, + target_sample_rate=22050, + input_roles=["user", "User"], + output_roles=["agent", "Assistant", "assistant", "Agent"], + ) + + +@pytest.fixture(scope="session") +def training_cutset_batch(): + cut = dummy_cut(0, recording=dummy_recording(0, with_data=True, duration=1.0, sampling_rate=22050)) + cut.target_audio = dummy_recording(1, with_data=True, duration=1.0, sampling_rate=22050) + cut.supervisions = [ + SupervisionSegment( + id=cut.id, + recording_id=cut.recording_id, + start=0, + duration=0.1, + text='hi', + speaker="user", + ), + SupervisionSegment( + id=cut.id, + recording_id=cut.recording_id, + start=0.3, + duration=0.1, + text='hello', + speaker="assistant", + ), + SupervisionSegment( + id=cut.id, + recording_id=cut.recording_id, + start=0.5, + duration=0.1, + text='ok', + speaker="user", + ), + SupervisionSegment( + id=cut.id, + recording_id=cut.recording_id, + start=0.6, + duration=0.1, + text='okay', + speaker="assistant", + ), + ] + return CutSet([cut]) + + +def test_eartts_dataset(dataset, training_cutset_batch): + batch = dataset[training_cutset_batch] + expected_keys = { + "sample_id", + "non_prompt_mask", + "prompt_lens", + "aligned_attention_mask", + "aligned_position_ids", + "source_audio", + "source_audio_lens", + "target_audio", + "target_audio_lens", + "target_text_tokens", + "target_token_lens", + "source_tokens", + "source_token_lens", + "target_texts", + "audio_prompt", + "audio_prompt_lens", + "formatter", + } + + for key in expected_keys: + assert key in batch, f"Missing key: {key}" + + tensor_keys = [ + "non_prompt_mask", + "aligned_attention_mask", + "aligned_position_ids", + "source_audio", + "source_audio_lens", + "target_audio", + "target_audio_lens", + "target_text_tokens", + "target_token_lens", + "source_tokens", + "source_token_lens", + "audio_prompt", + "audio_prompt_lens", + ] + + for key in tensor_keys: + assert torch.is_tensor(batch[key]), f"{key} must be a tensor" + + # Check target text consistency + assert batch["target_texts"] == ["hello okay"] + assert batch["source_tokens"].tolist() == [ + [ + 2, + 12, + 12, + 12, + 12, + 12, + 12, + 12, + 12, + 12, + 12, + 12, + 12, + 12, + 12, + 12, + 12, + 12, + 12, + 12, + 12, + 12, + 12, + 12, + 12, + 12, + 12, + 12, + 12, + 12, + 12, + 12, + 12, + 12, + 12, + 12, + 12, + 2, + 1, + 2, + 12, + 12, + 12, + 12, + 1, + 1662, + 2, + 12, + 12, + 12, + 12, + ] + ] + + assert batch["target_text_tokens"].tolist() == [ + [ + 2, + 12, + 12, + 12, + 12, + 12, + 12, + 12, + 12, + 12, + 12, + 12, + 12, + 12, + 12, + 12, + 12, + 12, + 12, + 12, + 12, + 12, + 12, + 12, + 12, + 12, + 12, + 12, + 12, + 12, + 12, + 12, + 12, + 12, + 12, + 12, + 12, + 2, + 12, + 12, + 12, + 12, + 1, + 2, + 12, + 12, + 1, + 2, + 1417, + 12, + 12, + ] + ] + + # Check formatter + assert batch["formatter"] == ["s2s_duplex"] + + +# test extra functions inside of eartts dataset +def test_add_speech_delay(): + source_audio = torch.ones(1, 16000) + target_audio = torch.ones(1, 22050) + + source_lens = torch.tensor([16000]) + target_lens = torch.tensor([22050]) + + num_delays = 2 + + # samples per frame (float → int handled explicitly) + target_samples_per_frame = source_audio.size(1) / 12.5 + source_samples_per_frame = target_audio.size(1) / 12.5 + + expected_extra_src_size = int(source_samples_per_frame * num_delays) + expected_extra_tgt_size = int(target_samples_per_frame * num_delays) + + out_src, out_src_lens, out_tgt, out_tgt_lens = add_speech_delay( + source_audio=source_audio, + source_audio_lens=source_lens, + target_audio=target_audio, + target_audio_lens=target_lens, + num_delay_speech_tokens=num_delays, + target_samples_per_frame=target_samples_per_frame, + source_samples_per_frame=source_samples_per_frame, + ) + + # -------------------------------------------------- + # Shape & length bookkeeping + # -------------------------------------------------- + assert out_src.shape == (1, source_audio.size(1) + expected_extra_src_size) + assert out_tgt.shape == (1, target_audio.size(1) + expected_extra_tgt_size) + assert out_src_lens.item() == source_lens.item() + expected_extra_src_size + assert out_tgt_lens.item() == target_lens.item() + expected_extra_tgt_size + + # -------------------------------------------------- + # Padding direction & content + # -------------------------------------------------- + # Target audio is left-padded + assert torch.all(out_tgt[:, :expected_extra_tgt_size] == 0) + assert torch.all(out_tgt[:, expected_extra_tgt_size:] == 1) + + # Source audio is right-padded + assert torch.all(out_src[:, : source_audio.size(1)] == 1) + assert torch.all(out_src[:, source_audio.size(1) :] == 0) + + +def test_sample_audio_segments_repeat(): + cases = [ + # (audio, lens, n_sample, expected_when_sample_false) + ( + torch.tensor([[1.0, 2.0, 3.0, 4.0, 5.0]]), + torch.tensor([5]), + 3, + torch.tensor([[1.0, 2.0, 3.0]]), + ), + ( + torch.tensor([[1.0, 2.0]]), + torch.tensor([2]), + 5, + torch.tensor([[1.0, 2.0, 1.0, 2.0, 1.0]]), + ), + ( + torch.zeros(1, 10), + torch.tensor([0]), + 4, + torch.zeros(1, 4), + ), + ] + + for prompt_audio, prompt_audio_lens, n_sample, expected in cases: + # -------------------------------------------------- + # sample=False → deterministic + sequence check + # -------------------------------------------------- + out = sample_audio_segments_repeat( + prompt_audio, + prompt_audio_lens, + n_sample=n_sample, + sample=False, + ) + + assert out.shape == expected.shape + assert torch.equal(out, expected) + + # -------------------------------------------------- + # sample=True → stochastic, shape only + # -------------------------------------------------- + out = sample_audio_segments_repeat( + prompt_audio, + prompt_audio_lens, + n_sample=n_sample, + sample=True, + ) + + assert out.shape == expected.shape + + +def test_eartts_training_step(model, dataset, training_cutset_batch): + model.train() + model.on_train_epoch_start() + batch = dataset[training_cutset_batch] + batch = move_data_to_device(batch, device=model.device) + results = model.training_step(batch, batch_idx=0) + assert torch.is_tensor(results["loss"]) + assert not torch.isnan(results["loss"]) + assert results["loss"] > 0 + + +def test_eartts_validation_step(model, dataset, training_cutset_batch): + model.eval() + model.on_validation_epoch_start() + batch = dataset[training_cutset_batch] + batch = move_data_to_device(batch, device=model.device) + results = model.validation_step({"dummy_val_set": batch}, batch_idx=0) + assert results is None # no return value + + +def test_eartts_offline_generation(model): + model.eval() + # generate random subword_ids + subword_ids = torch.ones(2, 10).long() + + # set init inputs and get it + model.set_init_inputs( + speaker_audio=torch.randn(1, 22050), + speaker_audio_lens=torch.tensor([22050]), + ) + init_inputs = model.get_init_inputs(B=subword_ids.size(0)) + torch.manual_seed(42) + if torch.cuda.is_available(): + torch.cuda.manual_seed_all(42) + + gen_audio, gen_audio_len = model.offline_inference( + next_subword_ids=subword_ids, + init_inputs=init_inputs, + ) + torch.manual_seed(42) + if torch.cuda.is_available(): + torch.cuda.manual_seed_all(42) + + gen_audio_inc, gen_audio_len_inc = model.offline_inference( + next_subword_ids=subword_ids, init_inputs=init_inputs, incremental_audio_decoding=True + ) + + assert torch.equal( + gen_audio_len, gen_audio_len_inc + ), "Audio lengths differ between incremental and non-incremental decoding." + + # compare waveform + torch.testing.assert_close( + gen_audio, + gen_audio_inc, + atol=1e-1, + rtol=0, + ) + + assert gen_audio.shape == (2, 17640) + assert gen_audio_len[0] == gen_audio.size(-1) + assert gen_audio.dtype == torch.float32 diff --git a/tests/collections/speechlm2/test_metrics.py b/tests/collections/speechlm2/test_metrics.py index 4972a7418fa2..9be859545587 100644 --- a/tests/collections/speechlm2/test_metrics.py +++ b/tests/collections/speechlm2/test_metrics.py @@ -13,7 +13,7 @@ # limitations under the License. import pytest -from nemo.collections.speechlm2.parts.metrics import BLEU, WER +from nemo.collections.speechlm2.parts.metrics import BLEU, WER, Intelligibility def test_bleu(): @@ -50,3 +50,27 @@ def test_wer(): assert ans["wer_dataset_1"] == 0.0 assert ans["wer_dataset_2"] == 1 / 3 assert ans["wer"] == 1 / 6 # average across datasets + + +def test_intelligibility(): + metric = Intelligibility(pretrained_asr=None, verbose=False, reuse_asr_hyps=True) + metric.update( + name="dataset_1", + refs=["a b c d e f g h i j k l", "m n o p r s t u v"], + asr_hyps=["a b c d e f g h i j k l", "m n o p r s t u v"], + pred_audio=None, + ) + metric.update( + name="dataset_2", + refs=["a b c"], + asr_hyps=["a b d"], + pred_audio=None, + ) + ans = metric.compute() + # wer + assert ans["wer_dataset_1"] == 0.0 + assert ans["wer_dataset_2"] == 1 / 3 + assert ans["wer"] == 1 / 6 # average across datasets + # cer + assert ans["cer_dataset_1"] == 0.0 + assert ans["cer_dataset_2"] == 1 / 5