Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 0 additions & 1 deletion eole/predict/generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -197,7 +197,6 @@ def _score_target(self, batch, enc_out, src_len):
log_probs, attn = self._decode_and_generate(
src,
None,
batch,
src_len=src_len,
)

Expand Down
2 changes: 1 addition & 1 deletion eole/predict/inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -436,7 +436,7 @@ def _process_bucket(bucket_predictions):
batch_data = self.predict_batch(batch, attn_debug)

predictions = prediction_builder.from_batch(batch_data)
is_seq2seq = hasattr(self.model, "encoder") and hasattr(self.model, "decoder")
is_seq2seq = self.model.encoder is not None and self.model.decoder is not None
if (
is_seq2seq
and self._tgt_sep_idx != self._tgt_unk_idx
Expand Down
3 changes: 2 additions & 1 deletion eole/transforms/insert_mask_before_placeholder.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,8 @@ def apply(self, example, is_train=False, stats=None, **kwargs):
if response is not None:
_src = "".join([prompt, response])
example["src"] = _src.split(" ")
example["tgt"] = _src.split(" ")
if example["tgt"] is not None:
example["tgt"] = _src.split(" ")
else:
logger.info("The mask_before could not be inserted")
return example
Expand Down
3 changes: 3 additions & 0 deletions eole/transforms/tokenize_id.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,6 +92,9 @@ def tokenize_string(self, string, side="src", is_train=False):
kwargs = {"max_length": self.max_length, "truncation": True}
else:
kwargs = {}
string = string.replace(DefaultTokens.SEP, "\n").replace(
DefaultTokens.MASK_BEFORE, self.tokenizers[side].pad_token
)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

shouldn't we handle it the same way for other tokenizers ?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I have only tested for that one (with eurollm) at this point; an error will be raised with the others.

tokens = self.tokenizers[side].encode(string, **kwargs)
return tokens

Expand Down
5 changes: 2 additions & 3 deletions eole/utils/loss.py
Original file line number Diff line number Diff line change
Expand Up @@ -255,9 +255,8 @@ def ignore_prompt(self, batch):
batch: The current batch.
"""
# Create a mask with zeros at prompt positions and ones at answer postions.
mask = batch["src"].squeeze(dim=-1) == self.padding_idx
mask = torch.cumsum(mask.int(), 1)
# Apply the mask on the target side.
mask = (batch["tgt"].squeeze(dim=-1) == self.padding_idx).cumsum(dim=1)
mask = mask >= mask.max(dim=1, keepdim=True).values
batch["tgt"] *= mask.int()
# Put the padding token index at the prompt positions.
batch["tgt"] += self.padding_idx * (1 - mask.int())
Expand Down
78 changes: 57 additions & 21 deletions eole/utils/scoring_utils.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import codecs
import os
from eole.predict import GNMTGlobalScorer, Translator
from eole.predict import GNMTGlobalScorer, Translator, GeneratorLM
from eole.config.run import (
PredictConfig,
) # probably should be done differently, but might work for now
Expand Down Expand Up @@ -42,58 +42,94 @@ def translate(self, model, gpu_rank, step):
preds (list): Detokenized predictions
texts_ref (list): Detokenized target sentences
"""
# ########## #
# Translator #
# ########## #
# ######### #
# Predictor #
# ######### #

# Build translator from options
# Build predictor from options
model_config = self.config.model
model_config._validate_model_config()

# This is somewhat broken and we shall remove or improve
# (take 'inference' field of config if exists?)
# Set "default" translation options on empty cfgfile
self.config.training.num_workers = 0
is_seq2seq = model.encoder is not None and model.decoder is not None
if not is_seq2seq:
if "insert_mask_before_placeholder" in self.config.transforms:
self.response_patterns = self.config.transforms_configs.insert_mask_before_placeholder.response_patterns
else:
self.response_patterns = None

predict_config = PredictConfig(
model_path=["dummy"],
src=self.config.data["valid"].path_src,
src="dummy",
compute_dtype=self.config.training.compute_dtype,
beam_size=1,
transforms=self.config.transforms,
transforms_configs=self.config.transforms_configs,
model=model_config,
tgt_file_prefix=self.config.transforms_configs.prefix.tgt_prefix != "",
gpu_ranks=[gpu_rank],
batch_type=self.config.training.batch_type,
batch_size=self.config.training.batch_size,
)

scorer = GNMTGlobalScorer.from_config(predict_config)
translator = Translator.from_config( # we need to review opt/config stuff in translator
model,
self.vocabs,
predict_config,
model_config,
device_id=gpu_rank,
global_scorer=scorer,
report_align=predict_config.report_align,
report_score=False,
logger=None,
)

if is_seq2seq:
predictor = Translator.from_config( # we need to review opt/config stuff in translator
model,
self.vocabs,
predict_config,
model_config,
device_id=gpu_rank,
global_scorer=scorer,
report_align=predict_config.report_align,
report_score=False,
logger=None,
)
else:
predictor = GeneratorLM.from_config(
model,
self.vocabs,
predict_config,
model_config,
device_id=gpu_rank,
global_scorer=scorer,
report_align=predict_config.report_align,
report_score=False,
logger=None,
)
Comment on lines +80 to +103
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe cleaner to just define a predictor_class in the condition, and call predictor_class(*) once, since they should have the same signature.


# ################### #
# Validation iterator #
# ################### #

# Reinstantiate the validation iterator
# Retrieve raw references and sources
with codecs.open(self.config.data["valid"].path_tgt, "r", encoding="utf-8") as f:
raw_refs = [line.strip("\n") for line in f if line.strip("\n")]
with codecs.open(self.config.data["valid"].path_src, "r", encoding="utf-8") as f:
raw_srcs = [line.strip("\n") for line in f if line.strip("\n")]

if not is_seq2seq and self.response_patterns is not None:
prompts, answers = [], []
for i, _raw_src in enumerate(raw_srcs):
for _pattern in self.response_patterns:
if len(_raw_src.split(_pattern)) == 2:
prompt, answer = _raw_src.split(_pattern)
prompts.append(prompt + _pattern)
answers.append(answer)
raw_srcs = prompts
raw_refs = answers
else:
with codecs.open(self.config.data["valid"].path_tgt, "r", encoding="utf-8") as f:
raw_refs = [line.strip("\n") for line in f if line.strip("\n")]

infer_iter = build_dynamic_dataset_iter(
predict_config,
self.transforms,
translator.vocabs,
predictor.vocabs,
src=raw_srcs,
task=CorpusTask.INFER,
tgt="", # This force to clear the target side (needed when using tgt_file_prefix)
device_id=gpu_rank,
Expand All @@ -102,7 +138,7 @@ def translate(self, model, gpu_rank, step):
# ########### #
# Predictions #
# ########### #
_, _, preds = translator._predict(
_, _, preds = predictor._predict(
infer_iter,
transform=infer_iter.transforms,
attn_debug=predict_config.attn_debug,
Expand Down
Loading