Skip to content

Commit 8f81e82

Browse files
rfejginAkCodes23
authored andcommitted
[TTS] MagpieTTS inference: Add command line option to select a subset of datasets to run inference on (NVIDIA-NeMo#15212)
* Added datasets filtering to the inference script New command line argument: --datasets <dataset1,dataset2,...> where dataset1, dataset2, ... are the names datasets to process in the datasets_json_path file. If not specified, all datasets in the datasets_json_path will be processed. If specified, only the datasets in the list will be processed. Signed-off-by: Fejgin, Roy <rfejgin@nvidia.com> * Refined datasets filtering in the inference script * Correctly handle comma-separated list of dataset names in the --datasets argument. * Help text Signed-off-by: Fejgin, Roy <rfejgin@nvidia.com> --------- Signed-off-by: Fejgin, Roy <rfejgin@nvidia.com> Signed-off-by: Akhil Varanasi <akhilvaranasi23@gmail.com>
1 parent 7e03538 commit 8f81e82

File tree

1 file changed

+28
-3
lines changed

1 file changed

+28
-3
lines changed

examples/tts/magpietts_inference.py

Lines changed: 28 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -117,11 +117,27 @@ def create_formatted_metrics_mean_ci(metrics_mean_ci: dict) -> dict:
117117
return metrics_mean_ci
118118

119119

120+
def filter_datasets(dataset_meta_info: dict, datasets: Optional[List[str]]) -> List[str]:
121+
"""Select datasets from the dataset meta info."""
122+
if datasets is None:
123+
# Dataset filtering not specified, return all datasets
124+
return list(dataset_meta_info.keys())
125+
else:
126+
datasets = datasets.split(",")
127+
# Check if datasets are valid
128+
for dataset in datasets:
129+
if dataset not in dataset_meta_info:
130+
raise ValueError(f"Dataset {dataset} not found in dataset meta info")
131+
# Return all requsted datasets
132+
return datasets
133+
134+
120135
def run_inference_and_evaluation(
121136
model_config: ModelLoadConfig,
122137
inference_config: InferenceConfig,
123138
eval_config: EvaluationConfig,
124139
dataset_meta_info: dict,
140+
datasets: Optional[List[str]],
125141
out_dir: str,
126142
num_repeats: int = 1,
127143
confidence_level: float = 0.95,
@@ -141,6 +157,8 @@ def run_inference_and_evaluation(
141157
inference_config: Configuration for inference.
142158
eval_config: Configuration for evaluation.
143159
dataset_meta_info: Dictionary containing dataset metadata.
160+
datasets: List of dataset names to run inference and evaluation on. If None, all datasets in the
161+
dataset meta info will be processed.
144162
out_dir: Output directory for results.
145163
num_repeats: Number of times to repeat inference (for CI estimation).
146164
confidence_level: Confidence level for CI calculation.
@@ -175,7 +193,6 @@ def run_inference_and_evaluation(
175193
runner = MagpieInferenceRunner(model, inference_config)
176194

177195
# Tracking metrics across datasets
178-
datasets = list(dataset_meta_info.keys())
179196
ssim_per_dataset = []
180197
cer_per_dataset = []
181198
all_datasets_filewise_metrics = {}
@@ -374,8 +391,15 @@ def create_argument_parser() -> argparse.ArgumentParser:
374391
data_group.add_argument(
375392
'--datasets_json_path',
376393
type=str,
394+
required=True,
395+
default=None,
396+
help='Path to dataset configuration JSON file (will process all datasets in the file if --datasets is not specified)',
397+
)
398+
data_group.add_argument(
399+
'--datasets',
400+
type=str,
377401
default=None,
378-
help='Path to dataset configuration JSON file (will process all datasets in the file)',
402+
help='Comma-separated list of dataset names to process using names from the datasets_json_path file. If not specified, all datasets in the datasets_json_path will be processed.',
379403
)
380404
data_group.add_argument(
381405
'--out_dir',
@@ -502,7 +526,7 @@ def main():
502526
args = parser.parse_args()
503527

504528
dataset_meta_info = load_evalset_config(args.datasets_json_path)
505-
datasets = list(dataset_meta_info.keys())
529+
datasets = filter_datasets(dataset_meta_info, args.datasets)
506530

507531
logging.info(f"Loaded {len(datasets)} datasets: {', '.join(datasets)}")
508532

@@ -585,6 +609,7 @@ def main():
585609
inference_config=inference_config,
586610
eval_config=eval_config,
587611
dataset_meta_info=dataset_meta_info,
612+
datasets=datasets,
588613
out_dir=args.out_dir,
589614
num_repeats=args.num_repeats,
590615
confidence_level=args.confidence_level,

0 commit comments

Comments
 (0)