1+ import os
2+
3+ import torch
14from fastapi import APIRouter , HTTPException , status
5+ from fastapi_cache .decorator import cache
26from pydantic import BaseModel , Field
37from transformers import M2M100ForConditionalGeneration , M2M100Tokenizer
4- import os
5- import torch
68
7- from babeltron .app .utils import get_model_path
9+ from babeltron .app .utils import ORJsonCoder , cache_key_builder , get_model_path
810
911router = APIRouter (tags = ["Translation" ])
1012
11- MODEL_COMPRESSION_ENABLED = os .environ .get ("MODEL_COMPRESSION_ENABLED" , "true" ).lower () in ("true" , "1" , "yes" )
13+ MODEL_COMPRESSION_ENABLED = os .environ .get (
14+ "MODEL_COMPRESSION_ENABLED" , "true"
15+ ).lower () in ("true" , "1" , "yes" )
16+ CACHE_TTL_SECONDS = int (os .environ .get ("CACHE_TTL_SECONDS" , "3600" ))
1217
1318try :
1419 MODEL_PATH = get_model_path ()
1924 if MODEL_COMPRESSION_ENABLED and torch .cuda .is_available ():
2025 print ("Applying FP16 model compression" )
2126 model = model .half () # Convert to FP16 precision
22- model = model .to (' cuda' ) # Move to GPU
27+ model = model .to (" cuda" ) # Move to GPU
2328 elif MODEL_COMPRESSION_ENABLED :
2429 print ("FP16 compression enabled but GPU not available, using CPU" )
2530 else :
@@ -71,6 +76,7 @@ class TranslationResponse(BaseModel):
7176 response_description = "The translated text in the target language" ,
7277 status_code = status .HTTP_200_OK ,
7378)
79+ @cache (expire = CACHE_TTL_SECONDS , key_builder = cache_key_builder , coder = ORJsonCoder )
7480async def translate (request : TranslationRequest ):
7581 if model is None or tokenizer is None :
7682 raise HTTPException (
@@ -84,17 +90,19 @@ async def translate(request: TranslationRequest):
8490
8591 # Move input to GPU if model is on GPU
8692 if torch .cuda .is_available () and next (model .parameters ()).is_cuda :
87- encoded_text = {k : v .to (' cuda' ) for k , v in encoded_text .items ()}
93+ encoded_text = {k : v .to (" cuda" ) for k , v in encoded_text .items ()}
8894
8995 generated_tokens = model .generate (
9096 ** encoded_text , forced_bos_token_id = tokenizer .get_lang_id (request .tgt_lang )
9197 )
92- translation = tokenizer .batch_decode (generated_tokens , skip_special_tokens = True )[0 ]
98+ translation = tokenizer .batch_decode (
99+ generated_tokens , skip_special_tokens = True
100+ )[0 ]
93101 return {"translation" : translation }
94102 except Exception as e :
95103 raise HTTPException (
96104 status_code = status .HTTP_500_INTERNAL_SERVER_ERROR ,
97- detail = f"Error during translation: { str (e )} "
105+ detail = f"Error during translation: { str (e )} " ,
98106 )
99107
100108
0 commit comments