|
19 | 19 | import logging |
20 | 20 | import math |
21 | 21 | from collections import OrderedDict |
22 | | -from typing import Dict, Generator, Iterable, Iterator, List, Tuple, Union |
| 22 | +from typing import Dict, Iterable, Iterator, List, Tuple, Union |
23 | 23 |
|
24 | 24 | import numpy |
25 | 25 |
|
|
31 | 31 | _LOGGER = logging.getLogger(__name__) |
32 | 32 |
|
33 | 33 |
|
34 | | -# A utility class to load data in batches for fixed number of iterations |
35 | | - |
36 | | - |
37 | | -class _BatchLoader: |
38 | | - __slots__ = [ |
39 | | - "_data", |
40 | | - "_batch_size", |
41 | | - "_was_wrapped_originally", |
42 | | - "_iterations", |
43 | | - "_batch_buffer", |
44 | | - "_batch_template", |
45 | | - "_batches_created", |
46 | | - ] |
47 | | - |
48 | | - def __init__( |
49 | | - self, |
50 | | - data: Iterable[Union[numpy.ndarray, List[numpy.ndarray]]], |
51 | | - batch_size: int, |
52 | | - iterations: int, |
53 | | - ): |
54 | | - self._data = data |
55 | | - self._was_wrapped_originally = type(self._data[0]) is list |
56 | | - if not self._was_wrapped_originally: |
57 | | - self._data = [self._data] |
58 | | - self._batch_size = batch_size |
59 | | - self._iterations = iterations |
60 | | - if batch_size <= 0 or iterations <= 0: |
61 | | - raise ValueError( |
62 | | - f"Both batch size and number of iterations should be positive, " |
63 | | - f"supplied values (batch_size, iterations):{(batch_size, iterations)}" |
64 | | - ) |
65 | | - |
66 | | - self._batch_buffer = [] |
67 | | - self._batch_template = self._init_batch_template() |
68 | | - self._batches_created = 0 |
69 | | - |
70 | | - def __iter__(self) -> Generator[List[numpy.ndarray], None, None]: |
71 | | - yield from self._multi_input_batch_generator() |
72 | | - |
73 | | - @property |
74 | | - def _buffer_is_full(self) -> bool: |
75 | | - return len(self._batch_buffer) == self._batch_size |
76 | | - |
77 | | - @property |
78 | | - def _all_batches_loaded(self) -> bool: |
79 | | - return self._batches_created >= self._iterations |
80 | | - |
81 | | - def _multi_input_batch_generator( |
82 | | - self, |
83 | | - ) -> Generator[List[numpy.ndarray], None, None]: |
84 | | - # A generator for with each element of the form |
85 | | - # [[(batch_size, features_a), (batch_size, features_b), ...]] |
86 | | - while not self._all_batches_loaded: |
87 | | - yield from self._batch_generator(source=self._data) |
88 | | - |
89 | | - def _batch_generator(self, source) -> Generator[List[numpy.ndarray], None, None]: |
90 | | - # batches from source |
91 | | - for sample in source: |
92 | | - self._batch_buffer.append(sample) |
93 | | - if self._buffer_is_full: |
94 | | - _batch = self._make_batch() |
95 | | - yield _batch |
96 | | - self._batch_buffer = [] |
97 | | - self._batches_created += 1 |
98 | | - if self._all_batches_loaded: |
99 | | - break |
100 | | - |
101 | | - def _init_batch_template( |
102 | | - self, |
103 | | - ) -> Iterable[Union[List[numpy.ndarray], numpy.ndarray]]: |
104 | | - # A placeholder for batches |
105 | | - return [ |
106 | | - numpy.ascontiguousarray( |
107 | | - numpy.zeros((self._batch_size, *_input.shape), dtype=_input.dtype) |
108 | | - ) |
109 | | - for _input in self._data[0] |
110 | | - ] |
111 | | - |
112 | | - def _make_batch(self) -> Iterable[Union[numpy.ndarray, List[numpy.ndarray]]]: |
113 | | - # Copy contents of buffer to batch placeholder |
114 | | - # and return A list of numpy array(s) representing the batch |
115 | | - |
116 | | - batch = [ |
117 | | - numpy.stack([sample[idx] for sample in self._batch_buffer], out=template) |
118 | | - for idx, template in enumerate(self._batch_template) |
119 | | - ] |
120 | | - |
121 | | - if not self._was_wrapped_originally: |
122 | | - # unwrap outer list |
123 | | - batch = batch[0] |
124 | | - return batch |
125 | | - |
126 | | - |
127 | 34 | class Dataset(Iterable): |
128 | 35 | """ |
129 | 36 | A numpy dataset implementation |
@@ -168,22 +75,6 @@ def data(self) -> List[Union[numpy.ndarray, Dict[str, numpy.ndarray]]]: |
168 | 75 | """ |
169 | 76 | return self._data |
170 | 77 |
|
171 | | - def iter_batches( |
172 | | - self, batch_size: int, iterations: int |
173 | | - ) -> Generator[List[numpy.ndarray], None, None]: |
174 | | - """ |
175 | | - A function to iterate over data in batches |
176 | | -
|
177 | | - :param batch_size: non-negative integer representing the size of each |
178 | | - :param iterations: non-negative integer representing |
179 | | - the number of batches to return |
180 | | - :returns: A generator for batches, each batch is enclosed in a list |
181 | | - Each batch is of the form [(batch_size, *feature_shape)] |
182 | | - """ |
183 | | - return _BatchLoader( |
184 | | - data=self.data, batch_size=batch_size, iterations=iterations |
185 | | - ) |
186 | | - |
187 | 78 |
|
188 | 79 | class RandomDataset(Dataset): |
189 | 80 | """ |
@@ -257,21 +148,21 @@ def __init__( |
257 | 148 | iter_steps: int = 0, |
258 | 149 | batch_as_list: bool = False, |
259 | 150 | ): |
260 | | - self._datasets = OrderedDict([(dataset.name, dataset) for dataset in datasets]) |
261 | | - self._batch_size = batch_size |
262 | | - self._iter_steps = iter_steps |
263 | | - self._batch_as_list = batch_as_list |
264 | | - self._num_items = -1 |
265 | | - |
266 | 151 | if len(datasets) < 1: |
267 | 152 | raise ValueError("len(datasets) must be > 0") |
268 | 153 |
|
269 | | - if self._batch_size < 1: |
| 154 | + if batch_size < 1: |
270 | 155 | raise ValueError("batch_size must be > 0") |
271 | 156 |
|
272 | | - if self._iter_steps < -1: |
| 157 | + if iter_steps < -1: |
273 | 158 | raise ValueError("iter_steps must be >= -1") |
274 | 159 |
|
| 160 | + self._datasets = OrderedDict([(dataset.name, dataset) for dataset in datasets]) |
| 161 | + self._batch_size = batch_size |
| 162 | + self._iter_steps = iter_steps |
| 163 | + self._batch_as_list = batch_as_list |
| 164 | + self._num_items = -1 |
| 165 | + |
275 | 166 | for dataset in datasets: |
276 | 167 | num_dataset_items = len(dataset) |
277 | 168 |
|
|
0 commit comments