Skip to content

Commit 79a10be

Browse files
authored
remove verbosity at validation/scoring (#185)
1 parent 0ac626a commit 79a10be

File tree

1 file changed

+17
-14
lines changed

1 file changed

+17
-14
lines changed

eole/utils/scoring_utils.py

Lines changed: 17 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -46,19 +46,27 @@ def translate(self, model, gpu_rank, step):
4646
# Translator #
4747
# ########## #
4848

49+
# Build translator from options
50+
model_config = self.config.model
51+
model_config._validate_model_config()
52+
4953
# This is somewhat broken and we shall remove or improve
5054
# (take 'inference' field of config if exists?)
5155
# Set "default" translation options on empty cfgfile
52-
predict_config = PredictConfig(model_path=["dummy"], src="dummy")
53-
predict_config.compute_dtype = self.config.training.compute_dtype
54-
if predict_config.transforms_configs.prefix.tgt_prefix != "":
55-
predict_config.tgt_file_prefix = True
56-
predict_config.beam_size = 1 # prevent OOM when GPU is almost full at training
57-
predict_config._validate_predict_config()
58-
# Build translator from options
56+
self.config.training.num_workers = 0
57+
predict_config = PredictConfig(
58+
model_path=["dummy"],
59+
src=self.config.data["valid"].path_src,
60+
compute_dtype=self.config.training.compute_dtype,
61+
beam_size=1,
62+
transforms=self.config.transforms,
63+
transforms_configs=self.config.transforms_configs,
64+
model=model_config,
65+
tgt_file_prefix=self.config.transforms_configs.prefix.tgt_prefix != "",
66+
gpu_ranks=[gpu_rank],
67+
)
68+
5969
scorer = GNMTGlobalScorer.from_config(predict_config)
60-
model_config = self.config.model
61-
model_config._validate_model_config()
6270
translator = Translator.from_config( # we need to review opt/config stuff in translator
6371
model,
6472
self.vocabs,
@@ -76,11 +84,6 @@ def translate(self, model, gpu_rank, step):
7684
# ################### #
7785

7886
# Reinstantiate the validation iterator
79-
self.config.training.num_workers = 0
80-
predict_config.src = self.config.data["valid"].path_src
81-
predict_config.transforms = self.config.transforms
82-
predict_config.transforms_configs = self.config.transforms_configs
83-
predict_config.model = model_config
8487
# Retrieve raw references and sources
8588
with codecs.open(self.config.data["valid"].path_tgt, "r", encoding="utf-8") as f:
8689
raw_refs = [line.strip("\n") for line in f if line.strip("\n")]

0 commit comments

Comments
 (0)