Skip to content

Commit 14d83da

Browse files
authored
feat(api): cache on translate API (#2)
1 parent 0715145 commit 14d83da

File tree

14 files changed

+625
-83
lines changed

14 files changed

+625
-83
lines changed

.coveragerc

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
[run]
22
source = babeltron
3-
omit =
3+
omit =
44
*/tests/*
55
*/test/*
66
*/venv/*
@@ -17,4 +17,4 @@ exclude_lines =
1717
raise ImportError
1818

1919
[html]
20-
directory = htmlcov
20+
directory = htmlcov

.github/workflows/build.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -52,5 +52,5 @@ jobs:
5252
name: dist
5353
path: dist/
5454

55-
- name: Check code quality
55+
- name: Run linters
5656
run: make lint

Makefile

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
.PHONY: check-poetry install test lint format help system-deps coverage coverage-html download-model download-model-small download-model-medium download-model-large serve serve-prod docker-build docker-run docker-compose-up docker-compose-down
1+
.PHONY: check-poetry install test lint format help system-deps coverage coverage-html download-model download-model-small download-model-medium download-model-large serve serve-prod docker-build docker-run docker-compose-up docker-compose-down pre-commit-install pre-commit-run
22

33
# Define model path variable with default value, can be overridden by environment
44
MODEL_PATH ?= ./models
@@ -131,3 +131,10 @@ docker-up: ## Build and start services with docker-compose
131131
docker-down:
132132
@echo "Stopping docker-compose services..."
133133
@docker-compose down
134+
135+
pre-commit-install:
136+
pip install pre-commit
137+
pre-commit install
138+
139+
pre-commit-run:
140+
pre-commit run --all-files

README.md

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -169,6 +169,10 @@ The Docker setup mounts the local `./models` directory to `/models` inside the c
169169

170170
If no models are found when starting the container, you'll be prompted to download the small model automatically.
171171

172+
## Contributing
173+
174+
Install pre-commit hooks with `make pre-commit-install` and refer to the [CONTRIBUTING.md](docs/CONTRIBUTING.md) file for more information.
175+
172176
## License
173177

174178
MIT License

babeltron/app/main.py

Lines changed: 26 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,14 @@
1+
import os
2+
from contextlib import asynccontextmanager
13
from importlib.metadata import version
4+
from typing import AsyncIterator
25

36
from fastapi import FastAPI
47
from fastapi.middleware.cors import CORSMiddleware
8+
from fastapi_cache import FastAPICache
9+
from fastapi_cache.backends.inmemory import InMemoryBackend
10+
from fastapi_cache.backends.redis import RedisBackend
11+
from redis import asyncio as aioredis
512

613
from babeltron.app.utils import include_routers
714

@@ -10,6 +17,24 @@
1017
except ImportError:
1118
__version__ = "0.1.0-dev"
1219

20+
21+
@asynccontextmanager
22+
async def lifespan(_: FastAPI) -> AsyncIterator[None]:
23+
cache_url = os.environ.get("CACHE_URL")
24+
25+
if cache_url.startswith("in-memory"):
26+
FastAPICache.init(InMemoryBackend(), prefix="babeltron")
27+
print("Using in-memory cache")
28+
elif cache_url.startswith("redis"):
29+
redis = aioredis.from_url(cache_url)
30+
FastAPICache.init(RedisBackend(redis), prefix="babeltron")
31+
print("Using Redis cache")
32+
else:
33+
print("No cache_url provided, not using cache")
34+
35+
yield
36+
37+
1338
app = FastAPI(
1439
title="Babeltron Translation API",
1540
description="API for machine translation using NLLB models",
@@ -26,6 +51,7 @@
2651
docs_url="/docs",
2752
redoc_url="/redoc",
2853
openapi_url="/openapi.json",
54+
lifespan=lifespan,
2955
)
3056

3157
# Configure CORS
@@ -40,7 +66,6 @@
4066
# Include all routers
4167
include_routers(app)
4268

43-
# This allows running the app directly with uvicorn when this file is executed
4469
if __name__ == "__main__":
4570
import uvicorn
4671

babeltron/app/routers/translate.py

Lines changed: 16 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,19 @@
1+
import os
2+
3+
import torch
14
from fastapi import APIRouter, HTTPException, status
5+
from fastapi_cache.decorator import cache
26
from pydantic import BaseModel, Field
37
from transformers import M2M100ForConditionalGeneration, M2M100Tokenizer
4-
import os
5-
import torch
68

7-
from babeltron.app.utils import get_model_path
9+
from babeltron.app.utils import ORJsonCoder, cache_key_builder, get_model_path
810

911
router = APIRouter(tags=["Translation"])
1012

11-
MODEL_COMPRESSION_ENABLED = os.environ.get("MODEL_COMPRESSION_ENABLED", "true").lower() in ("true", "1", "yes")
13+
MODEL_COMPRESSION_ENABLED = os.environ.get(
14+
"MODEL_COMPRESSION_ENABLED", "true"
15+
).lower() in ("true", "1", "yes")
16+
CACHE_TTL_SECONDS = int(os.environ.get("CACHE_TTL_SECONDS", "3600"))
1217

1318
try:
1419
MODEL_PATH = get_model_path()
@@ -19,7 +24,7 @@
1924
if MODEL_COMPRESSION_ENABLED and torch.cuda.is_available():
2025
print("Applying FP16 model compression")
2126
model = model.half() # Convert to FP16 precision
22-
model = model.to('cuda') # Move to GPU
27+
model = model.to("cuda") # Move to GPU
2328
elif MODEL_COMPRESSION_ENABLED:
2429
print("FP16 compression enabled but GPU not available, using CPU")
2530
else:
@@ -71,6 +76,7 @@ class TranslationResponse(BaseModel):
7176
response_description="The translated text in the target language",
7277
status_code=status.HTTP_200_OK,
7378
)
79+
@cache(expire=CACHE_TTL_SECONDS, key_builder=cache_key_builder, coder=ORJsonCoder)
7480
async def translate(request: TranslationRequest):
7581
if model is None or tokenizer is None:
7682
raise HTTPException(
@@ -84,17 +90,19 @@ async def translate(request: TranslationRequest):
8490

8591
# Move input to GPU if model is on GPU
8692
if torch.cuda.is_available() and next(model.parameters()).is_cuda:
87-
encoded_text = {k: v.to('cuda') for k, v in encoded_text.items()}
93+
encoded_text = {k: v.to("cuda") for k, v in encoded_text.items()}
8894

8995
generated_tokens = model.generate(
9096
**encoded_text, forced_bos_token_id=tokenizer.get_lang_id(request.tgt_lang)
9197
)
92-
translation = tokenizer.batch_decode(generated_tokens, skip_special_tokens=True)[0]
98+
translation = tokenizer.batch_decode(
99+
generated_tokens, skip_special_tokens=True
100+
)[0]
93101
return {"translation": translation}
94102
except Exception as e:
95103
raise HTTPException(
96104
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
97-
detail=f"Error during translation: {str(e)}"
105+
detail=f"Error during translation: {str(e)}",
98106
)
99107

100108

babeltron/app/utils.py

Lines changed: 55 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,15 @@
1+
import hashlib
12
import importlib
3+
import json
24
import os
35
import pkgutil
46
from pathlib import Path
7+
from typing import Any
58

6-
from fastapi import FastAPI
9+
import orjson
10+
from fastapi import FastAPI, Request, Response
11+
from fastapi.encoders import jsonable_encoder
12+
from fastapi_cache import Coder
713

814

915
def get_model_path() -> str:
@@ -36,3 +42,51 @@ def include_routers(app: FastAPI):
3642
module = importlib.import_module(f"{routers_package}.{module_name}")
3743
if hasattr(module, "router"):
3844
app.include_router(module.router)
45+
46+
47+
def cache_key_builder(
48+
func,
49+
namespace: str = "",
50+
request: Request = None,
51+
response: Response = None,
52+
*args,
53+
**kwargs,
54+
) -> str:
55+
if request is None:
56+
return ""
57+
58+
body_data = {}
59+
if hasattr(request, "state") and hasattr(request.state, "body"):
60+
try:
61+
body_data = json.loads(request.state.body)
62+
except (json.JSONDecodeError, AttributeError):
63+
pass
64+
65+
src_lang = body_data.get("src_lang", "")
66+
dst_lang = body_data.get("dst_lang", "")
67+
text = body_data.get("text", "")
68+
69+
text_md5 = hashlib.md5(text.encode()).hexdigest() if text else ""
70+
71+
return ":".join(
72+
[
73+
namespace,
74+
src_lang,
75+
dst_lang,
76+
text_md5,
77+
]
78+
)
79+
80+
81+
class ORJsonCoder(Coder):
82+
@classmethod
83+
def encode(cls, value: Any) -> bytes:
84+
return orjson.dumps(
85+
value,
86+
default=jsonable_encoder,
87+
option=orjson.OPT_NON_STR_KEYS | orjson.OPT_SERIALIZE_NUMPY,
88+
)
89+
90+
@classmethod
91+
def decode(cls, value: bytes) -> Any:
92+
return orjson.loads(value)

docker-compose.yml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@ services:
1111
- ./models:/models
1212
environment:
1313
- MODEL_PATH=/models
14+
- CACHE_URL=in-memory
1415
restart: unless-stopped
1516
healthcheck:
1617
test: ["CMD", "curl", "-f", "http://localhost:8000/healthz"]

0 commit comments

Comments
 (0)