Skip to content

Commit e0f6d00

Browse files
committed
refactor(detect): add cache skipping param
1 parent ceec70c commit e0f6d00

File tree

2 files changed

+116
-17
lines changed

2 files changed

+116
-17
lines changed

babeltron/app/routers/detect.py

Lines changed: 22 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -24,13 +24,14 @@ class DetectionRequest(BaseModel):
2424
description="The text to detect source language",
2525
example="Hello, how are you?",
2626
)
27+
cache: bool = Field(
28+
True,
29+
description="Whether to use and store results in cache. Set to false to bypass cache.",
30+
example=True,
31+
)
2732

2833
class Config:
29-
json_schema_extra = {
30-
"example": {
31-
"text": "Hello, how are you?",
32-
}
33-
}
34+
json_schema_extra = {"example": {"text": "Hello, how are you?", "cache": True}}
3435

3536

3637
class DetectionResponse(BaseModel):
@@ -52,6 +53,8 @@ class DetectionResponse(BaseModel):
5253
highly accurate even for short text snippets.
5354
5455
Provide the text to detect source language.
56+
57+
Set cache=false to bypass the cache service and always perform a fresh detection.
5558
""",
5659
response_description="The detected language",
5760
status_code=status.HTTP_200_OK,
@@ -60,20 +63,23 @@ class DetectionResponse(BaseModel):
6063
async def detect(request: DetectionRequest):
6164
current_span = trace.get_current_span()
6265
current_span.set_attribute("text_length", len(request.text))
66+
current_span.set_attribute("cache_enabled", request.cache)
6367

64-
# Check cache for existing detection result
65-
cached_result = cache_service.get_detection(request.text)
66-
if cached_result:
67-
logging.info("Cache hit for language detection")
68-
current_span.set_attribute("cache_hit", True)
68+
# Check cache for existing detection result only if caching is enabled
69+
cached_result = None
70+
if request.cache:
71+
cached_result = cache_service.get_detection(request.text)
72+
if cached_result:
73+
logging.info("Cache hit for language detection")
74+
current_span.set_attribute("cache_hit", True)
6975

70-
# Add the cached flag to the response
71-
cached_result["cached"] = True
72-
return cached_result
76+
cached_result["cached"] = True
77+
return cached_result
78+
79+
current_span.set_attribute("cache_hit", False)
7380

7481
# Use the pre-loaded model based on model_type
7582
model = detection_model
76-
current_span.set_attribute("cache_hit", False)
7783

7884
# Check if model is None
7985
if model is None:
@@ -116,8 +122,8 @@ async def detect(request: DetectionRequest):
116122
"cached": False,
117123
}
118124

119-
# Cache the result
120-
cache_service.save_detection(request.text, response)
125+
if request.cache:
126+
cache_service.save_detection(request.text, response)
121127

122128
return response
123129

tests/unit/app/routers/test_detect.py

Lines changed: 94 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
import pytest
2-
from unittest.mock import patch
2+
from unittest.mock import patch, MagicMock
33
from fastapi import status
44
from fastapi.testclient import TestClient
55

@@ -83,3 +83,96 @@ def test_detect_invalid_request(client):
8383
json={}, # Missing required field 'text'
8484
)
8585
assert response.status_code == status.HTTP_422_UNPROCESSABLE_ENTITY
86+
87+
88+
@patch("babeltron.app.models.detection.factory.get_detection_model")
89+
@patch("babeltron.app.routers.detect.detection_model", new_callable=MagicMock)
90+
def test_detect_with_cache_disabled(mock_detection_model, mock_get_model, client):
91+
# Create a mock model
92+
mock_model = MagicMock()
93+
mock_model.is_loaded = True
94+
mock_model.architecture = "lingua"
95+
mock_model.detect.return_value = ("fr", 0.95)
96+
97+
# Make the factory return our mock model
98+
mock_get_model.return_value = mock_model
99+
100+
# Configure the detection_model mock
101+
mock_detection_model.is_loaded = True
102+
mock_detection_model.architecture = "lingua"
103+
mock_detection_model.detect.return_value = ("fr", 0.95)
104+
105+
# Test data with cache disabled
106+
test_data = {
107+
"text": "Bonjour, comment ça va?",
108+
"cache": False
109+
}
110+
111+
# Mock the cache service to return a cached result
112+
with patch("babeltron.app.routers.detect.cache_service") as mock_cache:
113+
mock_cache.get_detection.return_value = {
114+
"language": "en",
115+
"confidence": 0.98,
116+
"cached": True
117+
}
118+
119+
response = client.post("/api/v1/detect", json=test_data)
120+
assert response.status_code == status.HTTP_200_OK
121+
data = response.json()
122+
assert data["language"] == "fr" # Should use fresh detection, not cached
123+
assert data["confidence"] == 0.95
124+
assert data["cached"] is False
125+
126+
# Verify the model was called correctly
127+
mock_detection_model.detect.assert_called_once()
128+
args, kwargs = mock_detection_model.detect.call_args
129+
assert args[0] == "Bonjour, comment ça va?"
130+
131+
# Verify cache was not used
132+
mock_cache.get_detection.assert_not_called()
133+
mock_cache.save_detection.assert_not_called()
134+
135+
136+
@patch("babeltron.app.models.detection.factory.get_detection_model")
137+
@patch("babeltron.app.routers.detect.detection_model", new_callable=MagicMock)
138+
def test_detect_with_cache_enabled(mock_detection_model, mock_get_model, client):
139+
# Create a mock model
140+
mock_model = MagicMock()
141+
mock_model.is_loaded = True
142+
mock_model.architecture = "lingua"
143+
mock_model.detect.return_value = ("fr", 0.95)
144+
145+
# Make the factory return our mock model
146+
mock_get_model.return_value = mock_model
147+
148+
# Configure the detection_model mock
149+
mock_detection_model.is_loaded = True
150+
mock_detection_model.architecture = "lingua"
151+
mock_detection_model.detect.return_value = ("fr", 0.95)
152+
153+
# Test data with cache enabled (default)
154+
test_data = {
155+
"text": "Bonjour, comment ça va?"
156+
}
157+
158+
# Mock the cache service to return a cached result
159+
with patch("babeltron.app.routers.detect.cache_service") as mock_cache:
160+
mock_cache.get_detection.return_value = {
161+
"language": "en",
162+
"confidence": 0.98,
163+
"cached": True
164+
}
165+
166+
response = client.post("/api/v1/detect", json=test_data)
167+
assert response.status_code == status.HTTP_200_OK
168+
data = response.json()
169+
assert data["language"] == "en" # Should use cached result
170+
assert data["confidence"] == 0.98
171+
assert data["cached"] is True
172+
173+
# Verify the model was not called (using cached result)
174+
mock_detection_model.detect.assert_not_called()
175+
176+
# Verify cache was used
177+
mock_cache.get_detection.assert_called_once()
178+
mock_cache.save_detection.assert_not_called() # Should not save since we used cached result

0 commit comments

Comments
 (0)