Skip to content
Merged
Show file tree
Hide file tree
Changes from 9 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions data/data_features/src-train.feat0
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
A A A A B A A A C
A B C D E
C B A B
3 changes: 3 additions & 0 deletions data/data_features/src-train.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
however, according to the logs, she is a hard-working.
however, according to the logs,
she is a hard-working.
1 change: 1 addition & 0 deletions data/data_features/src-val.feat0
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
C B A B
1 change: 1 addition & 0 deletions data/data_features/src-val.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
she is a hard-working.
3 changes: 3 additions & 0 deletions data/data_features/tgt-train.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
however, according to the logs, she is a hard-working.
however, according to the logs,
she is a hard-working.
1 change: 1 addition & 0 deletions data/data_features/tgt-val.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
she is a hard-working.
11 changes: 11 additions & 0 deletions data/features_data.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
# Corpus opts:
data:
corpus_1:
path_src: data/data_features/src-train.txt
path_tgt: data/data_features/tgt-train.txt
src_feats:
feat0: data/data_features/src-train.feat0
transforms: [filterfeats, inferfeats]
valid:
path_src: data/data_features/src-val.txt
path_tgt: data/data_features/tgt-val.txt
51 changes: 51 additions & 0 deletions docs/source/FAQ.md
Original file line number Diff line number Diff line change
Expand Up @@ -477,3 +477,54 @@ Training options to perform vocabulary update are:
* `-update_vocab`: set this option
* `-reset_optim`: set the value to "states"
* `-train_from`: checkpoint path


## How can I use source word features?

Extra information can be added to the words in the source sentences by defining word features.

Features should be defined in a separate file using blank spaces as a separator and with each row corresponding to a source sentence. An example of the input files:

data.src
```
however, according to the logs, she is hard-working.
```

feat0.txt
```
A C C C C A A B
```

**Notes**
- Prior tokenization is not necessary, features will be inferred by using the `FeatInferTransform` transform.
- `FilterFeatsTransform` and `FeatInferTransform` are required in order to ensure the functionality.

Sample config file:

```
data:
dummy:
path_src: data/train/data.src
path_tgt: data/train/data.tgt
src_feats:
feat_0: data/train/data.src.feat_0
feat_1: data/train/data.src.feat_1
transforms: [filterfeats, onmt_tokenize, inferfeats, filtertoolong]
weight: 1
valid:
path_src: data/valid/data.src
path_tgt: data/valid/data.tgt
src_feats:
feat_0: data/valid/data.src.feat_0
feat_1: data/valid/data.src.feat_1
transforms: [filterfeats, onmt_tokenize, inferfeats]

# # Vocab opts
src_vocab: exp/data.vocab.src
tgt_vocab: exp/data.vocab.tgt
src_feats_vocab:
feat_0: exp/data.vocab.feat_0
feat_1: exp/data.vocab.feat_1
feat_merge: "sum"

```
7 changes: 6 additions & 1 deletion onmt/bin/build_vocab.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,11 +32,13 @@ def build_vocab_main(opts):
transforms = make_transforms(opts, transforms_cls, fields)

logger.info(f"Counter vocab from {opts.n_sample} samples.")
src_counter, tgt_counter = build_vocab(
src_counter, tgt_counter, src_feats_counter = build_vocab(
opts, transforms, n_sample=opts.n_sample)

logger.info(f"Counters src:{len(src_counter)}")
logger.info(f"Counters tgt:{len(tgt_counter)}")
for feat_name, feat_counter in src_feats_counter.items():
logger.info(f"Counters {feat_name}:{len(feat_counter)}")

def save_counter(counter, save_path):
check_path(save_path, exist_ok=opts.overwrite, log=logger.warning)
Expand All @@ -52,6 +54,9 @@ def save_counter(counter, save_path):
else:
save_counter(src_counter, opts.src_vocab)
save_counter(tgt_counter, opts.tgt_vocab)

for k, v in src_feats_counter.items():
save_counter(v, opts.src_feats_vocab[k])


def _get_parser():
Expand Down
1 change: 1 addition & 0 deletions onmt/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ class CorpusName(object):
class SubwordMarker(object):
SPACER = '▁'
JOINER = '■'
CASE_MARKUP = ["⦅mrk_case_modifier_C⦆", "⦅mrk_begin_case_region_U⦆", "⦅mrk_end_case_region_U⦆"]


class ModelTask(object):
Expand Down
46 changes: 37 additions & 9 deletions onmt/inputters/corpus.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
from torchtext.data import Dataset as TorchtextDataset, \
Example as TorchtextExample

from collections import Counter
from collections import Counter, defaultdict
from contextlib import contextmanager

import multiprocessing as mp
Expand Down Expand Up @@ -74,6 +74,9 @@ def _process(item, is_train):
maybe_example['tgt'] = ' '.join(maybe_example['tgt'])
if 'align' in maybe_example:
maybe_example['align'] = ' '.join(maybe_example['align'])
if 'src_feats' in maybe_example:
for k in maybe_example['src_feats'].keys():
maybe_example['src_feats'][k] = ' '.join(maybe_example['src_feats'][k])
return maybe_example

def _maybe_add_dynamic_dict(self, example, fields):
Expand Down Expand Up @@ -107,23 +110,32 @@ def __call__(self, bucket):
class ParallelCorpus(object):
"""A parallel corpus file pair that can be loaded to iterate."""

def __init__(self, name, src, tgt, align=None):
def __init__(self, name, src, tgt, align=None, src_feats=None):
"""Initialize src & tgt side file path."""
self.id = name
self.src = src
self.tgt = tgt
self.align = align
self.src_feats = src_feats

def load(self, offset=0, stride=1):
"""
Load file and iterate by lines.
`offset` and `stride` allow to iterate only on every
`stride` example, starting from `offset`.
"""
if self.src_feats:
features_names = []
features_files = []
for feat_name, feat_path in self.src_feats.items():
features_names.append(feat_name)
features_files.append(open(feat_path, mode='rb'))
else:
features_files = []
with exfile_open(self.src, mode='rb') as fs,\
exfile_open(self.tgt, mode='rb') as ft,\
exfile_open(self.align, mode='rb') as fa:
for i, (sline, tline, align) in enumerate(zip(fs, ft, fa)):
for i, (sline, tline, align, *features) in enumerate(zip(fs, ft, fa, *features_files)):
if (i % stride) == offset:
sline = sline.decode('utf-8')
tline = tline.decode('utf-8')
Expand All @@ -133,12 +145,18 @@ def load(self, offset=0, stride=1):
}
if align is not None:
example['align'] = align.decode('utf-8')
if features:
example["src_feats"] = dict()
for j, feat in enumerate(features):
example["src_feats"][features_names[j]] = feat.decode("utf-8")
yield example
for f in features_files:
f.close()

def __str__(self):
cls_name = type(self).__name__
return '{}({}, {}, align={})'.format(
cls_name, self.src, self.tgt, self.align)
return '{}({}, {}, align={}, src_feats={})'.format(
cls_name, self.src, self.tgt, self.align, self.src_feats)


def get_corpora(opts, is_train=False):
Expand All @@ -150,7 +168,8 @@ def get_corpora(opts, is_train=False):
corpus_id,
corpus_dict["path_src"],
corpus_dict["path_tgt"],
corpus_dict["path_align"])
corpus_dict["path_align"],
corpus_dict["src_feats"])
else:
if CorpusName.VALID in opts.data.keys():
corpora_dict[CorpusName.VALID] = ParallelCorpus(
Expand Down Expand Up @@ -193,6 +212,9 @@ def _tokenize(self, stream):
example['src'], example['tgt'] = src, tgt
if 'align' in example:
example['align'] = example['align'].strip('\n').split()
if 'src_feats' in example:
for k in example['src_feats'].keys():
example['src_feats'][k] = example['src_feats'][k].strip('\n').split()
yield example

def _transform(self, stream):
Expand Down Expand Up @@ -286,6 +308,7 @@ def build_sub_vocab(corpora, transforms, opts, n_sample, stride, offset):
"""Build vocab on (strided) subpart of the data."""
sub_counter_src = Counter()
sub_counter_tgt = Counter()
sub_counter_src_feats = defaultdict(Counter)
datasets_iterables = build_corpora_iters(
corpora, transforms, opts.data,
skip_empty_level=opts.skip_empty_level,
Expand All @@ -298,6 +321,9 @@ def build_sub_vocab(corpora, transforms, opts, n_sample, stride, offset):
build_sub_vocab.queues[c_name][offset].put("blank")
continue
src_line, tgt_line = maybe_example['src'], maybe_example['tgt']
if 'src_feats' in maybe_example:
for feat_name, feat_line in maybe_example["src_feats"].items():
sub_counter_src_feats[feat_name].update(feat_line.split(' '))
sub_counter_src.update(src_line.split(' '))
sub_counter_tgt.update(tgt_line.split(' '))
if opts.dump_samples:
Expand All @@ -309,7 +335,7 @@ def build_sub_vocab(corpora, transforms, opts, n_sample, stride, offset):
break
if opts.dump_samples:
build_sub_vocab.queues[c_name][offset].put("break")
return sub_counter_src, sub_counter_tgt
return sub_counter_src, sub_counter_tgt, sub_counter_src_feats


def init_pool(queues):
Expand All @@ -333,6 +359,7 @@ def build_vocab(opts, transforms, n_sample=3):
corpora = get_corpora(opts, is_train=True)
counter_src = Counter()
counter_tgt = Counter()
counter_src_feats = defaultdict(Counter)
from functools import partial
queues = {c_name: [mp.Queue(opts.vocab_sample_queue_size)
for i in range(opts.num_threads)]
Expand All @@ -349,13 +376,14 @@ def build_vocab(opts, transforms, n_sample=3):
func = partial(
build_sub_vocab, corpora, transforms,
opts, n_sample, opts.num_threads)
for sub_counter_src, sub_counter_tgt in p.imap(
for sub_counter_src, sub_counter_tgt, sub_counter_src_feats in p.imap(
func, range(0, opts.num_threads)):
counter_src.update(sub_counter_src)
counter_tgt.update(sub_counter_tgt)
counter_src_feats.update(sub_counter_src_feats)
if opts.dump_samples:
write_process.join()
return counter_src, counter_tgt
return counter_src, counter_tgt, counter_src_feats


def save_transformed_sample(opts, transforms, n_sample=3):
Expand Down
13 changes: 9 additions & 4 deletions onmt/inputters/fields.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,11 +8,10 @@


def _get_dynamic_fields(opts):
# NOTE: not support nfeats > 0 yet
src_nfeats = 0
tgt_nfeats = 0
# NOTE: not support tgt feats yet
tgt_feats = None
with_align = hasattr(opts, 'lambda_align') and opts.lambda_align > 0.0
fields = get_fields('text', src_nfeats, tgt_nfeats,
fields = get_fields('text', opts.src_feats_vocab, tgt_feats,
dynamic_dict=opts.copy_attn,
src_truncate=opts.src_seq_length_trunc,
tgt_truncate=opts.tgt_seq_length_trunc,
Expand All @@ -33,6 +32,12 @@ def build_dynamic_fields(opts, src_specials=None, tgt_specials=None):
opts.src_vocab, 'src', counters,
min_freq=opts.src_words_min_frequency)

if opts.src_feats_vocab:
for feat_name, filepath in opts.src_feats_vocab.items():
_, _ = _load_vocab(
filepath, feat_name, counters,
min_freq=0)

if opts.tgt_vocab:
_tgt_vocab, _tgt_vocab_size = _load_vocab(
opts.tgt_vocab, 'tgt', counters,
Expand Down
12 changes: 6 additions & 6 deletions onmt/inputters/inputter.py
Original file line number Diff line number Diff line change
Expand Up @@ -111,8 +111,8 @@ def get_task_spec_tokens(data_task, pad, bos, eos):

def get_fields(
src_data_type,
n_src_feats,
n_tgt_feats,
src_feats,
tgt_feats,
pad=DefaultTokens.PAD,
bos=DefaultTokens.BOS,
eos=DefaultTokens.EOS,
Expand All @@ -125,11 +125,11 @@ def get_fields(
"""
Args:
src_data_type: type of the source input. Options are [text].
n_src_feats (int): the number of source features (not counting tokens)
src_feats (Optional[Dict]): source features dict containing their names
to create a :class:`torchtext.data.Field` for. (If
``src_data_type=="text"``, these fields are stored together
as a ``TextMultiField``).
n_tgt_feats (int): See above.
tgt_feats (Optional[Dict]): See above.
pad (str): Special pad symbol. Used on src and tgt side.
bos (str): Special beginning of sequence symbol. Only relevant
for tgt.
Expand Down Expand Up @@ -158,7 +158,7 @@ def get_fields(
task_spec_tokens = get_task_spec_tokens(data_task, pad, bos, eos)

src_field_kwargs = {
"n_feats": n_src_feats,
"feats": src_feats,
"include_lengths": True,
"pad": task_spec_tokens["src"]["pad"],
"bos": task_spec_tokens["src"]["bos"],
Expand All @@ -169,7 +169,7 @@ def get_fields(
fields["src"] = fields_getters[src_data_type](**src_field_kwargs)

tgt_field_kwargs = {
"n_feats": n_tgt_feats,
"feats": tgt_feats,
"include_lengths": False,
"pad": task_spec_tokens["tgt"]["pad"],
"bos": task_spec_tokens["tgt"]["bos"],
Expand Down
Loading