@@ -46,19 +46,27 @@ def translate(self, model, gpu_rank, step):
4646 # Translator #
4747 # ########## #
4848
49+ # Build translator from options
50+ model_config = self .config .model
51+ model_config ._validate_model_config ()
52+
4953 # This is somewhat broken and we shall remove or improve
5054 # (take 'inference' field of config if exists?)
5155 # Set "default" translation options on empty cfgfile
52- predict_config = PredictConfig (model_path = ["dummy" ], src = "dummy" )
53- predict_config .compute_dtype = self .config .training .compute_dtype
54- if predict_config .transforms_configs .prefix .tgt_prefix != "" :
55- predict_config .tgt_file_prefix = True
56- predict_config .beam_size = 1 # prevent OOM when GPU is almost full at training
57- predict_config ._validate_predict_config ()
58- # Build translator from options
56+ self .config .training .num_workers = 0
57+ predict_config = PredictConfig (
58+ model_path = ["dummy" ],
59+ src = self .config .data ["valid" ].path_src ,
60+ compute_dtype = self .config .training .compute_dtype ,
61+ beam_size = 1 ,
62+ transforms = self .config .transforms ,
63+ transforms_configs = self .config .transforms_configs ,
64+ model = model_config ,
65+ tgt_file_prefix = self .config .transforms_configs .prefix .tgt_prefix != "" ,
66+ gpu_ranks = [gpu_rank ],
67+ )
68+
5969 scorer = GNMTGlobalScorer .from_config (predict_config )
60- model_config = self .config .model
61- model_config ._validate_model_config ()
6270 translator = Translator .from_config ( # we need to review opt/config stuff in translator
6371 model ,
6472 self .vocabs ,
@@ -76,11 +84,6 @@ def translate(self, model, gpu_rank, step):
7684 # ################### #
7785
7886 # Reinstantiate the validation iterator
79- self .config .training .num_workers = 0
80- predict_config .src = self .config .data ["valid" ].path_src
81- predict_config .transforms = self .config .transforms
82- predict_config .transforms_configs = self .config .transforms_configs
83- predict_config .model = model_config
8487 # Retrieve raw references and sources
8588 with codecs .open (self .config .data ["valid" ].path_tgt , "r" , encoding = "utf-8" ) as f :
8689 raw_refs = [line .strip ("\n " ) for line in f if line .strip ("\n " )]
0 commit comments