Skip to content
Open
Show file tree
Hide file tree
Changes from 5 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 0 additions & 1 deletion eole/predict/generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -197,7 +197,6 @@ def _score_target(self, batch, enc_out, src_len):
log_probs, attn = self._decode_and_generate(
src,
None,
batch,
src_len=src_len,
)

Expand Down
2 changes: 1 addition & 1 deletion eole/predict/inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -436,7 +436,7 @@ def _process_bucket(bucket_predictions):
batch_data = self.predict_batch(batch, attn_debug)

predictions = prediction_builder.from_batch(batch_data)
is_seq2seq = hasattr(self.model, "encoder") and hasattr(self.model, "decoder")
is_seq2seq = self.model.encoder is not None and self.model.decoder is not None
if (
is_seq2seq
and self._tgt_sep_idx != self._tgt_unk_idx
Expand Down
3 changes: 2 additions & 1 deletion eole/transforms/insert_mask_before_placeholder.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,8 @@ def apply(self, example, is_train=False, stats=None, **kwargs):
if response is not None:
_src = "".join([prompt, response])
example["src"] = _src.split(" ")
example["tgt"] = _src.split(" ")
if example["tgt"] is not None:
example["tgt"] = _src.split(" ")
else:
logger.info("The mask_before could not be inserted")
return example
Expand Down
3 changes: 3 additions & 0 deletions eole/transforms/tokenize_id.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,6 +92,9 @@ def tokenize_string(self, string, side="src", is_train=False):
kwargs = {"max_length": self.max_length, "truncation": True}
else:
kwargs = {}
string = string.replace(DefaultTokens.SEP, "\n").replace(
DefaultTokens.MASK_BEFORE, self.tokenizers[side].pad_token
)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

shouldn't we handle it the same way for other tokenizers ?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I have only tested for that one (with eurollm) at this point; an error will be raised with the others.

tokens = self.tokenizers[side].encode(string, **kwargs)
return tokens

Expand Down
8 changes: 5 additions & 3 deletions eole/utils/loss.py
Original file line number Diff line number Diff line change
Expand Up @@ -255,9 +255,11 @@ def ignore_prompt(self, batch):
batch: The current batch.
"""
# Create a mask with zeros at prompt positions and ones at answer postions.
mask = batch["src"].squeeze(dim=-1) == self.padding_idx
mask = torch.cumsum(mask.int(), 1)
# Apply the mask on the target side.
mask = batch["tgt"].squeeze(dim=-1) == self.padding_idx
mask = mask.cumsum(dim=1)
row_max = mask.max(dim=1, keepdim=True).values
mask = torch.where(mask < row_max, 0, mask)
mask = torch.where(mask >= row_max, 1, mask)
batch["tgt"] *= mask.int()
# Put the padding token index at the prompt positions.
batch["tgt"] += self.padding_idx * (1 - mask.int())
Expand Down
70 changes: 55 additions & 15 deletions eole/utils/scoring_utils.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import codecs
import os
from eole.predict import GNMTGlobalScorer, Translator
from eole.predict import GNMTGlobalScorer, Translator, GeneratorLM
from eole.config.run import (
PredictConfig,
) # probably should be done differently, but might work for now
Expand Down Expand Up @@ -54,9 +54,17 @@ def translate(self, model, gpu_rank, step):
# (take 'inference' field of config if exists?)
# Set "default" translation options on empty cfgfile
self.config.training.num_workers = 0
is_seq2seq = model.encoder is not None and model.decoder is not None
print("is_seq2seq", is_seq2seq)
if not is_seq2seq:
if "insert_mask_before_placeholder" in self.config.transforms:
self.response_patterns = self.config.transforms_configs.insert_mask_before_placeholder.response_patterns
else:
self.response_patterns = None

predict_config = PredictConfig(
model_path=["dummy"],
src=self.config.data["valid"].path_src,
src="dummy",
compute_dtype=self.config.training.compute_dtype,
beam_size=1,
transforms=self.config.transforms,
Expand All @@ -67,33 +75,60 @@ def translate(self, model, gpu_rank, step):
)

scorer = GNMTGlobalScorer.from_config(predict_config)
translator = Translator.from_config( # we need to review opt/config stuff in translator
model,
self.vocabs,
predict_config,
model_config,
device_id=gpu_rank,
global_scorer=scorer,
report_align=predict_config.report_align,
report_score=False,
logger=None,
)

if is_seq2seq:
translator = Translator.from_config( # we need to review opt/config stuff in translator
model,
self.vocabs,
predict_config,
model_config,
device_id=gpu_rank,
global_scorer=scorer,
report_align=predict_config.report_align,
report_score=False,
logger=None,
)
else:
translator = GeneratorLM.from_config( # we need to review opt/config stuff in translator
model,
self.vocabs,
predict_config,
model_config,
device_id=gpu_rank,
global_scorer=scorer,
report_align=predict_config.report_align,
report_score=False,
logger=None,
)

# ################### #
# Validation iterator #
# ################### #

# Reinstantiate the validation iterator
# Retrieve raw references and sources
with codecs.open(self.config.data["valid"].path_tgt, "r", encoding="utf-8") as f:
raw_refs = [line.strip("\n") for line in f if line.strip("\n")]
with codecs.open(self.config.data["valid"].path_src, "r", encoding="utf-8") as f:
raw_srcs = [line.strip("\n") for line in f if line.strip("\n")]

if not is_seq2seq and self.response_patterns is not None:
prompts, answers = [], []
for i, _raw_src in enumerate(raw_srcs):
for _pattern in self.response_patterns:
if len(_raw_src.split(_pattern)) == 2:
prompt, answer = _raw_src.split(_pattern)
prompts.append(prompt + _pattern)
answers.append(answer)
raw_srcs = prompts
raw_refs = answers
else:
with codecs.open(self.config.data["valid"].path_tgt, "r", encoding="utf-8") as f:
raw_refs = [line.strip("\n") for line in f if line.strip("\n")]

infer_iter = build_dynamic_dataset_iter(
predict_config,
self.transforms,
translator.vocabs,
src=raw_srcs,
task=CorpusTask.INFER,
tgt="", # This force to clear the target side (needed when using tgt_file_prefix)
device_id=gpu_rank,
Expand All @@ -102,6 +137,11 @@ def translate(self, model, gpu_rank, step):
# ########### #
# Predictions #
# ########### #
if not is_seq2seq:
translator.id_tokenization = True
# In PredictionBuilder, the case "id_tokenization = False" (default) is not properly handled
# and the apply_verse method of the huggingface_tokenize transform
# does not handle lists of strings (only list of integers).
_, _, preds = translator._predict(
infer_iter,
transform=infer_iter.transforms,
Expand Down
Loading