|
19 | 19 | import logging |
20 | 20 | import math |
21 | 21 | from collections import OrderedDict |
22 | | -from typing import Dict, Iterable, Iterator, List, Tuple, Union |
| 22 | +from typing import Dict, Generator, Iterable, Iterator, List, Tuple, Union |
23 | 23 |
|
24 | 24 | import numpy |
25 | 25 |
|
|
28 | 28 |
|
29 | 29 | __all__ = ["Dataset", "RandomDataset", "DataLoader"] |
30 | 30 |
|
31 | | - |
32 | 31 | _LOGGER = logging.getLogger(__name__) |
33 | 32 |
|
34 | 33 |
|
| 34 | +# A utility class to load data in batches for fixed number of iterations |
| 35 | + |
| 36 | + |
| 37 | +class _BatchLoader: |
| 38 | + __slots__ = [ |
| 39 | + "_data", |
| 40 | + "_batch_size", |
| 41 | + "_was_wrapped_originally", |
| 42 | + "_iterations", |
| 43 | + "_batch_buffer", |
| 44 | + "_batch_template", |
| 45 | + "_batches_created", |
| 46 | + ] |
| 47 | + |
| 48 | + def __init__( |
| 49 | + self, |
| 50 | + data: Iterable[Union[numpy.ndarray, List[numpy.ndarray]]], |
| 51 | + batch_size: int, |
| 52 | + iterations: int, |
| 53 | + ): |
| 54 | + self._data = data |
| 55 | + self._was_wrapped_originally = type(self._data[0]) is list |
| 56 | + if not self._was_wrapped_originally: |
| 57 | + self._data = [self._data] |
| 58 | + self._batch_size = batch_size |
| 59 | + self._iterations = iterations |
| 60 | + if batch_size <= 0 or iterations <= 0: |
| 61 | + raise ValueError( |
| 62 | + f"Both batch size and number of iterations should be positive, " |
| 63 | + f"supplied values (batch_size, iterations):{(batch_size, iterations)}" |
| 64 | + ) |
| 65 | + |
| 66 | + self._batch_buffer = [] |
| 67 | + self._batch_template = self._init_batch_template() |
| 68 | + self._batches_created = 0 |
| 69 | + |
| 70 | + def __iter__(self) -> Generator[List[numpy.ndarray], None, None]: |
| 71 | + yield from self._multi_input_batch_generator() |
| 72 | + |
| 73 | + @property |
| 74 | + def _buffer_is_full(self) -> bool: |
| 75 | + return len(self._batch_buffer) == self._batch_size |
| 76 | + |
| 77 | + @property |
| 78 | + def _all_batches_loaded(self) -> bool: |
| 79 | + return self._batches_created >= self._iterations |
| 80 | + |
| 81 | + def _multi_input_batch_generator( |
| 82 | + self, |
| 83 | + ) -> Generator[List[numpy.ndarray], None, None]: |
| 84 | + # A generator for with each element of the form |
| 85 | + # [[(batch_size, features_a), (batch_size, features_b), ...]] |
| 86 | + while not self._all_batches_loaded: |
| 87 | + yield from self._batch_generator(source=self._data) |
| 88 | + |
| 89 | + def _batch_generator(self, source) -> Generator[List[numpy.ndarray], None, None]: |
| 90 | + # batches from source |
| 91 | + for sample in source: |
| 92 | + self._batch_buffer.append(sample) |
| 93 | + if self._buffer_is_full: |
| 94 | + _batch = self._make_batch() |
| 95 | + yield _batch |
| 96 | + self._batch_buffer = [] |
| 97 | + self._batches_created += 1 |
| 98 | + if self._all_batches_loaded: |
| 99 | + break |
| 100 | + |
| 101 | + def _init_batch_template( |
| 102 | + self, |
| 103 | + ) -> Iterable[Union[List[numpy.ndarray], numpy.ndarray]]: |
| 104 | + # A placeholder for batches |
| 105 | + return [ |
| 106 | + numpy.ascontiguousarray( |
| 107 | + numpy.zeros((self._batch_size, *_input.shape), dtype=_input.dtype) |
| 108 | + ) |
| 109 | + for _input in self._data[0] |
| 110 | + ] |
| 111 | + |
| 112 | + def _make_batch(self) -> Iterable[Union[numpy.ndarray, List[numpy.ndarray]]]: |
| 113 | + # Copy contents of buffer to batch placeholder |
| 114 | + # and return A list of numpy array(s) representing the batch |
| 115 | + |
| 116 | + batch = [ |
| 117 | + numpy.stack([sample[idx] for sample in self._batch_buffer], out=template) |
| 118 | + for idx, template in enumerate(self._batch_template) |
| 119 | + ] |
| 120 | + |
| 121 | + if not self._was_wrapped_originally: |
| 122 | + # unwrap outer list |
| 123 | + batch = batch[0] |
| 124 | + return batch |
| 125 | + |
| 126 | + |
35 | 127 | class Dataset(Iterable): |
36 | 128 | """ |
37 | 129 | A numpy dataset implementation |
@@ -76,6 +168,22 @@ def data(self) -> List[Union[numpy.ndarray, Dict[str, numpy.ndarray]]]: |
76 | 168 | """ |
77 | 169 | return self._data |
78 | 170 |
|
| 171 | + def iter_batches( |
| 172 | + self, batch_size: int, iterations: int |
| 173 | + ) -> Generator[List[numpy.ndarray], None, None]: |
| 174 | + """ |
| 175 | + A function to iterate over data in batches |
| 176 | +
|
| 177 | + :param batch_size: non-negative integer representing the size of each |
| 178 | + :param iterations: non-negative integer representing |
| 179 | + the number of batches to return |
| 180 | + :returns: A generator for batches, each batch is enclosed in a list |
| 181 | + Each batch is of the form [(batch_size, *feature_shape)] |
| 182 | + """ |
| 183 | + return _BatchLoader( |
| 184 | + data=self.data, batch_size=batch_size, iterations=iterations |
| 185 | + ) |
| 186 | + |
79 | 187 |
|
80 | 188 | class RandomDataset(Dataset): |
81 | 189 | """ |
|
0 commit comments