diff --git a/eole/predict/generator.py b/eole/predict/generator.py index 69c83cc78..7fbff455f 100644 --- a/eole/predict/generator.py +++ b/eole/predict/generator.py @@ -197,7 +197,6 @@ def _score_target(self, batch, enc_out, src_len): log_probs, attn = self._decode_and_generate( src, None, - batch, src_len=src_len, ) diff --git a/eole/predict/inference.py b/eole/predict/inference.py index b48db1a7a..c3bf37145 100644 --- a/eole/predict/inference.py +++ b/eole/predict/inference.py @@ -436,7 +436,7 @@ def _process_bucket(bucket_predictions): batch_data = self.predict_batch(batch, attn_debug) predictions = prediction_builder.from_batch(batch_data) - is_seq2seq = hasattr(self.model, "encoder") and hasattr(self.model, "decoder") + is_seq2seq = self.model.encoder is not None and self.model.decoder is not None if ( is_seq2seq and self._tgt_sep_idx != self._tgt_unk_idx diff --git a/eole/transforms/insert_mask_before_placeholder.py b/eole/transforms/insert_mask_before_placeholder.py index 9f2c12518..1fe434c99 100755 --- a/eole/transforms/insert_mask_before_placeholder.py +++ b/eole/transforms/insert_mask_before_placeholder.py @@ -38,7 +38,8 @@ def apply(self, example, is_train=False, stats=None, **kwargs): if response is not None: _src = "".join([prompt, response]) example["src"] = _src.split(" ") - example["tgt"] = _src.split(" ") + if example["tgt"] is not None: + example["tgt"] = _src.split(" ") else: logger.info("The mask_before could not be inserted") return example diff --git a/eole/transforms/tokenize_id.py b/eole/transforms/tokenize_id.py index 7af1e0ad8..1542926f4 100644 --- a/eole/transforms/tokenize_id.py +++ b/eole/transforms/tokenize_id.py @@ -92,6 +92,9 @@ def tokenize_string(self, string, side="src", is_train=False): kwargs = {"max_length": self.max_length, "truncation": True} else: kwargs = {} + string = string.replace(DefaultTokens.SEP, "\n").replace( + DefaultTokens.MASK_BEFORE, self.tokenizers[side].pad_token + ) tokens = self.tokenizers[side].encode(string, **kwargs) return tokens diff --git a/eole/utils/loss.py b/eole/utils/loss.py index 59cdabaf7..74ca3513c 100644 --- a/eole/utils/loss.py +++ b/eole/utils/loss.py @@ -255,9 +255,8 @@ def ignore_prompt(self, batch): batch: The current batch. """ # Create a mask with zeros at prompt positions and ones at answer postions. - mask = batch["src"].squeeze(dim=-1) == self.padding_idx - mask = torch.cumsum(mask.int(), 1) - # Apply the mask on the target side. + mask = (batch["tgt"].squeeze(dim=-1) == self.padding_idx).cumsum(dim=1) + mask = mask >= mask.max(dim=1, keepdim=True).values batch["tgt"] *= mask.int() # Put the padding token index at the prompt positions. batch["tgt"] += self.padding_idx * (1 - mask.int()) diff --git a/eole/utils/scoring_utils.py b/eole/utils/scoring_utils.py index b0c63efa0..cee6f0729 100644 --- a/eole/utils/scoring_utils.py +++ b/eole/utils/scoring_utils.py @@ -1,6 +1,6 @@ import codecs import os -from eole.predict import GNMTGlobalScorer, Translator +from eole.predict import GNMTGlobalScorer, Translator, GeneratorLM from eole.config.run import ( PredictConfig, ) # probably should be done differently, but might work for now @@ -42,11 +42,11 @@ def translate(self, model, gpu_rank, step): preds (list): Detokenized predictions texts_ref (list): Detokenized target sentences """ - # ########## # - # Translator # - # ########## # + # ######### # + # Predictor # + # ######### # - # Build translator from options + # Build predictor from options model_config = self.config.model model_config._validate_model_config() @@ -54,9 +54,16 @@ def translate(self, model, gpu_rank, step): # (take 'inference' field of config if exists?) # Set "default" translation options on empty cfgfile self.config.training.num_workers = 0 + is_seq2seq = model.encoder is not None and model.decoder is not None + if not is_seq2seq: + if "insert_mask_before_placeholder" in self.config.transforms: + self.response_patterns = self.config.transforms_configs.insert_mask_before_placeholder.response_patterns + else: + self.response_patterns = None + predict_config = PredictConfig( model_path=["dummy"], - src=self.config.data["valid"].path_src, + src="dummy", compute_dtype=self.config.training.compute_dtype, beam_size=1, transforms=self.config.transforms, @@ -64,20 +71,36 @@ def translate(self, model, gpu_rank, step): model=model_config, tgt_file_prefix=self.config.transforms_configs.prefix.tgt_prefix != "", gpu_ranks=[gpu_rank], + batch_type=self.config.training.batch_type, + batch_size=self.config.training.batch_size, ) scorer = GNMTGlobalScorer.from_config(predict_config) - translator = Translator.from_config( # we need to review opt/config stuff in translator - model, - self.vocabs, - predict_config, - model_config, - device_id=gpu_rank, - global_scorer=scorer, - report_align=predict_config.report_align, - report_score=False, - logger=None, - ) + + if is_seq2seq: + predictor = Translator.from_config( # we need to review opt/config stuff in translator + model, + self.vocabs, + predict_config, + model_config, + device_id=gpu_rank, + global_scorer=scorer, + report_align=predict_config.report_align, + report_score=False, + logger=None, + ) + else: + predictor = GeneratorLM.from_config( + model, + self.vocabs, + predict_config, + model_config, + device_id=gpu_rank, + global_scorer=scorer, + report_align=predict_config.report_align, + report_score=False, + logger=None, + ) # ################### # # Validation iterator # @@ -85,15 +108,28 @@ def translate(self, model, gpu_rank, step): # Reinstantiate the validation iterator # Retrieve raw references and sources - with codecs.open(self.config.data["valid"].path_tgt, "r", encoding="utf-8") as f: - raw_refs = [line.strip("\n") for line in f if line.strip("\n")] with codecs.open(self.config.data["valid"].path_src, "r", encoding="utf-8") as f: raw_srcs = [line.strip("\n") for line in f if line.strip("\n")] + if not is_seq2seq and self.response_patterns is not None: + prompts, answers = [], [] + for i, _raw_src in enumerate(raw_srcs): + for _pattern in self.response_patterns: + if len(_raw_src.split(_pattern)) == 2: + prompt, answer = _raw_src.split(_pattern) + prompts.append(prompt + _pattern) + answers.append(answer) + raw_srcs = prompts + raw_refs = answers + else: + with codecs.open(self.config.data["valid"].path_tgt, "r", encoding="utf-8") as f: + raw_refs = [line.strip("\n") for line in f if line.strip("\n")] + infer_iter = build_dynamic_dataset_iter( predict_config, self.transforms, - translator.vocabs, + predictor.vocabs, + src=raw_srcs, task=CorpusTask.INFER, tgt="", # This force to clear the target side (needed when using tgt_file_prefix) device_id=gpu_rank, @@ -102,7 +138,7 @@ def translate(self, model, gpu_rank, step): # ########### # # Predictions # # ########### # - _, _, preds = translator._predict( + _, _, preds = predictor._predict( infer_iter, transform=infer_iter.transforms, attn_debug=predict_config.attn_debug,