Skip to content
This repository was archived by the owner on Jun 3, 2025. It is now read-only.

Commit 5ca44f8

Browse files
rahul-tulibfineran
andauthored
Add support for batched Iteration (#102)
* Add:Support for batched iteration * Refactor:Data to be list of numpy arrays * Add:test for tuple of numpy arrays * Add:batched iteration support * iter_batches function in Dataset class returns a BatchLoader object * BatchLoader class added * Moved utils.py * Renamed utils.py * Created test_data.py * Cleanup * Fix Typo * Update src/sparsezoo/utils/data.py Co-authored-by: Benjamin Fineran <[email protected]> * Fix:Single-Input cases Address:PR review comments * Update:Rename tests/utils.py to tests/helpers.py Fix:Unwrapping Single Input Errors * Update:Rename tests/utils.py to tests/helpers.py Fix:Unwrapping Single Input Errors Update:tests_data.py * Update:fixes from PR comments Co-authored-by: Benjamin Fineran <[email protected]>
1 parent 7538b0d commit 5ca44f8

File tree

13 files changed

+346
-11
lines changed

13 files changed

+346
-11
lines changed

src/sparsezoo/utils/data.py

Lines changed: 110 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@
1919
import logging
2020
import math
2121
from collections import OrderedDict
22-
from typing import Dict, Iterable, Iterator, List, Tuple, Union
22+
from typing import Dict, Generator, Iterable, Iterator, List, Tuple, Union
2323

2424
import numpy
2525

@@ -28,10 +28,102 @@
2828

2929
__all__ = ["Dataset", "RandomDataset", "DataLoader"]
3030

31-
3231
_LOGGER = logging.getLogger(__name__)
3332

3433

34+
# A utility class to load data in batches for fixed number of iterations
35+
36+
37+
class _BatchLoader:
38+
__slots__ = [
39+
"_data",
40+
"_batch_size",
41+
"_was_wrapped_originally",
42+
"_iterations",
43+
"_batch_buffer",
44+
"_batch_template",
45+
"_batches_created",
46+
]
47+
48+
def __init__(
49+
self,
50+
data: Iterable[Union[numpy.ndarray, List[numpy.ndarray]]],
51+
batch_size: int,
52+
iterations: int,
53+
):
54+
self._data = data
55+
self._was_wrapped_originally = type(self._data[0]) is list
56+
if not self._was_wrapped_originally:
57+
self._data = [self._data]
58+
self._batch_size = batch_size
59+
self._iterations = iterations
60+
if batch_size <= 0 or iterations <= 0:
61+
raise ValueError(
62+
f"Both batch size and number of iterations should be positive, "
63+
f"supplied values (batch_size, iterations):{(batch_size, iterations)}"
64+
)
65+
66+
self._batch_buffer = []
67+
self._batch_template = self._init_batch_template()
68+
self._batches_created = 0
69+
70+
def __iter__(self) -> Generator[List[numpy.ndarray], None, None]:
71+
yield from self._multi_input_batch_generator()
72+
73+
@property
74+
def _buffer_is_full(self) -> bool:
75+
return len(self._batch_buffer) == self._batch_size
76+
77+
@property
78+
def _all_batches_loaded(self) -> bool:
79+
return self._batches_created >= self._iterations
80+
81+
def _multi_input_batch_generator(
82+
self,
83+
) -> Generator[List[numpy.ndarray], None, None]:
84+
# A generator for with each element of the form
85+
# [[(batch_size, features_a), (batch_size, features_b), ...]]
86+
while not self._all_batches_loaded:
87+
yield from self._batch_generator(source=self._data)
88+
89+
def _batch_generator(self, source) -> Generator[List[numpy.ndarray], None, None]:
90+
# batches from source
91+
for sample in source:
92+
self._batch_buffer.append(sample)
93+
if self._buffer_is_full:
94+
_batch = self._make_batch()
95+
yield _batch
96+
self._batch_buffer = []
97+
self._batches_created += 1
98+
if self._all_batches_loaded:
99+
break
100+
101+
def _init_batch_template(
102+
self,
103+
) -> Iterable[Union[List[numpy.ndarray], numpy.ndarray]]:
104+
# A placeholder for batches
105+
return [
106+
numpy.ascontiguousarray(
107+
numpy.zeros((self._batch_size, *_input.shape), dtype=_input.dtype)
108+
)
109+
for _input in self._data[0]
110+
]
111+
112+
def _make_batch(self) -> Iterable[Union[numpy.ndarray, List[numpy.ndarray]]]:
113+
# Copy contents of buffer to batch placeholder
114+
# and return A list of numpy array(s) representing the batch
115+
116+
batch = [
117+
numpy.stack([sample[idx] for sample in self._batch_buffer], out=template)
118+
for idx, template in enumerate(self._batch_template)
119+
]
120+
121+
if not self._was_wrapped_originally:
122+
# unwrap outer list
123+
batch = batch[0]
124+
return batch
125+
126+
35127
class Dataset(Iterable):
36128
"""
37129
A numpy dataset implementation
@@ -76,6 +168,22 @@ def data(self) -> List[Union[numpy.ndarray, Dict[str, numpy.ndarray]]]:
76168
"""
77169
return self._data
78170

171+
def iter_batches(
172+
self, batch_size: int, iterations: int
173+
) -> Generator[List[numpy.ndarray], None, None]:
174+
"""
175+
A function to iterate over data in batches
176+
177+
:param batch_size: non-negative integer representing the size of each
178+
:param iterations: non-negative integer representing
179+
the number of batches to return
180+
:returns: A generator for batches, each batch is enclosed in a list
181+
Each batch is of the form [(batch_size, *feature_shape)]
182+
"""
183+
return _BatchLoader(
184+
data=self.data, batch_size=batch_size, iterations=iterations
185+
)
186+
79187

80188
class RandomDataset(Dataset):
81189
"""
File renamed without changes.

tests/sparsezoo/models/classification/test_efficientnet.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@
1515
import pytest
1616

1717
from sparsezoo.models.classification import efficientnet_b0, efficientnet_b4
18-
from tests.sparsezoo.utils import model_constructor
18+
from tests.sparsezoo.helpers import model_constructor
1919

2020

2121
@pytest.mark.parametrize(

tests/sparsezoo/models/classification/test_inception.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@
1515
import pytest
1616

1717
from sparsezoo.models.classification import inception_v3
18-
from tests.sparsezoo.utils import model_constructor
18+
from tests.sparsezoo.helpers import model_constructor
1919

2020

2121
@pytest.mark.parametrize(

tests/sparsezoo/models/classification/test_mobilenet.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@
1414
import pytest
1515

1616
from sparsezoo.models.classification import mobilenet_v1, mobilenet_v2
17-
from tests.sparsezoo.utils import model_constructor
17+
from tests.sparsezoo.helpers import model_constructor
1818

1919

2020
@pytest.mark.parametrize(

tests/sparsezoo/models/classification/test_resnet.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@
2323
resnet_101_2x,
2424
resnet_152,
2525
)
26-
from tests.sparsezoo.utils import model_constructor
26+
from tests.sparsezoo.helpers import model_constructor
2727

2828

2929
@pytest.mark.parametrize(

tests/sparsezoo/models/classification/test_vgg.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,7 @@
2424
vgg_19,
2525
vgg_19bn,
2626
)
27-
from tests.sparsezoo.utils import model_constructor
27+
from tests.sparsezoo.helpers import model_constructor
2828

2929

3030
@pytest.mark.parametrize(

tests/sparsezoo/models/detection/test_ssd.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@
1515
import pytest
1616

1717
from sparsezoo.models.detection import ssd_resnet50_300
18-
from tests.sparsezoo.utils import model_constructor
18+
from tests.sparsezoo.helpers import model_constructor
1919

2020

2121
@pytest.mark.parametrize(

tests/sparsezoo/models/detection/test_yolo.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@
1515
import pytest
1616

1717
from sparsezoo.models.detection import yolo_v3
18-
from tests.sparsezoo.utils import model_constructor
18+
from tests.sparsezoo.helpers import model_constructor
1919

2020

2121
@pytest.mark.parametrize(

tests/sparsezoo/models/test_zoo.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@
1919

2020
from sparsezoo import Zoo
2121
from sparsezoo.utils import CACHE_DIR
22-
from tests.sparsezoo.utils import validate_downloaded_model
22+
from tests.sparsezoo.helpers import validate_downloaded_model
2323

2424

2525
@pytest.mark.parametrize(

0 commit comments

Comments
 (0)