Skip to content

Commit 68dce91

Browse files
authored
Merge pull request #45 from BirkhoffG/generator-backends
Use `Generator` to control the randomness for each backend dataloader
2 parents a100a1d + ee1214f commit 68dce91

File tree

19 files changed

+253
-33
lines changed

19 files changed

+253
-33
lines changed

integration_tests/hf_test.py

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,9 @@
11
import jax_dataloader as jdl
22
import numpy as np
33
import datasets as hfds
4+
import jax.numpy as jnp
5+
import jax.tree_util as jtu
6+
import jax.random as jrand
47

58

69
def test_hf():
@@ -11,3 +14,30 @@ def test_hf():
1114
x, y = batch['feats'], batch['labels']
1215
z = x + y
1316
assert isinstance(z, np.ndarray)
17+
18+
def test_generator():
19+
ds = hfds.Dataset.from_dict({"feats": np.ones((10, 3)), "labels": np.ones((10, 3))})
20+
21+
g1 = jdl.Generator()
22+
g2 = jrand.PRNGKey(jdl.get_config().global_seed)
23+
24+
# Create two different dataloaders with different generators
25+
dl = jdl.DataLoader(ds, 'jax', batch_size=2, generator=g1, shuffle=True)
26+
batch = next(iter(dl))
27+
28+
dl = jdl.DataLoader(ds, 'jax', batch_size=2, generator=g2, shuffle=True)
29+
new_batch = next(iter(dl))
30+
31+
# Check that batches are equal using tree_map
32+
def are_equal(a, b):
33+
return jnp.all(a == b)
34+
35+
# Map the equality function over the entire pytree structure
36+
equal_elements = jtu.tree_map(are_equal, batch, new_batch)
37+
38+
# Check all elements are True
39+
all_equal = all(jtu.tree_leaves(equal_elements))
40+
assert all_equal
41+
42+
# Also verify the tree structures match
43+
assert jtu.tree_structure(batch) == jtu.tree_structure(new_batch)

integration_tests/jax_test.py

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,8 @@
11
import jax_dataloader as jdl
22
import jax.numpy as jnp
33
import pytest
4+
import jax.random as jrand
5+
import jax.tree_util as jtu
46

57

68
def test_jax():
@@ -27,3 +29,29 @@ def test_tf():
2729
batch = next(iter(dl))
2830
for x, y in dl: z = x + y
2931

32+
def test_generator():
33+
ds = jdl.ArrayDataset(jnp.ones((10, 3)), jnp.ones((10, 3)))
34+
35+
g1 = jdl.Generator()
36+
g2 = jrand.PRNGKey(jdl.get_config().global_seed)
37+
38+
# Create two different dataloaders with different generators
39+
dl = jdl.DataLoader(ds, 'jax', batch_size=2, generator=g1, shuffle=True)
40+
batch = next(iter(dl))
41+
42+
dl = jdl.DataLoader(ds, 'jax', batch_size=2, generator=g2, shuffle=True)
43+
new_batch = next(iter(dl))
44+
45+
# Check that batches are equal using tree_map
46+
def are_equal(a, b):
47+
return jnp.all(a == b)
48+
49+
# Map the equality function over the entire pytree structure
50+
equal_elements = jtu.tree_map(are_equal, batch, new_batch)
51+
52+
# Check all elements are True
53+
all_equal = all(jtu.tree_leaves(equal_elements))
54+
assert all_equal
55+
56+
# Also verify the tree structures match
57+
assert jtu.tree_structure(batch) == jtu.tree_structure(new_batch)

integration_tests/tf_test.py

Lines changed: 24 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,8 @@
22
import numpy as np
33
import tensorflow_datasets as tfds
44
import tensorflow as tf
5-
5+
import jax.random as jrand
6+
import jax.tree_util as jtu
67

78
def test_jax():
89
ds = jdl.ArrayDataset(np.ones((10, 3)), np.ones((10, 3)))
@@ -21,3 +22,25 @@ def test_tf():
2122
z = x + y
2223
assert isinstance(z, np.ndarray)
2324

25+
def test_generator():
26+
ds = jdl.ArrayDataset(np.ones((10, 3)), np.ones((10, 3)))
27+
28+
g1 = jdl.Generator().manual_seed(123)
29+
g2 = jrand.PRNGKey(jdl.get_config().global_seed)
30+
31+
# Create two different dataloaders with different generators
32+
dl = jdl.DataLoader(ds, 'tensorflow', batch_size=2, generator=g1, shuffle=True)
33+
batch = next(iter(dl))
34+
35+
dl = jdl.DataLoader(ds, 'tensorflow', batch_size=2, generator=g2, shuffle=True)
36+
new_batch = next(iter(dl))
37+
# Check that batches are equal using tree_map
38+
def are_equal(a, b):
39+
return np.all(a == b)
40+
# Map the equality function over the entire pytree structure
41+
equal_elements = jtu.tree_map(are_equal, batch, new_batch)
42+
# Check all elements are True
43+
all_equal = all(jtu.tree_leaves(equal_elements))
44+
assert all_equal
45+
# Also verify the tree structures match
46+
assert jtu.tree_structure(batch) == jtu.tree_structure(new_batch)

integration_tests/torch_test.py

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,8 @@
33
import numpy as np
44
import jax.numpy as jnp
55
from torch.utils.data import TensorDataset
6+
import jax.random as jrand
7+
import jax.tree_util as jtu
68

79

810
def test_jax_ds():
@@ -22,3 +24,30 @@ def test_torch():
2224
z = x + y
2325
assert isinstance(z, np.ndarray)
2426

27+
28+
def test_generator():
29+
ds = jdl.ArrayDataset(jnp.ones((10, 3)), jnp.ones((10, 3)))
30+
31+
g1 = jdl.Generator()
32+
g2 = torch.Generator().manual_seed(jdl.get_config().global_seed)
33+
34+
# Create two different dataloaders with different generators
35+
dl = jdl.DataLoader(ds, 'pytorch', batch_size=2, generator=g1, shuffle=True)
36+
batch = next(iter(dl))
37+
38+
dl = jdl.DataLoader(ds, 'pytorch', batch_size=2, generator=g2, shuffle=True)
39+
new_batch = next(iter(dl))
40+
41+
# Check that batches are equal using tree_map
42+
def are_equal(a, b):
43+
return jnp.all(a == b)
44+
45+
# Map the equality function over the entire pytree structure
46+
equal_elements = jtu.tree_map(are_equal, batch, new_batch)
47+
48+
# Check all elements are True
49+
all_equal = all(jtu.tree_leaves(equal_elements))
50+
assert all_equal
51+
52+
# Also verify the tree structures match
53+
assert jtu.tree_structure(batch) == jtu.tree_structure(new_batch)

jax_dataloader/_modidx.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -90,6 +90,8 @@
9090
'jax_dataloader/loaders/tensorflow.py'),
9191
'jax_dataloader.loaders.tensorflow.DataLoaderTensorflow.__next__': ( 'loader.tf.html#dataloadertensorflow.__next__',
9292
'jax_dataloader/loaders/tensorflow.py'),
93+
'jax_dataloader.loaders.tensorflow.get_seed': ( 'loader.tf.html#get_seed',
94+
'jax_dataloader/loaders/tensorflow.py'),
9395
'jax_dataloader.loaders.tensorflow.to_tf_dataset': ( 'loader.tf.html#to_tf_dataset',
9496
'jax_dataloader/loaders/tensorflow.py')},
9597
'jax_dataloader.loaders.torch': { 'jax_dataloader.loaders.torch.DataLoaderPytorch': ( 'loader.torch.html#dataloaderpytorch',
@@ -116,6 +118,7 @@
116118
'jax_dataloader/tests.py'),
117119
'jax_dataloader.tests.test_shuffle_reproducible': ( 'tests.html#test_shuffle_reproducible',
118120
'jax_dataloader/tests.py')},
121+
'jax_dataloader.types': {},
119122
'jax_dataloader.utils': { 'jax_dataloader.utils.Config': ('utils.html#config', 'jax_dataloader/utils.py'),
120123
'jax_dataloader.utils.Config.default': ('utils.html#config.default', 'jax_dataloader/utils.py'),
121124
'jax_dataloader.utils.Generator': ('utils.html#generator', 'jax_dataloader/utils.py'),

jax_dataloader/core.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
from .utils import *
99
from .datasets import *
1010
from .loaders import *
11+
from .types import *
1112

1213
# %% auto 0
1314
__all__ = ['SUPPORTED_DATASETS', 'DataloaderBackends', 'get_backend_compatibilities', 'DataLoader']
@@ -96,6 +97,7 @@ def __init__(
9697
batch_size: int = 1, # How many samples per batch to load
9798
shuffle: bool = False, # If true, dataloader reshuffles every epoch
9899
drop_last: bool = False, # If true, drop the last incomplete batch
100+
generator: Optional[GeneratorType] = None, # Random seed generator
99101
**kwargs
100102
):
101103
dl_cls = _dispatch_dataloader(backend)
@@ -104,6 +106,7 @@ def __init__(
104106
batch_size=batch_size,
105107
shuffle=shuffle,
106108
drop_last=drop_last,
109+
generator=generator,
107110
**kwargs
108111
)
109112

jax_dataloader/loaders/base.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,8 @@
33
# %% ../../nbs/loader.base.ipynb 3
44
from __future__ import print_function, division, annotations
55
from ..imports import *
6+
from ..utils import Generator
7+
from ..types import GeneratorType
68

79
# %% auto 0
810
__all__ = ['BaseDataLoader']
@@ -18,6 +20,7 @@ def __init__(
1820
shuffle: bool = False, # if true, dataloader shuffles before sampling each batch
1921
num_workers: int = 0, # how many subprocesses to use for data loading.
2022
drop_last: bool = False,
23+
generator: Optional[GeneratorType] = None,
2124
**kwargs
2225
):
2326
pass

jax_dataloader/loaders/jax.py

Lines changed: 15 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,8 @@
55
from ..imports import *
66
from ..datasets import ArrayDataset, JAXDataset
77
from . import BaseDataLoader
8-
from ..utils import get_config, asnumpy
8+
from ..utils import get_config, asnumpy, Generator
9+
from ..types import GeneratorType
910
from ..tests import *
1011
import jax_dataloader as jdl
1112
from threading import Thread, Event
@@ -45,16 +46,25 @@ def __init__(
4546
batch_size: int = 1, # batch size
4647
shuffle: bool = False, # if true, dataloader shuffles before sampling each batch
4748
num_workers: int = 0, # how many subprocesses to use for data loading. Ignored.
48-
drop_last: bool = False,
49+
drop_last: bool = False, # if true, drop the last incomplete batch
50+
generator: Optional[GeneratorType] = None, # random seed generator
4951
**kwargs
5052
):
51-
self.key = jrand.PRNGKey(get_config().global_seed)
5253
self.dataset = to_jax_dataset(dataset)
5354

5455
self.indices = np.arange(len(dataset))
5556
self.batch_size = batch_size
5657
self.shuffle = shuffle
5758
self.drop_last = drop_last
59+
60+
# init rng key via generator
61+
if generator is None:
62+
# explicitly set the manual seed of the generator
63+
generator = Generator().manual_seed(get_config().global_seed)
64+
if not isinstance(generator, Generator):
65+
generator = Generator(generator=generator)
66+
67+
self.key = generator.jax_generator()
5868

5969
def __iter__(self):
6070
# shuffle (permutation) indices every epoch
@@ -71,3 +81,5 @@ def next_key(self):
7181
def __len__(self):
7282
complete_batches, remainder = divmod(len(self.indices), self.batch_size)
7383
return complete_batches if self.drop_last else complete_batches + bool(remainder)
84+
85+
# %%

jax_dataloader/loaders/tensorflow.py

Lines changed: 21 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -5,12 +5,14 @@
55
from ..imports import *
66
from . import BaseDataLoader
77
from ..datasets import Dataset, ArrayDataset, JAXDataset
8-
from ..utils import check_tf_installed, get_config
8+
from ..utils import check_tf_installed, get_config, Generator
9+
from ..types import GeneratorType
910
from ..tests import *
1011
from jax.tree_util import tree_map
12+
import warnings
1113

1214
# %% auto 0
13-
__all__ = ['to_tf_dataset', 'DataLoaderTensorflow']
15+
__all__ = ['to_tf_dataset', 'get_seed', 'DataLoaderTensorflow']
1416

1517
# %% ../../nbs/loader.tf.ipynb 4
1618
@dispatch
@@ -26,6 +28,18 @@ def to_tf_dataset(dataset: HFDataset) -> tf.data.Dataset:
2628
return dataset.to_tf_dataset()
2729

2830
# %% ../../nbs/loader.tf.ipynb 5
31+
def get_seed(generator: Optional[Generator | jax.Array | torch.Generator] = None) -> int:
32+
if generator is None:
33+
generator = Generator()
34+
35+
if not isinstance(generator, Generator):
36+
generator = Generator(generator=generator)
37+
38+
seed = generator.seed()
39+
if seed is None:
40+
warnings.warn("No random seed provided. Using default seed which may not guarantee reproducible results.")
41+
return seed
42+
2943
class DataLoaderTensorflow(BaseDataLoader):
3044
"""Tensorflow Dataloader"""
3145

@@ -36,13 +50,17 @@ def __init__(
3650
batch_size: int = 1, # Batch size
3751
shuffle: bool = False, # If true, dataloader shuffles before sampling each batch
3852
drop_last: bool = False, # Drop last batch or not
53+
generator: Optional[GeneratorType] = None, # Random seed generator
3954
**kwargs
4055
):
4156
super().__init__(dataset, batch_size, shuffle, drop_last)
4257
check_tf_installed()
58+
# get random seed from generator
59+
seed = get_seed(generator)
60+
4361
# Convert to tf dataset
4462
ds = to_tf_dataset(dataset)
45-
ds = ds.shuffle(buffer_size=len(dataset), seed=get_config().global_seed) if shuffle else ds
63+
ds = ds.shuffle(buffer_size=len(dataset), seed=seed) if shuffle else ds
4664
ds = ds.batch(batch_size, drop_remainder=drop_last)
4765
ds = ds.prefetch(tf.data.AUTOTUNE)
4866
self.dataloader = ds

jax_dataloader/loaders/torch.py

Lines changed: 11 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,8 @@
55
from ..imports import *
66
from . import BaseDataLoader
77
from ..datasets import Dataset, ArrayDataset, JAXDataset
8-
from ..utils import check_pytorch_installed, get_config
8+
from ..utils import check_pytorch_installed, get_config, Generator
9+
from ..types import GeneratorType
910
from ..tests import *
1011
from jax.tree_util import tree_map
1112
import warnings
@@ -48,6 +49,7 @@ def __init__(
4849
batch_size: int = 1, # Batch size
4950
shuffle: bool = False, # If true, dataloader shuffles before sampling each batch
5051
drop_last: bool = False, # Drop last batch or not
52+
generator: Optional[GeneratorType] = None,
5153
**kwargs
5254
):
5355
super().__init__(dataset, batch_size, shuffle, drop_last)
@@ -61,8 +63,15 @@ def __init__(
6163

6264
# convert to torch dataset
6365
dataset = to_torch_dataset(dataset)
66+
# init generator
67+
if generator is None:
68+
# explicitly set the manual seed of the generator
69+
generator = Generator().manual_seed(get_config().global_seed)
70+
if not isinstance(generator, Generator):
71+
generator = Generator(generator=generator)
72+
73+
generator = generator.torch_generator()
6474
# init batch sampler
65-
generator = torch.Generator().manual_seed(get_config().global_seed)
6675
if shuffle:
6776
sampler = RandomSampler(dataset, generator=generator)
6877
else:

0 commit comments

Comments
 (0)