Skip to content

Commit 8df83ac

Browse files
[TTS] MagpieTTS Inference Refactoring (#15178)
* Modularize magpie inference code, move inference code from scripts to example Signed-off-by: subhankar-ghosh <[email protected]> * Modify magpie CI with inference changes Signed-off-by: subhankar-ghosh <[email protected]> * Renaming magpietts inference scripts from magpie to magpietts Signed-off-by: subhankar-ghosh <[email protected]> * infer_batch returns dataclass object Signed-off-by: subhankar-ghosh <[email protected]> * Fixed context embedding without context encoder Signed-off-by: subhankar-ghosh <[email protected]> * Apply isort and black reformatting Signed-off-by: subhankar-ghosh <[email protected]> * Remove unnecessary configurations Removed multiple long manifest configurations from evalset_config.py. Signed-off-by: Subhankar Ghosh <[email protected]> * Removing unused imports Signed-off-by: subhankar-ghosh <[email protected]> * Copilot suggested changes Signed-off-by: subhankar-ghosh <[email protected]> * Copilot suggested changes Signed-off-by: subhankar-ghosh <[email protected]> * Apply isort and black reformatting Signed-off-by: subhankar-ghosh <[email protected]> * Move inference helper modules from examples to tts collection Signed-off-by: subhankar-ghosh <[email protected]> * Apply isort and black reformatting Signed-off-by: subhankar-ghosh <[email protected]> * Review changes Signed-off-by: subhankar-ghosh <[email protected]> * Changes suggested in compute_mean_with_confidence_interval Signed-off-by: subhankar-ghosh <[email protected]> * Apply isort and black reformatting Signed-off-by: subhankar-ghosh <[email protected]> * Linting issue Signed-off-by: subhankar-ghosh <[email protected]> * support multiple voices - baked context embeddings Signed-off-by: subhankar-ghosh <[email protected]> * move evalset_config to json Signed-off-by: subhankar-ghosh <[email protected]> * Apply isort and black reformatting Signed-off-by: subhankar-ghosh <[email protected]> * Modularize magpie inference code, move inference code from scripts to example Signed-off-by: subhankar-ghosh <[email protected]> * Modify magpie CI with inference changes Signed-off-by: subhankar-ghosh <[email protected]> * Renaming magpietts inference scripts from magpie to magpietts Signed-off-by: subhankar-ghosh <[email protected]> * infer_batch returns dataclass object Signed-off-by: subhankar-ghosh <[email protected]> * Fixed context embedding without context encoder Signed-off-by: subhankar-ghosh <[email protected]> * Removing unused imports Signed-off-by: subhankar-ghosh <[email protected]> * Apply isort and black reformatting Signed-off-by: subhankar-ghosh <[email protected]> * Remove unnecessary configurations Removed multiple long manifest configurations from evalset_config.py. Signed-off-by: Subhankar Ghosh <[email protected]> * Copilot suggested changes Signed-off-by: subhankar-ghosh <[email protected]> * Copilot suggested changes Signed-off-by: subhankar-ghosh <[email protected]> * Apply isort and black reformatting Signed-off-by: subhankar-ghosh <[email protected]> * Move inference helper modules from examples to tts collection Signed-off-by: subhankar-ghosh <[email protected]> * Review changes Signed-off-by: subhankar-ghosh <[email protected]> * Apply isort and black reformatting Signed-off-by: subhankar-ghosh <[email protected]> * Changes suggested in compute_mean_with_confidence_interval Signed-off-by: subhankar-ghosh <[email protected]> * Linting issue Signed-off-by: subhankar-ghosh <[email protected]> * Apply isort and black reformatting Signed-off-by: subhankar-ghosh <[email protected]> * support multiple voices - baked context embeddings Signed-off-by: subhankar-ghosh <[email protected]> * move evalset_config to json Signed-off-by: subhankar-ghosh <[email protected]> * Apply isort and black reformatting Signed-off-by: subhankar-ghosh <[email protected]> * simplifying get_baked_context_embeddings_batch Signed-off-by: subhankar-ghosh <[email protected]> * Apply isort and black reformatting Signed-off-by: subhankar-ghosh <[email protected]> * Fix logging Signed-off-by: subhankar-ghosh <[email protected]> * Apply isort and black reformatting Signed-off-by: subhankar-ghosh <[email protected]> * Remove unused imports Signed-off-by: subhankar-ghosh <[email protected]> * logging changes Signed-off-by: subhankar-ghosh <[email protected]> * Changed baked_context_embeddings from tensor to flattened embeddings, print->logging Signed-off-by: subhankar-ghosh <[email protected]> * Removed comments and print statements Signed-off-by: subhankar-ghosh <[email protected]> --------- Signed-off-by: subhankar-ghosh <[email protected]> Signed-off-by: subhankar-ghosh <[email protected]> Signed-off-by: Subhankar Ghosh <[email protected]> Co-authored-by: subhankar-ghosh <[email protected]>
1 parent a599d89 commit 8df83ac

File tree

15 files changed

+1976
-924
lines changed

15 files changed

+1976
-924
lines changed

examples/tts/magpietts_inference.py

Lines changed: 612 additions & 0 deletions
Large diffs are not rendered by default.

nemo/collections/tts/models/__init__.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@
1717
from nemo.collections.tts.models.fastpitch import FastPitchModel
1818
from nemo.collections.tts.models.fastpitch_ssl import FastPitchModel_SSL
1919
from nemo.collections.tts.models.hifigan import HifiGanModel
20-
from nemo.collections.tts.models.magpietts import MagpieTTSModel
20+
from nemo.collections.tts.models.magpietts import InferBatchOutput, MagpieTTSModel
2121
from nemo.collections.tts.models.magpietts_preference_optimization import (
2222
MagpieTTSModelOfflinePO,
2323
MagpieTTSModelOfflinePODataGen,
@@ -41,6 +41,7 @@
4141
"SSLDisentangler",
4242
"GriffinLimModel",
4343
"HifiGanModel",
44+
"InferBatchOutput",
4445
"MelPsuedoInverseModel",
4546
"MixerTTSModel",
4647
"RadTTSModel",

nemo/collections/tts/models/magpietts.py

Lines changed: 241 additions & 32 deletions
Large diffs are not rendered by default.

nemo/collections/tts/models/magpietts_preference_optimization.py

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -98,14 +98,18 @@ def test_step(self, batch, batch_idx):
9898
topk = self.cfg.get('inference_topk', 80)
9999
use_cfg = self.cfg.get('inference_use_cfg', False)
100100
cfg_scale = self.cfg.get('inference_cfg_scale', 1.0)
101-
predicted_audio, predicted_audio_lens, predicted_codes, predicted_codes_lens, _ = self.infer_batch(
101+
output = self.infer_batch(
102102
batch,
103103
max_decoder_steps=self.cfg.get('max_decoder_steps', 500),
104104
temperature=temperature,
105105
topk=topk,
106106
use_cfg=use_cfg,
107107
cfg_scale=cfg_scale,
108108
)
109+
predicted_audio = output.predicted_audio
110+
predicted_audio_lens = output.predicted_audio_lens
111+
predicted_codes = output.predicted_codes
112+
predicted_codes_lens = output.predicted_codes_lens
109113
predicted_audio_paths = []
110114
audio_durations = []
111115
batch_invalid = False
@@ -612,7 +616,7 @@ def generate_and_reward(
612616
use_cfg = random.random() < self.cfg.inference_cfg_prob
613617
cfg_scale = self.cfg.get('inference_cfg_scale', 1.0)
614618

615-
predicted_audio, predicted_audio_lens, predicted_codes, predicted_codes_lens, _ = self.infer_batch(
619+
output = self.infer_batch(
616620
batch_repeated,
617621
max_decoder_steps=self.max_decoder_steps,
618622
temperature=temperature,
@@ -622,6 +626,10 @@ def generate_and_reward(
622626
use_local_transformer_for_inference=use_local_transformer_for_inference,
623627
use_LT_kv_cache=False, # We don't use KV caching for local transformer in GRPO due to issues.
624628
)
629+
predicted_audio = output.predicted_audio
630+
predicted_audio_lens = output.predicted_audio_lens
631+
predicted_codes = output.predicted_codes
632+
predicted_codes_lens = output.predicted_codes_lens
625633
predicted_audio_paths = []
626634
audio_durations = []
627635
for idx in range(predicted_audio.size(0)):
Lines changed: 70 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,70 @@
1+
# Copyright (c) 2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
"""
15+
MagpieTTS inference and evaluation subpackage.
16+
17+
This package provides modular components for:
18+
- Model loading and configuration (utils.py)
19+
- Batch inference (inference.py)
20+
- Audio quality evaluation (evaluation.py)
21+
- Metrics visualization (visualization.py)
22+
23+
Example Usage:
24+
from examples.tts.magpietts import (
25+
InferenceConfig,
26+
MagpieInferenceRunner,
27+
load_magpie_model,
28+
ModelLoadConfig,
29+
)
30+
31+
# Load model
32+
model_config = ModelLoadConfig(
33+
nemo_file="/path/to/model.nemo",
34+
codecmodel_path="/path/to/codec.nemo",
35+
)
36+
model, checkpoint_name = load_magpie_model(model_config)
37+
38+
# Create runner and run inference
39+
inference_config = InferenceConfig(temperature=0.6, topk=80)
40+
runner = MagpieInferenceRunner(model, inference_config)
41+
"""
42+
43+
from nemo.collections.tts.modules.magpietts_inference.evaluation import (
44+
DEFAULT_VIOLIN_METRICS,
45+
STANDARD_METRIC_KEYS,
46+
EvaluationConfig,
47+
compute_mean_with_confidence_interval,
48+
evaluate_generated_audio_dir,
49+
)
50+
from nemo.collections.tts.modules.magpietts_inference.inference import InferenceConfig, MagpieInferenceRunner
51+
from nemo.collections.tts.modules.magpietts_inference.utils import ModelLoadConfig, load_magpie_model
52+
from nemo.collections.tts.modules.magpietts_inference.visualization import create_combined_box_plot, create_violin_plot
53+
54+
__all__ = [
55+
# Utils
56+
"ModelLoadConfig",
57+
"load_magpie_model",
58+
# Inference
59+
"InferenceConfig",
60+
"MagpieInferenceRunner",
61+
# Evaluation
62+
"EvaluationConfig",
63+
"evaluate_generated_audio_dir",
64+
"compute_mean_with_confidence_interval",
65+
"STANDARD_METRIC_KEYS",
66+
"DEFAULT_VIOLIN_METRICS",
67+
# Visualization
68+
"create_violin_plot",
69+
"create_combined_box_plot",
70+
]
Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,8 @@
1+
{
2+
"an4_val_ci": {
3+
"manifest_path": "/home/TestData/an4_dataset/an4_val_context_v1.json",
4+
"audio_dir": "/",
5+
"feature_dir": null
6+
}
7+
}
8+

scripts/magpietts/evaluate_generated_audio.py renamed to nemo/collections/tts/modules/magpietts_inference/evaluate_generated_audio.py

Lines changed: 29 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -16,24 +16,36 @@
1616
"""
1717
import argparse
1818
import json
19-
import logging
2019
import os
2120
import pprint
2221
import string
2322
import tempfile
2423
import time
25-
from contextlib import contextmanager
2624
from functools import partial
25+
from pathlib import Path
2726

2827
import librosa
2928
import numpy as np
30-
import scripts.magpietts.evalset_config as evalset_config
3129
import soundfile as sf
3230
import torch
3331
from transformers import Wav2Vec2FeatureExtractor, WavLMForXVector, WhisperForConditionalGeneration, WhisperProcessor
3432

3533
import nemo.collections.asr as nemo_asr
3634
from nemo.collections.asr.metrics.wer import word_error_rate_detail
35+
from nemo.utils import logging
36+
37+
# Path to evalset config JSON
38+
EVALSET_CONFIG_PATH = Path(__file__).parent / 'evalset_config.json'
39+
40+
41+
def load_evalset_config(config_path: str = None) -> dict:
42+
"""Load dataset meta info from JSON config file."""
43+
if config_path is None:
44+
config_path = EVALSET_CONFIG_PATH
45+
with open(config_path, 'r') as f:
46+
return json.load(f)
47+
48+
3749
from nemo.collections.tts.modules.utmosv2 import UTMOSv2Calculator
3850

3951

@@ -126,31 +138,12 @@ def pad_audio_to_min_length(audio_np: np.ndarray, sampling_rate: int, min_second
126138
min_samples = round(min_seconds * sampling_rate)
127139

128140
if n_samples < min_samples:
129-
print(f"Padding audio from {n_samples/sampling_rate} seconds to {min_samples/sampling_rate} seconds")
141+
logging.info(f"Padding audio from {n_samples/sampling_rate} seconds to {min_samples/sampling_rate} seconds")
130142
padding_needed = min_samples - n_samples
131143
audio_np = np.pad(audio_np, (0, padding_needed), mode='constant', constant_values=0)
132144
return audio_np
133145

134146

135-
@contextmanager
136-
def nemo_log_level(level):
137-
"""
138-
A context manager that temporarily sets the logging level for the NeMo logger
139-
and restores the original level when the context manager is exited.
140-
141-
Args:
142-
level (int): The logging level to set.
143-
"""
144-
logger = logging.getLogger("nemo_logger")
145-
original_level = logger.level
146-
logger.setLevel(level)
147-
try:
148-
yield
149-
finally:
150-
# restore the original level when the context manager is exited (even if an exception was raised)
151-
logger.setLevel(original_level)
152-
153-
154147
def extract_embedding(model, extractor, audio_path, device, sv_model_type):
155148
speech_array, sampling_rate = librosa.load(audio_path, sr=16000)
156149
# pad to 0.5 seconds as the extractor may not be able to handle very short signals
@@ -170,14 +163,14 @@ def extract_embedding(model, extractor, audio_path, device, sv_model_type):
170163

171164

172165
def compute_utmosv2_scores(audio_dir, device):
173-
print(f"\nComputing UTMOSv2 scores for files in {audio_dir}...")
166+
logging.info(f"\nComputing UTMOSv2 scores for files in {audio_dir}...")
174167
start_time = time.time()
175168
utmosv2_calculator = UTMOSv2Calculator(device=device)
176169
utmosv2_scores = utmosv2_calculator.process_directory(audio_dir)
177170
# convert to to a dictionary indexed by file path
178171
utmosv2_scores_dict = {os.path.normpath(item['file_path']): item['predicted_mos'] for item in utmosv2_scores}
179172
end_time = time.time()
180-
print(f"UTMOSv2 scores computed for {len(utmosv2_scores)} files in {end_time - start_time:.2f} seconds\n")
173+
logging.info(f"UTMOSv2 scores computed for {len(utmosv2_scores)} files in {end_time - start_time:.2f} seconds\n")
181174
return utmosv2_scores_dict
182175

183176

@@ -221,12 +214,11 @@ def evaluate(
221214
)
222215
speaker_verification_model = speaker_verification_model.to(device)
223216
speaker_verification_model.eval()
224-
with nemo_log_level(logging.ERROR):
225-
# The model `titanet_small` prints thousands of lines during initialization, so suppress logs temporarily
226-
print("Loading `titanet_small` model...")
227-
speaker_verification_model_alternate = nemo_asr.models.EncDecSpeakerLabelModel.from_pretrained(
228-
model_name='titanet_small'
229-
)
217+
# The model `titanet_small` prints thousands of lines during initialization, so suppress logs temporarily
218+
logging.info("Loading `titanet_small` model...")
219+
speaker_verification_model_alternate = nemo_asr.models.EncDecSpeakerLabelModel.from_pretrained(
220+
model_name='titanet_small'
221+
)
230222
speaker_verification_model_alternate = speaker_verification_model_alternate.to(device)
231223
speaker_verification_model_alternate.eval()
232224

@@ -269,7 +261,7 @@ def evaluate(
269261
)
270262
gt_audio_text = process_text(gt_audio_text)
271263
except Exception as e:
272-
print("Error during ASR: {}".format(e))
264+
logging.info("Error during ASR: {}".format(e))
273265
pred_text = ""
274266
gt_audio_text = ""
275267

@@ -283,10 +275,10 @@ def evaluate(
283275
detailed_cer = word_error_rate_detail(hypotheses=[pred_text], references=[gt_text], use_cer=True)
284276
detailed_wer = word_error_rate_detail(hypotheses=[pred_text], references=[gt_text], use_cer=False)
285277

286-
print("{} GT Text:".format(ridx), gt_text)
287-
print("{} Pr Text:".format(ridx), pred_text)
278+
logging.info(f"{ridx} GT Text: {gt_text}")
279+
logging.info(f"{ridx} Pr Text: {pred_text}")
288280
# Format cer and wer to 2 decimal places
289-
print("CER:", "{:.4f} | WER: {:.4f}".format(detailed_cer[0], detailed_wer[0]))
281+
logging.info("CER:", "{:.4f} | WER: {:.4f}".format(detailed_cer[0], detailed_wer[0]))
290282

291283
pred_texts.append(pred_text)
292284
gt_texts.append(gt_text)
@@ -431,8 +423,8 @@ def main():
431423
args = parser.parse_args()
432424

433425
if args.evalset is not None:
434-
dataset_meta_info = evalset_config.dataset_meta_info
435-
assert args.evalset in dataset_meta_info
426+
dataset_meta_info = load_evalset_config()
427+
assert args.evalset in dataset_meta_info, f"Dataset '{args.evalset}' not found in evalset_config.json"
436428
args.manifest_path = dataset_meta_info[args.evalset]['manifest_path']
437429
args.audio_dir = dataset_meta_info[args.evalset]['audio_dir']
438430

0 commit comments

Comments
 (0)