diff --git a/onmt/bin/preprocess.py b/onmt/bin/preprocess.py index 1d48a20109..699d6c456d 100755 --- a/onmt/bin/preprocess.py +++ b/onmt/bin/preprocess.py @@ -247,10 +247,9 @@ def preprocess(opt): src_nfeats = 0 tgt_nfeats = 0 - for src, tgt in zip(opt.train_src, opt.train_tgt): - src_nfeats += count_features(src) if opt.data_type == 'text' \ - else 0 - tgt_nfeats += count_features(tgt) # tgt always text so far + src_nfeats = count_features(opt.train_src[0]) if opt.data_type == 'text' \ + else 0 + tgt_nfeats = count_features(opt.train_tgt[0]) # tgt always text so far logger.info(" * number of source features: %d." % src_nfeats) logger.info(" * number of target features: %d." % tgt_nfeats) diff --git a/onmt/inputters/inputter.py b/onmt/inputters/inputter.py index d4c43342c0..ab61fce38c 100644 --- a/onmt/inputters/inputter.py +++ b/onmt/inputters/inputter.py @@ -582,20 +582,44 @@ def _pool(data, batch_size, batch_size_fn, batch_size_multiple, yield b -class OrderedIterator(torchtext.data.Iterator): +class OnmtBatch(torchtext.data.Batch): + def __init__(self, data=None, dataset=None, + device=None, feat_no_time_shift=False): + super(OnmtBatch, self).__init__(data, dataset, device) + # we need to shift target features if needed + if not(feat_no_time_shift): + if hasattr(self, 'tgt') and self.tgt.size(-1) > 1: + # tokens: [ len x batch x 1] + tokens = self.tgt[:, :, 0].unsqueeze(-1) + # feats: [ len x batch x num_feats ] + feats = self.tgt[:, :, 1:] + # shift feats one step to the right + feats = torch.cat(( + feats[-1, :, :].unsqueeze(0), + feats[:-1, :, :] + )) + # build back target tensor + self.tgt = torch.cat(( + tokens, + feats + ), dim=-1) + +class OrderedIterator(torchtext.data.Iterator): def __init__(self, dataset, batch_size, pool_factor=1, batch_size_multiple=1, yield_raw_example=False, + feat_no_time_shift=False, **kwargs): super(OrderedIterator, self).__init__(dataset, batch_size, **kwargs) self.batch_size_multiple = batch_size_multiple self.yield_raw_example = yield_raw_example self.dataset = dataset self.pool_factor = pool_factor + self.feat_no_time_shift = feat_no_time_shift def create_batches(self): if self.train: @@ -627,7 +651,7 @@ def __iter__(self): """ Extended version of the definition in torchtext.data.Iterator. Added yield_raw_example behaviour to yield a torchtext.data.Example - instead of a torchtext.data.Batch object. + instead of an OnmtBatch object. """ while True: self.init_epoch() @@ -648,10 +672,11 @@ def __iter__(self): if self.yield_raw_example: yield minibatch[0] else: - yield torchtext.data.Batch( + yield OnmtBatch( minibatch, self.dataset, - self.device) + self.device, + feat_no_time_shift=self.feat_no_time_shift) if not self.repeat: return @@ -683,6 +708,7 @@ def __init__(self, self.sort_key = temp_dataset.sort_key self.random_shuffler = RandomShuffler() self.pool_factor = opt.pool_factor + self.feat_no_time_shift = opt.feat_no_time_shift del temp_dataset def _iter_datasets(self): @@ -709,9 +735,10 @@ def __iter__(self): self.random_shuffler, self.pool_factor): minibatch = sorted(minibatch, key=self.sort_key, reverse=True) - yield torchtext.data.Batch(minibatch, - self.iterables[0].dataset, - self.device) + yield OnmtBatch(minibatch, + self.iterables[0].dataset, + self.device, + feat_no_time_shift=self.feat_no_time_shift) class DatasetLazyIter(object): @@ -729,7 +756,8 @@ class DatasetLazyIter(object): def __init__(self, dataset_paths, fields, batch_size, batch_size_fn, batch_size_multiple, device, is_train, pool_factor, - repeat=True, num_batches_multiple=1, yield_raw_example=False): + repeat=True, num_batches_multiple=1, feat_no_time_shift=False, + yield_raw_example=False): self._paths = dataset_paths self.fields = fields self.batch_size = batch_size @@ -741,6 +769,7 @@ def __init__(self, dataset_paths, fields, batch_size, batch_size_fn, self.num_batches_multiple = num_batches_multiple self.yield_raw_example = yield_raw_example self.pool_factor = pool_factor + self.feat_no_time_shift = feat_no_time_shift def _iter_dataset(self, path): logger.info('Loading dataset from %s' % path) @@ -758,7 +787,8 @@ def _iter_dataset(self, path): sort=False, sort_within_batch=True, repeat=False, - yield_raw_example=self.yield_raw_example + yield_raw_example=self.yield_raw_example, + feat_no_time_shift=self.feat_no_time_shift ) for batch in cur_iter: self.dataset = cur_iter.dataset @@ -852,7 +882,8 @@ def build_dataset_iter(corpus_type, fields, opt, is_train=True, multi=False): opt.pool_factor, repeat=not opt.single_pass, num_batches_multiple=max(opt.accum_count) * opt.world_size, - yield_raw_example=multi) + yield_raw_example=multi, + feat_no_time_shift=opt.feat_no_time_shift) def build_dataset_iter_multiple(train_shards, fields, opt): diff --git a/onmt/model_builder.py b/onmt/model_builder.py index 43cd9731c8..7cacd9e4b3 100644 --- a/onmt/model_builder.py +++ b/onmt/model_builder.py @@ -13,8 +13,7 @@ from onmt.decoders import str2dec -from onmt.modules import Embeddings, VecEmbedding, CopyGenerator -from onmt.modules.util_class import Cast +from onmt.modules import Embeddings, VecEmbedding, Generator from onmt.utils.misc import use_gpu from onmt.utils.logging import logger from onmt.utils.parse import ArgumentParser @@ -88,6 +87,35 @@ def build_decoder(opt, embeddings): return str2dec[dec_type].from_opt(opt, embeddings) +def build_generator(model_opt, fields, decoder): + gen_sizes = [len(field[1].vocab) for field in fields['tgt'].fields] + if model_opt.share_decoder_embeddings: + rnn_sizes = ([model_opt.rnn_size - + (model_opt.feat_vec_size * (len(gen_sizes) - 1))] + + [model_opt.feat_vec_size] * (len(gen_sizes) - 1)) + else: + rnn_sizes = [model_opt.rnn_size] * len(gen_sizes) + + if model_opt.generator_function == "sparsemax": + gen_func = onmt.modules.sparse_activations.LogSparsemax(dim=-1) + else: + gen_func = nn.LogSoftmax(dim=-1) + + tgt_base_field = fields["tgt"].base_field + pad_idx = tgt_base_field.vocab.stoi[tgt_base_field.pad_token] + generator = Generator(rnn_sizes, gen_sizes, gen_func, + shared=model_opt.share_decoder_embeddings, + copy_attn=model_opt.copy_attn, + pad_idx=pad_idx) + + if model_opt.share_decoder_embeddings: + # share the weights + for gen, emb in zip(generator.generators, decoder.embeddings.emb_luts): + gen[0].weight = emb.weight + + return generator + + def load_test_model(opt, model_path=None): if model_path is None: model_path = opt.models[0] @@ -172,26 +200,7 @@ def build_base_model(model_opt, fields, gpu, checkpoint=None, gpu_id=None): model = onmt.models.NMTModel(encoder, decoder) # Build Generator. - if not model_opt.copy_attn: - if model_opt.generator_function == "sparsemax": - gen_func = onmt.modules.sparse_activations.LogSparsemax(dim=-1) - else: - gen_func = nn.LogSoftmax(dim=-1) - generator = nn.Sequential( - nn.Linear(model_opt.dec_rnn_size, - len(fields["tgt"].base_field.vocab)), - Cast(torch.float32), - gen_func - ) - if model_opt.share_decoder_embeddings: - generator[0].weight = decoder.embeddings.word_lut.weight - else: - tgt_base_field = fields["tgt"].base_field - vocab_size = len(tgt_base_field.vocab) - pad_idx = tgt_base_field.vocab.stoi[tgt_base_field.pad_token] - generator = CopyGenerator(model_opt.dec_rnn_size, vocab_size, pad_idx) - if model_opt.share_decoder_embeddings: - generator.linear.weight = decoder.embeddings.word_lut.weight + generator = build_generator(model_opt, fields, decoder) # Load the model states from checkpoint or initialize them. if checkpoint is not None: diff --git a/onmt/modules/__init__.py b/onmt/modules/__init__.py index 763ac8448a..646217fa69 100644 --- a/onmt/modules/__init__.py +++ b/onmt/modules/__init__.py @@ -3,6 +3,7 @@ from onmt.modules.gate import context_gate_factory, ContextGate from onmt.modules.global_attention import GlobalAttention from onmt.modules.conv_multi_step_attention import ConvMultiStepAttention +from onmt.modules.generator import Generator from onmt.modules.copy_generator import CopyGenerator, CopyGeneratorLoss, \ CopyGeneratorLossCompute from onmt.modules.multi_headed_attn import MultiHeadedAttention @@ -13,6 +14,6 @@ __all__ = ["Elementwise", "context_gate_factory", "ContextGate", "GlobalAttention", "ConvMultiStepAttention", "CopyGenerator", - "CopyGeneratorLoss", "CopyGeneratorLossCompute", + "Generator", "CopyGeneratorLoss", "CopyGeneratorLossCompute", "MultiHeadedAttention", "Embeddings", "PositionalEncoding", "WeightNormConv2d", "AverageAttention", "VecEmbedding"] diff --git a/onmt/modules/generator.py b/onmt/modules/generator.py new file mode 100644 index 0000000000..33479a77cc --- /dev/null +++ b/onmt/modules/generator.py @@ -0,0 +1,53 @@ +""" Onmt NMT Model base class definition """ +import torch +import torch.nn as nn + +from onmt.modules.util_class import Cast + +from onmt.modules.copy_generator import CopyGenerator + + +class Generator(nn.Module): + def __init__(self, rnn_sizes, gen_sizes, gen_func, + shared=False, copy_attn=False, pad_idx=None): + super(Generator, self).__init__() + self.generators = nn.ModuleList() + self.shared = shared + self.rnn_sizes = rnn_sizes + self.gen_sizes = gen_sizes + + def simple_generator(rnn_size, gen_size, gen_func): + return nn.Sequential( + nn.Linear(rnn_size, gen_size), + Cast(torch.float32), + gen_func) + + # create first generator + if copy_attn: + self.generators.append( + CopyGenerator(rnn_sizes[0], gen_sizes[0], pad_idx)) + else: + self.generators.append( + simple_generator(rnn_sizes[0], gen_sizes[0], gen_func)) + + # additional generators for features + for rnn_size, gen_size in zip(rnn_sizes[1:], gen_sizes[1:]): + self.generators.append( + simple_generator(rnn_size, gen_size, gen_func)) + + def forward(self, dec_out): + # if shared_decoder_embeddings, we slice the decoder output + if self.shared: + outs = [] + offset = 0 + for generator, s in zip(self.generators, self.rnn_sizes): + sliced_dec_out = dec_out[:, offset:offset+s] + out = generator(sliced_dec_out) + offset += s + outs.append(out) + return outs + else: + return [generator(dec_out) for generator in self.generators] + + def __getitem__(self, i): + return self.generators[0][i] diff --git a/onmt/opts.py b/onmt/opts.py index af47f79836..0a023deb17 100644 --- a/onmt/opts.py +++ b/onmt/opts.py @@ -56,6 +56,10 @@ def model_opts(parser): help="If -feat_merge_size is not set, feature " "embedding sizes will be set to N^feat_vec_exponent " "where N is the number of values the feature takes.") + group.add('--feat_no_time_shift', '-feat_no_time_shift', + action='store_true', + help="If set, do not shift the target features one step " + "to the right.") # Encoder-Decoder Options group = parser.add_argument_group('Model- Encoder-Decoder') diff --git a/onmt/trainer.py b/onmt/trainer.py index 4328ca52ea..8024e732c8 100644 --- a/onmt/trainer.py +++ b/onmt/trainer.py @@ -31,10 +31,10 @@ def build_trainer(opt, device_id, model, fields, optim, model_saver=None): used to save the model """ - tgt_field = dict(fields)["tgt"].base_field - train_loss = onmt.utils.loss.build_loss_compute(model, tgt_field, opt) + tgt_fields = dict(fields)["tgt"] + train_loss = onmt.utils.loss.build_loss_compute(model, tgt_fields, opt) valid_loss = onmt.utils.loss.build_loss_compute( - model, tgt_field, opt, train=False) + model, tgt_fields, opt, train=False) trunc_size = opt.truncated_decoder # Badly named... shard_size = opt.max_generator_batches if opt.model_dtype == 'fp32' else 0 diff --git a/onmt/translate/beam_search.py b/onmt/translate/beam_search.py index 9e9c89f563..d715b01545 100644 --- a/onmt/translate/beam_search.py +++ b/onmt/translate/beam_search.py @@ -93,11 +93,11 @@ def __init__(self, beam_size, batch_size, pad, bos, eos, n_best, not stepwise_penalty and self.global_scorer.has_cov_pen) self._cov_pen = self.global_scorer.has_cov_pen - def initialize(self, memory_bank, src_lengths, src_map=None, device=None): + def initialize(self, memory_bank, src_lengths, num_features, + src_map=None, device=None): """Initialize for decoding. Repeat src objects `beam_size` times. """ - def fn_map_state(state, dim): return tile(state, self.beam_size, dim=dim) @@ -115,7 +115,7 @@ def fn_map_state(state, dim): self.memory_lengths = tile(src_lengths, self.beam_size) super(BeamSearch, self).initialize( - memory_bank, self.memory_lengths, src_map, device) + memory_bank, self.memory_lengths, num_features, src_map, device) self.best_scores = torch.full( [self.batch_size], -1e10, dtype=torch.float, device=device) self._beam_offset = torch.arange( @@ -135,7 +135,7 @@ def fn_map_state(state, dim): @property def current_predictions(self): - return self.alive_seq[:, -1] + return self.alive_seq[:, :, -1] @property def current_backptr(self): @@ -148,6 +148,19 @@ def batch_offset(self): return self._batch_offset def advance(self, log_probs, attn): + # we need to get the features first + if len(log_probs) > 1: + # we take top 1 for feats + features_id = [] + for logits in log_probs[1:]: + features_id.append(logits.topk(1, dim=-1)[1]) + features_id = torch.cat(features_id, dim=-1) + else: + features_id = None + + # keep only log probs for tokens + log_probs = log_probs[0] + vocab_size = log_probs.size(-1) # using integer division to get an integer _B without casting @@ -174,7 +187,7 @@ def advance(self, log_probs, attn): curr_scores = log_probs / length_penalty # Avoid any direction that would repeat unwanted ngrams - self.block_ngram_repeats(curr_scores) + self.block_ngram_repeats(curr_scores) # TODO check compat with feats # Flatten probs into a list of possibilities. curr_scores = curr_scores.reshape(_B, self.beam_size * vocab_size) @@ -192,10 +205,18 @@ def advance(self, log_probs, attn): self.select_indices = self._batch_index.view(_B * self.beam_size) self.topk_ids.fmod_(vocab_size) # resolve true word ids + # Concatenate topk_ids for tokens and feats. + if features_id is not None: + topk_ids = torch.cat(( + self.topk_ids.view(_B * self.beam_size, 1), + features_id), dim=1) + else: + topk_ids = self.topk_ids.view(_B * self.beam_size, 1) + # Append last prediction. self.alive_seq = torch.cat( [self.alive_seq.index_select(0, self.select_indices), - self.topk_ids.view(_B * self.beam_size, 1)], -1) + topk_ids.unsqueeze(-1)], -1) self.maybe_update_forbidden_tokens() @@ -239,7 +260,7 @@ def update_finished(self): # it's faster to not move this back to the original device self.is_finished = self.is_finished.to('cpu') self.top_beam_finished |= self.is_finished[:, 0].eq(1) - predictions = self.alive_seq.view(_B_old, self.beam_size, step) + predictions = self.alive_seq.view(_B_old, self.beam_size, -1, step) attention = ( self.alive_attn.view( step - 1, _B_old, self.beam_size, self.alive_attn.size(-1)) @@ -256,9 +277,12 @@ def update_finished(self): self.best_scores[b] = s self.hypotheses[b].append(( self.topk_scores[i, j], - predictions[i, j, 1:], # Ignore start_token. + predictions[i, j, 0, 1:], # Ignore start_token. attention[:, i, j, :self.memory_lengths[i]] - if attention is not None else None)) + if attention is not None else None, + [predictions[i, 0, 1+k, 1:] + for k in range(self.num_features)] + if predictions.size(-2) > 1 else None)) # End condition is the top beam finished and we can return # n_best hypotheses. if self.ratio > 0: @@ -271,13 +295,14 @@ def update_finished(self): if finish_flag and len(self.hypotheses[b]) >= self.n_best: best_hyp = sorted( self.hypotheses[b], key=lambda x: x[0], reverse=True) - for n, (score, pred, attn) in enumerate(best_hyp): + for n, (score, pred, attn, feats) in enumerate(best_hyp): if n >= self.n_best: break self.scores[b].append(score) self.predictions[b].append(pred) # ``(batch, n_best,)`` self.attention[b].append( attn if attn is not None else []) + self.features[b].append(feats if feats is not None else []) else: non_finished_batch.append(i) non_finished = torch.tensor(non_finished_batch) @@ -297,7 +322,7 @@ def update_finished(self): self._batch_index = self._batch_index.index_select(0, non_finished) self.select_indices = self._batch_index.view(_B_new * self.beam_size) self.alive_seq = predictions.index_select(0, non_finished) \ - .view(-1, self.alive_seq.size(-1)) + .view(-1, self.alive_seq.size(-2), self.alive_seq.size(-1)) self.topk_scores = self.topk_scores.index_select(0, non_finished) self.topk_ids = self.topk_ids.index_select(0, non_finished) if self.alive_attn is not None: diff --git a/onmt/translate/decode_strategy.py b/onmt/translate/decode_strategy.py index 828e8a4d51..1bff35d4eb 100644 --- a/onmt/translate/decode_strategy.py +++ b/onmt/translate/decode_strategy.py @@ -68,6 +68,8 @@ def __init__(self, pad, bos, eos, batch_size, parallel_paths, self.predictions = [[] for _ in range(batch_size)] self.scores = [[] for _ in range(batch_size)] self.attention = [[] for _ in range(batch_size)] + # initialize features + self.features = [[] for _ in range(batch_size)] self.alive_attn = None @@ -83,7 +85,8 @@ def __init__(self, pad, bos, eos, batch_size, parallel_paths, self.done = False - def initialize(self, memory_bank, src_lengths, src_map=None, device=None): + def initialize(self, memory_bank, src_lengths, num_features, + src_map=None, device=None): """DecodeStrategy subclasses should override :func:`initialize()`. `initialize` should be called before all actions. @@ -91,20 +94,23 @@ def initialize(self, memory_bank, src_lengths, src_map=None, device=None): """ if device is None: device = torch.device('cpu') + # initialize to [ batch*beam x num_feats x 1] self.alive_seq = torch.full( - [self.batch_size * self.parallel_paths, 1], self.bos, + [self.batch_size * self.parallel_paths, num_features, 1], self.bos, dtype=torch.long, device=device) self.is_finished = torch.zeros( [self.batch_size, self.parallel_paths], dtype=torch.uint8, device=device) + self.num_features = num_features - 1 # tokens are not features return None, memory_bank, src_lengths, src_map def __len__(self): - return self.alive_seq.shape[1] + return self.alive_seq.shape[-1] def ensure_min_length(self, log_probs): if len(self) <= self.min_length: - log_probs[:, self.eos] = -1e20 + for probs in log_probs: + probs[:, self.eos] = -1e20 def ensure_max_length(self): # add one to account for BOS. Don't account for EOS because hitting diff --git a/onmt/translate/greedy_search.py b/onmt/translate/greedy_search.py index 8ebef32e15..331b64858f 100644 --- a/onmt/translate/greedy_search.py +++ b/onmt/translate/greedy_search.py @@ -91,7 +91,8 @@ def __init__(self, pad, bos, eos, batch_size, min_length, self.keep_topk = keep_topk self.topk_scores = None - def initialize(self, memory_bank, src_lengths, src_map=None, device=None): + def initialize(self, memory_bank, src_lengths, num_features, + src_map=None, device=None): """Initialize for decoding.""" fn_map_state = None @@ -104,7 +105,7 @@ def initialize(self, memory_bank, src_lengths, src_map=None, device=None): self.memory_lengths = src_lengths super(GreedySearch, self).initialize( - memory_bank, src_lengths, src_map, device) + memory_bank, src_lengths, num_features, src_map, device) self.select_indices = torch.arange( self.batch_size, dtype=torch.long, device=device) self.original_batch_idx = torch.arange( @@ -113,7 +114,7 @@ def initialize(self, memory_bank, src_lengths, src_map=None, device=None): @property def current_predictions(self): - return self.alive_seq[:, -1] + return self.alive_seq[:, :, -1] @property def batch_offset(self): @@ -132,6 +133,17 @@ def advance(self, log_probs, attn): attn (FloatTensor): Shaped ``(1, B, inp_seq_len)``. """ + # we need to get the feature first + if len(log_probs) > 1: + features_id = [] + for logits in log_probs[1:]: + features_id.append(logits.topk(1, dim=-1)[1]) + features_id = torch.cat(features_id, dim=-1) + else: + features_id = None + # keep only log probs for tokens + log_probs = log_probs[0] + self.ensure_min_length(log_probs) self.block_ngram_repeats(log_probs) topk_ids, self.topk_scores = sample_with_temperature( @@ -139,7 +151,14 @@ def advance(self, log_probs, attn): self.is_finished = topk_ids.eq(self.eos) - self.alive_seq = torch.cat([self.alive_seq, topk_ids], -1) + if features_id is not None: + topk_ids = torch.cat(( + topk_ids, features_id + ), dim=-1) + + self.alive_seq = torch.cat([ + self.alive_seq, + topk_ids.unsqueeze(-1)], -1) if self.return_attention: if self.alive_attn is None: self.alive_attn = attn @@ -154,7 +173,11 @@ def update_finished(self): for b in finished_batches.view(-1): b_orig = self.original_batch_idx[b] self.scores[b_orig].append(self.topk_scores[b, 0]) - self.predictions[b_orig].append(self.alive_seq[b, 1:]) + self.predictions[b_orig].append(self.alive_seq[b, 0, 1:]) + # check on first item of the batch ot get num_features + self.features[b_orig] = [[]] + for i in range(self.num_features): + self.features[b_orig][0].append(self.alive_seq[b, 1+i, 1:]) self.attention[b_orig].append( self.alive_attn[:, b, :self.memory_lengths[b]] if self.alive_attn is not None else []) diff --git a/onmt/translate/translation.py b/onmt/translate/translation.py index 21eeb91e96..1838a98940 100644 --- a/onmt/translate/translation.py +++ b/onmt/translate/translation.py @@ -23,7 +23,7 @@ class TranslationBuilder(object): """ def __init__(self, data, fields, n_best=1, replace_unk=False, - has_tgt=False, phrase_table=""): + has_tgt=False, phrase_table="", feat_no_time_shift=False): self.data = data self.fields = fields self._has_text_src = isinstance( @@ -32,17 +32,34 @@ def __init__(self, data, fields, n_best=1, replace_unk=False, self.replace_unk = replace_unk self.phrase_table = phrase_table self.has_tgt = has_tgt - - def _build_target_tokens(self, src, src_vocab, src_raw, pred, attn): - tgt_field = dict(self.fields)["tgt"].base_field + self.feat_no_time_shift = feat_no_time_shift + + def _build_target_tokens(self, src, src_vocab, src_raw, + pred, attn, all_feats=None): + # feats need do be shifted back one step to the left + if all_feats is not None: + if not(self.feat_no_time_shift): + all_feats = [list(feat[1:]) + [feat[0]] for feat in all_feats] + pred_iter = zip(pred, *all_feats) + else: + pred_iter = [(item,) for item in pred] + tgt_fields = dict(self.fields)["tgt"] + tgt_field = tgt_fields.base_field vocab = tgt_field.vocab + feats_vocabs = [field.vocab for name, field in tgt_fields.fields[1:]] tokens = [] - for tok in pred: + for tok_feats in pred_iter: + tok = tok_feats[0] if tok < len(vocab): - tokens.append(vocab.itos[tok]) + token = vocab.itos[tok] else: - tokens.append(src_vocab.itos[tok - len(vocab)]) - if tokens[-1] == tgt_field.eos_token: + token = src_vocab.itos[tok - len(vocab)] + if len(tok_feats) > 1: + feats = tok_feats[1:] + for feat, fv in zip(feats, feats_vocabs): + token += u"│" + fv.itos[feat] + tokens.append(token) + if token.split(u"│")[0] == tgt_field.eos_token: tokens = tokens[:-1] break if self.replace_unk and attn is not None and src is not None: @@ -63,8 +80,9 @@ def from_batch(self, translation_batch): len(translation_batch["predictions"])) batch_size = batch.batch_size - preds, pred_score, attn, align, gold_score, indices = list(zip( + preds, feats, pred_score, attn, align, gold_score, indices = list(zip( *sorted(zip(translation_batch["predictions"], + translation_batch["features"], translation_batch["scores"], translation_batch["attention"], translation_batch["alignment"], @@ -96,7 +114,8 @@ def from_batch(self, translation_batch): pred_sents = [self._build_target_tokens( src[:, b] if src is not None else None, src_vocab, src_raw, - preds[b][n], attn[b][n]) + preds[b][n], attn[b][n], + feats[b][n] if len(feats[0]) > 0 else None) for n in range(self.n_best)] gold_sent = None if tgt is not None: diff --git a/onmt/translate/translator.py b/onmt/translate/translator.py index 7fe4c5c09b..42ad98c8f8 100644 --- a/onmt/translate/translator.py +++ b/onmt/translate/translator.py @@ -130,7 +130,8 @@ def __init__( report_align=False, report_score=True, logger=None, - seed=-1): + seed=-1, + feat_no_time_shift=False): self.model = model self.fields = fields tgt_field = dict(self.fields)["tgt"].base_field @@ -187,6 +188,8 @@ def __init__( self.use_filter_pred = False self._filter_pred = None + self.feat_no_time_shift = feat_no_time_shift + # for debugging self.beam_trace = self.dump_beam != "" self.beam_accum = None @@ -259,7 +262,8 @@ def from_opt( report_align=report_align, report_score=report_score, logger=logger, - seed=opt.seed) + seed=opt.seed, + feat_no_time_shift=vars(opt).get("feat_no_time_shift", False)) def _log(self, msg): if self.logger: @@ -334,7 +338,7 @@ def translate( xlation_builder = onmt.translate.TranslationBuilder( data, self.fields, self.n_best, self.replace_unk, tgt, - self.phrase_table + self.phrase_table, self.feat_no_time_shift ) # Statistics @@ -646,6 +650,7 @@ def _translate_batch_with_strategy( results = { "predictions": None, + "features": None, "scores": None, "attention": None, "batch": batch, @@ -654,15 +659,19 @@ def _translate_batch_with_strategy( enc_states, batch_size, src)} # (2) prep decode_strategy. Possibly repeat src objects. + num_features = batch.src[0].size(-1) src_map = batch.src_map if use_src_map else None fn_map_state, memory_bank, memory_lengths, src_map = \ - decode_strategy.initialize(memory_bank, src_lengths, src_map) + decode_strategy.initialize( + memory_bank, src_lengths, + num_features, src_map) if fn_map_state is not None: self.model.decoder.map_state(fn_map_state) # (3) Begin decoding step by step: for step in range(decode_strategy.max_length): - decoder_input = decode_strategy.current_predictions.view(1, -1, 1) + decoder_input = decode_strategy.current_predictions\ + .view(1, -1, num_features) log_probs, attn = self._decode_and_generate( decoder_input, @@ -674,6 +683,7 @@ def _translate_batch_with_strategy( step=step, batch_offset=decode_strategy.batch_offset) + # Note: we may have probs over several features decode_strategy.advance(log_probs, attn) any_finished = decode_strategy.is_finished.any() if any_finished: @@ -702,6 +712,7 @@ def _translate_batch_with_strategy( results["scores"] = decode_strategy.scores results["predictions"] = decode_strategy.predictions + results["features"] = decode_strategy.features results["attention"] = decode_strategy.attention if self.report_align: results["alignment"] = self._align_forward( diff --git a/onmt/utils/loss.py b/onmt/utils/loss.py index c48f0d3d21..164d7f2128 100644 --- a/onmt/utils/loss.py +++ b/onmt/utils/loss.py @@ -12,7 +12,7 @@ from onmt.modules.sparse_activations import LogSparsemax -def build_loss_compute(model, tgt_field, opt, train=True): +def build_loss_compute(model, tgt_fields, opt, train=True): """ Returns a LossCompute subclass which wraps around an nn.Module subclass (such as nn.NLLLoss) which defines the loss criterion. The LossCompute @@ -22,7 +22,7 @@ def build_loss_compute(model, tgt_field, opt, train=True): for when using a copy mechanism. """ device = torch.device("cuda" if onmt.utils.misc.use_gpu(opt) else "cpu") - + tgt_field = tgt_fields.base_field padding_idx = tgt_field.vocab.stoi[tgt_field.pad_token] unk_idx = tgt_field.vocab.stoi[tgt_field.unk_token] @@ -31,33 +31,42 @@ def build_loss_compute(model, tgt_field, opt, train=True): "order to use --lambda_coverage != 0" if opt.copy_attn: - criterion = onmt.modules.CopyGeneratorLoss( + criterions = [onmt.modules.CopyGeneratorLoss( len(tgt_field.vocab), opt.copy_attn_force, unk_index=unk_idx, ignore_index=padding_idx - ) + )] elif opt.label_smoothing > 0 and train: - criterion = LabelSmoothingLoss( + criterions = [LabelSmoothingLoss( opt.label_smoothing, len(tgt_field.vocab), ignore_index=padding_idx - ) + )] elif isinstance(model.generator[-1], LogSparsemax): - criterion = SparsemaxLoss(ignore_index=padding_idx, reduction='sum') + criterions = [SparsemaxLoss(ignore_index=padding_idx, reduction='sum')] else: - criterion = nn.NLLLoss(ignore_index=padding_idx, reduction='sum') + criterions = [nn.NLLLoss(ignore_index=padding_idx, reduction='sum')] + + # we need to add as many additional criterion as we have features + for field in tgt_fields.fields[1:]: + padding_idx = field[1].vocab.stoi[field[1].pad_token] + criterions.append( + nn.NLLLoss(ignore_index=padding_idx, reduction='sum') + ) # if the loss function operates on vectors of raw logits instead of # probabilities, only the first part of the generator needs to be # passed to the NMTLossCompute. At the moment, the only supported # loss function of this kind is the sparsemax loss. - use_raw_logits = isinstance(criterion, SparsemaxLoss) + use_raw_logits = isinstance(criterions[0], SparsemaxLoss) + # TODO make this compatible with target features !!! loss_gen = model.generator[0] if use_raw_logits else model.generator if opt.copy_attn: + # TODO make this compatible with target features... compute = onmt.modules.CopyGeneratorLossCompute( - criterion, loss_gen, tgt_field.vocab, opt.copy_loss_by_seqlength, + criterions, loss_gen, tgt_field.vocab, opt.copy_loss_by_seqlength, lambda_coverage=opt.lambda_coverage ) else: compute = NMTLossCompute( - criterion, loss_gen, lambda_coverage=opt.lambda_coverage, + criterions, loss_gen, lambda_coverage=opt.lambda_coverage, lambda_align=opt.lambda_align) compute.to(device) @@ -83,14 +92,15 @@ class LossComputeBase(nn.Module): normalzation (str): normalize by "sents" or "tokens" """ - def __init__(self, criterion, generator): + def __init__(self, criterions, generator): super(LossComputeBase, self).__init__() - self.criterion = criterion + # We may have several criterions in the case of target word features + self.criterions = criterions self.generator = generator @property def padding_idx(self): - return self.criterion.ignore_index + return self.criterions[0].ignore_index def _make_shard_state(self, batch, output, range_, attns=None): """ @@ -178,7 +188,8 @@ def _stats(self, loss, scores, target): Returns: :obj:`onmt.utils.Statistics` : statistics for this batch. """ - pred = scores.max(1)[1] + # TODO we need to add some stats for features + pred = scores[0].max(1)[1] non_padding = target.ne(self.padding_idx) num_correct = pred.eq(target).masked_select(non_padding).sum().item() num_non_padding = non_padding.sum().item() @@ -214,7 +225,7 @@ def forward(self, output, target): output (FloatTensor): batch_size x n_classes target (LongTensor): batch_size """ - model_prob = self.one_hot.repeat(target.size(0), 1) + model_prob = self.one_hot.repeat(target.size(0), 1).to(target.device) model_prob.scatter_(1, target.unsqueeze(1), self.confidence) model_prob.masked_fill_((target == self.ignore_index).unsqueeze(1), 0) @@ -226,9 +237,9 @@ class NMTLossCompute(LossComputeBase): Standard NMT Loss Computation. """ - def __init__(self, criterion, generator, normalization="sents", + def __init__(self, criterions, generator, normalization="sents", lambda_coverage=0.0, lambda_align=0.0): - super(NMTLossCompute, self).__init__(criterion, generator) + super(NMTLossCompute, self).__init__(criterions, generator) self.lambda_coverage = lambda_coverage self.lambda_align = lambda_align @@ -237,6 +248,10 @@ def _make_shard_state(self, batch, output, range_, attns=None): "output": output, "target": batch.tgt[range_[0] + 1: range_[1], :, 0], } + if batch.tgt.size(-1) > 1: + shard_state["features"] = [ + batch.tgt[range_[0] + 1: range_[1], :, i+1] + for i in range(batch.tgt.size(-1) - 1)] if self.lambda_coverage != 0.0: coverage = attns.get("coverage", None) std = attns.get("std", None) @@ -275,15 +290,20 @@ def _make_shard_state(self, batch, output, range_, attns=None): }) return shard_state - def _compute_loss(self, batch, output, target, std_attn=None, - coverage_attn=None, align_head=None, ref_align=None): + def _compute_loss(self, batch, output, target, features=None, + std_attn=None, coverage_attn=None, + align_head=None, ref_align=None): bottled_output = self._bottle(output) scores = self.generator(bottled_output) gtruth = target.view(-1) - - loss = self.criterion(scores, gtruth) + loss = self.criterions[0](scores[0], gtruth) + if features is not None: + for score, crit, feat in zip(scores[1:], + self.criterions[1:], features): + truth = feat.view(-1) + loss += crit(score, truth) if self.lambda_coverage != 0.0: coverage_loss = self._compute_coverage_loss( std_attn=std_attn, coverage_attn=coverage_attn)