Skip to content

Commit 705b94f

Browse files
author
Ander Corral
committed
Fixed some issues
1 parent f5b1eef commit 705b94f

File tree

8 files changed

+64
-12
lines changed

8 files changed

+64
-12
lines changed

docs/source/FAQ.md

Lines changed: 12 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -498,6 +498,7 @@ A C C C C A A B
498498
**Notes**
499499
- Prior tokenization is not necessary, features will be inferred by using the `FeatInferTransform` transform.
500500
- `FilterFeatsTransform` and `FeatInferTransform` are required in order to ensure the functionality.
501+
- Not possible to do shared embeddings (at least with `feat_merge: concat` method)
501502
502503
Sample config file:
503504
@@ -529,10 +530,20 @@ feat_merge: "sum"
529530

530531
```
531532
532-
During inference you can pass features by using the `--src_feats` argument.
533+
During inference you can pass features by using the `--src_feats` argument. `src_feats` is expected to be a Python like dict, mapping feature name with its data file.
534+
535+
```
536+
{'feat_0': '../data.txt.feats0', 'feat_1': '../data.txt.feats1'}
537+
```
533538
534539
**Important note!** During inference, input sentence is expected to be tokenized. Therefore feature inferring should be handled prior to running the translate command. Example:
535540
536541
```bash
537542
python translate.py -model model_step_10.pt -src ../data.txt.tok -output ../data.out --src_feats "{'feat_0': '../data.txt.feats0', 'feat_1': '../data.txt.feats1'}"
538543
```
544+
545+
When using the Transformer arquitechture make sure the following options are appropiately set:
546+
547+
- `src_word_vec_size` and `tgt_word_vec_size` or `word_vec_size`
548+
- `feat_merge`: how to handle features vecs
549+
- `feat_vec_size` and maybe `feat_vec_exponent`

onmt/inputters/corpus.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -75,6 +75,7 @@ def _process(item, is_train):
7575
maybe_example['src'] = {"src": ' '.join(maybe_example['src'])}
7676

7777
# Make features part of src as in MultiTextField
78+
# {'src': {'src': ..., 'feat1': ...., 'feat2': ....}}
7879
if 'src_feats' in maybe_example:
7980
for feat_name, feat_value in maybe_example['src_feats'].items():
8081
maybe_example['src'][feat_name] = ' '.join(feat_value)
@@ -328,12 +329,12 @@ def build_sub_vocab(corpora, transforms, opts, n_sample, stride, offset):
328329
if opts.dump_samples:
329330
build_sub_vocab.queues[c_name][offset].put("blank")
330331
continue
331-
src_line, tgt_line = maybe_example['src'], maybe_example['tgt']
332+
src_line, tgt_line = maybe_example['src']['src'], maybe_example['tgt']['tgt']
332333
for feat_name, feat_line in maybe_example["src"].items():
333334
if feat_name != "src":
334335
sub_counter_src_feats[feat_name].update(feat_line.split(' '))
335-
sub_counter_src.update(src_line["src"].split(' '))
336-
sub_counter_tgt.update(tgt_line["tgt"].split(' '))
336+
sub_counter_src.update(src_line.split(' '))
337+
sub_counter_tgt.update(tgt_line.split(' '))
337338
if opts.dump_samples:
338339
build_sub_vocab.queues[c_name][offset].put(
339340
(i, src_line, tgt_line))

onmt/inputters/text_dataset.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,9 @@ def read(self, sequences, side, features={}):
1717
path to text file or iterable of the actual text data.
1818
side (str): Prefix used in return dict. Usually
1919
``"src"`` or ``"tgt"``.
20+
features: (Dict[str or Iterable[str]]):
21+
dictionary mapping feature names with th path to feature
22+
file or iterable of the actual feature data.
2023
2124
Yields:
2225
dictionaries whose keys are the names of fields and whose
@@ -53,6 +56,7 @@ def text_sort_key(ex):
5356
return len(ex.src[0])
5457

5558

59+
# Legacy function. Currently it only truncates input if truncate is set.
5660
# mix this with partial
5761
def _feature_tokenize(
5862
string, layer=0, tok_delim=None, feat_delim=None, truncate=None):

onmt/opts.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -761,8 +761,8 @@ def translate_opts(parser):
761761
help="Source sequence to decode (one line per "
762762
"sequence)")
763763
group.add("-src_feats", "--src_feats", required=False,
764-
help="Source sequence features (one line per "
765-
"sequence)")
764+
help="Source sequence features (dict format). "
765+
"Ex: {'feat_0': '../data.txt.feats0', 'feat_1': '../data.txt.feats1'}")
766766
group.add('--tgt', '-tgt',
767767
help='True target sequence (optional)')
768768
group.add('--tgt_prefix', '-tgt_prefix', action='store_true',

onmt/tests/test_subword_marker.py

Lines changed: 19 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -41,12 +41,18 @@ def test_subword_group_joiner(self):
4141
out = subword_map_by_joiner(data_in, marker=SubwordMarker.JOINER, case_markup=SubwordMarker.CASE_MARKUP)
4242
self.assertEqual(out, true_out)
4343

44-
def test_subword_group_joiner_with_markup(self):
44+
def test_subword_group_joiner_with_case_markup(self):
4545
data_in = ['⦅mrk_case_modifier_C⦆', 'however', '■,', 'according', 'to', 'the', 'logs', '■,', '⦅mrk_begin_case_region_U⦆', 'she', 'is', 'hard', '■-■', 'working', '■.', '⦅mrk_end_case_region_U⦆'] # noqa: E501
4646
true_out = [0, 0, 0, 1, 2, 3, 4, 4, 5, 5, 6, 7, 7, 7, 7, 7]
4747
out = subword_map_by_joiner(data_in, marker=SubwordMarker.JOINER, case_markup=SubwordMarker.CASE_MARKUP)
4848
self.assertEqual(out, true_out)
4949

50+
def test_subword_group_joiner_with_new_joiner(self):
51+
data_in = ['⦅mrk_case_modifier_C⦆', 'however', '■', ',', 'according', 'to', 'the', 'logs', '■', ',', '⦅mrk_begin_case_region_U⦆', 'she', 'is', 'hard', '■', '-', '■', 'working', '■', '.', '⦅mrk_end_case_region_U⦆'] # noqa: E501
52+
true_out = [0, 0, 0, 0, 1, 2, 3, 4, 4, 4, 5, 5, 6, 7, 7, 7, 7, 7, 7, 7, 7]
53+
out = subword_map_by_joiner(data_in, marker=SubwordMarker.JOINER, case_markup=SubwordMarker.CASE_MARKUP)
54+
self.assertEqual(out, true_out)
55+
5056
def test_subword_group_naive(self):
5157
data_in = ['however', ',', 'according', 'to', 'the', 'logs', ',', 'she', 'is', 'hard', '-', 'working', '.'] # noqa: E501
5258
true_out = [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12]
@@ -63,6 +69,18 @@ def test_subword_group_spacer(self):
6369
no_dummy_out = subword_map_by_spacer(no_dummy)
6470
self.assertEqual(no_dummy_out, true_out)
6571

72+
def test_subword_group_spacer_with_case_markup(self):
73+
data_in = ['⦅mrk_case_modifier_C⦆', '▁however', ',', '▁according', '▁to', '▁the', '▁logs', ',', '▁⦅mrk_begin_case_region_U⦆', '▁she', '▁is', '▁hard', '-', 'working', '.', '▁⦅mrk_end_case_region_U⦆'] # noqa: E501
74+
true_out = [0, 0, 0, 1, 2, 3, 4, 4, 5, 5, 6, 7, 7, 7, 7, 7]
75+
out = subword_map_by_spacer(data_in)
76+
self.assertEqual(out, true_out)
77+
78+
def test_subword_group_spacer_with_spacer_new(self):
79+
data_in = ['⦅mrk_case_modifier_C⦆', '▁', 'however', ',', '▁', 'according', '▁', 'to', '▁', 'the', '▁', 'logs', ',', '▁', '⦅mrk_begin_case_region_U⦆', '▁', 'she', '▁', 'is', '▁', 'hard', '-', 'working', '.', '▁', '⦅mrk_end_case_region_U⦆'] # noqa: E501
80+
true_out = [0, 0, 0, 0, 1, 1, 2, 2, 3, 3, 4, 4, 4, 5, 5, 5, 5, 6, 6, 7, 7, 7, 7, 7, 7, 7]
81+
out = subword_map_by_spacer(data_in)
82+
self.assertEqual(out, true_out)
83+
6684

6785
if __name__ == '__main__':
6886
unittest.main()

onmt/transforms/features.py

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -63,12 +63,9 @@ def apply(self, example, is_train=False, stats=None, **kwargs):
6363
# Do nothing
6464
return example
6565

66-
# TODO: support joiner_new or spacer_new options. Consistency not ensured currently
67-
6866
if self.reversible_tokenization == "joiner":
6967
word_to_subword_mapping = subword_map_by_joiner(example["src"])
7068
else: #Spacer
71-
# TODO: case markup
7269
word_to_subword_mapping = subword_map_by_spacer(example["src"])
7370

7471
inferred_feats = defaultdict(list)

onmt/translate/translator.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -346,6 +346,7 @@ def translate(
346346
Args:
347347
src: See :func:`self.src_reader.read()`.
348348
tgt: See :func:`self.tgt_reader.read()`.
349+
src_feats: See :func`self.src_reader.read()`.
349350
batch_size (int): size of examples per mini-batch
350351
attn_debug (bool): enables the attention logging
351352
align_debug (bool): enables the word alignment logging

onmt/utils/alignment.py

Lines changed: 22 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -134,9 +134,29 @@ def subword_map_by_joiner(subwords, marker=SubwordMarker.JOINER, case_markup=Sub
134134
return word_group
135135

136136

137-
def subword_map_by_spacer(subwords, marker=SubwordMarker.SPACER):
137+
def subword_map_by_spacer(subwords, marker=SubwordMarker.SPACER, case_markup=SubwordMarker.CASE_MARKUP):
138138
"""Return word id for each subword token (annotate by spacer)."""
139-
word_group = list(accumulate([int(marker in x) for x in subwords]))
139+
flags = [0] * len(subwords)
140+
for i, tok in enumerate(subwords):
141+
if marker in tok:
142+
if tok.replace(marker, "") in case_markup:
143+
if i < len(subwords)-1:
144+
flags[i] = 1
145+
else:
146+
if i > 0:
147+
previous = subwords[i-1].replace(marker, "")
148+
if previous not in case_markup:
149+
flags[i] = 1
150+
151+
# In case there is a final case_markup when new_spacer is on
152+
for i in range(1,len(subwords)-1):
153+
if subwords[-i] in case_markup:
154+
flags[-i] = 0
155+
elif subwords[-i] == marker:
156+
flags[-i] = 0
157+
break
158+
159+
word_group = list(accumulate(flags))
140160
if word_group[0] == 1: # when dummy prefix is set
141161
word_group = [item - 1 for item in word_group]
142162
return word_group

0 commit comments

Comments
 (0)