Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions onmt/bin/build_vocab.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,8 +37,8 @@ def build_vocab_main(opts):

logger.info(f"Counters src:{len(src_counter)}")
logger.info(f"Counters tgt:{len(tgt_counter)}")
for k, v in src_feats_counter["src_feats"].items():
logger.info(f"Counters {k}:{len(v)}")
for feat_name, feat_counter in src_feats_counter.items():
logger.info(f"Counters {feat_name}:{len(feat_counter)}")

def save_counter(counter, save_path):
check_path(save_path, exist_ok=opts.overwrite, log=logger.warning)
Expand All @@ -55,7 +55,7 @@ def save_counter(counter, save_path):
save_counter(src_counter, opts.src_vocab)
save_counter(tgt_counter, opts.tgt_vocab)

for k, v in src_feats_counter["src_feats"].items():
for k, v in src_feats_counter.items():
save_counter(v, opts.src_feats_vocab[k])


Expand Down
1 change: 1 addition & 0 deletions onmt/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ class CorpusName(object):
class SubwordMarker(object):
SPACER = '▁'
JOINER = '■'
CASE_MARKUP = ["⦅mrk_case_modifier_C⦆", "⦅mrk_begin_case_region_U⦆", "⦅mrk_end_case_region_U⦆"]


class ModelTask(object):
Expand Down
37 changes: 18 additions & 19 deletions onmt/inputters/corpus.py
Original file line number Diff line number Diff line change
Expand Up @@ -124,10 +124,12 @@ def load(self, offset=0, stride=1):
`offset` and `stride` allow to iterate only on every
`stride` example, starting from `offset`.
"""
#import pdb
#pdb.set_trace()
if self.src_feats:
features_files = [open(feat_path, mode='rb') for feat_name, feat_path in self.src_feats.items()]
features_names = []
features_files = []
for feat_name, feat_path in self.src_feats.items():
features_names.append(feat_name)
features_files.append(open(feat_path, mode='rb'))
else:
features_files = []
with exfile_open(self.src, mode='rb') as fs,\
Expand All @@ -146,7 +148,7 @@ def load(self, offset=0, stride=1):
if features:
example["src_feats"] = dict()
for j, feat in enumerate(features):
example["src_feats"][list(self.src_feats.keys())[j]] = feat.decode("utf-8")
example["src_feats"][features_names[j]] = feat.decode("utf-8")
yield example
for f in features_files:
f.close()
Expand Down Expand Up @@ -304,11 +306,9 @@ def write_files_from_queues(sample_path, queues):

def build_sub_vocab(corpora, transforms, opts, n_sample, stride, offset):
"""Build vocab on (strided) subpart of the data."""
#import pdb
#pdb.set_trace()
sub_counter_src = Counter()
sub_counter_tgt = Counter()
sub_counter_src_feats = {'src_feats': defaultdict(Counter)}
sub_counter_src_feats = defaultdict(Counter)
datasets_iterables = build_corpora_iters(
corpora, transforms, opts.data,
skip_empty_level=opts.skip_empty_level,
Expand All @@ -323,7 +323,7 @@ def build_sub_vocab(corpora, transforms, opts, n_sample, stride, offset):
src_line, tgt_line = maybe_example['src'], maybe_example['tgt']
if 'src_feats' in maybe_example:
for feat_name, feat_line in maybe_example["src_feats"].items():
sub_counter_src_feats['src_feats'][feat_name].update(feat_line.split(' '))
sub_counter_src_feats[feat_name].update(feat_line.split(' '))
sub_counter_src.update(src_line.split(' '))
sub_counter_tgt.update(tgt_line.split(' '))
if opts.dump_samples:
Expand Down Expand Up @@ -359,7 +359,7 @@ def build_vocab(opts, transforms, n_sample=3):
corpora = get_corpora(opts, is_train=True)
counter_src = Counter()
counter_tgt = Counter()
counter_src_feats = {'src_feats': defaultdict(Counter)}
counter_src_feats = defaultdict(Counter)
from functools import partial
queues = {c_name: [mp.Queue(opts.vocab_sample_queue_size)
for i in range(opts.num_threads)]
Expand All @@ -372,16 +372,15 @@ def build_vocab(opts, transforms, n_sample=3):
args=(sample_path, queues),
daemon=True)
write_process.start()
#with mp.Pool(opts.num_threads, init_pool, [queues]) as p:
func = partial(
build_sub_vocab, corpora, transforms,
opts, n_sample, opts.num_threads)
sub_counter_src, sub_counter_tgt, sub_counter_src_feats = func(0)
# for sub_counter_src, sub_counter_tgt in p.imap(
# func, range(0, opts.num_threads)):
counter_src.update(sub_counter_src)
counter_tgt.update(sub_counter_tgt)
counter_src_feats.update(sub_counter_src_feats)
with mp.Pool(opts.num_threads, init_pool, [queues]) as p:
func = partial(
build_sub_vocab, corpora, transforms,
opts, n_sample, opts.num_threads)
for sub_counter_src, sub_counter_tgt, sub_counter_src_feats in p.imap(
func, range(0, opts.num_threads)):
counter_src.update(sub_counter_src)
counter_tgt.update(sub_counter_tgt)
counter_src_feats.update(sub_counter_src_feats)
if opts.dump_samples:
write_process.join()
return counter_src, counter_tgt, counter_src_feats
Expand Down
7 changes: 3 additions & 4 deletions onmt/inputters/fields.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,11 +8,10 @@


def _get_dynamic_fields(opts):
# NOTE: not support nfeats > 0 yet
#src_nfeats = 0
tgt_nfeats = None #0
# NOTE: not support tgt feats yet
tgt_feats = None
with_align = hasattr(opts, 'lambda_align') and opts.lambda_align > 0.0
fields = get_fields('text', opts.src_feats_vocab, tgt_nfeats,
fields = get_fields('text', opts.src_feats_vocab, tgt_feats,
dynamic_dict=opts.copy_attn,
src_truncate=opts.src_seq_length_trunc,
tgt_truncate=opts.tgt_seq_length_trunc,
Expand Down
12 changes: 6 additions & 6 deletions onmt/inputters/inputter.py
Original file line number Diff line number Diff line change
Expand Up @@ -111,8 +111,8 @@ def get_task_spec_tokens(data_task, pad, bos, eos):

def get_fields(
src_data_type,
n_src_feats,
n_tgt_feats,
src_feats,
tgt_feats,
pad=DefaultTokens.PAD,
bos=DefaultTokens.BOS,
eos=DefaultTokens.EOS,
Expand All @@ -125,11 +125,11 @@ def get_fields(
"""
Args:
src_data_type: type of the source input. Options are [text].
n_src_feats (int): the number of source features (not counting tokens)
src_feats (int): source features dict containing their names
to create a :class:`torchtext.data.Field` for. (If
``src_data_type=="text"``, these fields are stored together
as a ``TextMultiField``).
n_tgt_feats (int): See above.
tgt_feats (int): See above.
pad (str): Special pad symbol. Used on src and tgt side.
bos (str): Special beginning of sequence symbol. Only relevant
for tgt.
Expand Down Expand Up @@ -158,7 +158,7 @@ def get_fields(
task_spec_tokens = get_task_spec_tokens(data_task, pad, bos, eos)

src_field_kwargs = {
"n_feats": n_src_feats,
"feats": src_feats,
"include_lengths": True,
"pad": task_spec_tokens["src"]["pad"],
"bos": task_spec_tokens["src"]["bos"],
Expand All @@ -169,7 +169,7 @@ def get_fields(
fields["src"] = fields_getters[src_data_type](**src_field_kwargs)

tgt_field_kwargs = {
"n_feats": n_tgt_feats,
"feats": tgt_feats,
"include_lengths": False,
"pad": task_spec_tokens["tgt"]["pad"],
"bos": task_spec_tokens["tgt"]["bos"],
Expand Down
11 changes: 5 additions & 6 deletions onmt/inputters/text_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -152,7 +152,7 @@ def text_fields(**kwargs):

Args:
base_name (str): Name associated with the field.
n_feats (int): Number of word level feats (not counting the tokens)
feats (int): Word level feats
include_lengths (bool): Optionally return the sequence lengths.
pad (str, optional): Defaults to ``"<blank>"``.
bos (str or NoneType, optional): Defaults to ``"<s>"``.
Expand All @@ -163,7 +163,7 @@ def text_fields(**kwargs):
TextMultiField
"""

n_feats = kwargs["n_feats"]
feats = kwargs["feats"]
include_lengths = kwargs["include_lengths"]
base_name = kwargs["base_name"]
pad = kwargs.get("pad", DefaultTokens.PAD)
Expand All @@ -187,10 +187,9 @@ def text_fields(**kwargs):
fields_.append((base_name, feat))

# Feats fields
#for i in range(n_feats + 1):
if n_feats:
for feat_name in n_feats.keys():
#name = base_name + "_feat_" + str(i - 1) if i > 0 else base_name
if feats:
for feat_name in feats.keys():
# Legacy function, it is not really necessary
tokenize = partial(
_feature_tokenize,
layer=None,
Expand Down
60 changes: 32 additions & 28 deletions onmt/transforms/features.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,13 @@
from onmt.utils.logging import logger
from onmt.transforms import register_transform
from .transform import Transform, ObservableStats
from onmt.constants import DefaultTokens, SubwordMarker
from onmt.utils.alignment import subword_map_by_joiner, subword_map_by_spacer
import re
from collections import defaultdict



@register_transform(name='filterfeats')
class FilterFeatsTransform(Transform):
"""Filter out examples with a mismatch between source and features."""
Expand Down Expand Up @@ -48,49 +51,50 @@ def add_options(cls, parser):
pass

def _parse_opts(self):
pass
super()._parse_opts()
logger.info("Parsed pyonmttok kwargs for src: {}".format(
self.opts.src_onmttok_kwargs))
self.src_onmttok_kwargs = self.opts.src_onmttok_kwargs

def apply(self, example, is_train=False, stats=None, **kwargs):

if "src_feats" not in example:
# Do nothing
return example

feats_i = 0
inferred_feats = defaultdict(list)
for subword in example["src"]:
next_ = False
for k, v in example["src_feats"].items():
# TODO: what about custom placeholders??
joiner = self.src_onmttok_kwargs["joiner"] if "joiner" in self.src_onmttok_kwargs else SubwordMarker.JOINER
case_markup = SubwordMarker.CASE_MARKUP if "case_markup" in self.src_onmttok_kwargs else []
# TODO: support joiner_new or spacer_new options. Consistency not ensured currently

# Placeholders
if re.match(r'⦅\w+⦆', subword):
inferred_feat = "N"
if "joiner_annotate" in self.src_onmttok_kwargs:
word_to_subword_mapping = subword_map_by_joiner(example["src"], marker=joiner, case_markup=case_markup)
elif "spacer_annotate" in self.src_onmttok_kwargs:
# TODO: case markup
word_to_subword_mapping = subword_map_by_spacer(example["src"], marker=joiner)
else:
# TODO: support not reversible tokenization
raise Exception("InferFeats transform does not currently work without either joiner_annotate or spacer_annotate")

# Punctuation only
elif not re.sub(r'(\W)+', '', subword).strip():
inferred_feat = "N"
inferred_feats = defaultdict(list)
for subword, word_id in zip(example["src"], word_to_subword_mapping):
for feat_name, feat_values in example["src_feats"].items():

# Joiner annotate
elif re.search("■", subword):
inferred_feat = v[feats_i]
# If case markup placeholder
if subword in case_markup:
inferred_feat = "<null>"

# Punctuation only (assumes joiner is also some punctuation token)
elif not re.sub(r'(\W)+', '', subword).strip():
inferred_feat = "<null>"

# Whole word
else:
inferred_feat = v[feats_i]
next_ = True
inferred_feat = feat_values[word_id]

inferred_feats[k].append(inferred_feat)

if next_:
feats_i += 1
inferred_feats[feat_name].append(inferred_feat)

# Check all features have been consumed
for k, v in example["src_feats"].items():
assert feats_i == len(v), f'Not all features consumed for {k}'
for feat_name, feat_values in inferred_feats.items():
example["src_feats"][feat_name] = inferred_feats[feat_name]

for k, v in inferred_feats.items():
example["src_feats"][k] = inferred_feats[k]
return example

def _repr_args(self):
Expand Down
4 changes: 2 additions & 2 deletions onmt/utils/alignment.py
Original file line number Diff line number Diff line change
Expand Up @@ -120,11 +120,11 @@ def to_word_align(src, tgt, subword_align, m_src='joiner', m_tgt='joiner'):
return " ".join(word_align)


def subword_map_by_joiner(subwords, marker=SubwordMarker.JOINER):
def subword_map_by_joiner(subwords, marker=SubwordMarker.JOINER, case_markup=[]):
"""Return word id for each subword token (annotate by joiner)."""
flags = [0] * len(subwords)
for i, tok in enumerate(subwords):
if tok.endswith(marker):
if tok.endswith(marker) or tok in case_markup:
flags[i] = 1
if tok.startswith(marker):
assert i >= 1 and flags[i-1] != 1, \
Expand Down