Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions babeltron/app/models/translation/nllb.py
Original file line number Diff line number Diff line change
Expand Up @@ -239,11 +239,16 @@ def _convert_lang_code(self, lang_code: str) -> str:
"en": "eng_Latn",
"fr": "fra_Latn",
"es": "spa_Latn",
"es-419": "spa_Latn",
"de": "deu_Latn",
"zh": "zho_Hans",
"zh-cn": "zho_Hans",
"zh-tw": "zho_Hant",
"ar": "ara_Arab",
"ru": "rus_Cyrl",
"pt": "por_Latn",
"pt-br": "por_Latn",
"pt-pt": "por_Latn",
"it": "ita_Latn",
"ja": "jpn_Jpan",
"ko": "kor_Hang",
Expand Down Expand Up @@ -311,6 +316,7 @@ def _convert_lang_code(self, lang_code: str) -> str:
return lang_code

# If we have a mapping, use it
lang_code = lang_code.lower()
if lang_code in iso_to_nllb:
return iso_to_nllb[lang_code]

Expand Down
38 changes: 22 additions & 16 deletions babeltron/app/routers/detect.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,13 +24,14 @@ class DetectionRequest(BaseModel):
description="The text to detect source language",
example="Hello, how are you?",
)
cache: bool = Field(
True,
description="Whether to use and store results in cache. Set to false to bypass cache.",
example=True,
)

class Config:
json_schema_extra = {
"example": {
"text": "Hello, how are you?",
}
}
json_schema_extra = {"example": {"text": "Hello, how are you?", "cache": True}}


class DetectionResponse(BaseModel):
Expand All @@ -52,6 +53,8 @@ class DetectionResponse(BaseModel):
highly accurate even for short text snippets.

Provide the text to detect source language.

Set cache=false to bypass the cache service and always perform a fresh detection.
""",
response_description="The detected language",
status_code=status.HTTP_200_OK,
Expand All @@ -60,20 +63,23 @@ class DetectionResponse(BaseModel):
async def detect(request: DetectionRequest):
current_span = trace.get_current_span()
current_span.set_attribute("text_length", len(request.text))
current_span.set_attribute("cache_enabled", request.cache)

# Check cache for existing detection result
cached_result = cache_service.get_detection(request.text)
if cached_result:
logging.info("Cache hit for language detection")
current_span.set_attribute("cache_hit", True)
# Check cache for existing detection result only if caching is enabled
cached_result = None
if request.cache:
cached_result = cache_service.get_detection(request.text)
if cached_result:
logging.info("Cache hit for language detection")
current_span.set_attribute("cache_hit", True)

# Add the cached flag to the response
cached_result["cached"] = True
return cached_result
cached_result["cached"] = True
return cached_result

current_span.set_attribute("cache_hit", False)

# Use the pre-loaded model based on model_type
model = detection_model
current_span.set_attribute("cache_hit", False)

# Check if model is None
if model is None:
Expand Down Expand Up @@ -116,8 +122,8 @@ async def detect(request: DetectionRequest):
"cached": False,
}

# Cache the result
cache_service.save_detection(request.text, response)
if request.cache:
cache_service.save_detection(request.text, response)

return response

Expand Down
43 changes: 26 additions & 17 deletions babeltron/app/routers/translate.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,13 +36,19 @@ class TranslationRequest(BaseModel):
tgt_lang: str = Field(
..., description="Target language code (ISO 639-1)", example="es"
)
cache: bool = Field(
True,
description="Whether to use and store results in cache. Set to false to bypass cache.",
example=True,
)

class Config:
json_schema_extra = {
"example": {
"text": "Hello, how are you?",
"src_lang": "en",
"tgt_lang": "es",
"cache": True,
}
}

Expand Down Expand Up @@ -75,6 +81,8 @@ class TranslationResponse(BaseModel):
For automatic source language detection, set src_lang to "auto" or leave it empty.

The model used for translation is determined by the BABELTRON_MODEL_TYPE environment variable.

Set cache=false to bypass the cache service and always perform a fresh translation.
""",
response_description="The translated text in the target language",
status_code=status.HTTP_200_OK,
Expand All @@ -84,6 +92,7 @@ async def translate(request: TranslationRequest):
current_span = trace.get_current_span()
current_span.set_attribute("text_length", len(request.text))
current_span.set_attribute("tgt_lang", request.tgt_lang)
current_span.set_attribute("cache_enabled", request.cache)

current_span.set_attribute("model_type", translation_model.model_type)

Expand Down Expand Up @@ -145,21 +154,20 @@ async def translate(request: TranslationRequest):
current_span.set_attribute("src_lang", src_lang)
current_span.set_attribute("auto_detection", False)

# Check cache for existing translation
cached_result = cache_service.get_translation(
request.text, src_lang, request.tgt_lang
)
if cached_result:
logging.info(f"Cache hit for translation: {src_lang} -> {request.tgt_lang}")
current_span.set_attribute("cache_hit", True)
# Check cache for existing translation only if caching is enabled
cached_result = None
if request.cache:
cached_result = cache_service.get_translation(
request.text, src_lang, request.tgt_lang
)
if cached_result:
logging.info(f"Cache hit for translation: {src_lang} -> {request.tgt_lang}")
current_span.set_attribute("cache_hit", True)

# Add the cached flag to the response
cached_result["cached"] = True
return cached_result
# Add the cached flag to the response
cached_result["cached"] = True
return cached_result

logging.info(
f"Translating text from {src_lang} to {request.tgt_lang} using {translation_model.model_type} model"
)
current_span.set_attribute("cache_hit", False)

if not translation_model.is_loaded:
Expand Down Expand Up @@ -199,10 +207,11 @@ async def translate(request: TranslationRequest):
"cached": False,
}

# Cache the result
cache_service.save_translation(
request.text, src_lang, request.tgt_lang, response
)
# Cache the result only if caching is enabled
if request.cache:
cache_service.save_translation(
request.text, src_lang, request.tgt_lang, response
)

return response

Expand Down
95 changes: 94 additions & 1 deletion tests/unit/app/routers/test_detect.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import pytest
from unittest.mock import patch
from unittest.mock import patch, MagicMock
from fastapi import status
from fastapi.testclient import TestClient

Expand Down Expand Up @@ -83,3 +83,96 @@ def test_detect_invalid_request(client):
json={}, # Missing required field 'text'
)
assert response.status_code == status.HTTP_422_UNPROCESSABLE_ENTITY


@patch("babeltron.app.models.detection.factory.get_detection_model")
@patch("babeltron.app.routers.detect.detection_model", new_callable=MagicMock)
def test_detect_with_cache_disabled(mock_detection_model, mock_get_model, client):
# Create a mock model
mock_model = MagicMock()
mock_model.is_loaded = True
mock_model.architecture = "lingua"
mock_model.detect.return_value = ("fr", 0.95)

# Make the factory return our mock model
mock_get_model.return_value = mock_model

# Configure the detection_model mock
mock_detection_model.is_loaded = True
mock_detection_model.architecture = "lingua"
mock_detection_model.detect.return_value = ("fr", 0.95)

# Test data with cache disabled
test_data = {
"text": "Bonjour, comment ça va?",
"cache": False
}

# Mock the cache service to return a cached result
with patch("babeltron.app.routers.detect.cache_service") as mock_cache:
mock_cache.get_detection.return_value = {
"language": "en",
"confidence": 0.98,
"cached": True
}

response = client.post("/api/v1/detect", json=test_data)
assert response.status_code == status.HTTP_200_OK
data = response.json()
assert data["language"] == "fr" # Should use fresh detection, not cached
assert data["confidence"] == 0.95
assert data["cached"] is False

# Verify the model was called correctly
mock_detection_model.detect.assert_called_once()
args, kwargs = mock_detection_model.detect.call_args
assert args[0] == "Bonjour, comment ça va?"

# Verify cache was not used
mock_cache.get_detection.assert_not_called()
mock_cache.save_detection.assert_not_called()


@patch("babeltron.app.models.detection.factory.get_detection_model")
@patch("babeltron.app.routers.detect.detection_model", new_callable=MagicMock)
def test_detect_with_cache_enabled(mock_detection_model, mock_get_model, client):
# Create a mock model
mock_model = MagicMock()
mock_model.is_loaded = True
mock_model.architecture = "lingua"
mock_model.detect.return_value = ("fr", 0.95)

# Make the factory return our mock model
mock_get_model.return_value = mock_model

# Configure the detection_model mock
mock_detection_model.is_loaded = True
mock_detection_model.architecture = "lingua"
mock_detection_model.detect.return_value = ("fr", 0.95)

# Test data with cache enabled (default)
test_data = {
"text": "Bonjour, comment ça va?"
}

# Mock the cache service to return a cached result
with patch("babeltron.app.routers.detect.cache_service") as mock_cache:
mock_cache.get_detection.return_value = {
"language": "en",
"confidence": 0.98,
"cached": True
}

response = client.post("/api/v1/detect", json=test_data)
assert response.status_code == status.HTTP_200_OK
data = response.json()
assert data["language"] == "en" # Should use cached result
assert data["confidence"] == 0.98
assert data["cached"] is True

# Verify the model was not called (using cached result)
mock_detection_model.detect.assert_not_called()

# Verify cache was used
mock_cache.get_detection.assert_called_once()
mock_cache.save_detection.assert_not_called() # Should not save since we used cached result
101 changes: 101 additions & 0 deletions tests/unit/app/routers/test_translate.py
Original file line number Diff line number Diff line change
Expand Up @@ -454,3 +454,104 @@ def test_translate_detection_error(mock_m2m_model, client):
assert "detail" in data
assert "detection error" in data["detail"].lower()
assert "test error" in data["detail"].lower()


@patch("babeltron.app.models.translation.factory.get_translation_model")
@patch("babeltron.app.routers.translate.translation_model", new_callable=MagicMock)
def test_translate_with_cache_disabled(mock_translation_model, mock_get_model, mock_m2m_model, client):
# Set up both mocks to return our mock model
mock_get_model.return_value = mock_m2m_model

# Configure the translation_model mock
for attr_name in ["is_loaded", "architecture", "translate"]:
setattr(mock_translation_model, attr_name, getattr(mock_m2m_model, attr_name))

# Set the model_type attribute on the mock
mock_translation_model.model_type = "m2m100"
mock_m2m_model.model_type = "m2m100"

# Test data with cache disabled
test_data = {
"text": "Hello world",
"src_lang": "en",
"tgt_lang": "fr",
"cache": False
}

# Mock the cache service to return a cached result
with patch("babeltron.app.routers.translate.cache_service") as mock_cache:
mock_cache.get_translation.return_value = {
"translation": "Cached translation",
"model_type": "m2m100",
"architecture": "cpu_compiled",
"cached": True
}

response = client.post("/api/v1/translate", json=test_data)
assert response.status_code == status.HTTP_200_OK
data = response.json()
assert data["translation"] == "Bonjour le monde" # Should use fresh translation, not cached
assert data["model_type"] == "m2m100"
assert data["architecture"] == "cpu_compiled"
assert data["detected_lang"] is None
assert data["detection_confidence"] is None
assert data["cached"] is False

# Verify the model was called correctly
mock_m2m_model.translate.assert_called_once()
args, kwargs = mock_m2m_model.translate.call_args
assert args[0] == "Hello world"
assert args[1] == "en"
assert args[2] == "fr"

# Verify cache was not used
mock_cache.get_translation.assert_not_called()
mock_cache.save_translation.assert_not_called()


@patch("babeltron.app.models.translation.factory.get_translation_model")
@patch("babeltron.app.routers.translate.translation_model", new_callable=MagicMock)
def test_translate_with_cache_enabled(mock_translation_model, mock_get_model, mock_m2m_model, client):
# Set up both mocks to return our mock model
mock_get_model.return_value = mock_m2m_model

# Configure the translation_model mock
for attr_name in ["is_loaded", "architecture", "translate"]:
setattr(mock_translation_model, attr_name, getattr(mock_m2m_model, attr_name))

# Set the model_type attribute on the mock
mock_translation_model.model_type = "m2m100"
mock_m2m_model.model_type = "m2m100"

# Test data with cache enabled (default)
test_data = {
"text": "Hello world",
"src_lang": "en",
"tgt_lang": "fr"
}

# Mock the cache service to return a cached result
with patch("babeltron.app.routers.translate.cache_service") as mock_cache:
mock_cache.get_translation.return_value = {
"translation": "Cached translation",
"model_type": "m2m100",
"architecture": "cpu_compiled",
"cached": True
}

response = client.post("/api/v1/translate", json=test_data)
assert response.status_code == status.HTTP_200_OK
data = response.json()
assert data["translation"] == "Cached translation" # Should use cached translation
assert data["model_type"] == "m2m100"
assert data["architecture"] == "cpu_compiled"
assert data["detected_lang"] is None
assert data["detection_confidence"] is None
assert data["cached"] is True

# Verify the model was not called (using cached result)
mock_m2m_model.translate.assert_not_called()

# Verify cache was used
mock_cache.get_translation.assert_called_once()
mock_cache.save_translation.assert_not_called() # Should not save since we used cached result