77import json
88import logging
99import os
10- from contextlib import contextmanager
1110from copy import deepcopy
1211from time import perf_counter
1312from typing import TypedDict
@@ -33,39 +32,6 @@ class TranslateRequest(TypedDict):
3332 ctranslate2 .set_random_seed (420 )
3433
3534
36- @contextmanager
37- def translate_context (config : dict ):
38- try :
39- tokenizer = SentencePieceProcessor ()
40- tokenizer .Load (os .path .join (config ["loader" ]["model_path" ], config ["tokenizer_file" ]))
41-
42- translator = ctranslate2 .Translator (
43- ** {
44- # only NVIDIA GPUs are supported by CTranslate2 for now
45- "device" : "cuda" if os .getenv ("COMPUTE_DEVICE" ) == "CUDA" else "cpu" ,
46- ** config ["loader" ],
47- }
48- )
49- except KeyError as e :
50- raise ServiceException (
51- "Incorrect config file, ensure all required keys are present from the default config"
52- ) from e
53- except Exception as e :
54- raise ServiceException ("Error loading the translation model" ) from e
55-
56- try :
57- start = perf_counter ()
58- yield (tokenizer , translator )
59- elapsed = perf_counter () - start
60- logger .info (f"time taken: { elapsed :.2f} s" )
61- except Exception as e :
62- raise ServiceException ("Error translating the input text" ) from e
63- finally :
64- del tokenizer
65- # todo: offload to cpu?
66- del translator
67-
68-
6935class Service :
7036 def __init__ (self , config : dict ):
7137 global logger
@@ -94,12 +60,34 @@ def load_config(self, config: dict):
9460
9561 self .config = config_copy
9662
63+ def load_model (self ):
64+ try :
65+ self .tokenizer = SentencePieceProcessor ()
66+ self .tokenizer .Load (os .path .join (self .config ["loader" ]["model_path" ], self .config ["tokenizer_file" ]))
67+
68+ self .translator = ctranslate2 .Translator (
69+ ** {
70+ "device" : "cuda" if os .getenv ("COMPUTE_DEVICE" ) == "CUDA" else "cpu" ,
71+ ** self .config ["loader" ],
72+ }
73+ )
74+ except KeyError as e :
75+ raise ServiceException (
76+ "Incorrect config file, ensure all required keys are present from the default config"
77+ ) from e
78+ except Exception as e :
79+ raise ServiceException ("Error loading the translation model" ) from e
80+
9781 def translate (self , data : TranslateRequest ) -> str :
9882 logger .debug (f"translating text to: { data ['target_language' ]} " )
9983
100- with translate_context (self .config ) as (tokenizer , translator ):
101- input_tokens = tokenizer .Encode (f"<2{ data ['target_language' ]} > { clean_text (data ['input' ])} " , out_type = str )
102- results = translator .translate_batch (
84+ try :
85+ start = perf_counter ()
86+ input_tokens = self .tokenizer .Encode (
87+ f"<2{ data ['target_language' ]} > { clean_text (data ['input' ])} " ,
88+ out_type = str ,
89+ )
90+ results = self .translator .translate_batch (
10391 [input_tokens ],
10492 batch_type = "tokens" ,
10593 ** self .config ["inference" ],
@@ -109,7 +97,11 @@ def translate(self, data: TranslateRequest) -> str:
10997 raise ServiceException ("Empty result returned from translator" )
11098
11199 # todo: handle multiple hypotheses
112- translation = tokenizer .Decode (results [0 ].hypotheses [0 ])
100+ translation = self .tokenizer .Decode (results [0 ].hypotheses [0 ])
101+ elapsed = perf_counter () - start
102+ logger .info (f"time taken: { elapsed :.2f} s" )
103+ except Exception as e :
104+ raise ServiceException ("Error translating the input text" ) from e
113105
114106 logger .debug (f"Translated string: { translation } " )
115107 return translation
0 commit comments