Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
30 changes: 30 additions & 0 deletions integration_tests/hf_test.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,9 @@
import jax_dataloader as jdl
import numpy as np
import datasets as hfds
import jax.numpy as jnp
import jax.tree_util as jtu
import jax.random as jrand


def test_hf():
Expand All @@ -11,3 +14,30 @@ def test_hf():
x, y = batch['feats'], batch['labels']
z = x + y
assert isinstance(z, np.ndarray)

def test_generator():
ds = hfds.Dataset.from_dict({"feats": np.ones((10, 3)), "labels": np.ones((10, 3))})

g1 = jdl.Generator()
g2 = jrand.PRNGKey(jdl.get_config().global_seed)

# Create two different dataloaders with different generators
dl = jdl.DataLoader(ds, 'jax', batch_size=2, generator=g1, shuffle=True)
batch = next(iter(dl))

dl = jdl.DataLoader(ds, 'jax', batch_size=2, generator=g2, shuffle=True)
new_batch = next(iter(dl))

# Check that batches are equal using tree_map
def are_equal(a, b):
return jnp.all(a == b)

# Map the equality function over the entire pytree structure
equal_elements = jtu.tree_map(are_equal, batch, new_batch)

# Check all elements are True
all_equal = all(jtu.tree_leaves(equal_elements))
assert all_equal

# Also verify the tree structures match
assert jtu.tree_structure(batch) == jtu.tree_structure(new_batch)
28 changes: 28 additions & 0 deletions integration_tests/jax_test.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
import jax_dataloader as jdl
import jax.numpy as jnp
import pytest
import jax.random as jrand
import jax.tree_util as jtu


def test_jax():
Expand All @@ -27,3 +29,29 @@ def test_tf():
batch = next(iter(dl))
for x, y in dl: z = x + y

def test_generator():
ds = jdl.ArrayDataset(jnp.ones((10, 3)), jnp.ones((10, 3)))

g1 = jdl.Generator()
g2 = jrand.PRNGKey(jdl.get_config().global_seed)

# Create two different dataloaders with different generators
dl = jdl.DataLoader(ds, 'jax', batch_size=2, generator=g1, shuffle=True)
batch = next(iter(dl))

dl = jdl.DataLoader(ds, 'jax', batch_size=2, generator=g2, shuffle=True)
new_batch = next(iter(dl))

# Check that batches are equal using tree_map
def are_equal(a, b):
return jnp.all(a == b)

# Map the equality function over the entire pytree structure
equal_elements = jtu.tree_map(are_equal, batch, new_batch)

# Check all elements are True
all_equal = all(jtu.tree_leaves(equal_elements))
assert all_equal

# Also verify the tree structures match
assert jtu.tree_structure(batch) == jtu.tree_structure(new_batch)
25 changes: 24 additions & 1 deletion integration_tests/tf_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,8 @@
import numpy as np
import tensorflow_datasets as tfds
import tensorflow as tf

import jax.random as jrand
import jax.tree_util as jtu

def test_jax():
ds = jdl.ArrayDataset(np.ones((10, 3)), np.ones((10, 3)))
Expand All @@ -21,3 +22,25 @@ def test_tf():
z = x + y
assert isinstance(z, np.ndarray)

def test_generator():
ds = jdl.ArrayDataset(np.ones((10, 3)), np.ones((10, 3)))

g1 = jdl.Generator().manual_seed(123)
g2 = jrand.PRNGKey(jdl.get_config().global_seed)

# Create two different dataloaders with different generators
dl = jdl.DataLoader(ds, 'tensorflow', batch_size=2, generator=g1, shuffle=True)
batch = next(iter(dl))

dl = jdl.DataLoader(ds, 'tensorflow', batch_size=2, generator=g2, shuffle=True)
new_batch = next(iter(dl))
# Check that batches are equal using tree_map
def are_equal(a, b):
return np.all(a == b)
# Map the equality function over the entire pytree structure
equal_elements = jtu.tree_map(are_equal, batch, new_batch)
# Check all elements are True
all_equal = all(jtu.tree_leaves(equal_elements))
assert all_equal
# Also verify the tree structures match
assert jtu.tree_structure(batch) == jtu.tree_structure(new_batch)
29 changes: 29 additions & 0 deletions integration_tests/torch_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,8 @@
import numpy as np
import jax.numpy as jnp
from torch.utils.data import TensorDataset
import jax.random as jrand
import jax.tree_util as jtu


def test_jax_ds():
Expand All @@ -22,3 +24,30 @@ def test_torch():
z = x + y
assert isinstance(z, np.ndarray)


def test_generator():
ds = jdl.ArrayDataset(jnp.ones((10, 3)), jnp.ones((10, 3)))

g1 = jdl.Generator()
g2 = torch.Generator().manual_seed(jdl.get_config().global_seed)

# Create two different dataloaders with different generators
dl = jdl.DataLoader(ds, 'pytorch', batch_size=2, generator=g1, shuffle=True)
batch = next(iter(dl))

dl = jdl.DataLoader(ds, 'pytorch', batch_size=2, generator=g2, shuffle=True)
new_batch = next(iter(dl))

# Check that batches are equal using tree_map
def are_equal(a, b):
return jnp.all(a == b)

# Map the equality function over the entire pytree structure
equal_elements = jtu.tree_map(are_equal, batch, new_batch)

# Check all elements are True
all_equal = all(jtu.tree_leaves(equal_elements))
assert all_equal

# Also verify the tree structures match
assert jtu.tree_structure(batch) == jtu.tree_structure(new_batch)
3 changes: 3 additions & 0 deletions jax_dataloader/_modidx.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,6 +90,8 @@
'jax_dataloader/loaders/tensorflow.py'),
'jax_dataloader.loaders.tensorflow.DataLoaderTensorflow.__next__': ( 'loader.tf.html#dataloadertensorflow.__next__',
'jax_dataloader/loaders/tensorflow.py'),
'jax_dataloader.loaders.tensorflow.get_seed': ( 'loader.tf.html#get_seed',
'jax_dataloader/loaders/tensorflow.py'),
'jax_dataloader.loaders.tensorflow.to_tf_dataset': ( 'loader.tf.html#to_tf_dataset',
'jax_dataloader/loaders/tensorflow.py')},
'jax_dataloader.loaders.torch': { 'jax_dataloader.loaders.torch.DataLoaderPytorch': ( 'loader.torch.html#dataloaderpytorch',
Expand All @@ -116,6 +118,7 @@
'jax_dataloader/tests.py'),
'jax_dataloader.tests.test_shuffle_reproducible': ( 'tests.html#test_shuffle_reproducible',
'jax_dataloader/tests.py')},
'jax_dataloader.types': {},
'jax_dataloader.utils': { 'jax_dataloader.utils.Config': ('utils.html#config', 'jax_dataloader/utils.py'),
'jax_dataloader.utils.Config.default': ('utils.html#config.default', 'jax_dataloader/utils.py'),
'jax_dataloader.utils.Generator': ('utils.html#generator', 'jax_dataloader/utils.py'),
Expand Down
3 changes: 3 additions & 0 deletions jax_dataloader/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
from .utils import *
from .datasets import *
from .loaders import *
from .types import *

# %% auto 0
__all__ = ['SUPPORTED_DATASETS', 'DataloaderBackends', 'get_backend_compatibilities', 'DataLoader']
Expand Down Expand Up @@ -96,6 +97,7 @@ def __init__(
batch_size: int = 1, # How many samples per batch to load
shuffle: bool = False, # If true, dataloader reshuffles every epoch
drop_last: bool = False, # If true, drop the last incomplete batch
generator: Optional[GeneratorType] = None, # Random seed generator
**kwargs
):
dl_cls = _dispatch_dataloader(backend)
Expand All @@ -104,6 +106,7 @@ def __init__(
batch_size=batch_size,
shuffle=shuffle,
drop_last=drop_last,
generator=generator,
**kwargs
)

Expand Down
3 changes: 3 additions & 0 deletions jax_dataloader/loaders/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,8 @@
# %% ../../nbs/loader.base.ipynb 3
from __future__ import print_function, division, annotations
from ..imports import *
from ..utils import Generator
from ..types import GeneratorType

# %% auto 0
__all__ = ['BaseDataLoader']
Expand All @@ -18,6 +20,7 @@ def __init__(
shuffle: bool = False, # if true, dataloader shuffles before sampling each batch
num_workers: int = 0, # how many subprocesses to use for data loading.
drop_last: bool = False,
generator: Optional[GeneratorType] = None,
**kwargs
):
pass
Expand Down
18 changes: 15 additions & 3 deletions jax_dataloader/loaders/jax.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,8 @@
from ..imports import *
from ..datasets import ArrayDataset, JAXDataset
from . import BaseDataLoader
from ..utils import get_config, asnumpy
from ..utils import get_config, asnumpy, Generator
from ..types import GeneratorType
from ..tests import *
import jax_dataloader as jdl
from threading import Thread, Event
Expand Down Expand Up @@ -45,16 +46,25 @@ def __init__(
batch_size: int = 1, # batch size
shuffle: bool = False, # if true, dataloader shuffles before sampling each batch
num_workers: int = 0, # how many subprocesses to use for data loading. Ignored.
drop_last: bool = False,
drop_last: bool = False, # if true, drop the last incomplete batch
generator: Optional[GeneratorType] = None, # random seed generator
**kwargs
):
self.key = jrand.PRNGKey(get_config().global_seed)
self.dataset = to_jax_dataset(dataset)

self.indices = np.arange(len(dataset))
self.batch_size = batch_size
self.shuffle = shuffle
self.drop_last = drop_last

# init rng key via generator
if generator is None:
# explicitly set the manual seed of the generator
generator = Generator().manual_seed(get_config().global_seed)
if not isinstance(generator, Generator):
generator = Generator(generator=generator)

self.key = generator.jax_generator()

def __iter__(self):
# shuffle (permutation) indices every epoch
Expand All @@ -71,3 +81,5 @@ def next_key(self):
def __len__(self):
complete_batches, remainder = divmod(len(self.indices), self.batch_size)
return complete_batches if self.drop_last else complete_batches + bool(remainder)

# %%
24 changes: 21 additions & 3 deletions jax_dataloader/loaders/tensorflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,12 +5,14 @@
from ..imports import *
from . import BaseDataLoader
from ..datasets import Dataset, ArrayDataset, JAXDataset
from ..utils import check_tf_installed, get_config
from ..utils import check_tf_installed, get_config, Generator
from ..types import GeneratorType
from ..tests import *
from jax.tree_util import tree_map
import warnings

# %% auto 0
__all__ = ['to_tf_dataset', 'DataLoaderTensorflow']
__all__ = ['to_tf_dataset', 'get_seed', 'DataLoaderTensorflow']

# %% ../../nbs/loader.tf.ipynb 4
@dispatch
Expand All @@ -26,6 +28,18 @@ def to_tf_dataset(dataset: HFDataset) -> tf.data.Dataset:
return dataset.to_tf_dataset()

# %% ../../nbs/loader.tf.ipynb 5
def get_seed(generator: Optional[Generator | jax.Array | torch.Generator] = None) -> int:
if generator is None:
generator = Generator()

if not isinstance(generator, Generator):
generator = Generator(generator=generator)

seed = generator.seed()
if seed is None:
warnings.warn("No random seed provided. Using default seed which may not guarantee reproducible results.")
return seed

class DataLoaderTensorflow(BaseDataLoader):
"""Tensorflow Dataloader"""

Expand All @@ -36,13 +50,17 @@ def __init__(
batch_size: int = 1, # Batch size
shuffle: bool = False, # If true, dataloader shuffles before sampling each batch
drop_last: bool = False, # Drop last batch or not
generator: Optional[GeneratorType] = None, # Random seed generator
**kwargs
):
super().__init__(dataset, batch_size, shuffle, drop_last)
check_tf_installed()
# get random seed from generator
seed = get_seed(generator)

# Convert to tf dataset
ds = to_tf_dataset(dataset)
ds = ds.shuffle(buffer_size=len(dataset), seed=get_config().global_seed) if shuffle else ds
ds = ds.shuffle(buffer_size=len(dataset), seed=seed) if shuffle else ds
ds = ds.batch(batch_size, drop_remainder=drop_last)
ds = ds.prefetch(tf.data.AUTOTUNE)
self.dataloader = ds
Expand Down
13 changes: 11 additions & 2 deletions jax_dataloader/loaders/torch.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,8 @@
from ..imports import *
from . import BaseDataLoader
from ..datasets import Dataset, ArrayDataset, JAXDataset
from ..utils import check_pytorch_installed, get_config
from ..utils import check_pytorch_installed, get_config, Generator
from ..types import GeneratorType
from ..tests import *
from jax.tree_util import tree_map
import warnings
Expand Down Expand Up @@ -48,6 +49,7 @@ def __init__(
batch_size: int = 1, # Batch size
shuffle: bool = False, # If true, dataloader shuffles before sampling each batch
drop_last: bool = False, # Drop last batch or not
generator: Optional[GeneratorType] = None,
**kwargs
):
super().__init__(dataset, batch_size, shuffle, drop_last)
Expand All @@ -61,8 +63,15 @@ def __init__(

# convert to torch dataset
dataset = to_torch_dataset(dataset)
# init generator
if generator is None:
# explicitly set the manual seed of the generator
generator = Generator().manual_seed(get_config().global_seed)
if not isinstance(generator, Generator):
generator = Generator(generator=generator)

generator = generator.torch_generator()
# init batch sampler
generator = torch.Generator().manual_seed(get_config().global_seed)
if shuffle:
sampler = RandomSampler(dataset, generator=generator)
else:
Expand Down
4 changes: 4 additions & 0 deletions jax_dataloader/types.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
from .imports import *
from .utils import Generator

GeneratorType = Union[Generator, jax.Array, 'torch.Generator']
13 changes: 7 additions & 6 deletions jax_dataloader/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,31 +67,32 @@ def check_tf_installed():

# %% ../nbs/utils.ipynb 21
class Generator:
"""A wrapper around JAX and PyTorch generators. This is used to generate random numbers in a reproducible way."""

def __init__(
self,
*,
generator: jrand.Array | torch.Generator = None,
generator: jax.Array | torch.Generator = None, # Optional generator
):
self._seed = None
self._jax_generator = None
self._torch_generator = None

if generator is None:
self._seed = get_config().global_seed
elif (torch is not None) and isinstance(generator, torch.Generator):
self._torch_generator = generator
elif isinstance(generator, jax.Array):
self._jax_generator = generator
elif isinstance(generator, torch.Generator):
self._torch_generator = generator
else:
raise ValueError(f"generator=`{generator}` is invalid. Must be either a `jax.random.PRNGKey` or a `torch.Generator`.")

if self._seed is None and self._torch_generator is not None:
self._seed = self._torch_generator.initial_seed()

def seed(self) -> int:
def seed(self) -> Optional[int]:
"""The initial seed of the generator"""
if self._seed is None:
raise ValueError("The seed is not specified. Please set the seed using `manual_seed()` or pass a generator.")
# TODO: the seed might not be initialized if the generator is a `jax.random.PRNGKey`
return self._seed

def manual_seed(self, seed: int) -> Generator:
Expand Down
Loading