Skip to content

Commit 8f6f261

Browse files
committed
Removing the loaded model cache in MultiSourceSeq2Seq.
1 parent 8a80da6 commit 8f6f261

File tree

2 files changed

+4
-27
lines changed

2 files changed

+4
-27
lines changed

roosterize/ml/naming/MultiSourceSeq2Seq.py

Lines changed: 1 addition & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -237,8 +237,6 @@ def __init__(self, model_dir: Path, model_spec: ModelSpec):
237237

238238
# Cache for processing data
239239
self.data_cache: dict = dict()
240-
# Cache for loaded model during translation
241-
self.loaded_model_cache = None
242240
return
243241

244242
def get_input(
@@ -848,15 +846,7 @@ def eval_impl(
848846
# translate.main
849847
ArgumentParser.validate_translate_opts(opt)
850848

851-
# Cached model loading
852-
if self.loaded_model_cache is None:
853-
self.loaded_model_cache = MultiSourceTranslator.load_model(self.config.get_src_types(), opt)
854-
translator = MultiSourceTranslator.build_translator(
855-
self.config.get_src_types(),
856-
opt,
857-
loaded_model=self.loaded_model_cache,
858-
report_score=False,
859-
)
849+
translator = MultiSourceTranslator.build_translator(self.config.get_src_types(), opt, report_score=False)
860850

861851
has_target = True
862852
raw_data_keys = [f"src.{src_type}" for src_type in self.config.get_src_types()] + (["tgt"] if has_target else [])

roosterize/ml/onmt/MultiSourceTranslator.py

Lines changed: 3 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -20,27 +20,15 @@
2020
from seutil import LoggingUtils
2121

2222

23-
class LoadedModel(NamedTuple):
24-
fields: any
25-
model: any
26-
model_opt: any
27-
28-
2923
class MultiSourceTranslator(CustomTranslator):
3024

3125
logger = LoggingUtils.get_logger(__name__)
3226

33-
@classmethod
34-
def load_model(cls, src_types, opt) -> LoadedModel:
35-
fields, model, model_opt = MultiSourceModelBuilder.load_test_model(src_types, opt)
36-
return LoadedModel(fields, model, model_opt)
37-
3827
@classmethod
3928
def build_translator(
4029
cls,
4130
src_types,
4231
opt,
43-
loaded_model: LoadedModel = None,
4432
report_score=True,
4533
logger=None,
4634
out_file=None,
@@ -50,10 +38,9 @@ def build_translator(
5038

5139
assert len(opt.models) == 1, "ensemble model is not supported"
5240

53-
if loaded_model is None:
54-
fields, model, model_opt = MultiSourceModelBuilder.load_test_model(src_types, opt)
55-
else:
56-
fields, model, model_opt = loaded_model
41+
# load_test_model = onmt.decoders.ensemble.load_test_model \
42+
# if len(opt.models) > 1 else onmt.model_builder.load_test_model
43+
fields, model, model_opt = MultiSourceModelBuilder.load_test_model(src_types, opt)
5744

5845
scorer = onmt.translate.GNMTGlobalScorer.from_opt(opt)
5946

0 commit comments

Comments
 (0)