Skip to content
Closed
Show file tree
Hide file tree
Changes from 11 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions data/data_features/src-test-with-feats.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
she│C is│B a│A hard-working.│B
1 change: 0 additions & 1 deletion data/data_features/src-test.feat0

This file was deleted.

3 changes: 3 additions & 0 deletions data/data_features/src-train-with-feats.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
however,│A according│A to│A the│A logs,│B she│A is│A a│A hard-working.│C
however,│A according│B to│C the│D logs,│E
she│C is│B a│A hard-working.│B
3 changes: 0 additions & 3 deletions data/data_features/src-train.feat0

This file was deleted.

1 change: 1 addition & 0 deletions data/data_features/src-val-with-feats.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
she│C is│B a│A hard-working.│B
1 change: 0 additions & 1 deletion data/data_features/src-val.feat0

This file was deleted.

1 change: 1 addition & 0 deletions data/data_features/tgt-test-with-feats.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
she│C is│B a│A hard-working.│B
1 change: 1 addition & 0 deletions data/data_features/tgt-test.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
she is a hard-working.
3 changes: 3 additions & 0 deletions data/data_features/tgt-train-with-feats.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
however,│A according│A to│A the│A logs,│B she│A is│A a│A hard-working.│C
however,│A according│B to│C the│D logs,│E
she│C is│B a│A hard-working.│B
1 change: 1 addition & 0 deletions data/data_features/tgt-val-with-feats.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
she│C is│B a│A hard-working.│B
20 changes: 20 additions & 0 deletions data/features_configs/source_and_target_features.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
# Corpus opts:
data:
corpus_1:
path_src: data/data_features/src-train-with-feats.txt
path_tgt: data/data_features/tgt-train-with-feats.txt
transforms: [inferfeats]
corpus_2:
path_src: data/data_features/src-train.txt
path_tgt: data/data_features/tgt-train.txt
transforms: [inferfeats]
valid:
path_src: data/data_features/src-val-with-feats.txt
path_tgt: data/data_features/tgt-val-with-feats.txt
transforms: [inferfeats]

# # Feats options
n_src_feats: 1
n_tgt_feats: 1
src_feats_defaults: "0"
tgt_feats_defaults: "1"
18 changes: 18 additions & 0 deletions data/features_configs/source_features_only.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
# Corpus opts:
data:
corpus_1:
path_src: data/data_features/src-train-with-feats.txt
path_tgt: data/data_features/tgt-train.txt
transforms: [inferfeats]
corpus_2:
path_src: data/data_features/src-train.txt
path_tgt: data/data_features/tgt-train.txt
transforms: [inferfeats]
valid:
path_src: data/data_features/src-val-with-feats.txt
path_tgt: data/data_features/tgt-val.txt
transforms: [inferfeats]

# # Feats options
n_src_feats: 1
src_feats_defaults: "0"
18 changes: 18 additions & 0 deletions data/features_configs/target_features_only.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
# Corpus opts:
data:
corpus_1:
path_src: data/data_features/src-train.txt
path_tgt: data/data_features/tgt-train-with-feats.txt
transforms: [inferfeats]
corpus_2:
path_src: data/data_features/src-train.txt
path_tgt: data/data_features/tgt-train.txt
transforms: [inferfeats]
valid:
path_src: data/data_features/src-val.txt
path_tgt: data/data_features/tgt-val-with-feats.txt
transforms: [inferfeats]

# # Feats options
n_tgt_feats: 1
tgt_feats_defaults: "0"
11 changes: 0 additions & 11 deletions data/features_data.yaml

This file was deleted.

88 changes: 56 additions & 32 deletions onmt/bin/build_vocab.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
from onmt.inputters.text_utils import process
from onmt.transforms import make_transforms, get_transforms_cls
from onmt.constants import CorpusName, CorpusTask
from collections import Counter, defaultdict
from collections import Counter
import multiprocessing as mp


Expand Down Expand Up @@ -42,19 +42,26 @@ def write_files_from_queues(sample_path, queues):

# Just for debugging purposes
# It appends features to subwords when dumping to file
def append_features_to_example(example, features):
ex_toks = example.split(' ')
feat_toks = features.split(' ')
toks = [f"{subword}│{feat}" for subword, feat in
zip(ex_toks, feat_toks)]
return " ".join(toks)
def append_features_to_text(text, features):
text_tok = text.split(' ')
feats_tok = [x.split(' ') for x in features]

pretty_toks = []
for tok, *feats in zip(text_tok, *feats_tok):
feats = '│'.join(feats)
if feats:
pretty_toks.append(f"{tok}│{feats}")
else:
pretty_toks.append(tok)
return " ".join(pretty_toks)


def build_sub_vocab(corpora, transforms, opts, n_sample, stride, offset):
"""Build vocab on (strided) subpart of the data."""
sub_counter_src = Counter()
sub_counter_tgt = Counter()
sub_counter_src_feats = defaultdict(Counter)
sub_counter_src_feats = [Counter() for _ in range(opts.n_src_feats)]
sub_counter_tgt_feats = [Counter() for _ in range(opts.n_tgt_feats)]
datasets_iterables = build_corpora_iters(
corpora, transforms, opts.data,
skip_empty_level=opts.skip_empty_level,
Expand All @@ -68,28 +75,35 @@ def build_sub_vocab(corpora, transforms, opts, n_sample, stride, offset):
if opts.dump_samples:
build_sub_vocab.queues[c_name][offset].put("blank")
continue
src_line, tgt_line = (maybe_example['src']['src'],
maybe_example['tgt']['tgt'])
src_line_pretty = src_line
for feat_name, feat_line in maybe_example["src"].items():
if feat_name not in ["src", "src_original"]:
sub_counter_src_feats[feat_name].update(
feat_line.split(' '))
if opts.dump_samples:
src_line_pretty = append_features_to_example(
src_line_pretty, feat_line)
src_line = maybe_example['src']['src']
tgt_line = maybe_example['tgt']['tgt']
src_feats_lines = maybe_example['src']['feats']
tgt_feats_lines = maybe_example['tgt']['feats']

sub_counter_src.update(src_line.split(' '))
sub_counter_tgt.update(tgt_line.split(' '))
for i in range(opts.n_src_feats):
sub_counter_src_feats[i].update(src_feats_lines[i].split(' '))
for i in range(opts.n_tgt_feats):
sub_counter_tgt_feats[i].update(tgt_feats_lines[i].split(' '))

if opts.dump_samples:
src_pretty_line = append_features_to_text(
src_line, src_feats_lines)
tgt_pretty_line = append_features_to_text(
tgt_line, tgt_feats_lines)
build_sub_vocab.queues[c_name][offset].put(
(i, src_line_pretty, tgt_line))
(i, src_pretty_line, tgt_pretty_line))
if n_sample > 0 and ((i+1) * stride + offset) >= n_sample:
if opts.dump_samples:
build_sub_vocab.queues[c_name][offset].put("break")
break
if opts.dump_samples:
build_sub_vocab.queues[c_name][offset].put("break")
return sub_counter_src, sub_counter_tgt, sub_counter_src_feats
return (sub_counter_src,
sub_counter_tgt,
sub_counter_src_feats,
sub_counter_tgt_feats)


def init_pool(queues):
Expand All @@ -113,7 +127,8 @@ def build_vocab(opts, transforms, n_sample=3):
corpora = get_corpora(opts, task=CorpusTask.TRAIN)
counter_src = Counter()
counter_tgt = Counter()
counter_src_feats = defaultdict(Counter)
counter_src_feats = [Counter() for _ in range(opts.n_src_feats)]
counter_tgt_feats = [Counter() for _ in range(opts.n_tgt_feats)]
from functools import partial
queues = {c_name: [mp.Queue(opts.vocab_sample_queue_size)
for i in range(opts.num_threads)]
Expand All @@ -130,14 +145,18 @@ def build_vocab(opts, transforms, n_sample=3):
func = partial(
build_sub_vocab, corpora, transforms,
opts, n_sample, opts.num_threads)
for sub_counter_src, sub_counter_tgt, sub_counter_src_feats in p.imap(
func, range(0, opts.num_threads)):
for (sub_counter_src, sub_counter_tgt,
sub_counter_src_feats, sub_counter_tgt_feats) \
in p.imap(func, range(0, opts.num_threads)):
counter_src.update(sub_counter_src)
counter_tgt.update(sub_counter_tgt)
counter_src_feats.update(sub_counter_src_feats)
for i in range(opts.n_src_feats):
counter_src_feats[i].update(sub_counter_src_feats[i])
for i in range(opts.n_tgt_feats):
counter_tgt_feats[i].update(sub_counter_tgt_feats[i])
if opts.dump_samples:
write_process.join()
return counter_src, counter_tgt, counter_src_feats
return counter_src, counter_tgt, counter_src_feats, counter_tgt_feats


def build_vocab_main(opts):
Expand All @@ -163,13 +182,16 @@ def build_vocab_main(opts):
transforms = make_transforms(opts, transforms_cls, None)

logger.info(f"Counter vocab from {opts.n_sample} samples.")
src_counter, tgt_counter, src_feats_counter = build_vocab(
(src_counter, tgt_counter,
src_feats_counter, tgt_feats_counter) = build_vocab(
opts, transforms, n_sample=opts.n_sample)

logger.info(f"Counters src:{len(src_counter)}")
logger.info(f"Counters tgt:{len(tgt_counter)}")
for feat_name, feat_counter in src_feats_counter.items():
logger.info(f"Counters {feat_name}:{len(feat_counter)}")
logger.info(f"Counters src: {len(src_counter)}")
logger.info(f"Counters tgt: {len(tgt_counter)}")
for i, feat_counter in enumerate(src_feats_counter):
logger.info(f"Counters src feat_{i}: {len(feat_counter)}")
for i, feat_counter in enumerate(tgt_feats_counter):
logger.info(f"Counters tgt feat_{i}: {len(feat_counter)}")

def save_counter(counter, save_path):
check_path(save_path, exist_ok=opts.overwrite, log=logger.warning)
Expand All @@ -186,8 +208,10 @@ def save_counter(counter, save_path):
save_counter(src_counter, opts.src_vocab)
save_counter(tgt_counter, opts.tgt_vocab)

for k, v in src_feats_counter.items():
save_counter(v, opts.src_feats_vocab[k])
for i, c in enumerate(src_feats_counter):
save_counter(c, f"{opts.src_vocab}_feat{i}")
for i, c in enumerate(tgt_feats_counter):
save_counter(c, f"{opts.tgt_vocab}_feat{i}")


def _get_parser():
Expand Down
29 changes: 23 additions & 6 deletions onmt/decoders/ensemble.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,14 +99,31 @@ def forward(self, hidden, attn=None, src_map=None):
by averaging distributions from models in the ensemble.
All models in the ensemble must share a target vocabulary.
"""
distributions = torch.stack(
[mg(h) if attn is None else mg(h, attn, src_map)
for h, mg in zip(hidden, self.model_generators)]
)

distributions, feats_distributions = [], []
n_feats = len(self.model_generators[0].feats_generators)
for h, mg in zip(hidden, self.model_generators):
scores, feats_scores = \
(mg(h) if attn is None else mg(h, attn, src_map))
distributions.append(scores)
feats_distributions.append(feats_distributions)

distributions = torch.stack(distributions)

stacked_feats_distributions = []
for i in range(n_feats):
stacked_feats_distributions.append(
torch.stack([feats_distributions[i]
for feat_distribution in feats_distributions
for i in range(n_feats)]))

if self._raw_probs:
return torch.log(torch.exp(distributions).mean(0))
return (torch.log(torch.exp(distributions).mean(0)),
[torch.log(torch.exp(d).mean(0))
for d in stacked_feats_distributions])
else:
return distributions.mean(0)
return (distributions.mean(0),
[d.mean(0) for d in stacked_feats_distributions])


class EnsembleModel(NMTModel):
Expand Down
54 changes: 37 additions & 17 deletions onmt/inputters/inputter.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,11 +34,11 @@ def build_vocab(opt, specials):
""" Build vocabs dict to be stored in the checkpoint
based on vocab files having each line [token, count]
Args:
opt: src_vocab, tgt_vocab, src_feats_vocab
opt: src_vocab, tgt_vocab, n_src_feats, n_tgt_feats
Return:
vocabs: {'src': pyonmttok.Vocab, 'tgt': pyonmttok.Vocab,
'src_feats' : {'feat0': pyonmttok.Vocab,
'feat1': pyonmttok.Vocab, ...},
'src_feats' : [pyonmttok.Vocab, pyonmttok.Vocab, ...],
'tgt_feats' : [pyonmttok.Vocab, pyonmttok.Vocab, ...],
'data_task': seq2seq or lm
}
"""
Expand Down Expand Up @@ -85,10 +85,10 @@ def _pad_vocab_to_multiple(vocab, multiple):
opt.vocab_size_multiple)
vocabs['tgt'] = tgt_vocab

if opt.src_feats_vocab:
src_feats = {}
for feat_name, filepath in opt.src_feats_vocab.items():
src_f_vocab = _read_vocab_file(filepath, 1)
if opt.n_src_feats > 0:
src_feats_vocabs = []
for i in range(opt.n_src_feats):
src_f_vocab = _read_vocab_file(f"{opt.src_vocab}_feat{i}", 1)
src_f_vocab = pyonmttok.build_vocab_from_tokens(
src_f_vocab,
maximum_size=0,
Expand All @@ -101,8 +101,27 @@ def _pad_vocab_to_multiple(vocab, multiple):
if opt.vocab_size_multiple > 1:
src_f_vocab = _pad_vocab_to_multiple(src_f_vocab,
opt.vocab_size_multiple)
src_feats[feat_name] = src_f_vocab
vocabs['src_feats'] = src_feats
src_feats_vocabs.append(src_f_vocab)
vocabs["src_feats"] = src_feats_vocabs

if opt.n_tgt_feats > 0:
tgt_feats_vocabs = []
for i in range(opt.n_tgt_feats):
tgt_f_vocab = _read_vocab_file(f"{opt.tgt_vocab}_feat{i}", 1)
tgt_f_vocab = pyonmttok.build_vocab_from_tokens(
tgt_f_vocab,
maximum_size=0,
minimum_frequency=1,
special_tokens=[DefaultTokens.UNK,
DefaultTokens.PAD,
DefaultTokens.BOS,
DefaultTokens.EOS])
tgt_f_vocab.default_id = tgt_f_vocab[DefaultTokens.UNK]
if opt.vocab_size_multiple > 1:
tgt_f_vocab = _pad_vocab_to_multiple(tgt_f_vocab,
opt.vocab_size_multiple)
tgt_feats_vocabs.append(tgt_f_vocab)
vocabs["tgt_feats"] = tgt_feats_vocabs

vocabs['data_task'] = opt.data_task

Expand Down Expand Up @@ -146,10 +165,11 @@ def vocabs_to_dict(vocabs):
vocabs_dict['src'] = vocabs['src'].ids_to_tokens
vocabs_dict['tgt'] = vocabs['tgt'].ids_to_tokens
if 'src_feats' in vocabs.keys():
vocabs_dict['src_feats'] = {}
for feat in vocabs['src_feats'].keys():
vocabs_dict['src_feats'][feat] = \
vocabs['src_feats'][feat].ids_to_tokens
vocabs_dict['src_feats'] = [feat_vocab.ids_to_tokens
for feat_vocab in vocabs['src_feats']]
if 'tgt_feats' in vocabs.keys():
vocabs_dict['tgt_feats'] = [feat_vocab.ids_to_tokens
for feat_vocab in vocabs['tgt_feats']]
vocabs_dict['data_task'] = vocabs['data_task']
return vocabs_dict

Expand All @@ -167,9 +187,9 @@ def dict_to_vocabs(vocabs_dict):
else:
vocabs['tgt'] = pyonmttok.build_vocab_from_tokens(vocabs_dict['tgt'])
if 'src_feats' in vocabs_dict.keys():
vocabs['src_feats'] = {}
for feat in vocabs_dict['src_feats'].keys():
vocabs['src_feats'][feat] = \
vocabs['src_feats'] = []
for feat_vocab in vocabs_dict['src_feats']:
vocabs['src_feats'].append(
pyonmttok.build_vocab_from_tokens(
vocabs_dict['src_feats'][feat])
feat_vocab))
return vocabs
Loading