Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
49 changes: 41 additions & 8 deletions mteb/models/model_implementations/encodec_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,7 @@ def get_audio_embeddings(
import torchaudio

all_embeddings = []
max_samples = int(self.max_audio_length_seconds * self.sampling_rate)

for batch in tqdm(
inputs,
Expand Down Expand Up @@ -76,22 +77,54 @@ def get_audio_embeddings(
array = resampler(array)

array = array.squeeze()

# Handle edge case where squeeze results in 0-dim tensor
if array.dim() == 0:
array = array.unsqueeze(0)

# Warn and handle empty audio
if array.shape[-1] == 0:
logger.warning(
f"Empty audio sample at index {idx}, using 1 second of silence."
)
array = torch.zeros(self.sampling_rate) # 1 second of silence

# Truncate if too long (processor doesn't allow both padding and truncation)
if array.shape[-1] > max_samples:
array = array[:max_samples]

# Ensure minimum length for encoder (Encodec needs ~320 samples per frame)
# Use 1 second minimum to be safe
min_samples = self.sampling_rate
if array.shape[-1] < min_samples:
padding = torch.zeros(min_samples - array.shape[-1])
array = torch.cat([array, padding])

audio_arrays.append(array.numpy())

with torch.no_grad():
# Process audio through EnCodec's processor
max_samples = int(self.max_audio_length_seconds * self.sampling_rate)

feature_inputs = self.processor(
# Use processor for batch padding (truncation/min-length done manually above)
processed = self.processor(
raw_audio=audio_arrays,
sampling_rate=self.sampling_rate,
padding=True,
return_tensors="pt",
padding="max_length",
max_length=max_samples,
).to(self.device)
)
input_values = processed["input_values"].to(self.device)

# Add channel dimension if needed (B, T) -> (B, 1, T)
if input_values.dim() == 2:
input_values = input_values.unsqueeze(1)

# Get the latent representations directly from the encoder
latent = self.model.encoder(feature_inputs.input_values)
latent = self.model.encoder(input_values)

# Validate latent has time frames
if latent.shape[2] == 0:
raise ValueError(
f"Encodec encoder produced 0 time frames. "
f"Input shape: {input_values.shape}, latent shape: {latent.shape}"
)

# Apply mean pooling over the time dimension to get fixed-size embeddings
embeddings = torch.mean(latent, dim=2) # Average over time dimension
Expand Down
Loading