Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -2,3 +2,4 @@ hw.sol
inference_wavtokenizer
.zvuk/**
inference_wavtokenizer.py
__pycache__
27 changes: 19 additions & 8 deletions aulate/metrics_evaluation.py
Original file line number Diff line number Diff line change
Expand Up @@ -152,7 +152,7 @@ def infer_text_to_audio(self, text: str, prompt: str, top_k: int = 50):

# Use default implementation
max_seq_length = 1024
formatted_text = f"Say '{text.upper()}' {prompt}"
formatted_text = f"Say '{text.lower()}' {prompt}"

text_tokenized = self.tokenizer(formatted_text, return_tensors="pt")
text_input_tokens = text_tokenized["input_ids"].to(self.device)
Expand Down Expand Up @@ -184,7 +184,7 @@ def calculate_metrics(
reference_audio: Union[np.ndarray, torch.Tensor],
generated_audio: Union[np.ndarray, torch.Tensor],
ref_sr: int = 16000,
gen_sr: int = 44100,
gen_sr: int = 24000,
) -> AudioMetricsResult:
"""Calculate PESQ, STOI, SI-SDR, SIM-O, and SIM-R metrics"""
# Resample reference and generated audio to a common sample rate for SI-SDR calculation
Expand Down Expand Up @@ -397,7 +397,9 @@ def evaluate_batch(
samples: list,
prompt: Optional[str] = None,
batch_metadata: Optional[Dict[str, Any]] = None,
save_audio: bool = False
save_audio: bool = False,
ref_sr: int = 16000,
gen_sr: int = 24000,
) -> pd.DataFrame:
"""
Evaluate metrics on a batch of samples.
Expand Down Expand Up @@ -427,7 +429,12 @@ def evaluate_batch(
generated_audio = generated_audio.audio_data.squeeze()
text = None

metrics = self.calculate_metrics(reference_audio, generated_audio)
metrics = self.calculate_metrics(
reference_audio,
generated_audio,
ref_sr=ref_sr,
gen_sr=gen_sr
)
print({'PESQ': metrics.pesq,
'STOI': metrics.stoi,
'SI-SDR': metrics.si_sdr,
Expand All @@ -453,8 +460,8 @@ def evaluate_batch(
audio_filename_ = f"orig_{idx}.wav"
audio_path = os.path.join(self.gen_audio_dir, audio_filename)
audio_path_ = os.path.join(self.gen_audio_dir, audio_filename_)
sf.write(audio_path, generated_audio.astype(np.float32), 44100)
sf.write(audio_path_, reference_audio.astype(np.float32), 16000)
sf.write(audio_path, generated_audio.detach().cpu().numpy().astype(np.float32), gen_sr)
sf.write(audio_path_, reference_audio.astype(np.float32), ref_sr)
result_dict['audio_path'] = audio_path

results.append(result_dict)
Expand Down Expand Up @@ -483,7 +490,11 @@ def evaluate_on_librispeech(
):
"""Evaluate metrics on LibriSpeech samples"""
print(f"Loading LibriSpeech dataset ({subset})...")
dataset = torchaudio.datasets.LIBRISPEECH("./data", url=subset, download=True)

dataset_path = "./data"
os.makedirs(dataset_path, exist_ok=True)

dataset = torchaudio.datasets.LIBRISPEECH(dataset_path, url=subset, download=True)

# Randomly sample entries
indices = range(600,801)#random.sample(range(len(dataset)), num_samples)
Expand Down Expand Up @@ -535,6 +546,6 @@ def evaluate_on_librispeech(

results_df = evaluator.evaluate_on_librispeech(
num_samples=50,
prompt="with a male speaker delivers a very monotone and high-pitched speech with a very fast speed in a setting with almost no noise, creating a clear and loud recording.",
prompt="with a voice of male speaker delivers a very monotone and high-pitched speech with a very fast speed in a setting with almost no noise, creating a clear and loud recording.",
)