forked from Fsoft-AIC/RoPADet
-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathround_robin_zip_datasets.py
More file actions
160 lines (137 loc) · 6.23 KB
/
round_robin_zip_datasets.py
File metadata and controls
160 lines (137 loc) · 6.23 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
import logging
from collections import OrderedDict
from typing import Dict, Sequence
import numpy as np
from . import FairseqDataset, LanguagePairDataset
logger = logging.getLogger(__name__)
class RoundRobinZipDatasets(FairseqDataset):
"""Zip multiple :class:`~fairseq.data.FairseqDataset` instances together.
Shorter datasets are repeated in a round-robin fashion to match the length
of the longest one.
Args:
datasets (Dict[~fairseq.data.FairseqDataset]): a dictionary of
:class:`~fairseq.data.FairseqDataset` instances.
eval_key (str, optional): a key used at evaluation time that causes
this instance to pass-through batches from *datasets[eval_key]*.
"""
def __init__(self, datasets, eval_key=None):
super().__init__()
if isinstance(datasets, dict):
datasets = OrderedDict(datasets)
assert isinstance(datasets, OrderedDict)
assert datasets, "Can't make a RoundRobinZipDatasets out of nothing"
for dataset in datasets.values():
assert isinstance(dataset, FairseqDataset)
self.datasets = datasets
self.eval_key = eval_key
self.longest_dataset_key = max(datasets, key=lambda k: len(datasets[k]))
self.longest_dataset = datasets[self.longest_dataset_key]
self._ordered_indices: Dict[str, Sequence[int]] = None
def _map_index(self, key, index):
assert (
self._ordered_indices is not None
), "Must call RoundRobinZipDatasets.ordered_indices() first"
o = self._ordered_indices[key]
return o[index % len(o)]
def __getitem__(self, index):
if self.eval_key is None:
return OrderedDict(
[
(key, dataset[self._map_index(key, index)])
for key, dataset in self.datasets.items()
]
)
else:
# at evaluation time it's useful to pass-through batches from a single key
return self.datasets[self.eval_key][self._map_index(self.eval_key, index)]
def __len__(self):
if self._ordered_indices is not None:
return len(self._ordered_indices[self.longest_dataset_key])
return len(self.longest_dataset)
def collater(self, samples):
"""Merge a list of samples to form a mini-batch."""
if len(samples) == 0:
return None
if self.eval_key is None:
return OrderedDict(
[
(key, dataset.collater([sample[key] for sample in samples]))
for key, dataset in self.datasets.items()
]
)
else:
# at evaluation time it's useful to pass-through batches from a single key
return self.datasets[self.eval_key].collater(samples)
def num_tokens(self, index):
"""Return an example's length (number of tokens), used for batching."""
# TODO make it configurable whether to use max() or sum() here
return max(
dataset.num_tokens(self._map_index(key, index))
for key, dataset in self.datasets.items()
)
def size(self, index):
"""Return an example's size as a float or tuple. This value is used when
filtering a dataset with ``--max-positions``."""
return {
key: dataset.size(self._map_index(key, index))
for key, dataset in self.datasets.items()
}
def ordered_indices(self):
"""Ordered indices for batching."""
if self._ordered_indices is None:
# Call the underlying dataset's ordered_indices() here, so that we
# get the same random ordering as we would have from using the
# underlying sub-datasets directly.
self._ordered_indices = OrderedDict(
[
(key, dataset.ordered_indices())
for key, dataset in self.datasets.items()
]
)
return np.arange(len(self))
def filter_indices_by_size(self, indices, max_positions=None):
"""
Filter each sub-dataset independently, then update the round robin to work
on the filtered sub-datasets.
"""
def _deep_until_language_pair(dataset):
if isinstance(dataset, LanguagePairDataset):
return dataset
if hasattr(dataset, "tgt_dataset"):
return _deep_until_language_pair(dataset.tgt_dataset)
if hasattr(dataset, "dataset"):
return _deep_until_language_pair(dataset.dataset)
raise Exception(f"Don't know how to unwrap this dataset: {dataset}")
if not isinstance(max_positions, dict):
max_positions = {k: max_positions for k in self.datasets.keys()}
ignored_some = False
for key, dataset in self.datasets.items():
dataset = _deep_until_language_pair(dataset)
self._ordered_indices[key], ignored = dataset.filter_indices_by_size(
self._ordered_indices[key], max_positions[key]
)
if len(ignored) > 0:
ignored_some = True
logger.warning(
f"{len(ignored)} samples from {key} have invalid sizes and will be skipped, "
f"max_positions={max_positions[key]}, first few sample ids={ignored[:10]}"
)
# Since we are modifying in place the _ordered_indices,
# it's not possible anymore to return valid ignored indices.
# Hopefully the extra debug information print above should be enough to debug.
# Ideally we would receive ignore_invalid_inputs so that we could have
# a proper error message.
return (np.arange(len(self)), [0] if ignored_some else [])
@property
def supports_prefetch(self):
return all(
getattr(dataset, "supports_prefetch", False)
for dataset in self.datasets.values()
)
def prefetch(self, indices):
for key, dataset in self.datasets.items():
dataset.prefetch([self._map_index(key, index) for index in indices])