1
1
from __future__ import annotations
2
- from pydantic import BaseModel
2
+ from pydantic import BaseModel , PositiveInt
3
3
from typing import (
4
4
Tuple ,
5
5
Dict ,
@@ -56,6 +56,9 @@ def __init__(
56
56
dataset : Dataset [T ],
57
57
sampler : torch .utils .data .Sampler = None
58
58
):
59
+ if len (dataset ) == 0 :
60
+ raise ValueError ('Cannot create datastream from empty dataset' )
61
+
59
62
super ().__init__ (
60
63
dataset = dataset ,
61
64
sampler = (
@@ -67,6 +70,9 @@ def __init__(
67
70
68
71
def __len__ (self ):
69
72
return len (self .sampler )
73
+
74
+ def __iter__ (self ):
75
+ return map (self .dataset .__getitem__ , iter (self .sampler ))
70
76
71
77
@staticmethod
72
78
def merge (datastreams_and_ns : Tuple [Union [
@@ -151,6 +157,10 @@ def data_loader(
151
157
'''
152
158
Get ``torch.utils.data.DataLoader`` for use in pytorch pipeline.
153
159
160
+ The argument ``n_batches_per_epoch`` overrides the underlying length
161
+ of the dataset. If the epoch ends before the full dataset has been
162
+ processed then it will continue from the same point the next epoch.
163
+
154
164
>>> data_loader = (
155
165
... Datastream(Dataset.from_subscriptable([5, 5, 5]))
156
166
... .data_loader(batch_size=5, n_batches_per_epoch=10)
@@ -218,13 +228,15 @@ def sample_proportion(
218
228
219
229
def take (
220
230
self : Datastream [T ],
221
- n_samples : int ,
231
+ n_samples : PositiveInt ,
222
232
) -> Datastream [T ]:
223
233
'''
224
234
Like :func:`Datastream.sample_proportion` but specify the number of
225
235
samples instead of a proportion.
226
236
'''
227
- return self .sample_proportion (min (1 , n_samples / len (self )))
237
+ if n_samples < 1 :
238
+ raise ValueError ('n_samples must be greater than or equal to 1' )
239
+ return self .sample_proportion (n_samples / len (self ))
228
240
229
241
def state_dict (self ) -> Dict :
230
242
'''Get state of datastream. Useful for checkpointing sample weights.'''
@@ -278,6 +290,28 @@ def cache(
278
290
)
279
291
280
292
293
+ def test_infinite ():
294
+
295
+ datastream = Datastream (Dataset .from_subscriptable (list ('abc' )))
296
+ it = iter (datastream .data_loader (batch_size = 8 , n_batches_per_epoch = 10 ))
297
+ for _ in range (10 ):
298
+ batch = next (it )
299
+
300
+
301
+ def test_iter ():
302
+
303
+ datastream = Datastream (Dataset .from_subscriptable (list ('abc' )))
304
+ assert len (list (datastream )) == 3
305
+
306
+
307
+ def test_empty ():
308
+
309
+ import pytest
310
+
311
+ with pytest .raises (ValueError ):
312
+ Datastream (Dataset .from_subscriptable (list ()))
313
+
314
+
281
315
def test_datastream_merge ():
282
316
283
317
datastream = Datastream .merge ([
@@ -289,10 +323,16 @@ def test_datastream_merge():
289
323
for _ in range (2 ):
290
324
index = next (it )
291
325
292
- it = iter (datastream .data_loader (batch_size = 8 ))
326
+ it = iter (datastream .data_loader (batch_size = 8 , n_batches_per_epoch = 10 ))
293
327
for _ in range (10 ):
294
328
batch = next (it )
295
329
330
+ assert (
331
+ len (list (
332
+ datastream .data_loader (batch_size = 1 )
333
+ )) == len (datastream )
334
+ )
335
+
296
336
297
337
def test_datastream_zip ():
298
338
@@ -314,6 +354,12 @@ def test_datastream_zip():
314
354
assert batch [1 ][0 ] == 3 and batch [1 ][1 ] == 4 and batch [1 ][2 ] == 5
315
355
assert batch [2 ][0 ] == 6 and batch [2 ][1 ] == 7 and batch [2 ][2 ] == 6
316
356
357
+ assert (
358
+ len (list (
359
+ zipped_datastream .data_loader (batch_size = 1 )
360
+ )) == len (zipped_datastream )
361
+ )
362
+
317
363
318
364
def test_datastream_merge_zip_merge ():
319
365
'''
@@ -442,3 +488,20 @@ def test_multi_sample():
442
488
zero_indices = set ([index for _ , index in output [:2 ]])
443
489
for number , index in output2 :
444
490
assert index not in zero_indices
491
+
492
+
493
+ def test_take ():
494
+
495
+ import pytest
496
+
497
+ datastream = Datastream (Dataset .from_subscriptable (list ('abc' ))).take (2 )
498
+ assert len (list (datastream .data_loader (batch_size = 1 ))) == 2
499
+
500
+ with pytest .raises (ValueError ):
501
+ Datastream (Dataset .from_subscriptable (list ('abc' ))).take (0 )
502
+
503
+ datastream = Datastream .merge ([
504
+ Datastream (Dataset .from_subscriptable (list ('abc' ))),
505
+ Datastream (Dataset .from_subscriptable (list ('d' ))),
506
+ ])
507
+ assert len (list (datastream .take (2 ).data_loader (batch_size = 1 ))) == 2
0 commit comments