diff --git a/integration_tests/hf_test.py b/integration_tests/hf_test.py index 0d658d2..85f87e9 100644 --- a/integration_tests/hf_test.py +++ b/integration_tests/hf_test.py @@ -1,6 +1,9 @@ import jax_dataloader as jdl import numpy as np import datasets as hfds +import jax.numpy as jnp +import jax.tree_util as jtu +import jax.random as jrand def test_hf(): @@ -11,3 +14,30 @@ def test_hf(): x, y = batch['feats'], batch['labels'] z = x + y assert isinstance(z, np.ndarray) + +def test_generator(): + ds = hfds.Dataset.from_dict({"feats": np.ones((10, 3)), "labels": np.ones((10, 3))}) + + g1 = jdl.Generator() + g2 = jrand.PRNGKey(jdl.get_config().global_seed) + + # Create two different dataloaders with different generators + dl = jdl.DataLoader(ds, 'jax', batch_size=2, generator=g1, shuffle=True) + batch = next(iter(dl)) + + dl = jdl.DataLoader(ds, 'jax', batch_size=2, generator=g2, shuffle=True) + new_batch = next(iter(dl)) + + # Check that batches are equal using tree_map + def are_equal(a, b): + return jnp.all(a == b) + + # Map the equality function over the entire pytree structure + equal_elements = jtu.tree_map(are_equal, batch, new_batch) + + # Check all elements are True + all_equal = all(jtu.tree_leaves(equal_elements)) + assert all_equal + + # Also verify the tree structures match + assert jtu.tree_structure(batch) == jtu.tree_structure(new_batch) diff --git a/integration_tests/jax_test.py b/integration_tests/jax_test.py index 7382671..30a0b3d 100644 --- a/integration_tests/jax_test.py +++ b/integration_tests/jax_test.py @@ -1,6 +1,8 @@ import jax_dataloader as jdl import jax.numpy as jnp import pytest +import jax.random as jrand +import jax.tree_util as jtu def test_jax(): @@ -27,3 +29,29 @@ def test_tf(): batch = next(iter(dl)) for x, y in dl: z = x + y +def test_generator(): + ds = jdl.ArrayDataset(jnp.ones((10, 3)), jnp.ones((10, 3))) + + g1 = jdl.Generator() + g2 = jrand.PRNGKey(jdl.get_config().global_seed) + + # Create two different dataloaders with different generators + dl = jdl.DataLoader(ds, 'jax', batch_size=2, generator=g1, shuffle=True) + batch = next(iter(dl)) + + dl = jdl.DataLoader(ds, 'jax', batch_size=2, generator=g2, shuffle=True) + new_batch = next(iter(dl)) + + # Check that batches are equal using tree_map + def are_equal(a, b): + return jnp.all(a == b) + + # Map the equality function over the entire pytree structure + equal_elements = jtu.tree_map(are_equal, batch, new_batch) + + # Check all elements are True + all_equal = all(jtu.tree_leaves(equal_elements)) + assert all_equal + + # Also verify the tree structures match + assert jtu.tree_structure(batch) == jtu.tree_structure(new_batch) diff --git a/integration_tests/tf_test.py b/integration_tests/tf_test.py index a865aed..bda762b 100644 --- a/integration_tests/tf_test.py +++ b/integration_tests/tf_test.py @@ -2,7 +2,8 @@ import numpy as np import tensorflow_datasets as tfds import tensorflow as tf - +import jax.random as jrand +import jax.tree_util as jtu def test_jax(): ds = jdl.ArrayDataset(np.ones((10, 3)), np.ones((10, 3))) @@ -21,3 +22,25 @@ def test_tf(): z = x + y assert isinstance(z, np.ndarray) +def test_generator(): + ds = jdl.ArrayDataset(np.ones((10, 3)), np.ones((10, 3))) + + g1 = jdl.Generator().manual_seed(123) + g2 = jrand.PRNGKey(jdl.get_config().global_seed) + + # Create two different dataloaders with different generators + dl = jdl.DataLoader(ds, 'tensorflow', batch_size=2, generator=g1, shuffle=True) + batch = next(iter(dl)) + + dl = jdl.DataLoader(ds, 'tensorflow', batch_size=2, generator=g2, shuffle=True) + new_batch = next(iter(dl)) + # Check that batches are equal using tree_map + def are_equal(a, b): + return np.all(a == b) + # Map the equality function over the entire pytree structure + equal_elements = jtu.tree_map(are_equal, batch, new_batch) + # Check all elements are True + all_equal = all(jtu.tree_leaves(equal_elements)) + assert all_equal + # Also verify the tree structures match + assert jtu.tree_structure(batch) == jtu.tree_structure(new_batch) diff --git a/integration_tests/torch_test.py b/integration_tests/torch_test.py index 0d5f76a..f01ce67 100644 --- a/integration_tests/torch_test.py +++ b/integration_tests/torch_test.py @@ -3,6 +3,8 @@ import numpy as np import jax.numpy as jnp from torch.utils.data import TensorDataset +import jax.random as jrand +import jax.tree_util as jtu def test_jax_ds(): @@ -22,3 +24,30 @@ def test_torch(): z = x + y assert isinstance(z, np.ndarray) + +def test_generator(): + ds = jdl.ArrayDataset(jnp.ones((10, 3)), jnp.ones((10, 3))) + + g1 = jdl.Generator() + g2 = torch.Generator().manual_seed(jdl.get_config().global_seed) + + # Create two different dataloaders with different generators + dl = jdl.DataLoader(ds, 'pytorch', batch_size=2, generator=g1, shuffle=True) + batch = next(iter(dl)) + + dl = jdl.DataLoader(ds, 'pytorch', batch_size=2, generator=g2, shuffle=True) + new_batch = next(iter(dl)) + + # Check that batches are equal using tree_map + def are_equal(a, b): + return jnp.all(a == b) + + # Map the equality function over the entire pytree structure + equal_elements = jtu.tree_map(are_equal, batch, new_batch) + + # Check all elements are True + all_equal = all(jtu.tree_leaves(equal_elements)) + assert all_equal + + # Also verify the tree structures match + assert jtu.tree_structure(batch) == jtu.tree_structure(new_batch) diff --git a/jax_dataloader/_modidx.py b/jax_dataloader/_modidx.py index 3ee55eb..a8bac08 100644 --- a/jax_dataloader/_modidx.py +++ b/jax_dataloader/_modidx.py @@ -90,6 +90,8 @@ 'jax_dataloader/loaders/tensorflow.py'), 'jax_dataloader.loaders.tensorflow.DataLoaderTensorflow.__next__': ( 'loader.tf.html#dataloadertensorflow.__next__', 'jax_dataloader/loaders/tensorflow.py'), + 'jax_dataloader.loaders.tensorflow.get_seed': ( 'loader.tf.html#get_seed', + 'jax_dataloader/loaders/tensorflow.py'), 'jax_dataloader.loaders.tensorflow.to_tf_dataset': ( 'loader.tf.html#to_tf_dataset', 'jax_dataloader/loaders/tensorflow.py')}, 'jax_dataloader.loaders.torch': { 'jax_dataloader.loaders.torch.DataLoaderPytorch': ( 'loader.torch.html#dataloaderpytorch', @@ -116,6 +118,7 @@ 'jax_dataloader/tests.py'), 'jax_dataloader.tests.test_shuffle_reproducible': ( 'tests.html#test_shuffle_reproducible', 'jax_dataloader/tests.py')}, + 'jax_dataloader.types': {}, 'jax_dataloader.utils': { 'jax_dataloader.utils.Config': ('utils.html#config', 'jax_dataloader/utils.py'), 'jax_dataloader.utils.Config.default': ('utils.html#config.default', 'jax_dataloader/utils.py'), 'jax_dataloader.utils.Generator': ('utils.html#generator', 'jax_dataloader/utils.py'), diff --git a/jax_dataloader/core.py b/jax_dataloader/core.py index 95ffd25..daeafd1 100644 --- a/jax_dataloader/core.py +++ b/jax_dataloader/core.py @@ -8,6 +8,7 @@ from .utils import * from .datasets import * from .loaders import * +from .types import * # %% auto 0 __all__ = ['SUPPORTED_DATASETS', 'DataloaderBackends', 'get_backend_compatibilities', 'DataLoader'] @@ -96,6 +97,7 @@ def __init__( batch_size: int = 1, # How many samples per batch to load shuffle: bool = False, # If true, dataloader reshuffles every epoch drop_last: bool = False, # If true, drop the last incomplete batch + generator: Optional[GeneratorType] = None, # Random seed generator **kwargs ): dl_cls = _dispatch_dataloader(backend) @@ -104,6 +106,7 @@ def __init__( batch_size=batch_size, shuffle=shuffle, drop_last=drop_last, + generator=generator, **kwargs ) diff --git a/jax_dataloader/loaders/base.py b/jax_dataloader/loaders/base.py index e007903..1168c8e 100644 --- a/jax_dataloader/loaders/base.py +++ b/jax_dataloader/loaders/base.py @@ -3,6 +3,8 @@ # %% ../../nbs/loader.base.ipynb 3 from __future__ import print_function, division, annotations from ..imports import * +from ..utils import Generator +from ..types import GeneratorType # %% auto 0 __all__ = ['BaseDataLoader'] @@ -18,6 +20,7 @@ def __init__( shuffle: bool = False, # if true, dataloader shuffles before sampling each batch num_workers: int = 0, # how many subprocesses to use for data loading. drop_last: bool = False, + generator: Optional[GeneratorType] = None, **kwargs ): pass diff --git a/jax_dataloader/loaders/jax.py b/jax_dataloader/loaders/jax.py index f1a277f..62af712 100644 --- a/jax_dataloader/loaders/jax.py +++ b/jax_dataloader/loaders/jax.py @@ -5,7 +5,8 @@ from ..imports import * from ..datasets import ArrayDataset, JAXDataset from . import BaseDataLoader -from ..utils import get_config, asnumpy +from ..utils import get_config, asnumpy, Generator +from ..types import GeneratorType from ..tests import * import jax_dataloader as jdl from threading import Thread, Event @@ -45,16 +46,25 @@ def __init__( batch_size: int = 1, # batch size shuffle: bool = False, # if true, dataloader shuffles before sampling each batch num_workers: int = 0, # how many subprocesses to use for data loading. Ignored. - drop_last: bool = False, + drop_last: bool = False, # if true, drop the last incomplete batch + generator: Optional[GeneratorType] = None, # random seed generator **kwargs ): - self.key = jrand.PRNGKey(get_config().global_seed) self.dataset = to_jax_dataset(dataset) self.indices = np.arange(len(dataset)) self.batch_size = batch_size self.shuffle = shuffle self.drop_last = drop_last + + # init rng key via generator + if generator is None: + # explicitly set the manual seed of the generator + generator = Generator().manual_seed(get_config().global_seed) + if not isinstance(generator, Generator): + generator = Generator(generator=generator) + + self.key = generator.jax_generator() def __iter__(self): # shuffle (permutation) indices every epoch @@ -71,3 +81,5 @@ def next_key(self): def __len__(self): complete_batches, remainder = divmod(len(self.indices), self.batch_size) return complete_batches if self.drop_last else complete_batches + bool(remainder) + +# %% diff --git a/jax_dataloader/loaders/tensorflow.py b/jax_dataloader/loaders/tensorflow.py index 7d618fa..73e5bc9 100644 --- a/jax_dataloader/loaders/tensorflow.py +++ b/jax_dataloader/loaders/tensorflow.py @@ -5,12 +5,14 @@ from ..imports import * from . import BaseDataLoader from ..datasets import Dataset, ArrayDataset, JAXDataset -from ..utils import check_tf_installed, get_config +from ..utils import check_tf_installed, get_config, Generator +from ..types import GeneratorType from ..tests import * from jax.tree_util import tree_map +import warnings # %% auto 0 -__all__ = ['to_tf_dataset', 'DataLoaderTensorflow'] +__all__ = ['to_tf_dataset', 'get_seed', 'DataLoaderTensorflow'] # %% ../../nbs/loader.tf.ipynb 4 @dispatch @@ -26,6 +28,18 @@ def to_tf_dataset(dataset: HFDataset) -> tf.data.Dataset: return dataset.to_tf_dataset() # %% ../../nbs/loader.tf.ipynb 5 +def get_seed(generator: Optional[Generator | jax.Array | torch.Generator] = None) -> int: + if generator is None: + generator = Generator() + + if not isinstance(generator, Generator): + generator = Generator(generator=generator) + + seed = generator.seed() + if seed is None: + warnings.warn("No random seed provided. Using default seed which may not guarantee reproducible results.") + return seed + class DataLoaderTensorflow(BaseDataLoader): """Tensorflow Dataloader""" @@ -36,13 +50,17 @@ def __init__( batch_size: int = 1, # Batch size shuffle: bool = False, # If true, dataloader shuffles before sampling each batch drop_last: bool = False, # Drop last batch or not + generator: Optional[GeneratorType] = None, # Random seed generator **kwargs ): super().__init__(dataset, batch_size, shuffle, drop_last) check_tf_installed() + # get random seed from generator + seed = get_seed(generator) + # Convert to tf dataset ds = to_tf_dataset(dataset) - ds = ds.shuffle(buffer_size=len(dataset), seed=get_config().global_seed) if shuffle else ds + ds = ds.shuffle(buffer_size=len(dataset), seed=seed) if shuffle else ds ds = ds.batch(batch_size, drop_remainder=drop_last) ds = ds.prefetch(tf.data.AUTOTUNE) self.dataloader = ds diff --git a/jax_dataloader/loaders/torch.py b/jax_dataloader/loaders/torch.py index 92b4c20..f270c7b 100644 --- a/jax_dataloader/loaders/torch.py +++ b/jax_dataloader/loaders/torch.py @@ -5,7 +5,8 @@ from ..imports import * from . import BaseDataLoader from ..datasets import Dataset, ArrayDataset, JAXDataset -from ..utils import check_pytorch_installed, get_config +from ..utils import check_pytorch_installed, get_config, Generator +from ..types import GeneratorType from ..tests import * from jax.tree_util import tree_map import warnings @@ -48,6 +49,7 @@ def __init__( batch_size: int = 1, # Batch size shuffle: bool = False, # If true, dataloader shuffles before sampling each batch drop_last: bool = False, # Drop last batch or not + generator: Optional[GeneratorType] = None, **kwargs ): super().__init__(dataset, batch_size, shuffle, drop_last) @@ -61,8 +63,15 @@ def __init__( # convert to torch dataset dataset = to_torch_dataset(dataset) + # init generator + if generator is None: + # explicitly set the manual seed of the generator + generator = Generator().manual_seed(get_config().global_seed) + if not isinstance(generator, Generator): + generator = Generator(generator=generator) + + generator = generator.torch_generator() # init batch sampler - generator = torch.Generator().manual_seed(get_config().global_seed) if shuffle: sampler = RandomSampler(dataset, generator=generator) else: diff --git a/jax_dataloader/types.py b/jax_dataloader/types.py new file mode 100644 index 0000000..80e7186 --- /dev/null +++ b/jax_dataloader/types.py @@ -0,0 +1,4 @@ +from .imports import * +from .utils import Generator + +GeneratorType = Union[Generator, jax.Array, 'torch.Generator'] \ No newline at end of file diff --git a/jax_dataloader/utils.py b/jax_dataloader/utils.py index c620354..9261cf1 100644 --- a/jax_dataloader/utils.py +++ b/jax_dataloader/utils.py @@ -67,10 +67,12 @@ def check_tf_installed(): # %% ../nbs/utils.ipynb 21 class Generator: + """A wrapper around JAX and PyTorch generators. This is used to generate random numbers in a reproducible way.""" + def __init__( self, *, - generator: jrand.Array | torch.Generator = None, + generator: jax.Array | torch.Generator = None, # Optional generator ): self._seed = None self._jax_generator = None @@ -78,20 +80,19 @@ def __init__( if generator is None: self._seed = get_config().global_seed + elif (torch is not None) and isinstance(generator, torch.Generator): + self._torch_generator = generator elif isinstance(generator, jax.Array): self._jax_generator = generator - elif isinstance(generator, torch.Generator): - self._torch_generator = generator else: raise ValueError(f"generator=`{generator}` is invalid. Must be either a `jax.random.PRNGKey` or a `torch.Generator`.") if self._seed is None and self._torch_generator is not None: self._seed = self._torch_generator.initial_seed() - def seed(self) -> int: + def seed(self) -> Optional[int]: """The initial seed of the generator""" - if self._seed is None: - raise ValueError("The seed is not specified. Please set the seed using `manual_seed()` or pass a generator.") + # TODO: the seed might not be initialized if the generator is a `jax.random.PRNGKey` return self._seed def manual_seed(self, seed: int) -> Generator: diff --git a/nbs/core.ipynb b/nbs/core.ipynb index 67ea166..d844cf3 100644 --- a/nbs/core.ipynb +++ b/nbs/core.ipynb @@ -45,7 +45,8 @@ "from jax_dataloader.imports import *\n", "from jax_dataloader.utils import *\n", "from jax_dataloader.datasets import *\n", - "from jax_dataloader.loaders import *" + "from jax_dataloader.loaders import *\n", + "from jax_dataloader.types import *" ] }, { @@ -173,6 +174,7 @@ " batch_size: int = 1, # How many samples per batch to load\n", " shuffle: bool = False, # If true, dataloader reshuffles every epoch\n", " drop_last: bool = False, # If true, drop the last incomplete batch\n", + " generator: Optional[GeneratorType] = None, # Random seed generator\n", " **kwargs\n", " ):\n", " dl_cls = _dispatch_dataloader(backend)\n", @@ -181,6 +183,7 @@ " batch_size=batch_size, \n", " shuffle=shuffle, \n", " drop_last=drop_last,\n", + " generator=generator,\n", " **kwargs\n", " )\n", "\n", diff --git a/nbs/loader.base.ipynb b/nbs/loader.base.ipynb index ea713eb..63901cd 100644 --- a/nbs/loader.base.ipynb +++ b/nbs/loader.base.ipynb @@ -38,7 +38,9 @@ "source": [ "#| export\n", "from __future__ import print_function, division, annotations\n", - "from jax_dataloader.imports import *" + "from jax_dataloader.imports import *\n", + "from jax_dataloader.utils import Generator\n", + "from jax_dataloader.types import GeneratorType" ] }, { @@ -58,6 +60,7 @@ " shuffle: bool = False, # if true, dataloader shuffles before sampling each batch\n", " num_workers: int = 0, # how many subprocesses to use for data loading.\n", " drop_last: bool = False,\n", + " generator: Optional[GeneratorType] = None,\n", " **kwargs\n", " ):\n", " pass\n", diff --git a/nbs/loader.jax.ipynb b/nbs/loader.jax.ipynb index 155c777..5b718b4 100644 --- a/nbs/loader.jax.ipynb +++ b/nbs/loader.jax.ipynb @@ -41,7 +41,8 @@ "from jax_dataloader.imports import *\n", "from jax_dataloader.datasets import ArrayDataset, JAXDataset\n", "from jax_dataloader.loaders import BaseDataLoader\n", - "from jax_dataloader.utils import get_config, asnumpy\n", + "from jax_dataloader.utils import get_config, asnumpy, Generator\n", + "from jax_dataloader.types import GeneratorType\n", "from jax_dataloader.tests import *\n", "import jax_dataloader as jdl\n", "from threading import Thread, Event\n", @@ -99,16 +100,25 @@ " batch_size: int = 1, # batch size\n", " shuffle: bool = False, # if true, dataloader shuffles before sampling each batch\n", " num_workers: int = 0, # how many subprocesses to use for data loading. Ignored.\n", - " drop_last: bool = False,\n", + " drop_last: bool = False, # if true, drop the last incomplete batch\n", + " generator: Optional[GeneratorType] = None, # random seed generator\n", " **kwargs\n", " ):\n", - " self.key = jrand.PRNGKey(get_config().global_seed)\n", " self.dataset = to_jax_dataset(dataset)\n", " \n", " self.indices = np.arange(len(dataset))\n", " self.batch_size = batch_size\n", " self.shuffle = shuffle\n", " self.drop_last = drop_last\n", + "\n", + " # init rng key via generator\n", + " if generator is None:\n", + " # explicitly set the manual seed of the generator \n", + " generator = Generator().manual_seed(get_config().global_seed)\n", + " if not isinstance(generator, Generator):\n", + " generator = Generator(generator=generator)\n", + " \n", + " self.key = generator.jax_generator()\n", " \n", " def __iter__(self):\n", " # shuffle (permutation) indices every epoch \n", @@ -124,7 +134,9 @@ " \n", " def __len__(self):\n", " complete_batches, remainder = divmod(len(self.indices), self.batch_size)\n", - " return complete_batches if self.drop_last else complete_batches + bool(remainder)" + " return complete_batches if self.drop_last else complete_batches + bool(remainder)\n", + "\n", + "# %%" ] }, { diff --git a/nbs/loader.tf.ipynb b/nbs/loader.tf.ipynb index dd6172c..ec6db30 100644 --- a/nbs/loader.tf.ipynb +++ b/nbs/loader.tf.ipynb @@ -34,9 +34,11 @@ "from jax_dataloader.imports import *\n", "from jax_dataloader.loaders import BaseDataLoader\n", "from jax_dataloader.datasets import Dataset, ArrayDataset, JAXDataset\n", - "from jax_dataloader.utils import check_tf_installed, get_config\n", + "from jax_dataloader.utils import check_tf_installed, get_config, Generator\n", + "from jax_dataloader.types import GeneratorType\n", "from jax_dataloader.tests import *\n", - "from jax.tree_util import tree_map" + "from jax.tree_util import tree_map\n", + "import warnings" ] }, { @@ -73,6 +75,18 @@ "outputs": [], "source": [ "#| export\n", + "def get_seed(generator: Optional[Generator | jax.Array | torch.Generator] = None) -> int:\n", + " if generator is None:\n", + " generator = Generator()\n", + " \n", + " if not isinstance(generator, Generator):\n", + " generator = Generator(generator=generator)\n", + " \n", + " seed = generator.seed()\n", + " if seed is None:\n", + " warnings.warn(\"No random seed provided. Using default seed which may not guarantee reproducible results.\")\n", + " return seed\n", + "\n", "class DataLoaderTensorflow(BaseDataLoader):\n", " \"\"\"Tensorflow Dataloader\"\"\"\n", " \n", @@ -83,13 +97,17 @@ " batch_size: int = 1, # Batch size\n", " shuffle: bool = False, # If true, dataloader shuffles before sampling each batch\n", " drop_last: bool = False, # Drop last batch or not\n", + " generator: Optional[GeneratorType] = None, # Random seed generator\n", " **kwargs\n", " ):\n", " super().__init__(dataset, batch_size, shuffle, drop_last)\n", " check_tf_installed()\n", + " # get random seed from generator\n", + " seed = get_seed(generator)\n", + "\n", " # Convert to tf dataset\n", " ds = to_tf_dataset(dataset)\n", - " ds = ds.shuffle(buffer_size=len(dataset), seed=get_config().global_seed) if shuffle else ds\n", + " ds = ds.shuffle(buffer_size=len(dataset), seed=seed) if shuffle else ds\n", " ds = ds.batch(batch_size, drop_remainder=drop_last)\n", " ds = ds.prefetch(tf.data.AUTOTUNE)\n", " self.dataloader = ds\n", diff --git a/nbs/loader.torch.ipynb b/nbs/loader.torch.ipynb index 6e92de2..1b9a997 100644 --- a/nbs/loader.torch.ipynb +++ b/nbs/loader.torch.ipynb @@ -34,7 +34,8 @@ "from jax_dataloader.imports import *\n", "from jax_dataloader.loaders import BaseDataLoader\n", "from jax_dataloader.datasets import Dataset, ArrayDataset, JAXDataset\n", - "from jax_dataloader.utils import check_pytorch_installed, get_config\n", + "from jax_dataloader.utils import check_pytorch_installed, get_config, Generator\n", + "from jax_dataloader.types import GeneratorType\n", "from jax_dataloader.tests import *\n", "from jax.tree_util import tree_map\n", "import warnings\n" @@ -113,6 +114,7 @@ " batch_size: int = 1, # Batch size\n", " shuffle: bool = False, # If true, dataloader shuffles before sampling each batch\n", " drop_last: bool = False, # Drop last batch or not\n", + " generator: Optional[GeneratorType] = None,\n", " **kwargs\n", " ):\n", " super().__init__(dataset, batch_size, shuffle, drop_last)\n", @@ -126,8 +128,15 @@ "\n", " # convert to torch dataset\n", " dataset = to_torch_dataset(dataset)\n", + " # init generator\n", + " if generator is None:\n", + " # explicitly set the manual seed of the generator\n", + " generator = Generator().manual_seed(get_config().global_seed)\n", + " if not isinstance(generator, Generator):\n", + " generator = Generator(generator=generator)\n", + " \n", + " generator = generator.torch_generator()\n", " # init batch sampler\n", - " generator = torch.Generator().manual_seed(get_config().global_seed)\n", " if shuffle: \n", " sampler = RandomSampler(dataset, generator=generator)\n", " else: \n", diff --git a/nbs/utils.ipynb b/nbs/utils.ipynb index a769bac..1a3ac82 100644 --- a/nbs/utils.ipynb +++ b/nbs/utils.ipynb @@ -242,10 +242,12 @@ "source": [ "#| export\n", "class Generator:\n", + " \"\"\"A wrapper around JAX and PyTorch generators. This is used to generate random numbers in a reproducible way.\"\"\"\n", + "\n", " def __init__(\n", " self, \n", " *, \n", - " generator: jrand.Array | torch.Generator = None,\n", + " generator: jax.Array | torch.Generator = None, # Optional generator\n", " ):\n", " self._seed = None\n", " self._jax_generator = None\n", @@ -253,20 +255,19 @@ "\n", " if generator is None:\n", " self._seed = get_config().global_seed\n", + " elif (torch is not None) and isinstance(generator, torch.Generator):\n", + " self._torch_generator = generator\n", " elif isinstance(generator, jax.Array):\n", " self._jax_generator = generator\n", - " elif isinstance(generator, torch.Generator):\n", - " self._torch_generator = generator\n", " else:\n", " raise ValueError(f\"generator=`{generator}` is invalid. Must be either a `jax.random.PRNGKey` or a `torch.Generator`.\")\n", " \n", " if self._seed is None and self._torch_generator is not None:\n", " self._seed = self._torch_generator.initial_seed()\n", "\n", - " def seed(self) -> int:\n", + " def seed(self) -> Optional[int]:\n", " \"\"\"The initial seed of the generator\"\"\"\n", - " if self._seed is None:\n", - " raise ValueError(\"The seed is not specified. Please set the seed using `manual_seed()` or pass a generator.\")\n", + " # TODO: the seed might not be initialized if the generator is a `jax.random.PRNGKey`\n", " return self._seed\n", " \n", " def manual_seed(self, seed: int) -> Generator:\n", @@ -310,6 +311,7 @@ "# Examples of using the generator when passing a `jax.random.PRNGKey` or `torch.Generator`\n", "g_jax = Generator(generator=jax.random.PRNGKey(123))\n", "assert jnp.array_equal(g_jax.jax_generator(), jax.random.PRNGKey(123))\n", + "assert g_jax.seed() is None\n", "\n", "g_torch = Generator(generator=torch.Generator().manual_seed(123))\n", "assert g_torch.torch_generator().initial_seed() == 123\n", @@ -324,7 +326,6 @@ "outputs": [], "source": [ "#| hide\n", - "test_fail(g_jax.seed, contains='The seed is not specified')\n", "test_fail(g_jax.torch_generator, contains='Neither pytorch generator or seed is specified')" ] }, diff --git a/pyproject.toml b/pyproject.toml new file mode 100644 index 0000000..5b6286e --- /dev/null +++ b/pyproject.toml @@ -0,0 +1,11 @@ +[build-system] +requires = ["setuptools>=64.0"] +build-backend = "setuptools.build_meta" + +[project] +name="jax-dataloader" +requires-python=">=3.8" +dynamic = [ "keywords", "description", "version", "dependencies", "optional-dependencies", "readme", "license", "authors", "classifiers", "entry-points", "scripts", "urls"] + +[tool.uv] +cache-keys = [{ file = "pyproject.toml" }, { file = "settings.ini" }, { file = "setup.py" }]