Skip to content
Merged
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 6 additions & 1 deletion onmt/bin/build_vocab.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,11 +32,13 @@ def build_vocab_main(opts):
transforms = make_transforms(opts, transforms_cls, fields)

logger.info(f"Counter vocab from {opts.n_sample} samples.")
src_counter, tgt_counter = build_vocab(
src_counter, tgt_counter, src_feats_counter = build_vocab(
opts, transforms, n_sample=opts.n_sample)

logger.info(f"Counters src:{len(src_counter)}")
logger.info(f"Counters tgt:{len(tgt_counter)}")
for feat_name, feat_counter in src_feats_counter.items():
logger.info(f"Counters {feat_name}:{len(feat_counter)}")

def save_counter(counter, save_path):
check_path(save_path, exist_ok=opts.overwrite, log=logger.warning)
Expand All @@ -52,6 +54,9 @@ def save_counter(counter, save_path):
else:
save_counter(src_counter, opts.src_vocab)
save_counter(tgt_counter, opts.tgt_vocab)

for k, v in src_feats_counter.items():
save_counter(v, opts.src_feats_vocab[k])


def _get_parser():
Expand Down
1 change: 1 addition & 0 deletions onmt/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ class CorpusName(object):
class SubwordMarker(object):
SPACER = '▁'
JOINER = '■'
CASE_MARKUP = ["⦅mrk_case_modifier_C⦆", "⦅mrk_begin_case_region_U⦆", "⦅mrk_end_case_region_U⦆"]


class ModelTask(object):
Expand Down
46 changes: 37 additions & 9 deletions onmt/inputters/corpus.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
from torchtext.data import Dataset as TorchtextDataset, \
Example as TorchtextExample

from collections import Counter
from collections import Counter, defaultdict
from contextlib import contextmanager

import multiprocessing as mp
Expand Down Expand Up @@ -74,6 +74,9 @@ def _process(item, is_train):
maybe_example['tgt'] = ' '.join(maybe_example['tgt'])
if 'align' in maybe_example:
maybe_example['align'] = ' '.join(maybe_example['align'])
if 'src_feats' in maybe_example:
for k in maybe_example['src_feats'].keys():
maybe_example['src_feats'][k] = ' '.join(maybe_example['src_feats'][k])
return maybe_example

def _maybe_add_dynamic_dict(self, example, fields):
Expand Down Expand Up @@ -107,23 +110,32 @@ def __call__(self, bucket):
class ParallelCorpus(object):
"""A parallel corpus file pair that can be loaded to iterate."""

def __init__(self, name, src, tgt, align=None):
def __init__(self, name, src, tgt, align=None, src_feats=None):
"""Initialize src & tgt side file path."""
self.id = name
self.src = src
self.tgt = tgt
self.align = align
self.src_feats = src_feats

def load(self, offset=0, stride=1):
"""
Load file and iterate by lines.
`offset` and `stride` allow to iterate only on every
`stride` example, starting from `offset`.
"""
if self.src_feats:
features_names = []
features_files = []
for feat_name, feat_path in self.src_feats.items():
features_names.append(feat_name)
features_files.append(open(feat_path, mode='rb'))
else:
features_files = []
with exfile_open(self.src, mode='rb') as fs,\
exfile_open(self.tgt, mode='rb') as ft,\
exfile_open(self.align, mode='rb') as fa:
for i, (sline, tline, align) in enumerate(zip(fs, ft, fa)):
for i, (sline, tline, align, *features) in enumerate(zip(fs, ft, fa, *features_files)):
if (i % stride) == offset:
sline = sline.decode('utf-8')
tline = tline.decode('utf-8')
Expand All @@ -133,12 +145,18 @@ def load(self, offset=0, stride=1):
}
if align is not None:
example['align'] = align.decode('utf-8')
if features:
example["src_feats"] = dict()
for j, feat in enumerate(features):
example["src_feats"][features_names[j]] = feat.decode("utf-8")
yield example
for f in features_files:
f.close()

def __str__(self):
cls_name = type(self).__name__
return '{}({}, {}, align={})'.format(
cls_name, self.src, self.tgt, self.align)
return '{}({}, {}, align={}, src_feats={})'.format(
cls_name, self.src, self.tgt, self.align, self.src_feats)


def get_corpora(opts, is_train=False):
Expand All @@ -150,7 +168,8 @@ def get_corpora(opts, is_train=False):
corpus_id,
corpus_dict["path_src"],
corpus_dict["path_tgt"],
corpus_dict["path_align"])
corpus_dict["path_align"],
corpus_dict["src_feats"])
else:
if CorpusName.VALID in opts.data.keys():
corpora_dict[CorpusName.VALID] = ParallelCorpus(
Expand Down Expand Up @@ -193,6 +212,9 @@ def _tokenize(self, stream):
example['src'], example['tgt'] = src, tgt
if 'align' in example:
example['align'] = example['align'].strip('\n').split()
if 'src_feats' in example:
for k in example['src_feats'].keys():
example['src_feats'][k] = example['src_feats'][k].strip('\n').split()
yield example

def _transform(self, stream):
Expand Down Expand Up @@ -286,6 +308,7 @@ def build_sub_vocab(corpora, transforms, opts, n_sample, stride, offset):
"""Build vocab on (strided) subpart of the data."""
sub_counter_src = Counter()
sub_counter_tgt = Counter()
sub_counter_src_feats = defaultdict(Counter)
datasets_iterables = build_corpora_iters(
corpora, transforms, opts.data,
skip_empty_level=opts.skip_empty_level,
Expand All @@ -298,6 +321,9 @@ def build_sub_vocab(corpora, transforms, opts, n_sample, stride, offset):
build_sub_vocab.queues[c_name][offset].put("blank")
continue
src_line, tgt_line = maybe_example['src'], maybe_example['tgt']
if 'src_feats' in maybe_example:
for feat_name, feat_line in maybe_example["src_feats"].items():
sub_counter_src_feats[feat_name].update(feat_line.split(' '))
sub_counter_src.update(src_line.split(' '))
sub_counter_tgt.update(tgt_line.split(' '))
if opts.dump_samples:
Expand All @@ -309,7 +335,7 @@ def build_sub_vocab(corpora, transforms, opts, n_sample, stride, offset):
break
if opts.dump_samples:
build_sub_vocab.queues[c_name][offset].put("break")
return sub_counter_src, sub_counter_tgt
return sub_counter_src, sub_counter_tgt, sub_counter_src_feats


def init_pool(queues):
Expand All @@ -333,6 +359,7 @@ def build_vocab(opts, transforms, n_sample=3):
corpora = get_corpora(opts, is_train=True)
counter_src = Counter()
counter_tgt = Counter()
counter_src_feats = defaultdict(Counter)
from functools import partial
queues = {c_name: [mp.Queue(opts.vocab_sample_queue_size)
for i in range(opts.num_threads)]
Expand All @@ -349,13 +376,14 @@ def build_vocab(opts, transforms, n_sample=3):
func = partial(
build_sub_vocab, corpora, transforms,
opts, n_sample, opts.num_threads)
for sub_counter_src, sub_counter_tgt in p.imap(
for sub_counter_src, sub_counter_tgt, sub_counter_src_feats in p.imap(
func, range(0, opts.num_threads)):
counter_src.update(sub_counter_src)
counter_tgt.update(sub_counter_tgt)
counter_src_feats.update(sub_counter_src_feats)
if opts.dump_samples:
write_process.join()
return counter_src, counter_tgt
return counter_src, counter_tgt, counter_src_feats


def save_transformed_sample(opts, transforms, n_sample=3):
Expand Down
13 changes: 9 additions & 4 deletions onmt/inputters/fields.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,11 +8,10 @@


def _get_dynamic_fields(opts):
# NOTE: not support nfeats > 0 yet
src_nfeats = 0
tgt_nfeats = 0
# NOTE: not support tgt feats yet
tgt_feats = None
with_align = hasattr(opts, 'lambda_align') and opts.lambda_align > 0.0
fields = get_fields('text', src_nfeats, tgt_nfeats,
fields = get_fields('text', opts.src_feats_vocab, tgt_feats,
dynamic_dict=opts.copy_attn,
src_truncate=opts.src_seq_length_trunc,
tgt_truncate=opts.tgt_seq_length_trunc,
Expand All @@ -33,6 +32,12 @@ def build_dynamic_fields(opts, src_specials=None, tgt_specials=None):
opts.src_vocab, 'src', counters,
min_freq=opts.src_words_min_frequency)

if opts.src_feats_vocab:
for feat_name, filepath in opts.src_feats_vocab.items():
_, _ = _load_vocab(
filepath, feat_name, counters,
min_freq=0)

if opts.tgt_vocab:
_tgt_vocab, _tgt_vocab_size = _load_vocab(
opts.tgt_vocab, 'tgt', counters,
Expand Down
12 changes: 6 additions & 6 deletions onmt/inputters/inputter.py
Original file line number Diff line number Diff line change
Expand Up @@ -111,8 +111,8 @@ def get_task_spec_tokens(data_task, pad, bos, eos):

def get_fields(
src_data_type,
n_src_feats,
n_tgt_feats,
src_feats,
tgt_feats,
pad=DefaultTokens.PAD,
bos=DefaultTokens.BOS,
eos=DefaultTokens.EOS,
Expand All @@ -125,11 +125,11 @@ def get_fields(
"""
Args:
src_data_type: type of the source input. Options are [text].
n_src_feats (int): the number of source features (not counting tokens)
src_feats (int): source features dict containing their names
to create a :class:`torchtext.data.Field` for. (If
``src_data_type=="text"``, these fields are stored together
as a ``TextMultiField``).
n_tgt_feats (int): See above.
tgt_feats (int): See above.
pad (str): Special pad symbol. Used on src and tgt side.
bos (str): Special beginning of sequence symbol. Only relevant
for tgt.
Expand Down Expand Up @@ -158,7 +158,7 @@ def get_fields(
task_spec_tokens = get_task_spec_tokens(data_task, pad, bos, eos)

src_field_kwargs = {
"n_feats": n_src_feats,
"feats": src_feats,
"include_lengths": True,
"pad": task_spec_tokens["src"]["pad"],
"bos": task_spec_tokens["src"]["bos"],
Expand All @@ -169,7 +169,7 @@ def get_fields(
fields["src"] = fields_getters[src_data_type](**src_field_kwargs)

tgt_field_kwargs = {
"n_feats": n_tgt_feats,
"feats": tgt_feats,
"include_lengths": False,
"pad": task_spec_tokens["tgt"]["pad"],
"bos": task_spec_tokens["tgt"]["bos"],
Expand Down
48 changes: 32 additions & 16 deletions onmt/inputters/text_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -152,7 +152,7 @@ def text_fields(**kwargs):

Args:
base_name (str): Name associated with the field.
n_feats (int): Number of word level feats (not counting the tokens)
feats (int): Word level feats
include_lengths (bool): Optionally return the sequence lengths.
pad (str, optional): Defaults to ``"<blank>"``.
bos (str or NoneType, optional): Defaults to ``"<s>"``.
Expand All @@ -163,28 +163,44 @@ def text_fields(**kwargs):
TextMultiField
"""

n_feats = kwargs["n_feats"]
feats = kwargs["feats"]
include_lengths = kwargs["include_lengths"]
base_name = kwargs["base_name"]
pad = kwargs.get("pad", DefaultTokens.PAD)
bos = kwargs.get("bos", DefaultTokens.BOS)
eos = kwargs.get("eos", DefaultTokens.EOS)
truncate = kwargs.get("truncate", None)
fields_ = []
feat_delim = u"│" if n_feats > 0 else None
for i in range(n_feats + 1):
name = base_name + "_feat_" + str(i - 1) if i > 0 else base_name
tokenize = partial(
_feature_tokenize,
layer=i,
truncate=truncate,
feat_delim=feat_delim)
use_len = i == 0 and include_lengths
feat = Field(
init_token=bos, eos_token=eos,
pad_token=pad, tokenize=tokenize,
include_lengths=use_len)
fields_.append((name, feat))

feat_delim = None #u"│" if n_feats > 0 else None

# Base field
tokenize = partial(
_feature_tokenize,
layer=None,
truncate=truncate,
feat_delim=feat_delim)
feat = Field(
init_token=bos, eos_token=eos,
pad_token=pad, tokenize=tokenize,
include_lengths=include_lengths)
fields_.append((base_name, feat))

# Feats fields
if feats:
for feat_name in feats.keys():
# Legacy function, it is not really necessary
tokenize = partial(
_feature_tokenize,
layer=None,
truncate=truncate,
feat_delim=feat_delim)
feat = Field(
init_token=bos, eos_token=eos,
pad_token=pad, tokenize=tokenize,
include_lengths=False)
fields_.append((feat_name, feat))

assert fields_[0][0] == base_name # sanity check
field = TextMultiField(fields_[0][0], fields_[0][1], fields_[1:])
return field
5 changes: 5 additions & 0 deletions onmt/opts.py
Original file line number Diff line number Diff line change
Expand Up @@ -132,6 +132,11 @@ def _add_dynamic_fields_opts(parser, build_vocab_only=False):
group.add("-share_vocab", "--share_vocab", action="store_true",
help="Share source and target vocabulary.")

group.add("-src_feats_vocab", "--src_feats_vocab",
help=("List of paths to save" if build_vocab_only else "List of paths to")
+ " src features vocabulary files. "
"Files format: one <word> or <word>\t<count> per line.")

if not build_vocab_only:
group.add("-src_vocab_size", "--src_vocab_size",
type=int, default=50000,
Expand Down
3 changes: 3 additions & 0 deletions onmt/train_single.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,9 @@ def main(opt, fields, transforms_cls, checkpoint, device_id,
"""Start training on `device_id`."""
# NOTE: It's important that ``opt`` has been validated and updated
# at this point.

#import pdb
#pdb.set_trace()
configure_process(opt, device_id)
init_logger(opt.log_file)

Expand Down
Loading