Skip to content

Commit d5d3c74

Browse files
authored
customizable transform statistics (#2059)
1 parent 6078fd3 commit d5d3c74

File tree

11 files changed

+217
-97
lines changed

11 files changed

+217
-97
lines changed

onmt/bin/train.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -139,7 +139,7 @@ def train(opt):
139139
opt, fields, transforms_cls, stride=nb_gpu, offset=device_id)
140140
producer = mp.Process(target=batch_producer,
141141
args=(train_iter, queues[device_id],
142-
semaphore, opt,),
142+
semaphore, opt, device_id),
143143
daemon=True)
144144
producers.append(producer)
145145
producers[device_id].start()

onmt/inputters/corpus.py

Lines changed: 15 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -123,7 +123,6 @@ def load(self, offset=0, stride=1):
123123
with exfile_open(self.src, mode='rb') as fs,\
124124
exfile_open(self.tgt, mode='rb') as ft,\
125125
exfile_open(self.align, mode='rb') as fa:
126-
logger.info(f"Loading {repr(self)}...")
127126
for i, (sline, tline, align) in enumerate(zip(fs, ft, fa)):
128127
if (i % stride) == offset:
129128
sline = sline.decode('utf-8')
@@ -136,7 +135,7 @@ def load(self, offset=0, stride=1):
136135
example['align'] = align.decode('utf-8')
137136
yield example
138137

139-
def __repr__(self):
138+
def __str__(self):
140139
cls_name = type(self).__name__
141140
return '{}({}, {}, align={})'.format(
142141
cls_name, self.src, self.tgt, self.align)
@@ -169,19 +168,17 @@ class ParallelCorpusIterator(object):
169168
170169
Args:
171170
corpus (ParallelCorpus): corpus to iterate;
172-
transform (Transform): transforms to be applied to corpus;
173-
infinitely (bool): True to iterate endlessly;
171+
transform (TransformPipe): transforms to be applied to corpus;
174172
skip_empty_level (str): security level when encouter empty line;
175173
stride (int): iterate corpus with this line stride;
176174
offset (int): iterate corpus with this line offset.
177175
"""
178176

179-
def __init__(self, corpus, transform, infinitely=False,
177+
def __init__(self, corpus, transform,
180178
skip_empty_level='warning', stride=1, offset=0):
181179
self.cid = corpus.id
182180
self.corpus = corpus
183181
self.transform = transform
184-
self.infinitely = infinitely
185182
if skip_empty_level not in ['silent', 'warning', 'error']:
186183
raise ValueError(
187184
f"Invalid argument skip_empty_level={skip_empty_level}")
@@ -208,8 +205,11 @@ def _transform(self, stream):
208205
yield item
209206
report_msg = self.transform.stats()
210207
if report_msg != '':
211-
logger.info("Transform statistics for {}:\n{}".format(
212-
self.cid, report_msg))
208+
logger.info(
209+
"* Transform statistics for {}({:.2f}%):\n{}\n".format(
210+
self.cid, 100/self.stride, report_msg
211+
)
212+
)
213213

214214
def _add_index(self, stream):
215215
for i, item in enumerate(stream):
@@ -227,24 +227,17 @@ def _add_index(self, stream):
227227
continue
228228
yield item
229229

230-
def _iter_corpus(self):
230+
def __iter__(self):
231231
corpus_stream = self.corpus.load(
232-
stride=self.stride, offset=self.offset)
232+
stride=self.stride, offset=self.offset
233+
)
233234
tokenized_corpus = self._tokenize(corpus_stream)
234235
transformed_corpus = self._transform(tokenized_corpus)
235236
indexed_corpus = self._add_index(transformed_corpus)
236237
yield from indexed_corpus
237238

238-
def __iter__(self):
239-
if self.infinitely:
240-
while True:
241-
_iter = self._iter_corpus()
242-
yield from _iter
243-
else:
244-
yield from self._iter_corpus()
245-
246239

247-
def build_corpora_iters(corpora, transforms, corpora_info, is_train=False,
240+
def build_corpora_iters(corpora, transforms, corpora_info,
248241
skip_empty_level='warning', stride=1, offset=0):
249242
"""Return `ParallelCorpusIterator` for all corpora defined in opts."""
250243
corpora_iters = dict()
@@ -256,7 +249,7 @@ def build_corpora_iters(corpora, transforms, corpora_info, is_train=False,
256249
transform_pipe = TransformPipe.build_from(corpus_transform)
257250
logger.info(f"{c_id}'s transforms: {str(transform_pipe)}")
258251
corpus_iter = ParallelCorpusIterator(
259-
corpus, transform_pipe, infinitely=is_train,
252+
corpus, transform_pipe,
260253
skip_empty_level=skip_empty_level, stride=stride, offset=offset)
261254
corpora_iters[c_id] = corpus_iter
262255
return corpora_iters
@@ -294,7 +287,7 @@ def build_sub_vocab(corpora, transforms, opts, n_sample, stride, offset):
294287
sub_counter_src = Counter()
295288
sub_counter_tgt = Counter()
296289
datasets_iterables = build_corpora_iters(
297-
corpora, transforms, opts.data, is_train=False,
290+
corpora, transforms, opts.data,
298291
skip_empty_level=opts.skip_empty_level,
299292
stride=stride, offset=offset)
300293
for c_name, c_iter in datasets_iterables.items():
@@ -380,7 +373,7 @@ def save_transformed_sample(opts, transforms, n_sample=3):
380373

381374
corpora = get_corpora(opts, is_train=True)
382375
datasets_iterables = build_corpora_iters(
383-
corpora, transforms, opts.data, is_train=False,
376+
corpora, transforms, opts.data,
384377
skip_empty_level=opts.skip_empty_level)
385378
sample_path = os.path.join(
386379
os.path.dirname(opts.save_data), CorpusName.SAMPLE)

onmt/inputters/dynamic_iterator.py

Lines changed: 15 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
from onmt.inputters.corpus import get_corpora, build_corpora_iters,\
77
DatasetAdapter
88
from onmt.transforms import make_transforms
9+
from onmt.utils.logging import logger
910

1011

1112
class MixingStrategy(object):
@@ -47,13 +48,22 @@ class WeightedMixer(MixingStrategy):
4748

4849
def __init__(self, iterables, weights):
4950
super().__init__(iterables, weights)
50-
self._iterators = {
51-
ds_name: iter(generator)
52-
for ds_name, generator in self.iterables.items()
53-
}
51+
self._iterators = {}
52+
self._counts = {}
53+
for ds_name in self.iterables.keys():
54+
self._reset_iter(ds_name)
55+
56+
def _logging(self):
57+
"""Report corpora loading statistics."""
58+
msgs = []
59+
for ds_name, ds_count in self._counts.items():
60+
msgs.append(f"\t\t\t* {ds_name}: {ds_count}")
61+
logger.info("Weighted corpora loaded so far:\n"+"\n".join(msgs))
5462

5563
def _reset_iter(self, ds_name):
5664
self._iterators[ds_name] = iter(self.iterables[ds_name])
65+
self._counts[ds_name] = self._counts.get(ds_name, 0) + 1
66+
self._logging()
5767

5868
def _iter_datasets(self):
5969
for ds_name, ds_weight in self.weights.items():
@@ -144,8 +154,7 @@ def from_opts(cls, corpora, transforms, fields, opts, is_train,
144154

145155
def _init_datasets(self):
146156
datasets_iterables = build_corpora_iters(
147-
self.corpora, self.transforms,
148-
self.corpora_info, self.is_train,
157+
self.corpora, self.transforms, self.corpora_info,
149158
skip_empty_level=self.skip_empty_level,
150159
stride=self.stride, offset=self.offset)
151160
self.dataset_adapter = DatasetAdapter(self.fields, self.is_train)

onmt/opts.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,10 @@ def _add_logging_opts(parser, is_train=True):
2626
action=StoreLoggingLevelAction,
2727
choices=StoreLoggingLevelAction.CHOICES,
2828
default="0")
29+
group.add('--verbose', '-verbose', action="store_true",
30+
help='Print data loading and statistics for all process'
31+
'(default only log the first process shard)' if is_train
32+
else 'Print scores and predictions for each sentence')
2933

3034
if is_train:
3135
group.add('--report_every', '-report_every', type=int, default=50,
@@ -44,8 +48,6 @@ def _add_logging_opts(parser, is_train=True):
4448
"This is also the name of the run.")
4549
else:
4650
# Options only during inference
47-
group.add('--verbose', '-verbose', action="store_true",
48-
help='Print scores and predictions for each sentence')
4951
group.add('--attn_debug', '-attn_debug', action="store_true",
5052
help='Print best attn for each word')
5153
group.add('--align_debug', '-align_debug', action="store_true",
@@ -75,7 +77,7 @@ def _add_dynamic_corpus_opts(parser, build_vocab_only=False):
7577
help="Security level when encounter empty examples."
7678
"silent: silently ignore/skip empty example;"
7779
"warning: warning when ignore/skip empty example;"
78-
"error: raise error & stop excution when encouter empty.)")
80+
"error: raise error & stop execution when encouter empty.")
7981
group.add("-transforms", "--transforms", default=[], nargs="+",
8082
choices=AVAILABLE_TRANSFORMS.keys(),
8183
help="Default transform pipeline to apply to data. "

onmt/tests/test_transform.py

Lines changed: 47 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,12 @@
55
import yaml
66
import math
77
from argparse import Namespace
8-
from onmt.transforms import get_transforms_cls, get_specials, make_transforms
8+
from onmt.transforms import (
9+
get_transforms_cls,
10+
get_specials,
11+
make_transforms,
12+
TransformPipe,
13+
)
914
from onmt.transforms.bart import BARTNoising
1015

1116

@@ -51,6 +56,47 @@ def test_transform_specials(self):
5156
self.assertEqual(specials, specials_expected)
5257

5358

59+
def test_transform_pipe(self):
60+
# 1. Init first transform in the pipe
61+
prefix_cls = get_transforms_cls(["prefix"])["prefix"]
62+
corpora = yaml.safe_load("""
63+
trainset:
64+
path_src: data/src-train.txt
65+
path_tgt: data/tgt-train.txt
66+
transforms: [prefix, filtertoolong]
67+
weight: 1
68+
src_prefix: "⦅_pf_src⦆"
69+
tgt_prefix: "⦅_pf_tgt⦆"
70+
""")
71+
opt = Namespace(data=corpora, seed=-1)
72+
prefix_transform = prefix_cls(opt)
73+
prefix_transform.warm_up()
74+
# 2. Init second transform in the pipe
75+
filter_cls = get_transforms_cls(["filtertoolong"])["filtertoolong"]
76+
opt = Namespace(src_seq_length=4, tgt_seq_length=4)
77+
filter_transform = filter_cls(opt)
78+
# 3. Sequential combine them into a transform pipe
79+
transform_pipe = TransformPipe.build_from(
80+
[prefix_transform, filter_transform]
81+
)
82+
ex = {
83+
"src": ["Hello", ",", "world", "."],
84+
"tgt": ["Bonjour", "le", "monde", "."],
85+
}
86+
# 4. apply transform pipe for example
87+
ex_after = transform_pipe.apply(
88+
copy.deepcopy(ex), corpus_name="trainset"
89+
)
90+
# 5. example after the pipe exceed the length limit, thus filtered
91+
self.assertIsNone(ex_after)
92+
# 6. Transform statistics registed (here for filtertoolong)
93+
self.assertTrue(len(transform_pipe.statistics.observables) > 0)
94+
msg = transform_pipe.statistics.report()
95+
self.assertIsNotNone(msg)
96+
# 7. after report, statistics become empty as a fresh start
97+
self.assertTrue(len(transform_pipe.statistics.observables) == 0)
98+
99+
54100
class TestMiscTransform(unittest.TestCase):
55101
def test_prefix(self):
56102
prefix_cls = get_transforms_cls(["prefix"])["prefix"]

onmt/transforms/misc.py

Lines changed: 13 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,17 @@
11
from onmt.utils.logging import logger
22
from onmt.transforms import register_transform
3-
from .transform import Transform
3+
from .transform import Transform, ObservableStats
4+
5+
6+
class FilterTooLongStats(ObservableStats):
7+
"""Runing statistics for FilterTooLongTransform."""
8+
__slots__ = ["filtered"]
9+
10+
def __init__(self):
11+
self.filtered = 1
12+
13+
def update(self, other: "FilterTooLongStats"):
14+
self.filtered += other.filtered
415

516

617
@register_transform(name='filtertoolong')
@@ -28,7 +39,7 @@ def apply(self, example, is_train=False, stats=None, **kwargs):
2839
if (len(example['src']) > self.src_seq_length or
2940
len(example['tgt']) > self.tgt_seq_length):
3041
if stats is not None:
31-
stats.filter_too_long()
42+
stats.update(FilterTooLongStats())
3243
return None
3344
else:
3445
return example

onmt/transforms/sampling.py

Lines changed: 46 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
import numpy as np
44
from onmt.constants import DefaultTokens
55
from onmt.transforms import register_transform
6-
from .transform import Transform
6+
from .transform import Transform, ObservableStats
77

88

99
class HammingDistanceSampling(object):
@@ -44,6 +44,20 @@ def _set_seed(self, seed):
4444
random.seed(seed)
4545

4646

47+
class SwitchOutStats(ObservableStats):
48+
"""Runing statistics for counting tokens being switched out."""
49+
50+
__slots__ = ["changed", "total"]
51+
52+
def __init__(self, changed: int, total: int):
53+
self.changed = changed
54+
self.total = total
55+
56+
def update(self, other: "SwitchOutStats"):
57+
self.changed += other.changed
58+
self.total += other.total
59+
60+
4761
@register_transform(name='switchout')
4862
class SwitchOutTransform(HammingDistanceSamplingTransform):
4963
"""
@@ -81,7 +95,7 @@ def _switchout(self, tokens, vocab, stats=None):
8195
for i in chosen_indices:
8296
tokens[i] = self._sample_replace(vocab, reject=tokens[i])
8397
if stats is not None:
84-
stats.switchout(n_switchout=n_chosen, n_total=len(tokens))
98+
stats.update(SwitchOutStats(n_chosen, len(tokens)))
8599
return tokens
86100

87101
def apply(self, example, is_train=False, stats=None, **kwargs):
@@ -98,6 +112,20 @@ def _repr_args(self):
98112
return '{}={}'.format('switchout_temperature', self.temperature)
99113

100114

115+
class TokenDropStats(ObservableStats):
116+
"""Runing statistics for counting tokens being switched out."""
117+
118+
__slots__ = ["dropped", "total"]
119+
120+
def __init__(self, dropped: int, total: int):
121+
self.dropped = dropped
122+
self.total = total
123+
124+
def update(self, other: "TokenDropStats"):
125+
self.dropped += other.dropped
126+
self.total += other.total
127+
128+
101129
@register_transform(name='tokendrop')
102130
class TokenDropTransform(HammingDistanceSamplingTransform):
103131
"""Random drop tokens from sentence."""
@@ -126,7 +154,7 @@ def _token_drop(self, tokens, stats=None):
126154
out = [tok for (i, tok) in enumerate(tokens)
127155
if i not in chosen_indices]
128156
if stats is not None:
129-
stats.token_drop(n_dropped=n_chosen, n_total=n_items)
157+
stats.update(TokenDropStats(n_chosen, n_items))
130158
return out
131159

132160
def apply(self, example, is_train=False, stats=None, **kwargs):
@@ -141,6 +169,20 @@ def _repr_args(self):
141169
return '{}={}'.format('tokendrop_temperature', self.temperature)
142170

143171

172+
class TokenMaskStats(ObservableStats):
173+
"""Runing statistics for counting tokens being switched out."""
174+
175+
__slots__ = ["masked", "total"]
176+
177+
def __init__(self, masked: int, total: int):
178+
self.masked = masked
179+
self.total = total
180+
181+
def update(self, other: "TokenMaskStats"):
182+
self.masked += other.masked
183+
self.total += other.total
184+
185+
144186
@register_transform(name='tokenmask')
145187
class TokenMaskTransform(HammingDistanceSamplingTransform):
146188
"""Random mask tokens from src sentence."""
@@ -175,7 +217,7 @@ def _token_mask(self, tokens, stats=None):
175217
for i in chosen_indices:
176218
tokens[i] = self.MASK_TOK
177219
if stats is not None:
178-
stats.token_mask(n_masked=n_chosen, n_total=len(tokens))
220+
stats.update(TokenDropStats(n_chosen, len(tokens)))
179221
return tokens
180222

181223
def apply(self, example, is_train=False, stats=None, **kwargs):

0 commit comments

Comments
 (0)