4242import json
4343import os
4444import shutil
45+ from dataclasses import fields
4546from pathlib import Path
4647from typing import List , Optional , Tuple
4748
4849import numpy as np
4950
5051from nemo .collections .asr .parts .utils .manifest_utils import read_manifest
52+ from nemo .collections .tts .models .magpietts import ModelInferenceParameters
5153from nemo .collections .tts .modules .magpietts_inference .evaluate_generated_audio import load_evalset_config
5254
5355# Import the modular components
6567 load_magpie_model ,
6668)
6769from nemo .collections .tts .modules .magpietts_inference .visualization import create_combined_box_plot , create_violin_plot
70+ from nemo .collections .tts .modules .magpietts_modules import EOSDetectionMethod
6871from nemo .utils import logging
6972
7073
@@ -428,11 +431,20 @@ def create_argument_parser() -> argparse.ArgumentParser:
428431
429432 # Inference arguments
430433 infer_group = parser .add_argument_group ('Inference Parameters' )
431- infer_group .add_argument ('--temperature' , type = float , default = 0.6 )
432- infer_group .add_argument ('--topk' , type = int , default = 80 )
434+ # Add model specific parameters
435+ for field in fields (ModelInferenceParameters ):
436+ extra_args = {"type" : field .type }
437+ if field .type == bool :
438+ extra_args ["action" ] = "store_true"
439+ del extra_args ["type" ]
440+ if field .name == "estimate_alignment_from_layers" or field .name == "apply_prior_to_layers" :
441+ extra_args ["help" ] = "Must be a comma separate string. Not enclosed in brackets"
442+ extra_args ["type" ] = str
443+ elif field .name == "eos_detection_method" :
444+ extra_args ["choices" ] = [m .value for m in EOSDetectionMethod ]
445+ infer_group .add_argument (f"--{ field .name } " , ** extra_args )
433446 infer_group .add_argument ('--batch_size' , type = int , default = 32 )
434447 infer_group .add_argument ('--use_cfg' , action = 'store_true' , help = 'Enable classifier-free guidance' )
435- infer_group .add_argument ('--cfg_scale' , type = float , default = 2.5 )
436448 infer_group .add_argument (
437449 '--longform_mode' ,
438450 type = str ,
@@ -453,54 +465,17 @@ def create_argument_parser() -> argparse.ArgumentParser:
453465 help = 'Maximum decoder steps for longform inference' ,
454466 )
455467
456- # Attention prior arguments
457- prior_group = parser .add_argument_group ('Attention Prior' )
458- prior_group .add_argument ('--apply_attention_prior' , action = 'store_true' )
459- prior_group .add_argument ('--attention_prior_epsilon' , type = float , default = 0.1 )
460- prior_group .add_argument ('--attention_prior_lookahead_window' , type = int , default = 5 )
461- prior_group .add_argument (
462- '--estimate_alignment_from_layers' ,
463- type = str ,
464- default = None ,
465- help = 'Comma-separated layer indices for alignment estimation' ,
466- )
467- prior_group .add_argument (
468- '--apply_prior_to_layers' ,
469- type = str ,
470- default = None ,
471- help = 'Comma-separated layer indices to apply prior' ,
472- )
473- prior_group .add_argument ('--start_prior_after_n_audio_steps' , type = int , default = 0 )
474-
475468 # Local transformer / MaskGit arguments
476- lt_group = parser .add_argument_group ('Local Transformer / MaskGit' )
477- lt_group .add_argument ('--use_local_transformer' , action = 'store_true' )
478- lt_group .add_argument ('--maskgit_n_steps' , type = int , default = 3 )
479- lt_group .add_argument ('--maskgit_noise_scale' , type = float , default = 0.0 )
480- lt_group .add_argument ('--maskgit_fixed_schedule' , type = int , nargs = '+' , default = None )
481- lt_group .add_argument (
469+ infer_group .add_argument ('--use_local_transformer' , action = 'store_true' )
470+ infer_group .add_argument ('--maskgit_n_steps' , type = int , default = 3 )
471+ infer_group .add_argument ('--maskgit_noise_scale' , type = float , default = 0.0 )
472+ infer_group .add_argument ('--maskgit_fixed_schedule' , type = int , nargs = '+' , default = None )
473+ infer_group .add_argument (
482474 '--maskgit_sampling_type' ,
483475 default = None ,
484476 choices = ["default" , "causal" , "purity_causal" , "purity_default" ],
485477 )
486478
487- # EOS detection
488- eos_group = parser .add_argument_group ('EOS Detection' )
489- eos_group .add_argument (
490- '--eos_detection_method' ,
491- type = str ,
492- default = "argmax_or_multinomial_any" ,
493- choices = [
494- "argmax_any" ,
495- "argmax_or_multinomial_any" ,
496- "argmax_all" ,
497- "argmax_or_multinomial_all" ,
498- "argmax_zero_cb" ,
499- "argmax_or_multinomial_zero_cb" ,
500- ],
501- )
502- eos_group .add_argument ('--ignore_finished_sentence_tracking' , action = 'store_true' )
503-
504479 # Evaluation arguments
505480 eval_group = parser .add_argument_group ('Evaluation' )
506481 eval_group .add_argument (
@@ -549,7 +524,7 @@ def main():
549524 has_nemo_mode = args .nemo_files is not None and args .nemo_files != "null"
550525
551526 if not has_checkpoint_mode and not has_nemo_mode :
552- parser .error ("You must provide either:\n " " 1. --hparams_files and --checkpoint_files\n " " 2. --nemo_files" )
527+ parser .error ("You must provide either:\n 1. --hparams_files and --checkpoint_files\n 2. --nemo_files" )
553528
554529 # Build configurations
555530 # Use higher max_decoder_steps for longform inference when mode is 'always'
@@ -560,29 +535,31 @@ def main():
560535 max_decoder_steps = args .longform_max_decoder_steps
561536 else : # 'never'
562537 max_decoder_steps = 440
538+ model_inference_parameters = {}
539+ for field in fields (ModelInferenceParameters ):
540+ field = field .name
541+ if field == "max_decoder_steps" :
542+ model_inference_parameters [field ] = max_decoder_steps
543+ continue
544+ arg_from_cmdline = vars (args )[field ]
545+ if arg_from_cmdline is not None :
546+ if field in ["estimate_alignment_from_layers" , "apply_prior_to_layers" ]:
547+ model_inference_parameters [field ] = parse_layer_list (vars (args )[field ])
548+ else :
549+ model_inference_parameters [field ] = vars (args )[field ]
563550
564551 inference_config = InferenceConfig (
565- temperature = args .temperature ,
566- topk = args .topk ,
552+ model_inference_parameters = ModelInferenceParameters .from_dict (model_inference_parameters ),
567553 batch_size = args .batch_size ,
568554 use_cfg = args .use_cfg ,
569- cfg_scale = args .cfg_scale ,
570- max_decoder_steps = max_decoder_steps ,
571555 apply_attention_prior = args .apply_attention_prior ,
572- attention_prior_epsilon = args .attention_prior_epsilon ,
573- attention_prior_lookahead_window = args .attention_prior_lookahead_window ,
574- estimate_alignment_from_layers = parse_layer_list (args .estimate_alignment_from_layers ),
575- apply_prior_to_layers = parse_layer_list (args .apply_prior_to_layers ),
576- start_prior_after_n_audio_steps = args .start_prior_after_n_audio_steps ,
577556 use_local_transformer = args .use_local_transformer ,
578557 maskgit_n_steps = args .maskgit_n_steps ,
579- longform_mode = args .longform_mode ,
580- longform_word_threshold = args .longform_word_threshold ,
581558 maskgit_noise_scale = args .maskgit_noise_scale ,
582559 maskgit_fixed_schedule = args .maskgit_fixed_schedule ,
583560 maskgit_sampling_type = args .maskgit_sampling_type ,
584- eos_detection_method = args .eos_detection_method ,
585- ignore_finished_sentence_tracking = args .ignore_finished_sentence_tracking ,
561+ longform_mode = args .longform_mode ,
562+ longform_word_threshold = args .longform_word_threshold ,
586563 )
587564
588565 eval_config = EvaluationConfig (
0 commit comments