@@ -246,9 +246,9 @@ def evaluate(
246246 gt_audio_texts = []
247247 total_generated_audio_seconds = 0.0
248248 for ridx , record in enumerate (records ):
249- gt_audio_filepath = record [ 'audio_filepath' ]
249+ gt_audio_filepath = record . get ( 'audio_filepath' , None )
250250 context_audio_filepath = record .get ('context_audio_filepath' , None )
251- if audio_dir is not None :
251+ if audio_dir is not None and gt_audio_filepath is not None :
252252 gt_audio_filepath = os .path .join (audio_dir , gt_audio_filepath )
253253 if context_audio_filepath is not None :
254254 context_audio_filepath = os .path .join (audio_dir , context_audio_filepath )
@@ -265,17 +265,25 @@ def evaluate(
265265 with torch .inference_mode ():
266266 pred_text = asr_model .transcribe ([pred_audio_filepath ], batch_size = 1 , use_lhotse = False )[0 ].text
267267 pred_text = process_text (pred_text )
268- gt_audio_text = asr_model .transcribe ([gt_audio_filepath ], batch_size = 1 , use_lhotse = False )[0 ].text
269- gt_audio_text = process_text (gt_audio_text )
268+ if gt_audio_filepath is not None :
269+ gt_audio_text = asr_model .transcribe ([gt_audio_filepath ], batch_size = 1 , use_lhotse = False )[
270+ 0
271+ ].text
272+ gt_audio_text = process_text (gt_audio_text )
273+ else :
274+ gt_audio_text = None
270275 else :
271276 pred_text = transcribe_with_whisper (
272277 whisper_model , whisper_processor , pred_audio_filepath , language , device
273278 )
274279 pred_text = process_text (pred_text )
275- gt_audio_text = transcribe_with_whisper (
276- whisper_model , whisper_processor , gt_audio_filepath , language , device
277- )
278- gt_audio_text = process_text (gt_audio_text )
280+ if gt_audio_filepath is not None :
281+ gt_audio_text = transcribe_with_whisper (
282+ whisper_model , whisper_processor , gt_audio_filepath , language , device
283+ )
284+ gt_audio_text = process_text (gt_audio_text )
285+ else :
286+ gt_audio_text = None
279287 except Exception as e :
280288 logging .info ("Error during ASR: {}" .format (e ))
281289 pred_text = ""
@@ -318,19 +326,29 @@ def evaluate(
318326 sv_model_type = sv_model_type ,
319327 )
320328
321- # Ground truth vs. predicted
322- gt_speaker_embedding = extract_embedding_fn (audio_path = gt_audio_filepath )
323- pred_speaker_embedding = extract_embedding_fn (audio_path = pred_audio_filepath )
324- pred_gt_ssim = torch .nn .functional .cosine_similarity (
325- gt_speaker_embedding , pred_speaker_embedding , dim = 0
326- ).item ()
329+ # Initialize SSIMs with a default since the context or ground truth audio
330+ # may be unavailable.
331+ pred_context_ssim = float ('NaN' )
332+ gt_context_ssim = float ('NaN' )
333+ pred_context_ssim_alternate = float ('NaN' )
334+ gt_context_ssim_alternate = float ('NaN' )
335+ pred_gt_ssim = float ('NaN' )
336+ pred_gt_ssim_alternate = float ('NaN' )
337+
338+ if gt_audio_filepath is not None :
339+ # Ground truth vs. predicted
340+ gt_speaker_embedding = extract_embedding_fn (audio_path = gt_audio_filepath )
341+ pred_speaker_embedding = extract_embedding_fn (audio_path = pred_audio_filepath )
342+ pred_gt_ssim = torch .nn .functional .cosine_similarity (
343+ gt_speaker_embedding , pred_speaker_embedding , dim = 0
344+ ).item ()
327345
328- # Ground truth vs. predicted (alternate model)
329- gt_speaker_embedding_alternate = extract_embedding_fn_alternate (audio_path = gt_audio_filepath )
330- pred_speaker_embedding_alternate = extract_embedding_fn_alternate (audio_path = pred_audio_filepath )
331- pred_gt_ssim_alternate = torch .nn .functional .cosine_similarity (
332- gt_speaker_embedding_alternate , pred_speaker_embedding_alternate , dim = 0
333- ).item ()
346+ # Ground truth vs. predicted (alternate model)
347+ gt_speaker_embedding_alternate = extract_embedding_fn_alternate (audio_path = gt_audio_filepath )
348+ pred_speaker_embedding_alternate = extract_embedding_fn_alternate (audio_path = pred_audio_filepath )
349+ pred_gt_ssim_alternate = torch .nn .functional .cosine_similarity (
350+ gt_speaker_embedding_alternate , pred_speaker_embedding_alternate , dim = 0
351+ ).item ()
334352
335353 if context_audio_filepath is not None :
336354 context_speaker_embedding = extract_embedding_fn (audio_path = context_audio_filepath )
@@ -341,18 +359,20 @@ def evaluate(
341359 pred_speaker_embedding , context_speaker_embedding , dim = 0
342360 ).item ()
343361 # Ground truth vs. context
344- gt_context_ssim = torch .nn .functional .cosine_similarity (
345- gt_speaker_embedding , context_speaker_embedding , dim = 0
346- ).item ()
362+ if gt_audio_filepath is not None :
363+ gt_context_ssim = torch .nn .functional .cosine_similarity (
364+ gt_speaker_embedding , context_speaker_embedding , dim = 0
365+ ).item ()
347366
348367 # Predicted vs. context (alternate model)
349368 pred_context_ssim_alternate = torch .nn .functional .cosine_similarity (
350369 pred_speaker_embedding_alternate , context_speaker_embedding_alternate , dim = 0
351370 ).item ()
352371 # Ground truth vs. context (alternate model)
353- gt_context_ssim_alternate = torch .nn .functional .cosine_similarity (
354- gt_speaker_embedding_alternate , context_speaker_embedding_alternate , dim = 0
355- ).item ()
372+ if gt_audio_filepath is not None :
373+ gt_context_ssim_alternate = torch .nn .functional .cosine_similarity (
374+ gt_speaker_embedding_alternate , context_speaker_embedding_alternate , dim = 0
375+ ).item ()
356376 total_generated_audio_seconds += get_wav_file_duration (pred_audio_filepath )
357377
358378 filewise_metrics .append (
@@ -415,12 +435,20 @@ def evaluate(
415435 avg_metrics ['ssim_gt_context_avg_alternate' ] = sum (
416436 [m ['gt_context_ssim_alternate' ] for m in filewise_metrics ]
417437 ) / len (filewise_metrics )
418- avg_metrics ["cer_gt_audio_cumulative" ] = word_error_rate_detail (
419- hypotheses = gt_audio_texts , references = gt_texts , use_cer = True
420- )[0 ]
421- avg_metrics ["wer_gt_audio_cumulative" ] = word_error_rate_detail (
422- hypotheses = gt_audio_texts , references = gt_texts , use_cer = False
423- )[0 ]
438+ if not None in gt_audio_texts :
439+ avg_metrics ["cer_gt_audio_cumulative" ] = word_error_rate_detail (
440+ hypotheses = gt_audio_texts , references = gt_texts , use_cer = True
441+ )[0 ]
442+ avg_metrics ["wer_gt_audio_cumulative" ] = word_error_rate_detail (
443+ hypotheses = gt_audio_texts , references = gt_texts , use_cer = False
444+ )[0 ]
445+ else :
446+ avg_metrics ["cer_gt_audio_cumulative" ] = float ('NaN' )
447+ avg_metrics ["wer_gt_audio_cumulative" ] = float ('NaN' )
448+ logging .warning (
449+ "Ground truth audio files are missing. Setting cumulative CER and WER for ground truth audio to NaN."
450+ )
451+
424452 avg_metrics ["utmosv2_avg" ] = sum ([m ['utmosv2' ] for m in filewise_metrics ]) / len (filewise_metrics )
425453 avg_metrics ["total_gen_audio_seconds" ] = total_generated_audio_seconds
426454 pprint .pprint (avg_metrics )
0 commit comments