Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
102 commits
Select commit Hold shift + click to select a range
d48bd1c
Add Duplex EARTTS modules
Edresson Nov 11, 2025
e3de872
Add EARTTS codec and extra missing modules
Edresson Nov 11, 2025
3742e35
Apply isort and black reformatting
Edresson Nov 11, 2025
326afc3
Add set_init_inputs and get_init_input methods
Edresson Nov 12, 2025
34e0413
Apply isort and black reformatting
Edresson Nov 12, 2025
ec72055
Add from config codec instanciation
Edresson Nov 13, 2025
3df84cc
Apply isort and black reformatting
Edresson Nov 13, 2025
81513bf
Remove unused imports
Edresson Nov 13, 2025
fc47181
Apply isort and black reformatting
Edresson Nov 13, 2025
fd9893f
Fix pylint issues
Edresson Nov 13, 2025
d4b61f0
Fix code scanning issues
Edresson Nov 13, 2025
2d10b4e
Add missing copyright
Edresson Nov 13, 2025
8dc232b
Apply isort and black reformatting
Edresson Nov 13, 2025
626b328
Fix code scanning issues
Edresson Nov 13, 2025
8929736
Fix merge issues
Edresson Nov 13, 2025
2f6ae33
Apply isort and black reformatting
Edresson Nov 13, 2025
3297d0c
Remove EARTTS configs and use directly DictConfig
Edresson Nov 13, 2025
c58ebd0
Apply isort and black reformatting
Edresson Nov 13, 2025
13a3998
Update
Edresson Nov 13, 2025
9e8a721
Implement EARTTS unit tests
Edresson Nov 14, 2025
1c6af8d
Apply isort and black reformatting
Edresson Nov 14, 2025
ea050e8
Add incremental decoding and unit test for it
Edresson Nov 14, 2025
814bc9e
Apply isort and black reformatting
Edresson Nov 14, 2025
5c4f33b
Add option to run codec in bf16 to speedup
Edresson Nov 14, 2025
e1cf144
Add docs
Edresson Nov 14, 2025
b30bbdb
Apply isort and black reformatting
Edresson Nov 14, 2025
433823e
rename codec context manager precision function
Edresson Nov 14, 2025
d9e74ab
Rename init_model_from_another_checkpoint to restore_from_pretrained_…
Edresson Nov 14, 2025
89fcb32
Replace torchaudio with librosa
Edresson Nov 16, 2025
10c95ff
Apply isort and black reformatting
Edresson Nov 16, 2025
4a451d0
Replace torchaudio with librosa on codec
Edresson Nov 17, 2025
ff662ed
Apply isort and black reformatting
Edresson Nov 17, 2025
36b09f8
Add sensitive_layers parameter
Edresson Nov 17, 2025
54b5418
Apply isort and black reformatting
Edresson Nov 17, 2025
71be84a
Remove torchaudio from metrics
Edresson Nov 17, 2025
65c084d
Update lhotse formmaters
Edresson Nov 17, 2025
4272968
Apply isort and black reformatting
Edresson Nov 17, 2025
7b71f41
Docs for set_model_dict_for_partial_init
Edresson Nov 17, 2025
e6d5977
Apply isort and black reformatting
Edresson Nov 17, 2025
41ce58a
Make codec sil tokens a buffer
Edresson Nov 17, 2025
2aa4e3b
Apply isort and black reformatting
Edresson Nov 17, 2025
1c3d1a6
Fix lint
Edresson Nov 17, 2025
fbf853f
Add CER/WER metrics unit test
Edresson Nov 17, 2025
4fc268d
Apply isort and black reformatting
Edresson Nov 17, 2025
a124eab
Disable triton if cuda is not available
Edresson Nov 17, 2025
31f13f9
Apply isort and black reformatting
Edresson Nov 17, 2025
d578125
Update codec run dtype
Edresson Nov 17, 2025
853d948
Add EOS dropout
Edresson Nov 21, 2025
53569ad
Apply isort and black reformatting
Edresson Nov 21, 2025
5452175
Fix Bleu
Edresson Nov 24, 2025
db2bdf9
Cleanup tests useless comments
Edresson Nov 24, 2025
a802460
Rename EARTTS files
Edresson Nov 24, 2025
da2cb92
Update EARTTS dataset docs
Edresson Nov 25, 2025
8ca4db4
Remove mixed precision fns
Edresson Nov 25, 2025
45659e4
Remove system prompt
Edresson Nov 25, 2025
544eee0
Apply isort and black reformatting
Edresson Nov 25, 2025
d32f4ef
Remove custom test sentence inference logic from the model
Edresson Nov 25, 2025
e12b192
Remove .nemo file loading support
Edresson Nov 25, 2025
1cfd4f3
Rename speaker_reference with audio_prompt
Edresson Nov 25, 2025
2caec09
Modularize dataloader get_item
Edresson Nov 25, 2025
c7d7440
Apply isort and black reformatting
Edresson Nov 25, 2025
96c3ddb
Update docs
Edresson Nov 26, 2025
f9e1066
Update EARTTS documentation
Edresson Nov 26, 2025
12625bc
Rename eval script
Edresson Nov 26, 2025
4c89963
Apply isort and black reformatting
Edresson Nov 26, 2025
0d67a58
Rename intellibility metric
Edresson Nov 26, 2025
ddda743
Apply isort and black reformatting
Edresson Nov 26, 2025
9e2b376
Add WER metric back
Edresson Nov 26, 2025
6951357
Apply isort and black reformatting
Edresson Nov 26, 2025
fa33c5b
Do not share small embeddings
Edresson Nov 26, 2025
9142d70
Remove unused imports
Edresson Nov 26, 2025
fd517a0
Reuse TTS get_mask_from_lengths
Edresson Nov 27, 2025
dc1901e
Apply isort and black reformatting
Edresson Nov 27, 2025
10df97f
Update rescale_state_dict
Edresson Nov 27, 2025
8498ea7
Add missing dataset docs
Edresson Nov 27, 2025
b887a77
Remove Pretrained class and ear_tts_commons.py
Edresson Nov 27, 2025
926e179
Apply isort and black reformatting
Edresson Nov 27, 2025
be7f689
Add top-level comment on Logger
Edresson Nov 27, 2025
912ef24
Apply isort and black reformatting
Edresson Nov 27, 2025
7ef91db
Refactor checkpoint loading
Edresson Nov 28, 2025
be12a72
Apply isort and black reformatting
Edresson Nov 28, 2025
7e7e77b
Add EOS dropout and duplication
Edresson Dec 12, 2025
8beaf8b
Apply isort and black reformatting
Edresson Dec 12, 2025
bc3eb8c
Remove debug dataloader code
Edresson Dec 12, 2025
b71490a
Remove unecessary operations on n samples_per_frame computation
Edresson Dec 15, 2025
ca277ab
Ignore sample on sample_audio_segments_repeat when audio is shorter
Edresson Dec 16, 2025
bbdf250
Move data utils to the end of the file
Edresson Dec 17, 2025
6b9966c
Remove duplicated code
Edresson Dec 17, 2025
5efbadf
Rename input_text_tokens to target_text_tokens
Edresson Dec 17, 2025
d316cc9
Move the bos/eos/pad token definition to AutoTokenizer and reuse the …
Edresson Dec 17, 2025
b14ef66
Apply isort and black reformatting
chtruong814 Dec 17, 2025
8aa65ef
Add docs and config for duplex EARTTS evaluation and set bos eos pad …
Edresson Dec 17, 2025
3ca5e28
Update docs
Edresson Dec 17, 2025
dfa7f16
Add extra unit tests
Edresson Dec 17, 2025
6460486
Apply isort and black reformatting
Edresson Dec 17, 2025
19705d1
Use CI cached path for Duplex EARTTS tests
Edresson Dec 20, 2025
edf8e14
Apply isort and black reformatting
Edresson Dec 20, 2025
421d15e
Fix eartts tests
Edresson Jan 5, 2026
f55fad0
Update CI nanov2 path
Edresson Jan 5, 2026
e204d7c
Apply isort and black reformatting
Edresson Jan 5, 2026
b51de02
Fix eartts dataset unittest
Edresson Jan 6, 2026
befea32
Fix unit test
Edresson Jan 7, 2026
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions docs/source/speechlm2/datasets.rst
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ Duplex S2S models use the Lhotse framework for audio data management. The primar

1. **DuplexS2SDataset**: For general duplex speech-to-speech models
2. **SALMDataset**: Specifically for the Speech-Augmented Language Model (SALM), which processes speech+text and outputs text.
3. **DuplexEARTTSDataset**: Dataset for Duplex EARTTS model, extending DuplexS2SDataset with additional output fields for TTS, including audio prompting. It optionally prepends an audio prompt (speaker reference) to target_audio, which is used to initialize speaker conditioning in the EARTTS model. The dataset provides audio_prompt, audio_prompt_lens, non_prompt_mask, aligned_attention_mask, and aligned_position_ids, and supports custom speaker reference audio through the context_audio field, while preserving full compatibility with the original data format.

DuplexS2S Dataset Structure
^^^^^^^^^^^^^^^^^^^^^^^^^^^
Expand Down
28 changes: 28 additions & 0 deletions docs/source/speechlm2/models.rst
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,30 @@ Core Model Architectures

The collection includes the following core model architectures:


DuplexEARTTS
^^^^^^^^^^^^

DuplexEARTTS is a streaming text-to-speech model designed for duplex speech-to-speech systems. It focuses on low-latency, fully streamable speech generation by converting text tokens into audio representations in real time.

The architecture is based on the Streaming TTS model proposed in `Audio Flamingo 3<https://arxiv.org/abs/2507.08128>`_, with several extensions for duplex interaction:

* **Gated fusion of text and audio representations**: (`GatedProjectedSumRMSNorm`), enabling better multimodal integration.
* **Subword-aware embeddings**: (`SubwordFlagEmbedding`) to improve pronunciation for words composed of multiple text tokens.
* **Custom BOS/EOS embeddings**: (`BOSEOSEmbedding`) for interruption-aware, multi-turn duplex generation.


Key components:

* **RVQVAEModel**: An RVQ-based neural audio codec that compresses speech into discrete acoustic tokens using a convolutional encoder and reconstructs high-quality audio via a convolutional decoder.
* **RVQEARTTSModel**: A streaming speech generation model that predicts multiple RVQ codebooks in parallel using a Mixture-of-Gaussians (MoG) prediction head. It produces audio tokens autoregressively from text representations with minimal latency.

DuplexEARTTS is particularly useful for:
* Duplex speech-to-speech systems requiring interruption-aware synthesis.
* Low-latency text-to-speech generation.
* Real-time conversational agents with streamed audio output.


SALM (Speech-Augmented Language Model)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^

Expand Down Expand Up @@ -78,6 +102,7 @@ Speech generation components convert text or token representations back into spe

1. **TransformerARSpeechDecoder**: An autoregressive transformer-based speech decoder
2. **Audio Codec Integration**: Works with audio codecs to generate natural speech from discrete tokens
3. **DuplexEARTTS**: A ready-to-use duplex text-to-speech model that supports user interruption via a special text interruption token. The model integrates an RVQ-based audio codec with a streaming speech generation module to enable low-latency, real-time synthesis.

Implementation Details
--------------------
Expand Down Expand Up @@ -200,6 +225,9 @@ All models in the speechlm2 collection can be instantiated from pretrained check
# Load DuplexS2SSpeechDecoderModel
decoder_model = slm.models.DuplexS2SSpeechDecoderModel.from_pretrained("path/to/checkpoint")

# Load DuplexEARTTS
decoder_model = slm.models.DuplexEARTTS.from_pretrained("path/to/checkpoint")

Model Configuration
-----------------

Expand Down
200 changes: 200 additions & 0 deletions examples/speechlm2/conf/duplex_eartts.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,200 @@
model:
pretrained_lm_name: "nvidia/NVIDIA-Nemotron-Nano-9B-v2"
pretrained_audio_codec: ??? # to be released
pretrained_tts_model: null
scoring_asr: stt_en_fastconformer_transducer_large # used only in validation/evaluation

# Regexp (re.compile) patterns matching parameters to be frozen.
freeze_params:
- "^audio_codec\\..+$" # Keep audio codec frozen as it only provides supervision for training.
- "^embed_tokens\\..+$" # Keep embed_tokens frozen as done in eartts

prevent_freeze_params: [] # Use to make specific submodules trainable; overrides freeze_params

# set custom text eos/bos/pad tokens
bos_token: "<s>"
eos_token: "</s>"
pad_token: "<SPECIAL_12>"

# inference params
inference_guidance_scale: 0.5
inference_noise_scale: 0.8
inference_top_p_or_k: 0.8
inference_guidance_enabled: true


optimizer:
_target_: torch.optim.AdamW
lr: 4e-05
betas: [0.9, 0.98]
weight_decay: 0
foreach: true # set to false if having issues with tensor-parallelism

lr_scheduler:
_target_: nemo.core.optim.lr_scheduler.InverseSquareRootAnnealing
warmup_steps: 2500
min_lr: 1e-6
max_steps: ${trainer.max_steps}

codec_config:
latent_size: 512
n_fft: 16
hop_length: 4
base_hidden_size: 384
channel_mult:
- 1
- 2
- 4
rates:
- 7
- 7
- 9
num_blocks: 3
kernel_size: 7
groups: 1
codebook_size: 1024
num_quantizers: 31
wav_to_token_ratio: 1764

tts_config:
# extra configs added
use_gated_fusion_for_text_audio: true
disable_eos_prediction: true # disable eos prediction
use_bos_eos_emb: true
use_subword_flag_emb: true
num_delay_speech_tokens: 2
# EAR-TTS configs
backbone_type: gemma3_text
backbone_model_class: null
backbone_config_class: null
backbone_config:
hidden_size: 1152
intermediate_size: 4608
num_hidden_layers: 28
num_attention_heads: 16
num_key_value_heads: 16
head_dim: 72
attention_dropout: 0.1
use_cache: false
latent_size: 512
codebook_size: 1024
num_quantizers: 31
context_hidden_size: null
cas_config:
backbone_type: t5gemma
backbone_model_class: null
backbone_config_class: null
backbone_config:
is_encoder_decoder: false
encoder:
hidden_size: 1152
intermediate_size: 4608
num_hidden_layers: 1
num_attention_heads: 16
num_key_value_heads: 16
head_dim: 72
use_cache: false
attention_dropout: 0.1
mog_head_config:
intermediate_size: 4608
num_layers: 3
low_rank: 64
num_predictions: 1024
min_log_std: -4.0
eps: 1e-06
p_uncond: 0.1
label_smoothing: 0.01
max_training_rate: 0.8
quantizer_dropout: 0.5
random_target_masking: false
exponent: 3.0
trainer:
devices: -1
accelerator: gpu
num_nodes: 1
precision: 32
logger: False # logger provided by exp_manager
enable_checkpointing: False
use_distributed_sampler: False
max_steps: 1000000
val_check_interval: 2000
limit_train_batches: ${trainer.val_check_interval} # an "epoch"
limit_val_batches: 2
log_every_n_steps: 20
num_sanity_val_steps: 0
gradient_clip_val: 1.0
accumulate_grad_batches: 1
strategy:
_target_: lightning.pytorch.strategies.DDPStrategy
gradient_as_bucket_view: true
find_unused_parameters: true

data:
# data loader configs
add_text_bos_and_eos_in_each_turn: true
add_audio_prompt_after_description: true
audio_prompt_duration: 3.0
frame_length: 0.08
source_sample_rate: 22050
target_sample_rate: 22050
input_roles: ["user", "User"]
output_roles: ["agent", "Assistant", "assistant","Agent"]

train_ds:
sample_rate: ${data.target_sample_rate}
input_cfg:
- type: lhotse_shar
shar_path: ???
seed: 42
shard_seed: "randomized"
num_workers: 2
batch_size: 4
# Optional bucketing:
# batch_size: null
# batch_duration: 100
# bucket_duration_bins: [8.94766,10.1551,11.64118,19.30376,42.85]
# use_bucketing: true
# num_buckets: 5
# bucket_buffer_size: 5000

validation_ds:
# The entries under 'datasets' are a list of separate dataloaders.
# The structure is <dataset-name>: {<dataloader-dict-config>}
# They inherit all settings from validation_ds, but can individually override them.
datasets:
val_set_0: # rename to your dataset name, add more as needed
shar_path: ???
sample_rate: ${data.target_sample_rate}
batch_size: 1
seed: 42
shard_seed: "randomized"

exp_manager:
exp_dir: null
explicit_log_dir: duplex_eartts_results/
name: eartts
create_tensorboard_logger: false
create_checkpoint_callback: true
use_datetime_version: true
max_time_per_run: 00:03:50:00

resume_from_checkpoint: null # The path to a checkpoint file to continue the training, restores the whole state including the epoch, step, LR schedulers, apex, etc.
# you need to set these two to True to continue the training
resume_if_exists: true
resume_ignore_no_checkpoint: true

# You may use this section to create a W&B logger
create_wandb_logger: false
wandb_logger_kwargs:
name: development-run
project: duplex_eartts
resume: true

checkpoint_callback_params:
filename: "{step}"
monitor: val_asr_bleu
mode: max
every_n_train_steps: null
every_n_epochs: 1
save_top_k: 1
always_save_nemo: false
101 changes: 101 additions & 0 deletions examples/speechlm2/duplex_eartts_eval.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,101 @@
# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""
Evaluation script for Duplex EARTTS models.

This script computes standard speech evaluation metrics for a given Duplex
EARTTS checkpoint, including Word Error Rate (WER), Character Error Rate (CER),
speaker encoder cosine similarity (SECS), and ASR BLEU score.

The configuration file must define a valid ``validation_ds`` based on a Lhotse
dataset using one of the following dataset formats:
- Duplex S2S standard format
- ``s2s_duplex_overlap_as_s2s_duplex``
- ``lhotse_magpietts_data_as_continuation``

During evaluation, the script saves generated audio samples to
``exp_manager.explicit_log_dir`` as specified in the configuration. For each
utterance, the following audio files may be produced:

- Autoregressive inference output (``*.wav``)
- Teacher-forced output (``*_tf.wav``)
- Ground-truth reference audio (``*_gt.wav``)

Args:
config-path (str): Path to the directory containing the YAML configuration file.
config-name (str): Name of the YAML configuration file.
checkpoint_path (str): Path to the Duplex EARTTS checkpoint file.

Usage:
python duplex_eartts_eval.py \
--config-path=conf/ \
--config-name=duplex_eartts.yaml \
++checkpoint_path=duplex_eartts_results/duplex_eartts/model.ckpt
"""

import os

import torch
from lightning.pytorch import Trainer
from omegaconf import OmegaConf

from nemo.collections.speechlm2 import DataModule, DuplexEARTTSDataset

from nemo.collections.speechlm2.models.duplex_ear_tts import DuplexEARTTS
from nemo.core.config import hydra_runner
from nemo.utils.exp_manager import exp_manager
from nemo.utils.trainer_utils import resolve_trainer_cfg

torch.cuda.set_device(int(os.environ["LOCAL_RANK"]))


@hydra_runner(config_path="conf", config_name="duplex_eartts")
def inference(cfg):
OmegaConf.resolve(cfg)
torch.distributed.init_process_group(backend="nccl")
torch.set_float32_matmul_precision("medium")
torch.backends.cudnn.allow_tf32 = True
trainer = Trainer(**resolve_trainer_cfg(cfg.trainer))
log_dir = exp_manager(trainer, cfg.get("exp_manager", None))
OmegaConf.save(cfg, log_dir / "exp_config.yaml")

with trainer.init_module():
if cfg.get("checkpoint_path", None):
model = DuplexEARTTS.load_from_checkpoint(
cfg.checkpoint_path,
cfg=OmegaConf.to_container(cfg, resolve=True),
)
else:
raise ValueError("For evaluation, you must provide `cfg.checkpoint_path`.")

dataset = DuplexEARTTSDataset(
tokenizer=model.tokenizer,
frame_length=cfg.data.frame_length,
source_sample_rate=cfg.data.source_sample_rate,
target_sample_rate=cfg.data.target_sample_rate,
input_roles=cfg.data.input_roles,
output_roles=cfg.data.output_roles,
add_text_bos_and_eos_in_each_turn=cfg.data.get("add_text_bos_and_eos_in_each_turn", True),
add_audio_prompt=cfg.data.get("add_audio_prompt", True),
audio_prompt_duration=cfg.data.get("audio_prompt_duration", 3),
num_delay_speech_tokens=cfg.model.get("num_delay_speech_tokens", 2),
)
datamodule = DataModule(cfg.data, tokenizer=model.tokenizer, dataset=dataset)

trainer.validate(model, datamodule)


if __name__ == "__main__":
inference()
Loading
Loading