Skip to content

Commit bca34e5

Browse files
committed
test: fix NLLB model fixtures
1 parent 831a0b6 commit bca34e5

File tree

1 file changed

+67
-28
lines changed

1 file changed

+67
-28
lines changed

tests/unit/app/routers/test_translate.py

Lines changed: 67 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,14 @@
1010
test_app = FastAPI()
1111
test_app.include_router(translate_router.router, prefix="/api/v1")
1212

13+
# Common patches for model loading and initialization that should be applied across all tests
14+
model_patches = [
15+
patch("transformers.AutoModelForSeq2SeqLM.from_pretrained"),
16+
patch("transformers.AutoTokenizer.from_pretrained"),
17+
patch("babeltron.app.models.nllb.get_model_path", return_value="/mocked/path"),
18+
patch("babeltron.app.models.m2m100.get_model_path", return_value="/mocked/path"),
19+
]
20+
1321

1422
@pytest.fixture
1523
def client():
@@ -18,22 +26,57 @@ def client():
1826
yield client
1927

2028

29+
@pytest.fixture(autouse=True)
30+
def patch_model_loading():
31+
"""This fixture patches model loading across all tests"""
32+
# Start all the patches
33+
started_patches = []
34+
for p in model_patches:
35+
started_patches.append(p.start())
36+
37+
yield
38+
39+
# Stop all patches after the test is done
40+
for p in model_patches:
41+
p.stop()
42+
43+
44+
@pytest.fixture
45+
def mock_nllb_model():
46+
"""Create a mock NLLB model"""
47+
mock = MagicMock()
48+
mock.is_loaded = True
49+
mock.architecture = "mps_fp16"
50+
mock._model_path = "/mocked/path"
51+
mock.translate.return_value = "Hola, ¿cómo está?"
52+
53+
with patch("babeltron.app.models.nllb.NLLBTranslationModel.__new__", return_value=mock):
54+
yield mock
55+
56+
57+
@pytest.fixture
58+
def mock_m2m_model():
59+
"""Create a mock M2M100 model"""
60+
mock = MagicMock()
61+
mock.is_loaded = True
62+
mock.architecture = "cpu_compiled"
63+
mock._model_path = "/mocked/path"
64+
mock.translate.return_value = "Bonjour le monde"
65+
66+
with patch("babeltron.app.models.m2m100.M2M100TranslationModel.__new__", return_value=mock):
67+
yield mock
68+
69+
2170
# Patch both the factory function and the global translation_model
2271
@patch("babeltron.app.models.factory.get_translation_model")
2372
@patch("babeltron.app.routers.translate.translation_model", new_callable=MagicMock)
24-
def test_translate_success(mock_translation_model, mock_get_model, client):
25-
# Create a mock model
26-
mock_model = MagicMock()
27-
mock_model.is_loaded = True
28-
mock_model.architecture = "cpu_compiled"
29-
mock_model.translate.return_value = "Bonjour le monde"
30-
73+
def test_translate_success(mock_translation_model, mock_get_model, mock_m2m_model, client):
3174
# Set up both mocks to return our mock model
32-
mock_get_model.return_value = mock_model
75+
mock_get_model.return_value = mock_m2m_model
3376

3477
# Configure the translation_model mock
3578
for attr_name in ["is_loaded", "architecture", "translate"]:
36-
setattr(mock_translation_model, attr_name, getattr(mock_model, attr_name))
79+
setattr(mock_translation_model, attr_name, getattr(mock_m2m_model, attr_name))
3780

3881
# Test data
3982
test_data = {
@@ -50,20 +93,14 @@ def test_translate_success(mock_translation_model, mock_get_model, client):
5093
assert data["architecture"] == "cpu_compiled"
5194

5295
# Verify the model was called correctly
53-
mock_model.translate.assert_called_once()
54-
args, kwargs = mock_model.translate.call_args
96+
mock_m2m_model.translate.assert_called_once()
97+
args, kwargs = mock_m2m_model.translate.call_args
5598
assert args[0] == "Hello world"
5699
assert args[1] == "en"
57100
assert args[2] == "fr"
58101

59102

60-
def test_translate_with_model_type(client):
61-
# Create a mock model for NLLB with translate method
62-
nllb_mock = MagicMock()
63-
nllb_mock.is_loaded = True
64-
nllb_mock.architecture = "mps_fp16"
65-
nllb_mock.translate.return_value = "Hola, ¿cómo está?"
66-
103+
def test_translate_with_model_type(mock_nllb_model, client):
67104
# Create a mock for the default model
68105
default_mock = MagicMock()
69106
default_mock.is_loaded = True
@@ -72,8 +109,8 @@ def test_translate_with_model_type(client):
72109
# Create a mock tracer
73110
mock_tracer = MagicMock()
74111

75-
# Patch the factory to return our NLLB mock when requested
76-
with patch("babeltron.app.models.factory.get_translation_model", return_value=nllb_mock), \
112+
# Patch the factory to return our NLLB mock
113+
with patch("babeltron.app.models.factory.get_translation_model", return_value=mock_nllb_model), \
77114
patch("babeltron.app.routers.translate.translation_model", default_mock), \
78115
patch("opentelemetry.trace.get_tracer", return_value=mock_tracer):
79116

@@ -103,6 +140,7 @@ def test_translate_model_not_loaded(mock_translation_model, mock_get_model, clie
103140
# Create a mock model that's not loaded
104141
mock_model = MagicMock()
105142
mock_model.is_loaded = False
143+
mock_model._model_path = "/mocked/path"
106144

107145
# Make the factory return our mock model
108146
mock_get_model.return_value = mock_model
@@ -130,6 +168,7 @@ def test_translate_model_error(mock_translation_model, mock_get_model, client):
130168
# Create a mock model that raises an error
131169
mock_model = MagicMock()
132170
mock_model.is_loaded = True
171+
mock_model._model_path = "/mocked/path"
133172

134173
# Make sure the exception is raised when translate is called
135174
mock_model.translate.side_effect = Exception("Test error")
@@ -164,6 +203,7 @@ def test_languages_endpoint(mock_translation_model, mock_get_model, client):
164203
# Create a mock model
165204
mock_model = MagicMock()
166205
mock_model.is_loaded = True
206+
mock_model._model_path = "/mocked/path"
167207
mock_model.get_languages.return_value = ["en", "fr", "es", "de"]
168208

169209
# Make the factory return our mock model
@@ -185,17 +225,15 @@ def test_languages_endpoint(mock_translation_model, mock_get_model, client):
185225

186226
@patch("babeltron.app.models.factory.get_translation_model")
187227
@patch("babeltron.app.routers.translate.translation_model", new_callable=MagicMock)
188-
def test_languages_with_model_type(mock_translation_model, mock_get_model, client):
189-
# Create a mock model
190-
mock_model = MagicMock()
191-
mock_model.is_loaded = True
192-
mock_model.get_languages.return_value = ["en", "fr", "es", "de"]
228+
def test_languages_with_model_type(mock_translation_model, mock_get_model, mock_nllb_model, client):
229+
# Configure translation_model mock
230+
mock_translation_model.is_loaded = True
193231

194232
# Make the factory return our mock model
195-
mock_get_model.return_value = mock_model
233+
mock_get_model.return_value = mock_nllb_model
196234

197-
# Configure the translation_model mock
198-
mock_translation_model.is_loaded = True
235+
# Set up mock languages
236+
mock_nllb_model.get_languages.return_value = ["en", "fr", "es", "de"]
199237

200238
response = client.get("/api/v1/languages?model_type=nllb")
201239
assert response.status_code == status.HTTP_200_OK
@@ -207,6 +245,7 @@ def test_languages_model_not_loaded(mock_translation_model, mock_get_model, clie
207245
# Create a mock model that's not loaded
208246
mock_model = MagicMock()
209247
mock_model.is_loaded = False
248+
mock_model._model_path = "/mocked/path"
210249

211250
# Make the factory return our mock model
212251
mock_get_model.return_value = mock_model

0 commit comments

Comments
 (0)