Skip to content

Commit 3fc5f37

Browse files
committed
Cleanup and add missing files
* address some CI linting issues * include a file that was missed in last commit Signed-off-by: Fejgin, Roy <[email protected]>
1 parent 14a9a27 commit 3fc5f37

File tree

3 files changed

+15
-7
lines changed

3 files changed

+15
-7
lines changed

examples/tts/magpietts_inference.py

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -101,6 +101,7 @@ def append_metrics_to_csv(csv_path: str, checkpoint_name: str, dataset: str, met
101101
metrics.get('wer_gt_audio_cumulative', ''),
102102
metrics.get('utmosv2_avg', ''),
103103
metrics.get('total_gen_audio_seconds', ''),
104+
metrics.get('frechet_codec_distance', ''),
104105
]
105106
with open(csv_path, "a") as f:
106107
f.write(",".join(str(v) for v in values) + "\n")
@@ -181,7 +182,7 @@ def run_inference_and_evaluation(
181182
"wer_cumulative,ssim_pred_gt_avg,ssim_pred_context_avg,ssim_gt_context_avg,"
182183
"ssim_pred_gt_avg_alternate,ssim_pred_context_avg_alternate,"
183184
"ssim_gt_context_avg_alternate,cer_gt_audio_cumulative,wer_gt_audio_cumulative,"
184-
"utmosv2_avg,total_gen_audio_seconds"
185+
"utmosv2_avg,total_gen_audio_seconds,frechet_codec_distance"
185186
)
186187

187188
for dataset in datasets:
@@ -222,7 +223,7 @@ def run_inference_and_evaluation(
222223
f"Dataset length mismatch: {len(test_dataset)} vs {len(manifest_records)} manifest records"
223224
)
224225

225-
rtf_metrics_list, _ = runner.run_inference_on_dataset(
226+
rtf_metrics_list, _, codec_file_paths = runner.run_inference_on_dataset(
226227
dataset=test_dataset,
227228
output_dir=repeat_audio_dir,
228229
manifest_records=manifest_records,
@@ -246,6 +247,7 @@ def run_inference_and_evaluation(
246247
asr_model_name=eval_config.asr_model_name,
247248
language=language,
248249
with_utmosv2=eval_config.with_utmosv2,
250+
codec_model_path=eval_config.codec_model_path,
249251
)
250252

251253
metrics, filewise_metrics = evaluate_generated_audio_dir(
@@ -272,6 +274,10 @@ def run_inference_and_evaluation(
272274
violin_path = Path(eval_dir) / f"{dataset}_violin_{repeat_idx}.png"
273275
create_violin_plot(filewise_metrics, violin_plot_metrics, violin_path)
274276

277+
# Delete temporary predicted codes files
278+
for codec_file_path in codec_file_paths:
279+
os.remove(codec_file_path)
280+
275281
if skip_evaluation or not metrics_all_repeats:
276282
continue
277283

@@ -463,6 +469,7 @@ def create_argument_parser() -> argparse.ArgumentParser:
463469
nargs='*',
464470
default=['cer', 'pred_context_ssim', 'utmosv2'],
465471
)
472+
eval_group.add_argument('--disable_fcd', action='store_true', help="Disable Frechet Codec Distance computation")
466473

467474
# Quality targets (for CI/CD)
468475
target_group = parser.add_argument_group('Quality Targets')
@@ -520,6 +527,7 @@ def main():
520527
sv_model=args.sv_model,
521528
asr_model_name=args.asr_model_name,
522529
with_utmosv2=not args.disable_utmosv2,
530+
codec_model_path=args.codecmodel_path if not args.disable_fcd else None,
523531
)
524532

525533
cer, ssim = None, None

nemo/collections/tts/metrics/frechet_codec_distance.py

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -50,9 +50,10 @@ def num_features(self) -> int:
5050
class FrechetCodecDistance(FrechetInceptionDistance):
5151
def __init__(self, codec_name: str):
5252
if codec_name.endswith(".nemo"):
53+
# Local .nemo file
5354
codec = AudioCodecModel.restore_from(codec_name, strict=False)
5455
elif codec_name.startswith("nvidia/"):
55-
# HuggingFace or NGC model name
56+
# Model on HuggingFace or NGC
5657
codec = AudioCodecModel.from_pretrained(codec_name)
5758
else:
5859
raise ValueError(
@@ -82,7 +83,7 @@ def encode_from_file(self, audio_path: str) -> Tensor:
8283

8384
def update(self, codes: Tensor, codes_len: Tensor, is_real: bool):
8485
if codes.numel() == 0:
85-
logging.warning(f"FCD: No valid codes to update, skipping update")
86+
logging.warning("FCD: No valid codes to update, skipping update")
8687
return
8788
if codes.shape[1] != self.codec.num_codebooks:
8889
logging.warning(
@@ -97,7 +98,7 @@ def update(self, codes: Tensor, codes_len: Tensor, is_real: bool):
9798
# combine into a single tensor. We treat each timestep independently so we can concatenate them all.
9899
codes_batch_all = torch.cat(codes_batch_all, dim=-1).permute(1, 0) # (B*T, C)
99100
if len(codes_batch_all) == 0:
100-
logging.warning(f"FCD: No valid codes to update, skipping update")
101+
logging.warning("FCD: No valid codes to update, skipping update")
101102
return
102103
# update
103104
super().update(codes_batch_all, real=is_real)
@@ -113,7 +114,7 @@ def update_from_audio_file(self, audio_path: str, is_real: bool):
113114

114115
def compute(self) -> Tensor:
115116
if not self.updated_since_last_reset:
116-
logging.warning(f"FCD: No updates since last reset, returning 0")
117+
logging.warning("FCD: No updates since last reset, returning 0")
117118
return torch.tensor(0.0, device=self.device)
118119
fcd = super().compute()
119120
min_allowed_fcd = -0.01 # a bit of tolerance for numerical issues

nemo/collections/tts/modules/magpietts_inference/evaluate_generated_audio.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,6 @@
3232
import nemo.collections.asr as nemo_asr
3333
from nemo.collections.asr.metrics.wer import word_error_rate_detail
3434
from nemo.collections.tts.metrics.frechet_codec_distance import FrechetCodecDistance
35-
from nemo.collections.tts.models import AudioCodecModel
3635
from nemo.utils import logging
3736

3837
# Optional import for UTMOSv2 (audio quality metric)

0 commit comments

Comments
 (0)