Skip to content

Commit 72e2bb0

Browse files
authored
Update MagpieTTS' Inference Parameter Configuration (#15254)
* move inference params to checkpoint and make do_tts apply prior Signed-off-by: Jason <jasoli@nvidia.com> * Apply isort and black reformatting Signed-off-by: blisc <blisc@users.noreply.github.com> * Enable LT in do_tts; add docstrings Signed-off-by: Jason <jasoli@nvidia.com> * update defaults; merge inference dataclasses Signed-off-by: Jason <jasoli@nvidia.com> * Apply isort and black reformatting Signed-off-by: blisc <blisc@users.noreply.github.com> * update epsilon value Signed-off-by: Jason <jasoli@nvidia.com> * add defaults for inference; fix longform mode; fix bug introduced in #15212 Signed-off-by: Jason <jasoli@nvidia.com> * Apply isort and black reformatting Signed-off-by: blisc <blisc@users.noreply.github.com> * fix field key usage Signed-off-by: Jason <jasoli@nvidia.com> --------- Signed-off-by: Jason <jasoli@nvidia.com> Signed-off-by: blisc <blisc@users.noreply.github.com> Co-authored-by: blisc <blisc@users.noreply.github.com>
1 parent c6bfb27 commit 72e2bb0

File tree

5 files changed

+189
-237
lines changed

5 files changed

+189
-237
lines changed

examples/tts/magpietts_inference.py

Lines changed: 36 additions & 59 deletions
Original file line numberDiff line numberDiff line change
@@ -42,12 +42,14 @@
4242
import json
4343
import os
4444
import shutil
45+
from dataclasses import fields
4546
from pathlib import Path
4647
from typing import List, Optional, Tuple
4748

4849
import numpy as np
4950

5051
from nemo.collections.asr.parts.utils.manifest_utils import read_manifest
52+
from nemo.collections.tts.models.magpietts import ModelInferenceParameters
5153
from nemo.collections.tts.modules.magpietts_inference.evaluate_generated_audio import load_evalset_config
5254

5355
# Import the modular components
@@ -65,6 +67,7 @@
6567
load_magpie_model,
6668
)
6769
from nemo.collections.tts.modules.magpietts_inference.visualization import create_combined_box_plot, create_violin_plot
70+
from nemo.collections.tts.modules.magpietts_modules import EOSDetectionMethod
6871
from nemo.utils import logging
6972

7073

@@ -428,11 +431,20 @@ def create_argument_parser() -> argparse.ArgumentParser:
428431

429432
# Inference arguments
430433
infer_group = parser.add_argument_group('Inference Parameters')
431-
infer_group.add_argument('--temperature', type=float, default=0.6)
432-
infer_group.add_argument('--topk', type=int, default=80)
434+
# Add model specific parameters
435+
for field in fields(ModelInferenceParameters):
436+
extra_args = {"type": field.type}
437+
if field.type == bool:
438+
extra_args["action"] = "store_true"
439+
del extra_args["type"]
440+
if field.name == "estimate_alignment_from_layers" or field.name == "apply_prior_to_layers":
441+
extra_args["help"] = "Must be a comma separate string. Not enclosed in brackets"
442+
extra_args["type"] = str
443+
elif field.name == "eos_detection_method":
444+
extra_args["choices"] = [m.value for m in EOSDetectionMethod]
445+
infer_group.add_argument(f"--{field.name}", **extra_args)
433446
infer_group.add_argument('--batch_size', type=int, default=32)
434447
infer_group.add_argument('--use_cfg', action='store_true', help='Enable classifier-free guidance')
435-
infer_group.add_argument('--cfg_scale', type=float, default=2.5)
436448
infer_group.add_argument(
437449
'--longform_mode',
438450
type=str,
@@ -453,54 +465,17 @@ def create_argument_parser() -> argparse.ArgumentParser:
453465
help='Maximum decoder steps for longform inference',
454466
)
455467

456-
# Attention prior arguments
457-
prior_group = parser.add_argument_group('Attention Prior')
458-
prior_group.add_argument('--apply_attention_prior', action='store_true')
459-
prior_group.add_argument('--attention_prior_epsilon', type=float, default=0.1)
460-
prior_group.add_argument('--attention_prior_lookahead_window', type=int, default=5)
461-
prior_group.add_argument(
462-
'--estimate_alignment_from_layers',
463-
type=str,
464-
default=None,
465-
help='Comma-separated layer indices for alignment estimation',
466-
)
467-
prior_group.add_argument(
468-
'--apply_prior_to_layers',
469-
type=str,
470-
default=None,
471-
help='Comma-separated layer indices to apply prior',
472-
)
473-
prior_group.add_argument('--start_prior_after_n_audio_steps', type=int, default=0)
474-
475468
# Local transformer / MaskGit arguments
476-
lt_group = parser.add_argument_group('Local Transformer / MaskGit')
477-
lt_group.add_argument('--use_local_transformer', action='store_true')
478-
lt_group.add_argument('--maskgit_n_steps', type=int, default=3)
479-
lt_group.add_argument('--maskgit_noise_scale', type=float, default=0.0)
480-
lt_group.add_argument('--maskgit_fixed_schedule', type=int, nargs='+', default=None)
481-
lt_group.add_argument(
469+
infer_group.add_argument('--use_local_transformer', action='store_true')
470+
infer_group.add_argument('--maskgit_n_steps', type=int, default=3)
471+
infer_group.add_argument('--maskgit_noise_scale', type=float, default=0.0)
472+
infer_group.add_argument('--maskgit_fixed_schedule', type=int, nargs='+', default=None)
473+
infer_group.add_argument(
482474
'--maskgit_sampling_type',
483475
default=None,
484476
choices=["default", "causal", "purity_causal", "purity_default"],
485477
)
486478

487-
# EOS detection
488-
eos_group = parser.add_argument_group('EOS Detection')
489-
eos_group.add_argument(
490-
'--eos_detection_method',
491-
type=str,
492-
default="argmax_or_multinomial_any",
493-
choices=[
494-
"argmax_any",
495-
"argmax_or_multinomial_any",
496-
"argmax_all",
497-
"argmax_or_multinomial_all",
498-
"argmax_zero_cb",
499-
"argmax_or_multinomial_zero_cb",
500-
],
501-
)
502-
eos_group.add_argument('--ignore_finished_sentence_tracking', action='store_true')
503-
504479
# Evaluation arguments
505480
eval_group = parser.add_argument_group('Evaluation')
506481
eval_group.add_argument(
@@ -549,7 +524,7 @@ def main():
549524
has_nemo_mode = args.nemo_files is not None and args.nemo_files != "null"
550525

551526
if not has_checkpoint_mode and not has_nemo_mode:
552-
parser.error("You must provide either:\n" " 1. --hparams_files and --checkpoint_files\n" " 2. --nemo_files")
527+
parser.error("You must provide either:\n 1. --hparams_files and --checkpoint_files\n 2. --nemo_files")
553528

554529
# Build configurations
555530
# Use higher max_decoder_steps for longform inference when mode is 'always'
@@ -560,29 +535,31 @@ def main():
560535
max_decoder_steps = args.longform_max_decoder_steps
561536
else: # 'never'
562537
max_decoder_steps = 440
538+
model_inference_parameters = {}
539+
for field in fields(ModelInferenceParameters):
540+
field = field.name
541+
if field == "max_decoder_steps":
542+
model_inference_parameters[field] = max_decoder_steps
543+
continue
544+
arg_from_cmdline = vars(args)[field]
545+
if arg_from_cmdline is not None:
546+
if field in ["estimate_alignment_from_layers", "apply_prior_to_layers"]:
547+
model_inference_parameters[field] = parse_layer_list(vars(args)[field])
548+
else:
549+
model_inference_parameters[field] = vars(args)[field]
563550

564551
inference_config = InferenceConfig(
565-
temperature=args.temperature,
566-
topk=args.topk,
552+
model_inference_parameters=ModelInferenceParameters.from_dict(model_inference_parameters),
567553
batch_size=args.batch_size,
568554
use_cfg=args.use_cfg,
569-
cfg_scale=args.cfg_scale,
570-
max_decoder_steps=max_decoder_steps,
571555
apply_attention_prior=args.apply_attention_prior,
572-
attention_prior_epsilon=args.attention_prior_epsilon,
573-
attention_prior_lookahead_window=args.attention_prior_lookahead_window,
574-
estimate_alignment_from_layers=parse_layer_list(args.estimate_alignment_from_layers),
575-
apply_prior_to_layers=parse_layer_list(args.apply_prior_to_layers),
576-
start_prior_after_n_audio_steps=args.start_prior_after_n_audio_steps,
577556
use_local_transformer=args.use_local_transformer,
578557
maskgit_n_steps=args.maskgit_n_steps,
579-
longform_mode=args.longform_mode,
580-
longform_word_threshold=args.longform_word_threshold,
581558
maskgit_noise_scale=args.maskgit_noise_scale,
582559
maskgit_fixed_schedule=args.maskgit_fixed_schedule,
583560
maskgit_sampling_type=args.maskgit_sampling_type,
584-
eos_detection_method=args.eos_detection_method,
585-
ignore_finished_sentence_tracking=args.ignore_finished_sentence_tracking,
561+
longform_mode=args.longform_mode,
562+
longform_word_threshold=args.longform_word_threshold,
586563
)
587564

588565
eval_config = EvaluationConfig(

0 commit comments

Comments
 (0)