diff --git a/.gitignore b/.gitignore index 56f17c7..e315e10 100644 --- a/.gitignore +++ b/.gitignore @@ -2,3 +2,4 @@ hw.sol inference_wavtokenizer .zvuk/** inference_wavtokenizer.py +__pycache__ diff --git a/aulate/metrics_evaluation.py b/aulate/metrics_evaluation.py index a916594..0489e43 100644 --- a/aulate/metrics_evaluation.py +++ b/aulate/metrics_evaluation.py @@ -152,7 +152,7 @@ def infer_text_to_audio(self, text: str, prompt: str, top_k: int = 50): # Use default implementation max_seq_length = 1024 - formatted_text = f"Say '{text.upper()}' {prompt}" + formatted_text = f"Say '{text.lower()}' {prompt}" text_tokenized = self.tokenizer(formatted_text, return_tensors="pt") text_input_tokens = text_tokenized["input_ids"].to(self.device) @@ -184,7 +184,7 @@ def calculate_metrics( reference_audio: Union[np.ndarray, torch.Tensor], generated_audio: Union[np.ndarray, torch.Tensor], ref_sr: int = 16000, - gen_sr: int = 44100, + gen_sr: int = 24000, ) -> AudioMetricsResult: """Calculate PESQ, STOI, SI-SDR, SIM-O, and SIM-R metrics""" # Resample reference and generated audio to a common sample rate for SI-SDR calculation @@ -397,7 +397,9 @@ def evaluate_batch( samples: list, prompt: Optional[str] = None, batch_metadata: Optional[Dict[str, Any]] = None, - save_audio: bool = False + save_audio: bool = False, + ref_sr: int = 16000, + gen_sr: int = 24000, ) -> pd.DataFrame: """ Evaluate metrics on a batch of samples. @@ -427,7 +429,12 @@ def evaluate_batch( generated_audio = generated_audio.audio_data.squeeze() text = None - metrics = self.calculate_metrics(reference_audio, generated_audio) + metrics = self.calculate_metrics( + reference_audio, + generated_audio, + ref_sr=ref_sr, + gen_sr=gen_sr + ) print({'PESQ': metrics.pesq, 'STOI': metrics.stoi, 'SI-SDR': metrics.si_sdr, @@ -453,8 +460,8 @@ def evaluate_batch( audio_filename_ = f"orig_{idx}.wav" audio_path = os.path.join(self.gen_audio_dir, audio_filename) audio_path_ = os.path.join(self.gen_audio_dir, audio_filename_) - sf.write(audio_path, generated_audio.astype(np.float32), 44100) - sf.write(audio_path_, reference_audio.astype(np.float32), 16000) + sf.write(audio_path, generated_audio.detach().cpu().numpy().astype(np.float32), gen_sr) + sf.write(audio_path_, reference_audio.astype(np.float32), ref_sr) result_dict['audio_path'] = audio_path results.append(result_dict) @@ -483,7 +490,11 @@ def evaluate_on_librispeech( ): """Evaluate metrics on LibriSpeech samples""" print(f"Loading LibriSpeech dataset ({subset})...") - dataset = torchaudio.datasets.LIBRISPEECH("./data", url=subset, download=True) + + dataset_path = "./data" + os.makedirs(dataset_path, exist_ok=True) + + dataset = torchaudio.datasets.LIBRISPEECH(dataset_path, url=subset, download=True) # Randomly sample entries indices = range(600,801)#random.sample(range(len(dataset)), num_samples) @@ -535,6 +546,6 @@ def evaluate_on_librispeech( results_df = evaluator.evaluate_on_librispeech( num_samples=50, - prompt="with a male speaker delivers a very monotone and high-pitched speech with a very fast speed in a setting with almost no noise, creating a clear and loud recording.", + prompt="with a voice of male speaker delivers a very monotone and high-pitched speech with a very fast speed in a setting with almost no noise, creating a clear and loud recording.", )