Skip to content

Commit 3db230e

Browse files
committed
fix: sampler length works as intended and check for empty datastream
also added iter method. tests for empty datastreams and length of datastream is correctly implemented. changed length of merged datastream
1 parent 920a1e2 commit 3db230e

File tree

2 files changed

+72
-14
lines changed

2 files changed

+72
-14
lines changed

datastream/datastream.py

Lines changed: 67 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
from __future__ import annotations
2-
from pydantic import BaseModel
2+
from pydantic import BaseModel, PositiveInt
33
from typing import (
44
Tuple,
55
Dict,
@@ -56,6 +56,9 @@ def __init__(
5656
dataset: Dataset[T],
5757
sampler: torch.utils.data.Sampler = None
5858
):
59+
if len(dataset) == 0:
60+
raise ValueError('Cannot create datastream from empty dataset')
61+
5962
super().__init__(
6063
dataset=dataset,
6164
sampler=(
@@ -67,6 +70,9 @@ def __init__(
6770

6871
def __len__(self):
6972
return len(self.sampler)
73+
74+
def __iter__(self):
75+
return map(self.dataset.__getitem__, iter(self.sampler))
7076

7177
@staticmethod
7278
def merge(datastreams_and_ns: Tuple[Union[
@@ -151,6 +157,10 @@ def data_loader(
151157
'''
152158
Get ``torch.utils.data.DataLoader`` for use in pytorch pipeline.
153159
160+
The argument ``n_batches_per_epoch`` overrides the underlying length
161+
of the dataset. If the epoch ends before the full dataset has been
162+
processed then it will continue from the same point the next epoch.
163+
154164
>>> data_loader = (
155165
... Datastream(Dataset.from_subscriptable([5, 5, 5]))
156166
... .data_loader(batch_size=5, n_batches_per_epoch=10)
@@ -218,13 +228,15 @@ def sample_proportion(
218228

219229
def take(
220230
self: Datastream[T],
221-
n_samples: int,
231+
n_samples: PositiveInt,
222232
) -> Datastream[T]:
223233
'''
224234
Like :func:`Datastream.sample_proportion` but specify the number of
225235
samples instead of a proportion.
226236
'''
227-
return self.sample_proportion(min(1, n_samples / len(self)))
237+
if n_samples < 1:
238+
raise ValueError('n_samples must be greater than or equal to 1')
239+
return self.sample_proportion(n_samples / len(self))
228240

229241
def state_dict(self) -> Dict:
230242
'''Get state of datastream. Useful for checkpointing sample weights.'''
@@ -278,6 +290,28 @@ def cache(
278290
)
279291

280292

293+
def test_infinite():
294+
295+
datastream = Datastream(Dataset.from_subscriptable(list('abc')))
296+
it = iter(datastream.data_loader(batch_size=8, n_batches_per_epoch=10))
297+
for _ in range(10):
298+
batch = next(it)
299+
300+
301+
def test_iter():
302+
303+
datastream = Datastream(Dataset.from_subscriptable(list('abc')))
304+
assert len(list(datastream)) == 3
305+
306+
307+
def test_empty():
308+
309+
import pytest
310+
311+
with pytest.raises(ValueError):
312+
Datastream(Dataset.from_subscriptable(list()))
313+
314+
281315
def test_datastream_merge():
282316

283317
datastream = Datastream.merge([
@@ -289,10 +323,16 @@ def test_datastream_merge():
289323
for _ in range(2):
290324
index = next(it)
291325

292-
it = iter(datastream.data_loader(batch_size=8))
326+
it = iter(datastream.data_loader(batch_size=8, n_batches_per_epoch=10))
293327
for _ in range(10):
294328
batch = next(it)
295329

330+
assert (
331+
len(list(
332+
datastream.data_loader(batch_size=1)
333+
)) == len(datastream)
334+
)
335+
296336

297337
def test_datastream_zip():
298338

@@ -314,6 +354,12 @@ def test_datastream_zip():
314354
assert batch[1][0] == 3 and batch[1][1] == 4 and batch[1][2] == 5
315355
assert batch[2][0] == 6 and batch[2][1] == 7 and batch[2][2] == 6
316356

357+
assert (
358+
len(list(
359+
zipped_datastream.data_loader(batch_size=1)
360+
)) == len(zipped_datastream)
361+
)
362+
317363

318364
def test_datastream_merge_zip_merge():
319365
'''
@@ -442,3 +488,20 @@ def test_multi_sample():
442488
zero_indices = set([index for _, index in output[:2]])
443489
for number, index in output2:
444490
assert index not in zero_indices
491+
492+
493+
def test_take():
494+
495+
import pytest
496+
497+
datastream = Datastream(Dataset.from_subscriptable(list('abc'))).take(2)
498+
assert len(list(datastream.data_loader(batch_size=1))) == 2
499+
500+
with pytest.raises(ValueError):
501+
Datastream(Dataset.from_subscriptable(list('abc'))).take(0)
502+
503+
datastream = Datastream.merge([
504+
Datastream(Dataset.from_subscriptable(list('abc'))),
505+
Datastream(Dataset.from_subscriptable(list('d'))),
506+
])
507+
assert len(list(datastream.take(2).data_loader(batch_size=1))) == 2

datastream/samplers.py

Lines changed: 5 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22
from pydantic import BaseModel
33
from typing import Tuple, Callable, Iterable
44
from functools import partial
5-
from itertools import chain
5+
from itertools import chain, islice
66
import torch
77
from datastream.tools import starcompose, repeat_map_chain
88
from datastream import Dataset
@@ -96,14 +96,11 @@ def __len__(self):
9696
return self.length
9797

9898
def __iter__(self):
99-
return iter(self.merged_samplers)
99+
return islice(self.merged_samplers, self.length)
100100

101101
@staticmethod
102102
def merged_samplers_length(samplers):
103-
return (
104-
max([len(sampler) for sampler in samplers])
105-
* len(samplers)
106-
)
103+
return max([len(sampler) for sampler in samplers])
107104

108105
@staticmethod
109106
def merge_samplers(samplers, datasets, ns):
@@ -182,7 +179,7 @@ def __len__(self):
182179
return self.length
183180

184181
def __iter__(self):
185-
return iter(self.zipped_samplers)
182+
return islice(self.zipped_samplers, self.length)
186183

187184
@staticmethod
188185
def zip_samplers(samplers, datasets):
@@ -268,9 +265,7 @@ def __len__(self):
268265
return self.length
269266

270267
def __iter__(self):
271-
it = self.merged_samplers
272-
for _ in range(self.length):
273-
yield next(it)
268+
return islice(self.merged_samplers, self.length)
274269

275270
@staticmethod
276271
def merge_samplers(samplers, ns):

0 commit comments

Comments
 (0)