Skip to content

Commit 8aa12ae

Browse files
authored
Merge pull request #551 from beshkenadze/fix/lid-ecapa-parity-upstream
fix(lid): align ECAPA inference with SpeechBrain
2 parents 5fac1de + 4ec8854 commit 8aa12ae

File tree

2 files changed

+39
-2
lines changed

2 files changed

+39
-2
lines changed

mlx_audio/lid/models/ecapa_tdnn/ecapa_tdnn.py

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,7 @@ def __init__(self, in_dim: int, out_dim: int):
3636
self.norm = nn.BatchNorm(out_dim)
3737

3838
def __call__(self, x: mx.array) -> mx.array:
39-
return nn.relu(self.norm(self.linear(x)))
39+
return self.norm(nn.leaky_relu(self.linear(x), negative_slope=0.01))
4040

4141

4242
class DNN(nn.Module):
@@ -66,6 +66,7 @@ def __init__(self, config: ModelConfig):
6666

6767
def __call__(self, x: mx.array) -> mx.array:
6868
out = mx.squeeze(x, axis=1)
69+
out = nn.leaky_relu(out, negative_slope=0.01)
6970
out = self.norm(out)
7071
out = self.DNN(out)
7172
out = self.out(out)
@@ -121,10 +122,16 @@ def __call__(self, mel_features: mx.array) -> mx.array:
121122
Returns:
122123
Log-probabilities ``[batch, num_classes]``.
123124
"""
124-
embeddings = self.embedding_model(mel_features)
125+
normalized_mel_features = self.sentence_mean_normalize(mel_features)
126+
embeddings = self.embedding_model(normalized_mel_features)
125127
embeddings = mx.expand_dims(embeddings, axis=1)
126128
return self.classifier(embeddings)
127129

130+
@staticmethod
131+
def sentence_mean_normalize(mel_features: mx.array) -> mx.array:
132+
"""Mirror SpeechBrain's sentence-level mean-only InputNormalization."""
133+
return mel_features - mx.mean(mel_features, axis=1, keepdims=True)
134+
128135
def predict(
129136
self,
130137
audio: mx.array,

mlx_audio/lid/tests/test_lid.py

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
from unittest.mock import MagicMock, patch
33

44
import mlx.core as mx
5+
import mlx.nn as nn
56
import numpy as np
67

78

@@ -329,6 +330,35 @@ def test_forward_log_probs_sum(self):
329330
total = float(mx.sum(probs[0]).item())
330331
self.assertAlmostEqual(total, 1.0, places=3)
331332

333+
def test_sentence_mean_normalize_centers_each_mel_bin(self):
334+
mel = mx.array([[[1.0, 3.0], [3.0, 5.0], [5.0, 7.0]]])
335+
normalized = self.Model.sentence_mean_normalize(mel)
336+
mean_per_bin = mx.mean(normalized, axis=1)
337+
mx.eval(mean_per_bin)
338+
339+
self.assertAlmostEqual(float(mean_per_bin[0, 0].item()), 0.0, places=5)
340+
self.assertAlmostEqual(float(mean_per_bin[0, 1].item()), 0.0, places=5)
341+
342+
def test_classifier_matches_speechbrain_order(self):
343+
model = self.Model(self.config)
344+
model.eval()
345+
classifier = model.classifier
346+
x = mx.random.normal((1, 1, self.config.embedding_dim))
347+
348+
expected = mx.squeeze(x, axis=1)
349+
expected = nn.leaky_relu(expected, negative_slope=0.01)
350+
expected = classifier.norm(expected)
351+
expected = classifier.DNN.block_0.linear(expected)
352+
expected = nn.leaky_relu(expected, negative_slope=0.01)
353+
expected = classifier.DNN.block_0.norm(expected)
354+
expected = classifier.out(expected)
355+
expected = mx.log(mx.softmax(expected, axis=-1) + 1e-10)
356+
357+
actual = classifier(x)
358+
mx.eval(expected, actual)
359+
360+
self.assertTrue(mx.allclose(actual, expected, atol=1e-5, rtol=1e-5).item())
361+
332362
def test_predict_returns_sorted(self):
333363
model = self.Model(self.config)
334364
labels = {str(i): f"lang_{i}" for i in range(10)}

0 commit comments

Comments
 (0)