Skip to content

Commit 8477591

Browse files
AdnanElAssadi56Samoedisaac-chung
authored
[MAEB] Add safety checks for encodec model (#3858)
* Add safety checks for encodec model * Add Encodec Change Requests * Handle min_length and validate time frames. * run lint * fix import --------- Co-authored-by: Roman Solomatin <[email protected]> Co-authored-by: Isaac Chung <[email protected]>
1 parent fcb5e77 commit 8477591

File tree

1 file changed

+41
-8
lines changed

1 file changed

+41
-8
lines changed

mteb/models/model_implementations/encodec_model.py

Lines changed: 41 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -48,6 +48,7 @@ def get_audio_embeddings(
4848
import torchaudio
4949

5050
all_embeddings = []
51+
max_samples = int(self.max_audio_length_seconds * self.sampling_rate)
5152

5253
for batch in tqdm(
5354
inputs,
@@ -76,22 +77,54 @@ def get_audio_embeddings(
7677
array = resampler(array)
7778

7879
array = array.squeeze()
80+
81+
# Handle edge case where squeeze results in 0-dim tensor
82+
if array.dim() == 0:
83+
array = array.unsqueeze(0)
84+
85+
# Warn and handle empty audio
86+
if array.shape[-1] == 0:
87+
logger.warning(
88+
f"Empty audio sample at index {idx}, using 1 second of silence."
89+
)
90+
array = torch.zeros(self.sampling_rate) # 1 second of silence
91+
92+
# Truncate if too long (processor doesn't allow both padding and truncation)
93+
if array.shape[-1] > max_samples:
94+
array = array[:max_samples]
95+
96+
# Ensure minimum length for encoder (Encodec needs ~320 samples per frame)
97+
# Use 1 second minimum to be safe
98+
min_samples = self.sampling_rate
99+
if array.shape[-1] < min_samples:
100+
padding = torch.zeros(min_samples - array.shape[-1])
101+
array = torch.cat([array, padding])
102+
79103
audio_arrays.append(array.numpy())
80104

81105
with torch.no_grad():
82-
# Process audio through EnCodec's processor
83-
max_samples = int(self.max_audio_length_seconds * self.sampling_rate)
84-
85-
feature_inputs = self.processor(
106+
# Use processor for batch padding (truncation/min-length done manually above)
107+
processed = self.processor(
86108
raw_audio=audio_arrays,
87109
sampling_rate=self.sampling_rate,
110+
padding=True,
88111
return_tensors="pt",
89-
padding="max_length",
90-
max_length=max_samples,
91-
).to(self.device)
112+
)
113+
input_values = processed["input_values"].to(self.device)
114+
115+
# Add channel dimension if needed (B, T) -> (B, 1, T)
116+
if input_values.dim() == 2:
117+
input_values = input_values.unsqueeze(1)
92118

93119
# Get the latent representations directly from the encoder
94-
latent = self.model.encoder(feature_inputs.input_values)
120+
latent = self.model.encoder(input_values)
121+
122+
# Validate latent has time frames
123+
if latent.shape[2] == 0:
124+
raise ValueError(
125+
f"Encodec encoder produced 0 time frames. "
126+
f"Input shape: {input_values.shape}, latent shape: {latent.shape}"
127+
)
95128

96129
# Apply mean pooling over the time dimension to get fixed-size embeddings
97130
embeddings = torch.mean(latent, dim=2) # Average over time dimension

0 commit comments

Comments
 (0)