1010test_app = FastAPI ()
1111test_app .include_router (translate_router .router , prefix = "/api/v1" )
1212
13+ # Common patches for model loading and initialization that should be applied across all tests
14+ model_patches = [
15+ patch ("transformers.AutoModelForSeq2SeqLM.from_pretrained" ),
16+ patch ("transformers.AutoTokenizer.from_pretrained" ),
17+ patch ("babeltron.app.models.nllb.get_model_path" , return_value = "/mocked/path" ),
18+ patch ("babeltron.app.models.m2m100.get_model_path" , return_value = "/mocked/path" ),
19+ ]
20+
1321
1422@pytest .fixture
1523def client ():
@@ -18,22 +26,57 @@ def client():
1826 yield client
1927
2028
29+ @pytest .fixture (autouse = True )
30+ def patch_model_loading ():
31+ """This fixture patches model loading across all tests"""
32+ # Start all the patches
33+ started_patches = []
34+ for p in model_patches :
35+ started_patches .append (p .start ())
36+
37+ yield
38+
39+ # Stop all patches after the test is done
40+ for p in model_patches :
41+ p .stop ()
42+
43+
44+ @pytest .fixture
45+ def mock_nllb_model ():
46+ """Create a mock NLLB model"""
47+ mock = MagicMock ()
48+ mock .is_loaded = True
49+ mock .architecture = "mps_fp16"
50+ mock ._model_path = "/mocked/path"
51+ mock .translate .return_value = "Hola, ¿cómo está?"
52+
53+ with patch ("babeltron.app.models.nllb.NLLBTranslationModel.__new__" , return_value = mock ):
54+ yield mock
55+
56+
57+ @pytest .fixture
58+ def mock_m2m_model ():
59+ """Create a mock M2M100 model"""
60+ mock = MagicMock ()
61+ mock .is_loaded = True
62+ mock .architecture = "cpu_compiled"
63+ mock ._model_path = "/mocked/path"
64+ mock .translate .return_value = "Bonjour le monde"
65+
66+ with patch ("babeltron.app.models.m2m100.M2M100TranslationModel.__new__" , return_value = mock ):
67+ yield mock
68+
69+
2170# Patch both the factory function and the global translation_model
2271@patch ("babeltron.app.models.factory.get_translation_model" )
2372@patch ("babeltron.app.routers.translate.translation_model" , new_callable = MagicMock )
24- def test_translate_success (mock_translation_model , mock_get_model , client ):
25- # Create a mock model
26- mock_model = MagicMock ()
27- mock_model .is_loaded = True
28- mock_model .architecture = "cpu_compiled"
29- mock_model .translate .return_value = "Bonjour le monde"
30-
73+ def test_translate_success (mock_translation_model , mock_get_model , mock_m2m_model , client ):
3174 # Set up both mocks to return our mock model
32- mock_get_model .return_value = mock_model
75+ mock_get_model .return_value = mock_m2m_model
3376
3477 # Configure the translation_model mock
3578 for attr_name in ["is_loaded" , "architecture" , "translate" ]:
36- setattr (mock_translation_model , attr_name , getattr (mock_model , attr_name ))
79+ setattr (mock_translation_model , attr_name , getattr (mock_m2m_model , attr_name ))
3780
3881 # Test data
3982 test_data = {
@@ -50,20 +93,14 @@ def test_translate_success(mock_translation_model, mock_get_model, client):
5093 assert data ["architecture" ] == "cpu_compiled"
5194
5295 # Verify the model was called correctly
53- mock_model .translate .assert_called_once ()
54- args , kwargs = mock_model .translate .call_args
96+ mock_m2m_model .translate .assert_called_once ()
97+ args , kwargs = mock_m2m_model .translate .call_args
5598 assert args [0 ] == "Hello world"
5699 assert args [1 ] == "en"
57100 assert args [2 ] == "fr"
58101
59102
60- def test_translate_with_model_type (client ):
61- # Create a mock model for NLLB with translate method
62- nllb_mock = MagicMock ()
63- nllb_mock .is_loaded = True
64- nllb_mock .architecture = "mps_fp16"
65- nllb_mock .translate .return_value = "Hola, ¿cómo está?"
66-
103+ def test_translate_with_model_type (mock_nllb_model , client ):
67104 # Create a mock for the default model
68105 default_mock = MagicMock ()
69106 default_mock .is_loaded = True
@@ -72,8 +109,8 @@ def test_translate_with_model_type(client):
72109 # Create a mock tracer
73110 mock_tracer = MagicMock ()
74111
75- # Patch the factory to return our NLLB mock when requested
76- with patch ("babeltron.app.models.factory.get_translation_model" , return_value = nllb_mock ), \
112+ # Patch the factory to return our NLLB mock
113+ with patch ("babeltron.app.models.factory.get_translation_model" , return_value = mock_nllb_model ), \
77114 patch ("babeltron.app.routers.translate.translation_model" , default_mock ), \
78115 patch ("opentelemetry.trace.get_tracer" , return_value = mock_tracer ):
79116
@@ -103,6 +140,7 @@ def test_translate_model_not_loaded(mock_translation_model, mock_get_model, clie
103140 # Create a mock model that's not loaded
104141 mock_model = MagicMock ()
105142 mock_model .is_loaded = False
143+ mock_model ._model_path = "/mocked/path"
106144
107145 # Make the factory return our mock model
108146 mock_get_model .return_value = mock_model
@@ -130,6 +168,7 @@ def test_translate_model_error(mock_translation_model, mock_get_model, client):
130168 # Create a mock model that raises an error
131169 mock_model = MagicMock ()
132170 mock_model .is_loaded = True
171+ mock_model ._model_path = "/mocked/path"
133172
134173 # Make sure the exception is raised when translate is called
135174 mock_model .translate .side_effect = Exception ("Test error" )
@@ -164,6 +203,7 @@ def test_languages_endpoint(mock_translation_model, mock_get_model, client):
164203 # Create a mock model
165204 mock_model = MagicMock ()
166205 mock_model .is_loaded = True
206+ mock_model ._model_path = "/mocked/path"
167207 mock_model .get_languages .return_value = ["en" , "fr" , "es" , "de" ]
168208
169209 # Make the factory return our mock model
@@ -185,17 +225,15 @@ def test_languages_endpoint(mock_translation_model, mock_get_model, client):
185225
186226@patch ("babeltron.app.models.factory.get_translation_model" )
187227@patch ("babeltron.app.routers.translate.translation_model" , new_callable = MagicMock )
188- def test_languages_with_model_type (mock_translation_model , mock_get_model , client ):
189- # Create a mock model
190- mock_model = MagicMock ()
191- mock_model .is_loaded = True
192- mock_model .get_languages .return_value = ["en" , "fr" , "es" , "de" ]
228+ def test_languages_with_model_type (mock_translation_model , mock_get_model , mock_nllb_model , client ):
229+ # Configure translation_model mock
230+ mock_translation_model .is_loaded = True
193231
194232 # Make the factory return our mock model
195- mock_get_model .return_value = mock_model
233+ mock_get_model .return_value = mock_nllb_model
196234
197- # Configure the translation_model mock
198- mock_translation_model . is_loaded = True
235+ # Set up mock languages
236+ mock_nllb_model . get_languages . return_value = [ "en" , "fr" , "es" , "de" ]
199237
200238 response = client .get ("/api/v1/languages?model_type=nllb" )
201239 assert response .status_code == status .HTTP_200_OK
@@ -207,6 +245,7 @@ def test_languages_model_not_loaded(mock_translation_model, mock_get_model, clie
207245 # Create a mock model that's not loaded
208246 mock_model = MagicMock ()
209247 mock_model .is_loaded = False
248+ mock_model ._model_path = "/mocked/path"
210249
211250 # Make the factory return our mock model
212251 mock_get_model .return_value = mock_model
0 commit comments