diff --git a/mteb/models/model_implementations/encodec_model.py b/mteb/models/model_implementations/encodec_model.py index 0d922a1d19..43002fd6b2 100644 --- a/mteb/models/model_implementations/encodec_model.py +++ b/mteb/models/model_implementations/encodec_model.py @@ -48,6 +48,7 @@ def get_audio_embeddings( import torchaudio all_embeddings = [] + max_samples = int(self.max_audio_length_seconds * self.sampling_rate) for batch in tqdm( inputs, @@ -76,22 +77,54 @@ def get_audio_embeddings( array = resampler(array) array = array.squeeze() + + # Handle edge case where squeeze results in 0-dim tensor + if array.dim() == 0: + array = array.unsqueeze(0) + + # Warn and handle empty audio + if array.shape[-1] == 0: + logger.warning( + f"Empty audio sample at index {idx}, using 1 second of silence." + ) + array = torch.zeros(self.sampling_rate) # 1 second of silence + + # Truncate if too long (processor doesn't allow both padding and truncation) + if array.shape[-1] > max_samples: + array = array[:max_samples] + + # Ensure minimum length for encoder (Encodec needs ~320 samples per frame) + # Use 1 second minimum to be safe + min_samples = self.sampling_rate + if array.shape[-1] < min_samples: + padding = torch.zeros(min_samples - array.shape[-1]) + array = torch.cat([array, padding]) + audio_arrays.append(array.numpy()) with torch.no_grad(): - # Process audio through EnCodec's processor - max_samples = int(self.max_audio_length_seconds * self.sampling_rate) - - feature_inputs = self.processor( + # Use processor for batch padding (truncation/min-length done manually above) + processed = self.processor( raw_audio=audio_arrays, sampling_rate=self.sampling_rate, + padding=True, return_tensors="pt", - padding="max_length", - max_length=max_samples, - ).to(self.device) + ) + input_values = processed["input_values"].to(self.device) + + # Add channel dimension if needed (B, T) -> (B, 1, T) + if input_values.dim() == 2: + input_values = input_values.unsqueeze(1) # Get the latent representations directly from the encoder - latent = self.model.encoder(feature_inputs.input_values) + latent = self.model.encoder(input_values) + + # Validate latent has time frames + if latent.shape[2] == 0: + raise ValueError( + f"Encodec encoder produced 0 time frames. " + f"Input shape: {input_values.shape}, latent shape: {latent.shape}" + ) # Apply mean pooling over the time dimension to get fixed-size embeddings embeddings = torch.mean(latent, dim=2) # Average over time dimension