Skip to content

Commit f304402

Browse files
committed
Revert "Removing the loaded model cache in MultiSourceSeq2Seq."
This reverts commit 8f6f261.
1 parent d3875c9 commit f304402

File tree

2 files changed

+27
-4
lines changed

2 files changed

+27
-4
lines changed

roosterize/ml/naming/MultiSourceSeq2Seq.py

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -237,6 +237,8 @@ def __init__(self, model_dir: Path, model_spec: ModelSpec):
237237

238238
# Cache for processing data
239239
self.data_cache: dict = dict()
240+
# Cache for loaded model during translation
241+
self.loaded_model_cache = None
240242
return
241243

242244
def get_input(
@@ -848,7 +850,15 @@ def eval_impl(
848850
# translate.main
849851
ArgumentParser.validate_translate_opts(opt)
850852

851-
translator = MultiSourceTranslator.build_translator(self.config.get_src_types(), opt, report_score=False)
853+
# Cached model loading
854+
if self.loaded_model_cache is None:
855+
self.loaded_model_cache = MultiSourceTranslator.load_model(self.config.get_src_types(), opt)
856+
translator = MultiSourceTranslator.build_translator(
857+
self.config.get_src_types(),
858+
opt,
859+
loaded_model=self.loaded_model_cache,
860+
report_score=False,
861+
)
852862

853863
has_target = True
854864
raw_data_keys = [f"src.{src_type}" for src_type in self.config.get_src_types()] + (["tgt"] if has_target else [])

roosterize/ml/onmt/MultiSourceTranslator.py

Lines changed: 16 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -20,15 +20,27 @@
2020
from seutil import LoggingUtils
2121

2222

23+
class LoadedModel(NamedTuple):
24+
fields: any
25+
model: any
26+
model_opt: any
27+
28+
2329
class MultiSourceTranslator(CustomTranslator):
2430

2531
logger = LoggingUtils.get_logger(__name__)
2632

33+
@classmethod
34+
def load_model(cls, src_types, opt) -> LoadedModel:
35+
fields, model, model_opt = MultiSourceModelBuilder.load_test_model(src_types, opt)
36+
return LoadedModel(fields, model, model_opt)
37+
2738
@classmethod
2839
def build_translator(
2940
cls,
3041
src_types,
3142
opt,
43+
loaded_model: LoadedModel = None,
3244
report_score=True,
3345
logger=None,
3446
out_file=None,
@@ -38,9 +50,10 @@ def build_translator(
3850

3951
assert len(opt.models) == 1, "ensemble model is not supported"
4052

41-
# load_test_model = onmt.decoders.ensemble.load_test_model \
42-
# if len(opt.models) > 1 else onmt.model_builder.load_test_model
43-
fields, model, model_opt = MultiSourceModelBuilder.load_test_model(src_types, opt)
53+
if loaded_model is None:
54+
fields, model, model_opt = MultiSourceModelBuilder.load_test_model(src_types, opt)
55+
else:
56+
fields, model, model_opt = loaded_model
4457

4558
scorer = onmt.translate.GNMTGlobalScorer.from_opt(opt)
4659

0 commit comments

Comments
 (0)