Skip to content

Commit dcfdae4

Browse files
authored
Feat/language mapping (#18)
* feat(nllb): normalize language codes When mapping, convert to lower prior to checking in the language map * refactor(translate) cache query param The query param 'cache' controls whether or not to save predictions on cache or get cached results. By default it's always set to True, unless on request we see ?cache=false * refactor(detect): add cache skipping param
1 parent 308509b commit dcfdae4

File tree

5 files changed

+249
-34
lines changed

5 files changed

+249
-34
lines changed

babeltron/app/models/translation/nllb.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -239,11 +239,16 @@ def _convert_lang_code(self, lang_code: str) -> str:
239239
"en": "eng_Latn",
240240
"fr": "fra_Latn",
241241
"es": "spa_Latn",
242+
"es-419": "spa_Latn",
242243
"de": "deu_Latn",
243244
"zh": "zho_Hans",
245+
"zh-cn": "zho_Hans",
246+
"zh-tw": "zho_Hant",
244247
"ar": "ara_Arab",
245248
"ru": "rus_Cyrl",
246249
"pt": "por_Latn",
250+
"pt-br": "por_Latn",
251+
"pt-pt": "por_Latn",
247252
"it": "ita_Latn",
248253
"ja": "jpn_Jpan",
249254
"ko": "kor_Hang",
@@ -311,6 +316,7 @@ def _convert_lang_code(self, lang_code: str) -> str:
311316
return lang_code
312317

313318
# If we have a mapping, use it
319+
lang_code = lang_code.lower()
314320
if lang_code in iso_to_nllb:
315321
return iso_to_nllb[lang_code]
316322

babeltron/app/routers/detect.py

Lines changed: 22 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -24,13 +24,14 @@ class DetectionRequest(BaseModel):
2424
description="The text to detect source language",
2525
example="Hello, how are you?",
2626
)
27+
cache: bool = Field(
28+
True,
29+
description="Whether to use and store results in cache. Set to false to bypass cache.",
30+
example=True,
31+
)
2732

2833
class Config:
29-
json_schema_extra = {
30-
"example": {
31-
"text": "Hello, how are you?",
32-
}
33-
}
34+
json_schema_extra = {"example": {"text": "Hello, how are you?", "cache": True}}
3435

3536

3637
class DetectionResponse(BaseModel):
@@ -52,6 +53,8 @@ class DetectionResponse(BaseModel):
5253
highly accurate even for short text snippets.
5354
5455
Provide the text to detect source language.
56+
57+
Set cache=false to bypass the cache service and always perform a fresh detection.
5558
""",
5659
response_description="The detected language",
5760
status_code=status.HTTP_200_OK,
@@ -60,20 +63,23 @@ class DetectionResponse(BaseModel):
6063
async def detect(request: DetectionRequest):
6164
current_span = trace.get_current_span()
6265
current_span.set_attribute("text_length", len(request.text))
66+
current_span.set_attribute("cache_enabled", request.cache)
6367

64-
# Check cache for existing detection result
65-
cached_result = cache_service.get_detection(request.text)
66-
if cached_result:
67-
logging.info("Cache hit for language detection")
68-
current_span.set_attribute("cache_hit", True)
68+
# Check cache for existing detection result only if caching is enabled
69+
cached_result = None
70+
if request.cache:
71+
cached_result = cache_service.get_detection(request.text)
72+
if cached_result:
73+
logging.info("Cache hit for language detection")
74+
current_span.set_attribute("cache_hit", True)
6975

70-
# Add the cached flag to the response
71-
cached_result["cached"] = True
72-
return cached_result
76+
cached_result["cached"] = True
77+
return cached_result
78+
79+
current_span.set_attribute("cache_hit", False)
7380

7481
# Use the pre-loaded model based on model_type
7582
model = detection_model
76-
current_span.set_attribute("cache_hit", False)
7783

7884
# Check if model is None
7985
if model is None:
@@ -116,8 +122,8 @@ async def detect(request: DetectionRequest):
116122
"cached": False,
117123
}
118124

119-
# Cache the result
120-
cache_service.save_detection(request.text, response)
125+
if request.cache:
126+
cache_service.save_detection(request.text, response)
121127

122128
return response
123129

babeltron/app/routers/translate.py

Lines changed: 26 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -36,13 +36,19 @@ class TranslationRequest(BaseModel):
3636
tgt_lang: str = Field(
3737
..., description="Target language code (ISO 639-1)", example="es"
3838
)
39+
cache: bool = Field(
40+
True,
41+
description="Whether to use and store results in cache. Set to false to bypass cache.",
42+
example=True,
43+
)
3944

4045
class Config:
4146
json_schema_extra = {
4247
"example": {
4348
"text": "Hello, how are you?",
4449
"src_lang": "en",
4550
"tgt_lang": "es",
51+
"cache": True,
4652
}
4753
}
4854

@@ -75,6 +81,8 @@ class TranslationResponse(BaseModel):
7581
For automatic source language detection, set src_lang to "auto" or leave it empty.
7682
7783
The model used for translation is determined by the BABELTRON_MODEL_TYPE environment variable.
84+
85+
Set cache=false to bypass the cache service and always perform a fresh translation.
7886
""",
7987
response_description="The translated text in the target language",
8088
status_code=status.HTTP_200_OK,
@@ -84,6 +92,7 @@ async def translate(request: TranslationRequest):
8492
current_span = trace.get_current_span()
8593
current_span.set_attribute("text_length", len(request.text))
8694
current_span.set_attribute("tgt_lang", request.tgt_lang)
95+
current_span.set_attribute("cache_enabled", request.cache)
8796

8897
current_span.set_attribute("model_type", translation_model.model_type)
8998

@@ -145,21 +154,20 @@ async def translate(request: TranslationRequest):
145154
current_span.set_attribute("src_lang", src_lang)
146155
current_span.set_attribute("auto_detection", False)
147156

148-
# Check cache for existing translation
149-
cached_result = cache_service.get_translation(
150-
request.text, src_lang, request.tgt_lang
151-
)
152-
if cached_result:
153-
logging.info(f"Cache hit for translation: {src_lang} -> {request.tgt_lang}")
154-
current_span.set_attribute("cache_hit", True)
157+
# Check cache for existing translation only if caching is enabled
158+
cached_result = None
159+
if request.cache:
160+
cached_result = cache_service.get_translation(
161+
request.text, src_lang, request.tgt_lang
162+
)
163+
if cached_result:
164+
logging.info(f"Cache hit for translation: {src_lang} -> {request.tgt_lang}")
165+
current_span.set_attribute("cache_hit", True)
155166

156-
# Add the cached flag to the response
157-
cached_result["cached"] = True
158-
return cached_result
167+
# Add the cached flag to the response
168+
cached_result["cached"] = True
169+
return cached_result
159170

160-
logging.info(
161-
f"Translating text from {src_lang} to {request.tgt_lang} using {translation_model.model_type} model"
162-
)
163171
current_span.set_attribute("cache_hit", False)
164172

165173
if not translation_model.is_loaded:
@@ -199,10 +207,11 @@ async def translate(request: TranslationRequest):
199207
"cached": False,
200208
}
201209

202-
# Cache the result
203-
cache_service.save_translation(
204-
request.text, src_lang, request.tgt_lang, response
205-
)
210+
# Cache the result only if caching is enabled
211+
if request.cache:
212+
cache_service.save_translation(
213+
request.text, src_lang, request.tgt_lang, response
214+
)
206215

207216
return response
208217

tests/unit/app/routers/test_detect.py

Lines changed: 94 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
import pytest
2-
from unittest.mock import patch
2+
from unittest.mock import patch, MagicMock
33
from fastapi import status
44
from fastapi.testclient import TestClient
55

@@ -83,3 +83,96 @@ def test_detect_invalid_request(client):
8383
json={}, # Missing required field 'text'
8484
)
8585
assert response.status_code == status.HTTP_422_UNPROCESSABLE_ENTITY
86+
87+
88+
@patch("babeltron.app.models.detection.factory.get_detection_model")
89+
@patch("babeltron.app.routers.detect.detection_model", new_callable=MagicMock)
90+
def test_detect_with_cache_disabled(mock_detection_model, mock_get_model, client):
91+
# Create a mock model
92+
mock_model = MagicMock()
93+
mock_model.is_loaded = True
94+
mock_model.architecture = "lingua"
95+
mock_model.detect.return_value = ("fr", 0.95)
96+
97+
# Make the factory return our mock model
98+
mock_get_model.return_value = mock_model
99+
100+
# Configure the detection_model mock
101+
mock_detection_model.is_loaded = True
102+
mock_detection_model.architecture = "lingua"
103+
mock_detection_model.detect.return_value = ("fr", 0.95)
104+
105+
# Test data with cache disabled
106+
test_data = {
107+
"text": "Bonjour, comment ça va?",
108+
"cache": False
109+
}
110+
111+
# Mock the cache service to return a cached result
112+
with patch("babeltron.app.routers.detect.cache_service") as mock_cache:
113+
mock_cache.get_detection.return_value = {
114+
"language": "en",
115+
"confidence": 0.98,
116+
"cached": True
117+
}
118+
119+
response = client.post("/api/v1/detect", json=test_data)
120+
assert response.status_code == status.HTTP_200_OK
121+
data = response.json()
122+
assert data["language"] == "fr" # Should use fresh detection, not cached
123+
assert data["confidence"] == 0.95
124+
assert data["cached"] is False
125+
126+
# Verify the model was called correctly
127+
mock_detection_model.detect.assert_called_once()
128+
args, kwargs = mock_detection_model.detect.call_args
129+
assert args[0] == "Bonjour, comment ça va?"
130+
131+
# Verify cache was not used
132+
mock_cache.get_detection.assert_not_called()
133+
mock_cache.save_detection.assert_not_called()
134+
135+
136+
@patch("babeltron.app.models.detection.factory.get_detection_model")
137+
@patch("babeltron.app.routers.detect.detection_model", new_callable=MagicMock)
138+
def test_detect_with_cache_enabled(mock_detection_model, mock_get_model, client):
139+
# Create a mock model
140+
mock_model = MagicMock()
141+
mock_model.is_loaded = True
142+
mock_model.architecture = "lingua"
143+
mock_model.detect.return_value = ("fr", 0.95)
144+
145+
# Make the factory return our mock model
146+
mock_get_model.return_value = mock_model
147+
148+
# Configure the detection_model mock
149+
mock_detection_model.is_loaded = True
150+
mock_detection_model.architecture = "lingua"
151+
mock_detection_model.detect.return_value = ("fr", 0.95)
152+
153+
# Test data with cache enabled (default)
154+
test_data = {
155+
"text": "Bonjour, comment ça va?"
156+
}
157+
158+
# Mock the cache service to return a cached result
159+
with patch("babeltron.app.routers.detect.cache_service") as mock_cache:
160+
mock_cache.get_detection.return_value = {
161+
"language": "en",
162+
"confidence": 0.98,
163+
"cached": True
164+
}
165+
166+
response = client.post("/api/v1/detect", json=test_data)
167+
assert response.status_code == status.HTTP_200_OK
168+
data = response.json()
169+
assert data["language"] == "en" # Should use cached result
170+
assert data["confidence"] == 0.98
171+
assert data["cached"] is True
172+
173+
# Verify the model was not called (using cached result)
174+
mock_detection_model.detect.assert_not_called()
175+
176+
# Verify cache was used
177+
mock_cache.get_detection.assert_called_once()
178+
mock_cache.save_detection.assert_not_called() # Should not save since we used cached result

tests/unit/app/routers/test_translate.py

Lines changed: 101 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -454,3 +454,104 @@ def test_translate_detection_error(mock_m2m_model, client):
454454
assert "detail" in data
455455
assert "detection error" in data["detail"].lower()
456456
assert "test error" in data["detail"].lower()
457+
458+
459+
@patch("babeltron.app.models.translation.factory.get_translation_model")
460+
@patch("babeltron.app.routers.translate.translation_model", new_callable=MagicMock)
461+
def test_translate_with_cache_disabled(mock_translation_model, mock_get_model, mock_m2m_model, client):
462+
# Set up both mocks to return our mock model
463+
mock_get_model.return_value = mock_m2m_model
464+
465+
# Configure the translation_model mock
466+
for attr_name in ["is_loaded", "architecture", "translate"]:
467+
setattr(mock_translation_model, attr_name, getattr(mock_m2m_model, attr_name))
468+
469+
# Set the model_type attribute on the mock
470+
mock_translation_model.model_type = "m2m100"
471+
mock_m2m_model.model_type = "m2m100"
472+
473+
# Test data with cache disabled
474+
test_data = {
475+
"text": "Hello world",
476+
"src_lang": "en",
477+
"tgt_lang": "fr",
478+
"cache": False
479+
}
480+
481+
# Mock the cache service to return a cached result
482+
with patch("babeltron.app.routers.translate.cache_service") as mock_cache:
483+
mock_cache.get_translation.return_value = {
484+
"translation": "Cached translation",
485+
"model_type": "m2m100",
486+
"architecture": "cpu_compiled",
487+
"cached": True
488+
}
489+
490+
response = client.post("/api/v1/translate", json=test_data)
491+
assert response.status_code == status.HTTP_200_OK
492+
data = response.json()
493+
assert data["translation"] == "Bonjour le monde" # Should use fresh translation, not cached
494+
assert data["model_type"] == "m2m100"
495+
assert data["architecture"] == "cpu_compiled"
496+
assert data["detected_lang"] is None
497+
assert data["detection_confidence"] is None
498+
assert data["cached"] is False
499+
500+
# Verify the model was called correctly
501+
mock_m2m_model.translate.assert_called_once()
502+
args, kwargs = mock_m2m_model.translate.call_args
503+
assert args[0] == "Hello world"
504+
assert args[1] == "en"
505+
assert args[2] == "fr"
506+
507+
# Verify cache was not used
508+
mock_cache.get_translation.assert_not_called()
509+
mock_cache.save_translation.assert_not_called()
510+
511+
512+
@patch("babeltron.app.models.translation.factory.get_translation_model")
513+
@patch("babeltron.app.routers.translate.translation_model", new_callable=MagicMock)
514+
def test_translate_with_cache_enabled(mock_translation_model, mock_get_model, mock_m2m_model, client):
515+
# Set up both mocks to return our mock model
516+
mock_get_model.return_value = mock_m2m_model
517+
518+
# Configure the translation_model mock
519+
for attr_name in ["is_loaded", "architecture", "translate"]:
520+
setattr(mock_translation_model, attr_name, getattr(mock_m2m_model, attr_name))
521+
522+
# Set the model_type attribute on the mock
523+
mock_translation_model.model_type = "m2m100"
524+
mock_m2m_model.model_type = "m2m100"
525+
526+
# Test data with cache enabled (default)
527+
test_data = {
528+
"text": "Hello world",
529+
"src_lang": "en",
530+
"tgt_lang": "fr"
531+
}
532+
533+
# Mock the cache service to return a cached result
534+
with patch("babeltron.app.routers.translate.cache_service") as mock_cache:
535+
mock_cache.get_translation.return_value = {
536+
"translation": "Cached translation",
537+
"model_type": "m2m100",
538+
"architecture": "cpu_compiled",
539+
"cached": True
540+
}
541+
542+
response = client.post("/api/v1/translate", json=test_data)
543+
assert response.status_code == status.HTTP_200_OK
544+
data = response.json()
545+
assert data["translation"] == "Cached translation" # Should use cached translation
546+
assert data["model_type"] == "m2m100"
547+
assert data["architecture"] == "cpu_compiled"
548+
assert data["detected_lang"] is None
549+
assert data["detection_confidence"] is None
550+
assert data["cached"] is True
551+
552+
# Verify the model was not called (using cached result)
553+
mock_m2m_model.translate.assert_not_called()
554+
555+
# Verify cache was used
556+
mock_cache.get_translation.assert_called_once()
557+
mock_cache.save_translation.assert_not_called() # Should not save since we used cached result

0 commit comments

Comments
 (0)