Skip to content

Commit ace180b

Browse files
authored
[TTS] Allow inference without reference audio (#15213)
In some cases we need to run inference on manifests that do not include context audio and/or ground truth audio: * Text context manifests may not have a context audio included (and it wouldn't be very relevant anyway) * When generating arbitrary text, the ground truth audio cannot be assumed to be available. Note that even without the context and ground truth there are useful metrics that can be calculated, like WER, UTMOS, and inference speed. This change adapts the inference scripts and the data loader used by them to allow the absence of context and/or ground truth audio. When these is missing, the the metrics that depend on them are set to 0.0.
1 parent 194f1d5 commit ace180b

File tree

2 files changed

+68
-39
lines changed

2 files changed

+68
-39
lines changed

nemo/collections/tts/data/text_to_speech_dataset.py

Lines changed: 8 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -459,7 +459,7 @@ def __getitem__(self, index):
459459
if 'audio_filepath' in data.manifest_entry:
460460
# If audio_filepath is available, then use the actual audio file path.
461461
example['audio_filepath'] = data.manifest_entry['audio_filepath']
462-
else:
462+
elif 'audio_filepath' in data.manifest_entry:
463463
# Only load audio if codes are not available
464464
audio_array, _, audio_filepath_rel = load_audio(
465465
manifest_entry=data.manifest_entry,
@@ -661,13 +661,15 @@ def collate_fn(self, batch: List[dict]):
661661
speaker_indices_list = []
662662
for example in batch:
663663
dataset_name_list.append(example["dataset_name"])
664-
audio_filepath_list.append(example["audio_filepath"])
665664
raw_text_list.append(example["raw_text"])
666665
language_list.append(example["language"])
667666

668667
token_list.append(example["tokens"])
669668
token_len_list.append(example["text_len"])
670669

670+
if 'audio_filepath' in example:
671+
audio_filepath_list.append(example["audio_filepath"])
672+
671673
if 'audio' in example:
672674
audio_list.append(example["audio"])
673675
audio_len_list.append(example["audio_len"])
@@ -774,14 +776,13 @@ def collate_fn(self, batch: List[dict]):
774776
if len(speaker_indices_list) > 0:
775777
batch_dict['speaker_indices'] = torch.tensor(speaker_indices_list, dtype=torch.int64)
776778

777-
# Assert only ONE of context_audio or context_audio_codes in the batch
778-
assert ('audio' in batch_dict) ^ ('audio_codes' in batch_dict)
779+
# Assert no more than one of audio or audio_codes in the batch
780+
if 'audio' in batch_dict:
781+
assert 'audio_codes' not in batch_dict
779782

780-
# Assert only ONE of context_audio or context_audio_codes in the batch
783+
# Assert no more than one of context_audio or context_audio_codes in the batch
781784
if 'context_audio' in batch_dict:
782785
assert 'context_audio_codes' not in batch_dict
783-
if 'context_audio_codes' in batch_dict:
784-
assert 'context_audio' not in batch_dict
785786

786787
return batch_dict
787788

nemo/collections/tts/modules/magpietts_inference/evaluate_generated_audio.py

Lines changed: 60 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -246,9 +246,9 @@ def evaluate(
246246
gt_audio_texts = []
247247
total_generated_audio_seconds = 0.0
248248
for ridx, record in enumerate(records):
249-
gt_audio_filepath = record['audio_filepath']
249+
gt_audio_filepath = record.get('audio_filepath', None)
250250
context_audio_filepath = record.get('context_audio_filepath', None)
251-
if audio_dir is not None:
251+
if audio_dir is not None and gt_audio_filepath is not None:
252252
gt_audio_filepath = os.path.join(audio_dir, gt_audio_filepath)
253253
if context_audio_filepath is not None:
254254
context_audio_filepath = os.path.join(audio_dir, context_audio_filepath)
@@ -265,17 +265,25 @@ def evaluate(
265265
with torch.inference_mode():
266266
pred_text = asr_model.transcribe([pred_audio_filepath], batch_size=1, use_lhotse=False)[0].text
267267
pred_text = process_text(pred_text)
268-
gt_audio_text = asr_model.transcribe([gt_audio_filepath], batch_size=1, use_lhotse=False)[0].text
269-
gt_audio_text = process_text(gt_audio_text)
268+
if gt_audio_filepath is not None:
269+
gt_audio_text = asr_model.transcribe([gt_audio_filepath], batch_size=1, use_lhotse=False)[
270+
0
271+
].text
272+
gt_audio_text = process_text(gt_audio_text)
273+
else:
274+
gt_audio_text = None
270275
else:
271276
pred_text = transcribe_with_whisper(
272277
whisper_model, whisper_processor, pred_audio_filepath, language, device
273278
)
274279
pred_text = process_text(pred_text)
275-
gt_audio_text = transcribe_with_whisper(
276-
whisper_model, whisper_processor, gt_audio_filepath, language, device
277-
)
278-
gt_audio_text = process_text(gt_audio_text)
280+
if gt_audio_filepath is not None:
281+
gt_audio_text = transcribe_with_whisper(
282+
whisper_model, whisper_processor, gt_audio_filepath, language, device
283+
)
284+
gt_audio_text = process_text(gt_audio_text)
285+
else:
286+
gt_audio_text = None
279287
except Exception as e:
280288
logging.info("Error during ASR: {}".format(e))
281289
pred_text = ""
@@ -318,19 +326,29 @@ def evaluate(
318326
sv_model_type=sv_model_type,
319327
)
320328

321-
# Ground truth vs. predicted
322-
gt_speaker_embedding = extract_embedding_fn(audio_path=gt_audio_filepath)
323-
pred_speaker_embedding = extract_embedding_fn(audio_path=pred_audio_filepath)
324-
pred_gt_ssim = torch.nn.functional.cosine_similarity(
325-
gt_speaker_embedding, pred_speaker_embedding, dim=0
326-
).item()
329+
# Initialize SSIMs with a default since the context or ground truth audio
330+
# may be unavailable.
331+
pred_context_ssim = float('NaN')
332+
gt_context_ssim = float('NaN')
333+
pred_context_ssim_alternate = float('NaN')
334+
gt_context_ssim_alternate = float('NaN')
335+
pred_gt_ssim = float('NaN')
336+
pred_gt_ssim_alternate = float('NaN')
337+
338+
if gt_audio_filepath is not None:
339+
# Ground truth vs. predicted
340+
gt_speaker_embedding = extract_embedding_fn(audio_path=gt_audio_filepath)
341+
pred_speaker_embedding = extract_embedding_fn(audio_path=pred_audio_filepath)
342+
pred_gt_ssim = torch.nn.functional.cosine_similarity(
343+
gt_speaker_embedding, pred_speaker_embedding, dim=0
344+
).item()
327345

328-
# Ground truth vs. predicted (alternate model)
329-
gt_speaker_embedding_alternate = extract_embedding_fn_alternate(audio_path=gt_audio_filepath)
330-
pred_speaker_embedding_alternate = extract_embedding_fn_alternate(audio_path=pred_audio_filepath)
331-
pred_gt_ssim_alternate = torch.nn.functional.cosine_similarity(
332-
gt_speaker_embedding_alternate, pred_speaker_embedding_alternate, dim=0
333-
).item()
346+
# Ground truth vs. predicted (alternate model)
347+
gt_speaker_embedding_alternate = extract_embedding_fn_alternate(audio_path=gt_audio_filepath)
348+
pred_speaker_embedding_alternate = extract_embedding_fn_alternate(audio_path=pred_audio_filepath)
349+
pred_gt_ssim_alternate = torch.nn.functional.cosine_similarity(
350+
gt_speaker_embedding_alternate, pred_speaker_embedding_alternate, dim=0
351+
).item()
334352

335353
if context_audio_filepath is not None:
336354
context_speaker_embedding = extract_embedding_fn(audio_path=context_audio_filepath)
@@ -341,18 +359,20 @@ def evaluate(
341359
pred_speaker_embedding, context_speaker_embedding, dim=0
342360
).item()
343361
# Ground truth vs. context
344-
gt_context_ssim = torch.nn.functional.cosine_similarity(
345-
gt_speaker_embedding, context_speaker_embedding, dim=0
346-
).item()
362+
if gt_audio_filepath is not None:
363+
gt_context_ssim = torch.nn.functional.cosine_similarity(
364+
gt_speaker_embedding, context_speaker_embedding, dim=0
365+
).item()
347366

348367
# Predicted vs. context (alternate model)
349368
pred_context_ssim_alternate = torch.nn.functional.cosine_similarity(
350369
pred_speaker_embedding_alternate, context_speaker_embedding_alternate, dim=0
351370
).item()
352371
# Ground truth vs. context (alternate model)
353-
gt_context_ssim_alternate = torch.nn.functional.cosine_similarity(
354-
gt_speaker_embedding_alternate, context_speaker_embedding_alternate, dim=0
355-
).item()
372+
if gt_audio_filepath is not None:
373+
gt_context_ssim_alternate = torch.nn.functional.cosine_similarity(
374+
gt_speaker_embedding_alternate, context_speaker_embedding_alternate, dim=0
375+
).item()
356376
total_generated_audio_seconds += get_wav_file_duration(pred_audio_filepath)
357377

358378
filewise_metrics.append(
@@ -415,12 +435,20 @@ def evaluate(
415435
avg_metrics['ssim_gt_context_avg_alternate'] = sum(
416436
[m['gt_context_ssim_alternate'] for m in filewise_metrics]
417437
) / len(filewise_metrics)
418-
avg_metrics["cer_gt_audio_cumulative"] = word_error_rate_detail(
419-
hypotheses=gt_audio_texts, references=gt_texts, use_cer=True
420-
)[0]
421-
avg_metrics["wer_gt_audio_cumulative"] = word_error_rate_detail(
422-
hypotheses=gt_audio_texts, references=gt_texts, use_cer=False
423-
)[0]
438+
if not None in gt_audio_texts:
439+
avg_metrics["cer_gt_audio_cumulative"] = word_error_rate_detail(
440+
hypotheses=gt_audio_texts, references=gt_texts, use_cer=True
441+
)[0]
442+
avg_metrics["wer_gt_audio_cumulative"] = word_error_rate_detail(
443+
hypotheses=gt_audio_texts, references=gt_texts, use_cer=False
444+
)[0]
445+
else:
446+
avg_metrics["cer_gt_audio_cumulative"] = float('NaN')
447+
avg_metrics["wer_gt_audio_cumulative"] = float('NaN')
448+
logging.warning(
449+
"Ground truth audio files are missing. Setting cumulative CER and WER for ground truth audio to NaN."
450+
)
451+
424452
avg_metrics["utmosv2_avg"] = sum([m['utmosv2'] for m in filewise_metrics]) / len(filewise_metrics)
425453
avg_metrics["total_gen_audio_seconds"] = total_generated_audio_seconds
426454
pprint.pprint(avg_metrics)

0 commit comments

Comments
 (0)