Skip to content
This repository was archived by the owner on Jun 3, 2025. It is now read-only.

Commit 084f39e

Browse files
authored
Refactor: Remove _BatchLoader, add tests for DataLoader (#103)
* Refactor: Remove _BatchLoader, add tests for DataLoader * Add:descriptions and return Types for test fixtures
1 parent 5ca44f8 commit 084f39e

File tree

2 files changed

+121
-302
lines changed

2 files changed

+121
-302
lines changed

src/sparsezoo/utils/data.py

Lines changed: 9 additions & 118 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@
1919
import logging
2020
import math
2121
from collections import OrderedDict
22-
from typing import Dict, Generator, Iterable, Iterator, List, Tuple, Union
22+
from typing import Dict, Iterable, Iterator, List, Tuple, Union
2323

2424
import numpy
2525

@@ -31,99 +31,6 @@
3131
_LOGGER = logging.getLogger(__name__)
3232

3333

34-
# A utility class to load data in batches for fixed number of iterations
35-
36-
37-
class _BatchLoader:
38-
__slots__ = [
39-
"_data",
40-
"_batch_size",
41-
"_was_wrapped_originally",
42-
"_iterations",
43-
"_batch_buffer",
44-
"_batch_template",
45-
"_batches_created",
46-
]
47-
48-
def __init__(
49-
self,
50-
data: Iterable[Union[numpy.ndarray, List[numpy.ndarray]]],
51-
batch_size: int,
52-
iterations: int,
53-
):
54-
self._data = data
55-
self._was_wrapped_originally = type(self._data[0]) is list
56-
if not self._was_wrapped_originally:
57-
self._data = [self._data]
58-
self._batch_size = batch_size
59-
self._iterations = iterations
60-
if batch_size <= 0 or iterations <= 0:
61-
raise ValueError(
62-
f"Both batch size and number of iterations should be positive, "
63-
f"supplied values (batch_size, iterations):{(batch_size, iterations)}"
64-
)
65-
66-
self._batch_buffer = []
67-
self._batch_template = self._init_batch_template()
68-
self._batches_created = 0
69-
70-
def __iter__(self) -> Generator[List[numpy.ndarray], None, None]:
71-
yield from self._multi_input_batch_generator()
72-
73-
@property
74-
def _buffer_is_full(self) -> bool:
75-
return len(self._batch_buffer) == self._batch_size
76-
77-
@property
78-
def _all_batches_loaded(self) -> bool:
79-
return self._batches_created >= self._iterations
80-
81-
def _multi_input_batch_generator(
82-
self,
83-
) -> Generator[List[numpy.ndarray], None, None]:
84-
# A generator for with each element of the form
85-
# [[(batch_size, features_a), (batch_size, features_b), ...]]
86-
while not self._all_batches_loaded:
87-
yield from self._batch_generator(source=self._data)
88-
89-
def _batch_generator(self, source) -> Generator[List[numpy.ndarray], None, None]:
90-
# batches from source
91-
for sample in source:
92-
self._batch_buffer.append(sample)
93-
if self._buffer_is_full:
94-
_batch = self._make_batch()
95-
yield _batch
96-
self._batch_buffer = []
97-
self._batches_created += 1
98-
if self._all_batches_loaded:
99-
break
100-
101-
def _init_batch_template(
102-
self,
103-
) -> Iterable[Union[List[numpy.ndarray], numpy.ndarray]]:
104-
# A placeholder for batches
105-
return [
106-
numpy.ascontiguousarray(
107-
numpy.zeros((self._batch_size, *_input.shape), dtype=_input.dtype)
108-
)
109-
for _input in self._data[0]
110-
]
111-
112-
def _make_batch(self) -> Iterable[Union[numpy.ndarray, List[numpy.ndarray]]]:
113-
# Copy contents of buffer to batch placeholder
114-
# and return A list of numpy array(s) representing the batch
115-
116-
batch = [
117-
numpy.stack([sample[idx] for sample in self._batch_buffer], out=template)
118-
for idx, template in enumerate(self._batch_template)
119-
]
120-
121-
if not self._was_wrapped_originally:
122-
# unwrap outer list
123-
batch = batch[0]
124-
return batch
125-
126-
12734
class Dataset(Iterable):
12835
"""
12936
A numpy dataset implementation
@@ -168,22 +75,6 @@ def data(self) -> List[Union[numpy.ndarray, Dict[str, numpy.ndarray]]]:
16875
"""
16976
return self._data
17077

171-
def iter_batches(
172-
self, batch_size: int, iterations: int
173-
) -> Generator[List[numpy.ndarray], None, None]:
174-
"""
175-
A function to iterate over data in batches
176-
177-
:param batch_size: non-negative integer representing the size of each
178-
:param iterations: non-negative integer representing
179-
the number of batches to return
180-
:returns: A generator for batches, each batch is enclosed in a list
181-
Each batch is of the form [(batch_size, *feature_shape)]
182-
"""
183-
return _BatchLoader(
184-
data=self.data, batch_size=batch_size, iterations=iterations
185-
)
186-
18778

18879
class RandomDataset(Dataset):
18980
"""
@@ -257,21 +148,21 @@ def __init__(
257148
iter_steps: int = 0,
258149
batch_as_list: bool = False,
259150
):
260-
self._datasets = OrderedDict([(dataset.name, dataset) for dataset in datasets])
261-
self._batch_size = batch_size
262-
self._iter_steps = iter_steps
263-
self._batch_as_list = batch_as_list
264-
self._num_items = -1
265-
266151
if len(datasets) < 1:
267152
raise ValueError("len(datasets) must be > 0")
268153

269-
if self._batch_size < 1:
154+
if batch_size < 1:
270155
raise ValueError("batch_size must be > 0")
271156

272-
if self._iter_steps < -1:
157+
if iter_steps < -1:
273158
raise ValueError("iter_steps must be >= -1")
274159

160+
self._datasets = OrderedDict([(dataset.name, dataset) for dataset in datasets])
161+
self._batch_size = batch_size
162+
self._iter_steps = iter_steps
163+
self._batch_as_list = batch_as_list
164+
self._num_items = -1
165+
275166
for dataset in datasets:
276167
num_dataset_items = len(dataset)
277168

0 commit comments

Comments
 (0)