diff --git a/eole/bin/tools/LM_scoring.py b/eole/bin/tools/LM_scoring.py index 37dafe469..5eb2a9033 100644 --- a/eole/bin/tools/LM_scoring.py +++ b/eole/bin/tools/LM_scoring.py @@ -8,12 +8,12 @@ from eole.inputters.dynamic_iterator import build_dynamic_dataset_iter from eole.utils.loss import LossCompute from eole.constants import DefaultTokens, CorpusTask -from eole.transforms import get_transforms_cls +from eole.transforms import get_transforms_cls, make_transforms +from eole.models.model import BaseModel from argparse import ArgumentParser from eole.bin import BaseBin, register_bin from eole.config.cli import add_model -from eole.config import get_non_default_values from eole.config.run import PredictConfig """ @@ -21,18 +21,13 @@ For this purpose we use the same pipeline as the validation of a file Below is an example of settings of a config.yaml file -model: lm-de.news2021_step_100000.pt -src: newstest2014-ref.de -tgt: newstest2014-ref.de -transforms: [onmt_tokenize] -batch_size: 16 -gpu: 0 -src_subword_type: bpe -src_subword_model: subwords.en_de.bpe -src_eoletok_kwargs: '{"mode": "aggressive"}' -tgt_subword_type: bpe -tgt_subword_model: subwords.en_de.bpe -tgt_eoletok_kwargs: '{"mode": "aggressive"}' +verbose: false +world_size: 1 +gpu_ranks: [0] +# use symlinks to last saved step +model_path: data/wikitext/wikitext-103-raw-v1/run/model-lm +src: data/wikitext/wikitext-103-raw-v1/lm_input.txt +output: data/wikitext/wikitext-103-raw-v1/lm_pred.txt Output is the data and tab separated score use the -output setting for preds + scores @@ -44,13 +39,7 @@ class LMScoring(BaseBin): @classmethod def add_args(cls, parser): - parser.add_argument( - "-config", - "--config", - "-c", - required=False, - help="Path of main YAML config file.", - ) + parser.add_argument("-config", "--config", "-c", required=False, help="Path of main YAML config file.") @classmethod def run(cls, args): @@ -62,9 +51,6 @@ def run(cls, args): config = {} _parser = ArgumentParser() add_model(_parser, PredictConfig) - defaults = vars(_parser.parse_args([])) - stuff_to_update = get_non_default_values(args, defaults) - config.update(stuff_to_update) config = PredictConfig(**config) init_logger(config.log_file) set_random_seed(config.seed, False) @@ -75,26 +61,27 @@ def run(cls, args): if len(config.gpu_ranks) > 1: logger.warning(f"gpu_ranks is {str(config.gpu_ranks)} but only the first one will be used.") - vocabs, model, model_opt = config.model.model_class.load_test_model(config) + vocabs, model, model_opt = BaseModel.load_test_model(config, device.index) pad_token = vocabs["specials"].get("pad_token", DefaultTokens.PAD) - padding_idx = vocabs["tgt"][pad_token] + padding_idx = vocabs["tgt"].tokens_to_ids[pad_token] criterion = torch.nn.CrossEntropyLoss(ignore_index=padding_idx, reduction="none") valid_loss = LossCompute( criterion, model.generator, tgt_shift_index=0, - lambda_coverage=model_opt.lambda_coverage, - lambda_align=model_opt.lambda_align, + lambda_coverage=model_opt.decoder.lambda_coverage, + lambda_align=model_opt.decoder.lambda_align, + vocabs=vocabs, ) valid_loss.to(device) - transforms_cls = get_transforms_cls(config._all_transform) + transforms_cls = make_transforms(config, transforms_cls, vocabs) + + # if tgt is not precised in the inference config file, used from src + if config.tgt is None: + config.tgt = config.src infer_iter = build_dynamic_dataset_iter( - args, - transforms_cls, - vocabs, - task=CorpusTask.INFER, - device_id=config.gpu, + config, transforms_cls, vocabs, task=CorpusTask.INFER, device_id=device.index ) model.to(device) @@ -110,30 +97,28 @@ def run(cls, args): src = batch["src"] src_len = batch["srclen"] # print(batch) - outputs, attns = model(src, None, src_len, with_align=False) + outputs, attns, _ = model(src, None, src_len, with_align=False) # Compute and retrieve the loss for EACH sentence - loss, _ = valid_loss(batch, outputs, attns) + loss, _, _ = valid_loss(batch, outputs, attns) loss = loss.view(batch_size, -1) # (B, T) - losspertoken = loss.sum(1) / batch["tgt"][:, 1:, 0].ne(padding_idx).sum(1) + losspertoken = loss.sum(1) / batch["tgt"][:, 1:].ne(padding_idx).sum(1) ppl = torch.exp(losspertoken) cumul_loss += loss.sum().item() - cumul_length += batch["tgt"][:, 1:, 0].ne(padding_idx).sum().cpu() + cumul_length += batch["tgt"][:, 1:].ne(padding_idx).sum().cpu() # Now we need to rearrange the batch of ppl # in the original order with indices sent_ppl_orig = ppl.gather( 0, torch.tensor( - sorted( - range(len(batch["cid_line_number"])), - key=lambda k: batch["cid_line_number"][k], - ), + sorted(range(len(batch["cid_line_number"])), key=lambda k: batch["cid_line_number"][k]), device=ppl.device, ), ) for j in range(batch_size): ppl_file.write(str(sent_ppl_orig[j].item()) + "\n") logger.info( - "Loss: %.2f Tokens: %d Corpus PPL: %.2f" % (cumul_loss, cumul_length, np.exp(cumul_loss / cumul_length)) + "Loss: %.2f Tokens: %d Corpus PPL: %.2f" + % (cumul_loss / cumul_length.item(), cumul_length, np.exp(cumul_loss / cumul_length)) ) ppl_file.close()