@@ -101,6 +101,7 @@ def append_metrics_to_csv(csv_path: str, checkpoint_name: str, dataset: str, met
101101 metrics .get ('wer_gt_audio_cumulative' , '' ),
102102 metrics .get ('utmosv2_avg' , '' ),
103103 metrics .get ('total_gen_audio_seconds' , '' ),
104+ metrics .get ('frechet_codec_distance' , '' ),
104105 ]
105106 with open (csv_path , "a" ) as f :
106107 f .write ("," .join (str (v ) for v in values ) + "\n " )
@@ -181,7 +182,7 @@ def run_inference_and_evaluation(
181182 "wer_cumulative,ssim_pred_gt_avg,ssim_pred_context_avg,ssim_gt_context_avg,"
182183 "ssim_pred_gt_avg_alternate,ssim_pred_context_avg_alternate,"
183184 "ssim_gt_context_avg_alternate,cer_gt_audio_cumulative,wer_gt_audio_cumulative,"
184- "utmosv2_avg,total_gen_audio_seconds"
185+ "utmosv2_avg,total_gen_audio_seconds,frechet_codec_distance "
185186 )
186187
187188 for dataset in datasets :
@@ -222,7 +223,7 @@ def run_inference_and_evaluation(
222223 f"Dataset length mismatch: { len (test_dataset )} vs { len (manifest_records )} manifest records"
223224 )
224225
225- rtf_metrics_list , _ = runner .run_inference_on_dataset (
226+ rtf_metrics_list , _ , codec_file_paths = runner .run_inference_on_dataset (
226227 dataset = test_dataset ,
227228 output_dir = repeat_audio_dir ,
228229 manifest_records = manifest_records ,
@@ -246,6 +247,7 @@ def run_inference_and_evaluation(
246247 asr_model_name = eval_config .asr_model_name ,
247248 language = language ,
248249 with_utmosv2 = eval_config .with_utmosv2 ,
250+ codec_model_path = eval_config .codec_model_path ,
249251 )
250252
251253 metrics , filewise_metrics = evaluate_generated_audio_dir (
@@ -272,6 +274,10 @@ def run_inference_and_evaluation(
272274 violin_path = Path (eval_dir ) / f"{ dataset } _violin_{ repeat_idx } .png"
273275 create_violin_plot (filewise_metrics , violin_plot_metrics , violin_path )
274276
277+ # Delete temporary predicted codes files
278+ for codec_file_path in codec_file_paths :
279+ os .remove (codec_file_path )
280+
275281 if skip_evaluation or not metrics_all_repeats :
276282 continue
277283
@@ -463,6 +469,7 @@ def create_argument_parser() -> argparse.ArgumentParser:
463469 nargs = '*' ,
464470 default = ['cer' , 'pred_context_ssim' , 'utmosv2' ],
465471 )
472+ eval_group .add_argument ('--disable_fcd' , action = 'store_true' , help = "Disable Frechet Codec Distance computation" )
466473
467474 # Quality targets (for CI/CD)
468475 target_group = parser .add_argument_group ('Quality Targets' )
@@ -520,6 +527,7 @@ def main():
520527 sv_model = args .sv_model ,
521528 asr_model_name = args .asr_model_name ,
522529 with_utmosv2 = not args .disable_utmosv2 ,
530+ codec_model_path = args .codecmodel_path if not args .disable_fcd else None ,
523531 )
524532
525533 cer , ssim = None , None
0 commit comments