-
Notifications
You must be signed in to change notification settings - Fork 3.3k
Implement Nemotron-VoiceChat Speech Decoder #15066
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Merged
Merged
Changes from all commits
Commits
Show all changes
102 commits
Select commit
Hold shift + click to select a range
d48bd1c
Add Duplex EARTTS modules
Edresson e3de872
Add EARTTS codec and extra missing modules
Edresson 3742e35
Apply isort and black reformatting
Edresson 326afc3
Add set_init_inputs and get_init_input methods
Edresson 34e0413
Apply isort and black reformatting
Edresson ec72055
Add from config codec instanciation
Edresson 3df84cc
Apply isort and black reformatting
Edresson 81513bf
Remove unused imports
Edresson fc47181
Apply isort and black reformatting
Edresson fd9893f
Fix pylint issues
Edresson d4b61f0
Fix code scanning issues
Edresson 2d10b4e
Add missing copyright
Edresson 8dc232b
Apply isort and black reformatting
Edresson 626b328
Fix code scanning issues
Edresson 8929736
Fix merge issues
Edresson 2f6ae33
Apply isort and black reformatting
Edresson 3297d0c
Remove EARTTS configs and use directly DictConfig
Edresson c58ebd0
Apply isort and black reformatting
Edresson 13a3998
Update
Edresson 9e8a721
Implement EARTTS unit tests
Edresson 1c6af8d
Apply isort and black reformatting
Edresson ea050e8
Add incremental decoding and unit test for it
Edresson 814bc9e
Apply isort and black reformatting
Edresson 5c4f33b
Add option to run codec in bf16 to speedup
Edresson e1cf144
Add docs
Edresson b30bbdb
Apply isort and black reformatting
Edresson 433823e
rename codec context manager precision function
Edresson d9e74ab
Rename init_model_from_another_checkpoint to restore_from_pretrained_…
Edresson 89fcb32
Replace torchaudio with librosa
Edresson 10c95ff
Apply isort and black reformatting
Edresson 4a451d0
Replace torchaudio with librosa on codec
Edresson ff662ed
Apply isort and black reformatting
Edresson 36b09f8
Add sensitive_layers parameter
Edresson 54b5418
Apply isort and black reformatting
Edresson 71be84a
Remove torchaudio from metrics
Edresson 65c084d
Update lhotse formmaters
Edresson 4272968
Apply isort and black reformatting
Edresson 7b71f41
Docs for set_model_dict_for_partial_init
Edresson e6d5977
Apply isort and black reformatting
Edresson 41ce58a
Make codec sil tokens a buffer
Edresson 2aa4e3b
Apply isort and black reformatting
Edresson 1c3d1a6
Fix lint
Edresson fbf853f
Add CER/WER metrics unit test
Edresson 4fc268d
Apply isort and black reformatting
Edresson a124eab
Disable triton if cuda is not available
Edresson 31f13f9
Apply isort and black reformatting
Edresson d578125
Update codec run dtype
Edresson 853d948
Add EOS dropout
Edresson 53569ad
Apply isort and black reformatting
Edresson 5452175
Fix Bleu
Edresson db2bdf9
Cleanup tests useless comments
Edresson a802460
Rename EARTTS files
Edresson da2cb92
Update EARTTS dataset docs
Edresson 8ca4db4
Remove mixed precision fns
Edresson 45659e4
Remove system prompt
Edresson 544eee0
Apply isort and black reformatting
Edresson d32f4ef
Remove custom test sentence inference logic from the model
Edresson e12b192
Remove .nemo file loading support
Edresson 1cfd4f3
Rename speaker_reference with audio_prompt
Edresson 2caec09
Modularize dataloader get_item
Edresson c7d7440
Apply isort and black reformatting
Edresson 96c3ddb
Update docs
Edresson f9e1066
Update EARTTS documentation
Edresson 12625bc
Rename eval script
Edresson 4c89963
Apply isort and black reformatting
Edresson 0d67a58
Rename intellibility metric
Edresson ddda743
Apply isort and black reformatting
Edresson 9e2b376
Add WER metric back
Edresson 6951357
Apply isort and black reformatting
Edresson fa33c5b
Do not share small embeddings
Edresson 9142d70
Remove unused imports
Edresson fd517a0
Reuse TTS get_mask_from_lengths
Edresson dc1901e
Apply isort and black reformatting
Edresson 10df97f
Update rescale_state_dict
Edresson 8498ea7
Add missing dataset docs
Edresson b887a77
Remove Pretrained class and ear_tts_commons.py
Edresson 926e179
Apply isort and black reformatting
Edresson be7f689
Add top-level comment on Logger
Edresson 912ef24
Apply isort and black reformatting
Edresson 7ef91db
Refactor checkpoint loading
Edresson be12a72
Apply isort and black reformatting
Edresson 7e7e77b
Add EOS dropout and duplication
Edresson 8beaf8b
Apply isort and black reformatting
Edresson bc3eb8c
Remove debug dataloader code
Edresson b71490a
Remove unecessary operations on n samples_per_frame computation
Edresson ca277ab
Ignore sample on sample_audio_segments_repeat when audio is shorter
Edresson bbdf250
Move data utils to the end of the file
Edresson 6b9966c
Remove duplicated code
Edresson 5efbadf
Rename input_text_tokens to target_text_tokens
Edresson d316cc9
Move the bos/eos/pad token definition to AutoTokenizer and reuse the …
Edresson b14ef66
Apply isort and black reformatting
chtruong814 8aa65ef
Add docs and config for duplex EARTTS evaluation and set bos eos pad …
Edresson 3ca5e28
Update docs
Edresson dfa7f16
Add extra unit tests
Edresson 6460486
Apply isort and black reformatting
Edresson 19705d1
Use CI cached path for Duplex EARTTS tests
Edresson edf8e14
Apply isort and black reformatting
Edresson 421d15e
Fix eartts tests
Edresson f55fad0
Update CI nanov2 path
Edresson e204d7c
Apply isort and black reformatting
Edresson b51de02
Fix eartts dataset unittest
Edresson befea32
Fix unit test
Edresson File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,200 @@ | ||
| model: | ||
| pretrained_lm_name: "nvidia/NVIDIA-Nemotron-Nano-9B-v2" | ||
| pretrained_audio_codec: ??? # to be released | ||
| pretrained_tts_model: null | ||
| scoring_asr: stt_en_fastconformer_transducer_large # used only in validation/evaluation | ||
|
|
||
| # Regexp (re.compile) patterns matching parameters to be frozen. | ||
| freeze_params: | ||
| - "^audio_codec\\..+$" # Keep audio codec frozen as it only provides supervision for training. | ||
| - "^embed_tokens\\..+$" # Keep embed_tokens frozen as done in eartts | ||
|
|
||
| prevent_freeze_params: [] # Use to make specific submodules trainable; overrides freeze_params | ||
|
|
||
| # set custom text eos/bos/pad tokens | ||
| bos_token: "<s>" | ||
| eos_token: "</s>" | ||
| pad_token: "<SPECIAL_12>" | ||
|
|
||
| # inference params | ||
| inference_guidance_scale: 0.5 | ||
| inference_noise_scale: 0.8 | ||
| inference_top_p_or_k: 0.8 | ||
| inference_guidance_enabled: true | ||
|
|
||
|
|
||
| optimizer: | ||
| _target_: torch.optim.AdamW | ||
| lr: 4e-05 | ||
| betas: [0.9, 0.98] | ||
| weight_decay: 0 | ||
| foreach: true # set to false if having issues with tensor-parallelism | ||
|
|
||
| lr_scheduler: | ||
| _target_: nemo.core.optim.lr_scheduler.InverseSquareRootAnnealing | ||
| warmup_steps: 2500 | ||
| min_lr: 1e-6 | ||
| max_steps: ${trainer.max_steps} | ||
|
|
||
| codec_config: | ||
| latent_size: 512 | ||
| n_fft: 16 | ||
| hop_length: 4 | ||
| base_hidden_size: 384 | ||
| channel_mult: | ||
| - 1 | ||
| - 2 | ||
| - 4 | ||
| rates: | ||
| - 7 | ||
| - 7 | ||
| - 9 | ||
| num_blocks: 3 | ||
| kernel_size: 7 | ||
| groups: 1 | ||
| codebook_size: 1024 | ||
| num_quantizers: 31 | ||
| wav_to_token_ratio: 1764 | ||
|
|
||
| tts_config: | ||
| # extra configs added | ||
| use_gated_fusion_for_text_audio: true | ||
| disable_eos_prediction: true # disable eos prediction | ||
| use_bos_eos_emb: true | ||
| use_subword_flag_emb: true | ||
| num_delay_speech_tokens: 2 | ||
| # EAR-TTS configs | ||
| backbone_type: gemma3_text | ||
| backbone_model_class: null | ||
| backbone_config_class: null | ||
| backbone_config: | ||
| hidden_size: 1152 | ||
| intermediate_size: 4608 | ||
| num_hidden_layers: 28 | ||
| num_attention_heads: 16 | ||
| num_key_value_heads: 16 | ||
| head_dim: 72 | ||
| attention_dropout: 0.1 | ||
| use_cache: false | ||
| latent_size: 512 | ||
| codebook_size: 1024 | ||
| num_quantizers: 31 | ||
| context_hidden_size: null | ||
| cas_config: | ||
| backbone_type: t5gemma | ||
| backbone_model_class: null | ||
| backbone_config_class: null | ||
| backbone_config: | ||
| is_encoder_decoder: false | ||
| encoder: | ||
| hidden_size: 1152 | ||
| intermediate_size: 4608 | ||
| num_hidden_layers: 1 | ||
| num_attention_heads: 16 | ||
| num_key_value_heads: 16 | ||
| head_dim: 72 | ||
| use_cache: false | ||
| attention_dropout: 0.1 | ||
| mog_head_config: | ||
| intermediate_size: 4608 | ||
| num_layers: 3 | ||
| low_rank: 64 | ||
| num_predictions: 1024 | ||
| min_log_std: -4.0 | ||
| eps: 1e-06 | ||
| p_uncond: 0.1 | ||
| label_smoothing: 0.01 | ||
| max_training_rate: 0.8 | ||
| quantizer_dropout: 0.5 | ||
| random_target_masking: false | ||
| exponent: 3.0 | ||
| trainer: | ||
| devices: -1 | ||
| accelerator: gpu | ||
| num_nodes: 1 | ||
| precision: 32 | ||
| logger: False # logger provided by exp_manager | ||
| enable_checkpointing: False | ||
| use_distributed_sampler: False | ||
| max_steps: 1000000 | ||
| val_check_interval: 2000 | ||
| limit_train_batches: ${trainer.val_check_interval} # an "epoch" | ||
| limit_val_batches: 2 | ||
| log_every_n_steps: 20 | ||
| num_sanity_val_steps: 0 | ||
| gradient_clip_val: 1.0 | ||
| accumulate_grad_batches: 1 | ||
| strategy: | ||
| _target_: lightning.pytorch.strategies.DDPStrategy | ||
| gradient_as_bucket_view: true | ||
| find_unused_parameters: true | ||
|
|
||
| data: | ||
| # data loader configs | ||
| add_text_bos_and_eos_in_each_turn: true | ||
| add_audio_prompt_after_description: true | ||
| audio_prompt_duration: 3.0 | ||
| frame_length: 0.08 | ||
| source_sample_rate: 22050 | ||
| target_sample_rate: 22050 | ||
| input_roles: ["user", "User"] | ||
| output_roles: ["agent", "Assistant", "assistant","Agent"] | ||
|
|
||
| train_ds: | ||
| sample_rate: ${data.target_sample_rate} | ||
| input_cfg: | ||
| - type: lhotse_shar | ||
| shar_path: ??? | ||
| seed: 42 | ||
| shard_seed: "randomized" | ||
| num_workers: 2 | ||
| batch_size: 4 | ||
| # Optional bucketing: | ||
| # batch_size: null | ||
| # batch_duration: 100 | ||
| # bucket_duration_bins: [8.94766,10.1551,11.64118,19.30376,42.85] | ||
| # use_bucketing: true | ||
| # num_buckets: 5 | ||
| # bucket_buffer_size: 5000 | ||
|
|
||
| validation_ds: | ||
| # The entries under 'datasets' are a list of separate dataloaders. | ||
| # The structure is <dataset-name>: {<dataloader-dict-config>} | ||
| # They inherit all settings from validation_ds, but can individually override them. | ||
| datasets: | ||
| val_set_0: # rename to your dataset name, add more as needed | ||
| shar_path: ??? | ||
| sample_rate: ${data.target_sample_rate} | ||
| batch_size: 1 | ||
| seed: 42 | ||
| shard_seed: "randomized" | ||
|
|
||
| exp_manager: | ||
| exp_dir: null | ||
| explicit_log_dir: duplex_eartts_results/ | ||
| name: eartts | ||
| create_tensorboard_logger: false | ||
| create_checkpoint_callback: true | ||
| use_datetime_version: true | ||
| max_time_per_run: 00:03:50:00 | ||
|
|
||
| resume_from_checkpoint: null # The path to a checkpoint file to continue the training, restores the whole state including the epoch, step, LR schedulers, apex, etc. | ||
| # you need to set these two to True to continue the training | ||
| resume_if_exists: true | ||
| resume_ignore_no_checkpoint: true | ||
|
|
||
| # You may use this section to create a W&B logger | ||
| create_wandb_logger: false | ||
| wandb_logger_kwargs: | ||
| name: development-run | ||
| project: duplex_eartts | ||
| resume: true | ||
|
|
||
| checkpoint_callback_params: | ||
| filename: "{step}" | ||
| monitor: val_asr_bleu | ||
| mode: max | ||
| every_n_train_steps: null | ||
| every_n_epochs: 1 | ||
| save_top_k: 1 | ||
| always_save_nemo: false |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,101 @@ | ||
| # Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. | ||
| # | ||
| # Licensed under the Apache License, Version 2.0 (the "License"); | ||
| # you may not use this file except in compliance with the License. | ||
| # You may obtain a copy of the License at | ||
| # | ||
| # http://www.apache.org/licenses/LICENSE-2.0 | ||
| # | ||
| # Unless required by applicable law or agreed to in writing, software | ||
| # distributed under the License is distributed on an "AS IS" BASIS, | ||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
| # See the License for the specific language governing permissions and | ||
| # limitations under the License. | ||
|
|
||
| """ | ||
| Evaluation script for Duplex EARTTS models. | ||
|
|
||
| This script computes standard speech evaluation metrics for a given Duplex | ||
| EARTTS checkpoint, including Word Error Rate (WER), Character Error Rate (CER), | ||
| speaker encoder cosine similarity (SECS), and ASR BLEU score. | ||
|
|
||
| The configuration file must define a valid ``validation_ds`` based on a Lhotse | ||
| dataset using one of the following dataset formats: | ||
| - Duplex S2S standard format | ||
| - ``s2s_duplex_overlap_as_s2s_duplex`` | ||
| - ``lhotse_magpietts_data_as_continuation`` | ||
|
|
||
| During evaluation, the script saves generated audio samples to | ||
| ``exp_manager.explicit_log_dir`` as specified in the configuration. For each | ||
| utterance, the following audio files may be produced: | ||
|
|
||
| - Autoregressive inference output (``*.wav``) | ||
| - Teacher-forced output (``*_tf.wav``) | ||
| - Ground-truth reference audio (``*_gt.wav``) | ||
|
|
||
| Args: | ||
| config-path (str): Path to the directory containing the YAML configuration file. | ||
| config-name (str): Name of the YAML configuration file. | ||
| checkpoint_path (str): Path to the Duplex EARTTS checkpoint file. | ||
|
|
||
| Usage: | ||
| python duplex_eartts_eval.py \ | ||
| --config-path=conf/ \ | ||
| --config-name=duplex_eartts.yaml \ | ||
| ++checkpoint_path=duplex_eartts_results/duplex_eartts/model.ckpt | ||
| """ | ||
|
|
||
| import os | ||
|
|
||
| import torch | ||
| from lightning.pytorch import Trainer | ||
| from omegaconf import OmegaConf | ||
|
|
||
| from nemo.collections.speechlm2 import DataModule, DuplexEARTTSDataset | ||
|
|
||
| from nemo.collections.speechlm2.models.duplex_ear_tts import DuplexEARTTS | ||
| from nemo.core.config import hydra_runner | ||
| from nemo.utils.exp_manager import exp_manager | ||
| from nemo.utils.trainer_utils import resolve_trainer_cfg | ||
|
|
||
| torch.cuda.set_device(int(os.environ["LOCAL_RANK"])) | ||
Edresson marked this conversation as resolved.
Show resolved
Hide resolved
|
||
|
|
||
|
|
||
| @hydra_runner(config_path="conf", config_name="duplex_eartts") | ||
| def inference(cfg): | ||
| OmegaConf.resolve(cfg) | ||
| torch.distributed.init_process_group(backend="nccl") | ||
| torch.set_float32_matmul_precision("medium") | ||
| torch.backends.cudnn.allow_tf32 = True | ||
Edresson marked this conversation as resolved.
Show resolved
Hide resolved
|
||
| trainer = Trainer(**resolve_trainer_cfg(cfg.trainer)) | ||
| log_dir = exp_manager(trainer, cfg.get("exp_manager", None)) | ||
| OmegaConf.save(cfg, log_dir / "exp_config.yaml") | ||
|
|
||
| with trainer.init_module(): | ||
| if cfg.get("checkpoint_path", None): | ||
| model = DuplexEARTTS.load_from_checkpoint( | ||
| cfg.checkpoint_path, | ||
| cfg=OmegaConf.to_container(cfg, resolve=True), | ||
| ) | ||
| else: | ||
| raise ValueError("For evaluation, you must provide `cfg.checkpoint_path`.") | ||
|
|
||
| dataset = DuplexEARTTSDataset( | ||
| tokenizer=model.tokenizer, | ||
| frame_length=cfg.data.frame_length, | ||
| source_sample_rate=cfg.data.source_sample_rate, | ||
| target_sample_rate=cfg.data.target_sample_rate, | ||
| input_roles=cfg.data.input_roles, | ||
| output_roles=cfg.data.output_roles, | ||
| add_text_bos_and_eos_in_each_turn=cfg.data.get("add_text_bos_and_eos_in_each_turn", True), | ||
| add_audio_prompt=cfg.data.get("add_audio_prompt", True), | ||
| audio_prompt_duration=cfg.data.get("audio_prompt_duration", 3), | ||
| num_delay_speech_tokens=cfg.model.get("num_delay_speech_tokens", 2), | ||
| ) | ||
| datamodule = DataModule(cfg.data, tokenizer=model.tokenizer, dataset=dataset) | ||
|
|
||
| trainer.validate(model, datamodule) | ||
Edresson marked this conversation as resolved.
Show resolved
Hide resolved
|
||
|
|
||
|
|
||
| if __name__ == "__main__": | ||
| inference() | ||
Oops, something went wrong.
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Uh oh!
There was an error while loading. Please reload this page.