Skip to content
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 6 additions & 1 deletion onmt/bin/build_vocab.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,11 +32,13 @@ def build_vocab_main(opts):
transforms = make_transforms(opts, transforms_cls, fields)

logger.info(f"Counter vocab from {opts.n_sample} samples.")
src_counter, tgt_counter = build_vocab(
src_counter, tgt_counter, src_feats_counter = build_vocab(
opts, transforms, n_sample=opts.n_sample)

logger.info(f"Counters src:{len(src_counter)}")
logger.info(f"Counters tgt:{len(tgt_counter)}")
for k, v in src_feats_counter["src_feats"].items():
logger.info(f"Counters {k}:{len(v)}")

def save_counter(counter, save_path):
check_path(save_path, exist_ok=opts.overwrite, log=logger.warning)
Expand All @@ -52,6 +54,9 @@ def save_counter(counter, save_path):
else:
save_counter(src_counter, opts.src_vocab)
save_counter(tgt_counter, opts.tgt_vocab)

for k, v in src_feats_counter["src_feats"].items():
save_counter(v, opts.src_feats_vocab[k])


def _get_parser():
Expand Down
61 changes: 45 additions & 16 deletions onmt/inputters/corpus.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
from torchtext.data import Dataset as TorchtextDataset, \
Example as TorchtextExample

from collections import Counter
from collections import Counter, defaultdict
from contextlib import contextmanager

import multiprocessing as mp
Expand Down Expand Up @@ -74,6 +74,9 @@ def _process(item, is_train):
maybe_example['tgt'] = ' '.join(maybe_example['tgt'])
if 'align' in maybe_example:
maybe_example['align'] = ' '.join(maybe_example['align'])
if 'src_feats' in maybe_example:
for k in maybe_example['src_feats'].keys():
maybe_example['src_feats'][k] = ' '.join(maybe_example['src_feats'][k])
return maybe_example

def _maybe_add_dynamic_dict(self, example, fields):
Expand Down Expand Up @@ -107,23 +110,30 @@ def __call__(self, bucket):
class ParallelCorpus(object):
"""A parallel corpus file pair that can be loaded to iterate."""

def __init__(self, name, src, tgt, align=None):
def __init__(self, name, src, tgt, align=None, src_feats=None):
"""Initialize src & tgt side file path."""
self.id = name
self.src = src
self.tgt = tgt
self.align = align
self.src_feats = src_feats

def load(self, offset=0, stride=1):
"""
Load file and iterate by lines.
`offset` and `stride` allow to iterate only on every
`stride` example, starting from `offset`.
"""
#import pdb
#pdb.set_trace()
if self.src_feats:
features_files = [open(feat_path, mode='rb') for feat_name, feat_path in self.src_feats.items()]
else:
features_files = []
with exfile_open(self.src, mode='rb') as fs,\
exfile_open(self.tgt, mode='rb') as ft,\
exfile_open(self.align, mode='rb') as fa:
for i, (sline, tline, align) in enumerate(zip(fs, ft, fa)):
for i, (sline, tline, align, *features) in enumerate(zip(fs, ft, fa, *features_files)):
if (i % stride) == offset:
sline = sline.decode('utf-8')
tline = tline.decode('utf-8')
Expand All @@ -133,12 +143,18 @@ def load(self, offset=0, stride=1):
}
if align is not None:
example['align'] = align.decode('utf-8')
if features:
example["src_feats"] = dict()
for j, feat in enumerate(features):
example["src_feats"][list(self.src_feats.keys())[j]] = feat.decode("utf-8")
yield example
for f in features_files:
f.close()

def __str__(self):
cls_name = type(self).__name__
return '{}({}, {}, align={})'.format(
cls_name, self.src, self.tgt, self.align)
return '{}({}, {}, align={}, src_feats={})'.format(
cls_name, self.src, self.tgt, self.align, self.src_feats)


def get_corpora(opts, is_train=False):
Expand All @@ -150,7 +166,8 @@ def get_corpora(opts, is_train=False):
corpus_id,
corpus_dict["path_src"],
corpus_dict["path_tgt"],
corpus_dict["path_align"])
corpus_dict["path_align"],
corpus_dict["src_feats"])
else:
if CorpusName.VALID in opts.data.keys():
corpora_dict[CorpusName.VALID] = ParallelCorpus(
Expand Down Expand Up @@ -193,6 +210,9 @@ def _tokenize(self, stream):
example['src'], example['tgt'] = src, tgt
if 'align' in example:
example['align'] = example['align'].strip('\n').split()
if 'src_feats' in example:
for k in example['src_feats'].keys():
example['src_feats'][k] = example['src_feats'][k].strip('\n').split()
yield example

def _transform(self, stream):
Expand Down Expand Up @@ -284,8 +304,11 @@ def write_files_from_queues(sample_path, queues):

def build_sub_vocab(corpora, transforms, opts, n_sample, stride, offset):
"""Build vocab on (strided) subpart of the data."""
#import pdb
#pdb.set_trace()
sub_counter_src = Counter()
sub_counter_tgt = Counter()
sub_counter_src_feats = {'src_feats': defaultdict(Counter)}
datasets_iterables = build_corpora_iters(
corpora, transforms, opts.data,
skip_empty_level=opts.skip_empty_level,
Expand All @@ -298,6 +321,9 @@ def build_sub_vocab(corpora, transforms, opts, n_sample, stride, offset):
build_sub_vocab.queues[c_name][offset].put("blank")
continue
src_line, tgt_line = maybe_example['src'], maybe_example['tgt']
if 'src_feats' in maybe_example:
for feat_name, feat_line in maybe_example["src_feats"].items():
sub_counter_src_feats['src_feats'][feat_name].update(feat_line.split(' '))
sub_counter_src.update(src_line.split(' '))
sub_counter_tgt.update(tgt_line.split(' '))
if opts.dump_samples:
Expand All @@ -309,7 +335,7 @@ def build_sub_vocab(corpora, transforms, opts, n_sample, stride, offset):
break
if opts.dump_samples:
build_sub_vocab.queues[c_name][offset].put("break")
return sub_counter_src, sub_counter_tgt
return sub_counter_src, sub_counter_tgt, sub_counter_src_feats


def init_pool(queues):
Expand All @@ -333,6 +359,7 @@ def build_vocab(opts, transforms, n_sample=3):
corpora = get_corpora(opts, is_train=True)
counter_src = Counter()
counter_tgt = Counter()
counter_src_feats = {'src_feats': defaultdict(Counter)}
from functools import partial
queues = {c_name: [mp.Queue(opts.vocab_sample_queue_size)
for i in range(opts.num_threads)]
Expand All @@ -345,17 +372,19 @@ def build_vocab(opts, transforms, n_sample=3):
args=(sample_path, queues),
daemon=True)
write_process.start()
with mp.Pool(opts.num_threads, init_pool, [queues]) as p:
func = partial(
build_sub_vocab, corpora, transforms,
opts, n_sample, opts.num_threads)
for sub_counter_src, sub_counter_tgt in p.imap(
func, range(0, opts.num_threads)):
counter_src.update(sub_counter_src)
counter_tgt.update(sub_counter_tgt)
#with mp.Pool(opts.num_threads, init_pool, [queues]) as p:
func = partial(
build_sub_vocab, corpora, transforms,
opts, n_sample, opts.num_threads)
sub_counter_src, sub_counter_tgt, sub_counter_src_feats = func(0)
# for sub_counter_src, sub_counter_tgt in p.imap(
# func, range(0, opts.num_threads)):
counter_src.update(sub_counter_src)
counter_tgt.update(sub_counter_tgt)
counter_src_feats.update(sub_counter_src_feats)
if opts.dump_samples:
write_process.join()
return counter_src, counter_tgt
return counter_src, counter_tgt, counter_src_feats


def save_transformed_sample(opts, transforms, n_sample=3):
Expand Down
12 changes: 9 additions & 3 deletions onmt/inputters/fields.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,10 +9,10 @@

def _get_dynamic_fields(opts):
# NOTE: not support nfeats > 0 yet
src_nfeats = 0
tgt_nfeats = 0
#src_nfeats = 0
tgt_nfeats = None #0
with_align = hasattr(opts, 'lambda_align') and opts.lambda_align > 0.0
fields = get_fields('text', src_nfeats, tgt_nfeats,
fields = get_fields('text', opts.src_feats_vocab, tgt_nfeats,
dynamic_dict=opts.copy_attn,
src_truncate=opts.src_seq_length_trunc,
tgt_truncate=opts.tgt_seq_length_trunc,
Expand All @@ -33,6 +33,12 @@ def build_dynamic_fields(opts, src_specials=None, tgt_specials=None):
opts.src_vocab, 'src', counters,
min_freq=opts.src_words_min_frequency)

if opts.src_feats_vocab:
for feat_name, filepath in opts.src_feats_vocab.items():
_, _ = _load_vocab(
filepath, feat_name, counters,
min_freq=0)

if opts.tgt_vocab:
_tgt_vocab, _tgt_vocab_size = _load_vocab(
opts.tgt_vocab, 'tgt', counters,
Expand Down
45 changes: 31 additions & 14 deletions onmt/inputters/text_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -171,20 +171,37 @@ def text_fields(**kwargs):
eos = kwargs.get("eos", DefaultTokens.EOS)
truncate = kwargs.get("truncate", None)
fields_ = []
feat_delim = u"│" if n_feats > 0 else None
for i in range(n_feats + 1):
name = base_name + "_feat_" + str(i - 1) if i > 0 else base_name
tokenize = partial(
_feature_tokenize,
layer=i,
truncate=truncate,
feat_delim=feat_delim)
use_len = i == 0 and include_lengths
feat = Field(
init_token=bos, eos_token=eos,
pad_token=pad, tokenize=tokenize,
include_lengths=use_len)
fields_.append((name, feat))

feat_delim = None #u"│" if n_feats > 0 else None

# Base field
tokenize = partial(
_feature_tokenize,
layer=None,
truncate=truncate,
feat_delim=feat_delim)
feat = Field(
init_token=bos, eos_token=eos,
pad_token=pad, tokenize=tokenize,
include_lengths=include_lengths)
fields_.append((base_name, feat))

# Feats fields
#for i in range(n_feats + 1):
if n_feats:
for feat_name in n_feats.keys():
#name = base_name + "_feat_" + str(i - 1) if i > 0 else base_name
tokenize = partial(
_feature_tokenize,
layer=None,
truncate=truncate,
feat_delim=feat_delim)
feat = Field(
init_token=bos, eos_token=eos,
pad_token=pad, tokenize=tokenize,
include_lengths=False)
fields_.append((feat_name, feat))

assert fields_[0][0] == base_name # sanity check
field = TextMultiField(fields_[0][0], fields_[0][1], fields_[1:])
return field
5 changes: 5 additions & 0 deletions onmt/opts.py
Original file line number Diff line number Diff line change
Expand Up @@ -132,6 +132,11 @@ def _add_dynamic_fields_opts(parser, build_vocab_only=False):
group.add("-share_vocab", "--share_vocab", action="store_true",
help="Share source and target vocabulary.")

group.add("-src_feats_vocab", "--src_feats_vocab",
help=("List of paths to save" if build_vocab_only else "List of paths to")
+ " src features vocabulary files. "
"Files format: one <word> or <word>\t<count> per line.")

if not build_vocab_only:
group.add("-src_vocab_size", "--src_vocab_size",
type=int, default=50000,
Expand Down
3 changes: 3 additions & 0 deletions onmt/train_single.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,9 @@ def main(opt, fields, transforms_cls, checkpoint, device_id,
"""Start training on `device_id`."""
# NOTE: It's important that ``opt`` has been validated and updated
# at this point.

#import pdb
#pdb.set_trace()
configure_process(opt, device_id)
init_logger(opt.log_file)

Expand Down
97 changes: 97 additions & 0 deletions onmt/transforms/features.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,97 @@
from onmt.utils.logging import logger
from onmt.transforms import register_transform
from .transform import Transform, ObservableStats
import re
from collections import defaultdict


@register_transform(name='filterfeats')
class FilterFeatsTransform(Transform):
"""Filter out examples with a mismatch between source and features."""

def __init__(self, opts):
super().__init__(opts)

@classmethod
def add_options(cls, parser):
pass

def _parse_opts(self):
pass

def apply(self, example, is_train=False, stats=None, **kwargs):
"""Return None if mismatch"""

if 'src_feats' not in example:
# Do nothing
return example

for feat_name, feat_values in example['src_feats'].items():
if len(example['src']) != len(feat_values):
logger.warning(f"Skipping example due to mismatch between source and feature {feat_name}")
return None
return example

def _repr_args(self):
return ''


@register_transform(name='inferfeats')
class InferFeatsTransform(Transform):
"""Infer features for subword tokenization."""

def __init__(self, opts):
super().__init__(opts)

@classmethod
def add_options(cls, parser):
pass

def _parse_opts(self):
pass

def apply(self, example, is_train=False, stats=None, **kwargs):

if "src_feats" not in example:
# Do nothing
return example

feats_i = 0
inferred_feats = defaultdict(list)
for subword in example["src"]:
next_ = False
for k, v in example["src_feats"].items():
# TODO: what about custom placeholders??

# Placeholders
if re.match(r'⦅\w+⦆', subword):
inferred_feat = "N"

# Punctuation only
elif not re.sub(r'(\W)+', '', subword).strip():
inferred_feat = "N"

# Joiner annotate
elif re.search("■", subword):
inferred_feat = v[feats_i]

# Whole word
else:
inferred_feat = v[feats_i]
next_ = True

inferred_feats[k].append(inferred_feat)

if next_:
feats_i += 1

# Check all features have been consumed
for k, v in example["src_feats"].items():
assert feats_i == len(v), f'Not all features consumed for {k}'

for k, v in inferred_feats.items():
example["src_feats"][k] = inferred_feats[k]
return example

def _repr_args(self):
return ''
Loading