Skip to content

Commit 55240c1

Browse files
authored
feat: load the model once and keep it loaded (#82)
Signed-off-by: Anupam Kumar <kyteinsky@gmail.com>
1 parent 1448627 commit 55240c1

File tree

2 files changed

+31
-38
lines changed

2 files changed

+31
-38
lines changed

lib/Service.py

Lines changed: 30 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,6 @@
77
import json
88
import logging
99
import os
10-
from contextlib import contextmanager
1110
from copy import deepcopy
1211
from time import perf_counter
1312
from typing import TypedDict
@@ -33,39 +32,6 @@ class TranslateRequest(TypedDict):
3332
ctranslate2.set_random_seed(420)
3433

3534

36-
@contextmanager
37-
def translate_context(config: dict):
38-
try:
39-
tokenizer = SentencePieceProcessor()
40-
tokenizer.Load(os.path.join(config["loader"]["model_path"], config["tokenizer_file"]))
41-
42-
translator = ctranslate2.Translator(
43-
**{
44-
# only NVIDIA GPUs are supported by CTranslate2 for now
45-
"device": "cuda" if os.getenv("COMPUTE_DEVICE") == "CUDA" else "cpu",
46-
**config["loader"],
47-
}
48-
)
49-
except KeyError as e:
50-
raise ServiceException(
51-
"Incorrect config file, ensure all required keys are present from the default config"
52-
) from e
53-
except Exception as e:
54-
raise ServiceException("Error loading the translation model") from e
55-
56-
try:
57-
start = perf_counter()
58-
yield (tokenizer, translator)
59-
elapsed = perf_counter() - start
60-
logger.info(f"time taken: {elapsed:.2f}s")
61-
except Exception as e:
62-
raise ServiceException("Error translating the input text") from e
63-
finally:
64-
del tokenizer
65-
# todo: offload to cpu?
66-
del translator
67-
68-
6935
class Service:
7036
def __init__(self, config: dict):
7137
global logger
@@ -94,12 +60,34 @@ def load_config(self, config: dict):
9460

9561
self.config = config_copy
9662

63+
def load_model(self):
64+
try:
65+
self.tokenizer = SentencePieceProcessor()
66+
self.tokenizer.Load(os.path.join(self.config["loader"]["model_path"], self.config["tokenizer_file"]))
67+
68+
self.translator = ctranslate2.Translator(
69+
**{
70+
"device": "cuda" if os.getenv("COMPUTE_DEVICE") == "CUDA" else "cpu",
71+
**self.config["loader"],
72+
}
73+
)
74+
except KeyError as e:
75+
raise ServiceException(
76+
"Incorrect config file, ensure all required keys are present from the default config"
77+
) from e
78+
except Exception as e:
79+
raise ServiceException("Error loading the translation model") from e
80+
9781
def translate(self, data: TranslateRequest) -> str:
9882
logger.debug(f"translating text to: {data['target_language']}")
9983

100-
with translate_context(self.config) as (tokenizer, translator):
101-
input_tokens = tokenizer.Encode(f"<2{data['target_language']}> {clean_text(data['input'])}", out_type=str)
102-
results = translator.translate_batch(
84+
try:
85+
start = perf_counter()
86+
input_tokens = self.tokenizer.Encode(
87+
f"<2{data['target_language']}> {clean_text(data['input'])}",
88+
out_type=str,
89+
)
90+
results = self.translator.translate_batch(
10391
[input_tokens],
10492
batch_type="tokens",
10593
**self.config["inference"],
@@ -109,7 +97,11 @@ def translate(self, data: TranslateRequest) -> str:
10997
raise ServiceException("Empty result returned from translator")
11098

11199
# todo: handle multiple hypotheses
112-
translation = tokenizer.Decode(results[0].hypotheses[0])
100+
translation = self.tokenizer.Decode(results[0].hypotheses[0])
101+
elapsed = perf_counter() - start
102+
logger.info(f"time taken: {elapsed:.2f}s")
103+
except Exception as e:
104+
raise ServiceException("Error translating the input text") from e
113105

114106
logger.debug(f"Translated string: {translation}")
115107
return translation

lib/main.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -122,6 +122,7 @@ async def _(request: Request, exc: Exception):
122122
def task_fetch_thread(service: Service):
123123
global app_enabled
124124

125+
service.load_model()
125126
nc = NextcloudApp()
126127
while True:
127128
if not app_enabled.is_set():

0 commit comments

Comments
 (0)