Skip to content
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions jax_dataloader/_modidx.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,6 +90,8 @@
'jax_dataloader/loaders/tensorflow.py'),
'jax_dataloader.loaders.tensorflow.DataLoaderTensorflow.__next__': ( 'loader.tf.html#dataloadertensorflow.__next__',
'jax_dataloader/loaders/tensorflow.py'),
'jax_dataloader.loaders.tensorflow.get_seed': ( 'loader.tf.html#get_seed',
'jax_dataloader/loaders/tensorflow.py'),
'jax_dataloader.loaders.tensorflow.to_tf_dataset': ( 'loader.tf.html#to_tf_dataset',
'jax_dataloader/loaders/tensorflow.py')},
'jax_dataloader.loaders.torch': { 'jax_dataloader.loaders.torch.DataLoaderPytorch': ( 'loader.torch.html#dataloaderpytorch',
Expand Down
2 changes: 2 additions & 0 deletions jax_dataloader/loaders/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
# %% ../../nbs/loader.base.ipynb 3
from __future__ import print_function, division, annotations
from ..imports import *
from ..utils import Generator

# %% auto 0
__all__ = ['BaseDataLoader']
Expand All @@ -18,6 +19,7 @@ def __init__(
shuffle: bool = False, # if true, dataloader shuffles before sampling each batch
num_workers: int = 0, # how many subprocesses to use for data loading.
drop_last: bool = False,
generator: Optional[Generator | jax.Array | torch.Generator] = None,
**kwargs
):
pass
Expand Down
17 changes: 14 additions & 3 deletions jax_dataloader/loaders/jax.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
from ..imports import *
from ..datasets import ArrayDataset, JAXDataset
from . import BaseDataLoader
from ..utils import get_config, asnumpy
from ..utils import get_config, asnumpy, Generator
from ..tests import *
import jax_dataloader as jdl
from threading import Thread, Event
Expand Down Expand Up @@ -45,16 +45,25 @@ def __init__(
batch_size: int = 1, # batch size
shuffle: bool = False, # if true, dataloader shuffles before sampling each batch
num_workers: int = 0, # how many subprocesses to use for data loading. Ignored.
drop_last: bool = False,
drop_last: bool = False, # if true, drop the last incomplete batch
generator: Optional[Generator | jax.Array | torch.Generator] = None, # random seed generator
**kwargs
):
self.key = jrand.PRNGKey(get_config().global_seed)
self.dataset = to_jax_dataset(dataset)

self.indices = np.arange(len(dataset))
self.batch_size = batch_size
self.shuffle = shuffle
self.drop_last = drop_last

# init rng key via generator
if generator is None:
# explicitly set the manual seed of the generator
generator = Generator().manual_seed(get_config().global_seed)
if not isinstance(generator, Generator):
generator = Generator(generator=generator)

self.key = generator.jax_generator()

def __iter__(self):
# shuffle (permutation) indices every epoch
Expand All @@ -71,3 +80,5 @@ def next_key(self):
def __len__(self):
complete_batches, remainder = divmod(len(self.indices), self.batch_size)
return complete_batches if self.drop_last else complete_batches + bool(remainder)

# %%
23 changes: 20 additions & 3 deletions jax_dataloader/loaders/tensorflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,12 +5,13 @@
from ..imports import *
from . import BaseDataLoader
from ..datasets import Dataset, ArrayDataset, JAXDataset
from ..utils import check_tf_installed, get_config
from ..utils import check_tf_installed, get_config, Generator
from ..tests import *
from jax.tree_util import tree_map
import warnings

# %% auto 0
__all__ = ['to_tf_dataset', 'DataLoaderTensorflow']
__all__ = ['to_tf_dataset', 'get_seed', 'DataLoaderTensorflow']

# %% ../../nbs/loader.tf.ipynb 4
@dispatch
Expand All @@ -26,6 +27,18 @@ def to_tf_dataset(dataset: HFDataset) -> tf.data.Dataset:
return dataset.to_tf_dataset()

# %% ../../nbs/loader.tf.ipynb 5
def get_seed(generator: Optional[Generator | jax.Array | torch.Generator] = None) -> int:
if generator is None:
generator = Generator()

if not isinstance(generator, Generator):
generator = Generator(generator=generator)

seed = generator.seed()
if seed is None:
warnings.warn("No random seed provided. Using default seed which may not guarantee reproducible results.")
return seed

class DataLoaderTensorflow(BaseDataLoader):
"""Tensorflow Dataloader"""

Expand All @@ -36,13 +49,17 @@ def __init__(
batch_size: int = 1, # Batch size
shuffle: bool = False, # If true, dataloader shuffles before sampling each batch
drop_last: bool = False, # Drop last batch or not
generator: Optional[Generator | jax.Array | torch.Generator] = None, # Random seed generator
**kwargs
):
super().__init__(dataset, batch_size, shuffle, drop_last)
check_tf_installed()
# get random seed from generator
seed = get_seed(generator)

# Convert to tf dataset
ds = to_tf_dataset(dataset)
ds = ds.shuffle(buffer_size=len(dataset), seed=get_config().global_seed) if shuffle else ds
ds = ds.shuffle(buffer_size=len(dataset), seed=seed) if shuffle else ds
ds = ds.batch(batch_size, drop_remainder=drop_last)
ds = ds.prefetch(tf.data.AUTOTUNE)
self.dataloader = ds
Expand Down
12 changes: 10 additions & 2 deletions jax_dataloader/loaders/torch.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
from ..imports import *
from . import BaseDataLoader
from ..datasets import Dataset, ArrayDataset, JAXDataset
from ..utils import check_pytorch_installed, get_config
from ..utils import check_pytorch_installed, get_config, Generator
from ..tests import *
from jax.tree_util import tree_map
import warnings
Expand Down Expand Up @@ -48,6 +48,7 @@ def __init__(
batch_size: int = 1, # Batch size
shuffle: bool = False, # If true, dataloader shuffles before sampling each batch
drop_last: bool = False, # Drop last batch or not
generator: Optional[Generator | jax.Array | torch.Generator] = None,
**kwargs
):
super().__init__(dataset, batch_size, shuffle, drop_last)
Expand All @@ -61,8 +62,15 @@ def __init__(

# convert to torch dataset
dataset = to_torch_dataset(dataset)
# init generator
if generator is None:
# explicitly set the manual seed of the generator
generator = Generator().manual_seed(get_config().global_seed)
if not isinstance(generator, Generator):
generator = Generator(generator=generator)

generator = generator.torch_generator()
# init batch sampler
generator = torch.Generator().manual_seed(get_config().global_seed)
if shuffle:
sampler = RandomSampler(dataset, generator=generator)
else:
Expand Down
9 changes: 4 additions & 5 deletions jax_dataloader/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,20 +78,19 @@ def __init__(

if generator is None:
self._seed = get_config().global_seed
elif isinstance(generator, jax.Array):
self._jax_generator = generator
elif isinstance(generator, torch.Generator):
self._torch_generator = generator
elif isinstance(generator, jax.Array):
self._jax_generator = generator
else:
raise ValueError(f"generator=`{generator}` is invalid. Must be either a `jax.random.PRNGKey` or a `torch.Generator`.")

if self._seed is None and self._torch_generator is not None:
self._seed = self._torch_generator.initial_seed()

def seed(self) -> int:
def seed(self) -> Optional[int]:
"""The initial seed of the generator"""
if self._seed is None:
raise ValueError("The seed is not specified. Please set the seed using `manual_seed()` or pass a generator.")
# TODO: the seed might not be initizalized if the generator is a `jax.random.PRNGKey`
Copy link

Copilot AI Apr 24, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The word 'initizalized' should be corrected to 'initialized' for clarity.

Suggested change
# TODO: the seed might not be initizalized if the generator is a `jax.random.PRNGKey`
# TODO: the seed might not be initialized if the generator is a `jax.random.PRNGKey`

Copilot uses AI. Check for mistakes.
return self._seed

def manual_seed(self, seed: int) -> Generator:
Expand Down
4 changes: 3 additions & 1 deletion nbs/loader.base.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,8 @@
"source": [
"#| export\n",
"from __future__ import print_function, division, annotations\n",
"from jax_dataloader.imports import *"
"from jax_dataloader.imports import *\n",
"from jax_dataloader.utils import Generator"
]
},
{
Expand All @@ -58,6 +59,7 @@
" shuffle: bool = False, # if true, dataloader shuffles before sampling each batch\n",
" num_workers: int = 0, # how many subprocesses to use for data loading.\n",
" drop_last: bool = False,\n",
" generator: Optional[Generator | jax.Array | torch.Generator] = None,\n",
" **kwargs\n",
" ):\n",
" pass\n",
Expand Down
19 changes: 15 additions & 4 deletions nbs/loader.jax.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@
"from jax_dataloader.imports import *\n",
"from jax_dataloader.datasets import ArrayDataset, JAXDataset\n",
"from jax_dataloader.loaders import BaseDataLoader\n",
"from jax_dataloader.utils import get_config, asnumpy\n",
"from jax_dataloader.utils import get_config, asnumpy, Generator\n",
"from jax_dataloader.tests import *\n",
"import jax_dataloader as jdl\n",
"from threading import Thread, Event\n",
Expand Down Expand Up @@ -99,16 +99,25 @@
" batch_size: int = 1, # batch size\n",
" shuffle: bool = False, # if true, dataloader shuffles before sampling each batch\n",
" num_workers: int = 0, # how many subprocesses to use for data loading. Ignored.\n",
" drop_last: bool = False,\n",
" drop_last: bool = False, # if true, drop the last incomplete batch\n",
" generator: Optional[Generator | jax.Array | torch.Generator] = None, # random seed generator\n",
" **kwargs\n",
" ):\n",
" self.key = jrand.PRNGKey(get_config().global_seed)\n",
" self.dataset = to_jax_dataset(dataset)\n",
" \n",
" self.indices = np.arange(len(dataset))\n",
" self.batch_size = batch_size\n",
" self.shuffle = shuffle\n",
" self.drop_last = drop_last\n",
"\n",
" # init rng key via generator\n",
" if generator is None:\n",
" # explicitly set the manual seed of the generator \n",
" generator = Generator().manual_seed(get_config().global_seed)\n",
" if not isinstance(generator, Generator):\n",
" generator = Generator(generator=generator)\n",
" \n",
" self.key = generator.jax_generator()\n",
" \n",
" def __iter__(self):\n",
" # shuffle (permutation) indices every epoch \n",
Expand All @@ -124,7 +133,9 @@
" \n",
" def __len__(self):\n",
" complete_batches, remainder = divmod(len(self.indices), self.batch_size)\n",
" return complete_batches if self.drop_last else complete_batches + bool(remainder)"
" return complete_batches if self.drop_last else complete_batches + bool(remainder)\n",
"\n",
"# %%"
]
},
{
Expand Down
23 changes: 20 additions & 3 deletions nbs/loader.tf.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -34,9 +34,10 @@
"from jax_dataloader.imports import *\n",
"from jax_dataloader.loaders import BaseDataLoader\n",
"from jax_dataloader.datasets import Dataset, ArrayDataset, JAXDataset\n",
"from jax_dataloader.utils import check_tf_installed, get_config\n",
"from jax_dataloader.utils import check_tf_installed, get_config, Generator\n",
"from jax_dataloader.tests import *\n",
"from jax.tree_util import tree_map"
"from jax.tree_util import tree_map\n",
"import warnings"
]
},
{
Expand Down Expand Up @@ -73,6 +74,18 @@
"outputs": [],
"source": [
"#| export\n",
"def get_seed(generator: Optional[Generator | jax.Array | torch.Generator] = None) -> int:\n",
" if generator is None:\n",
" generator = Generator()\n",
" \n",
" if not isinstance(generator, Generator):\n",
" generator = Generator(generator=generator)\n",
" \n",
" seed = generator.seed()\n",
" if seed is None:\n",
" warnings.warn(\"No random seed provided. Using default seed which may not guarantee reproducible results.\")\n",
" return seed\n",
"\n",
"class DataLoaderTensorflow(BaseDataLoader):\n",
" \"\"\"Tensorflow Dataloader\"\"\"\n",
" \n",
Expand All @@ -83,13 +96,17 @@
" batch_size: int = 1, # Batch size\n",
" shuffle: bool = False, # If true, dataloader shuffles before sampling each batch\n",
" drop_last: bool = False, # Drop last batch or not\n",
" generator: Optional[Generator | jax.Array | torch.Generator] = None, # Random seed generator\n",
" **kwargs\n",
" ):\n",
" super().__init__(dataset, batch_size, shuffle, drop_last)\n",
" check_tf_installed()\n",
" # get random seed from generator\n",
" seed = get_seed(generator)\n",
"\n",
" # Convert to tf dataset\n",
" ds = to_tf_dataset(dataset)\n",
" ds = ds.shuffle(buffer_size=len(dataset), seed=get_config().global_seed) if shuffle else ds\n",
" ds = ds.shuffle(buffer_size=len(dataset), seed=seed) if shuffle else ds\n",
" ds = ds.batch(batch_size, drop_remainder=drop_last)\n",
" ds = ds.prefetch(tf.data.AUTOTUNE)\n",
" self.dataloader = ds\n",
Expand Down
12 changes: 10 additions & 2 deletions nbs/loader.torch.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@
"from jax_dataloader.imports import *\n",
"from jax_dataloader.loaders import BaseDataLoader\n",
"from jax_dataloader.datasets import Dataset, ArrayDataset, JAXDataset\n",
"from jax_dataloader.utils import check_pytorch_installed, get_config\n",
"from jax_dataloader.utils import check_pytorch_installed, get_config, Generator\n",
"from jax_dataloader.tests import *\n",
"from jax.tree_util import tree_map\n",
"import warnings\n"
Expand Down Expand Up @@ -113,6 +113,7 @@
" batch_size: int = 1, # Batch size\n",
" shuffle: bool = False, # If true, dataloader shuffles before sampling each batch\n",
" drop_last: bool = False, # Drop last batch or not\n",
" generator: Optional[Generator | jax.Array | torch.Generator] = None,\n",
" **kwargs\n",
" ):\n",
" super().__init__(dataset, batch_size, shuffle, drop_last)\n",
Expand All @@ -126,8 +127,15 @@
"\n",
" # convert to torch dataset\n",
" dataset = to_torch_dataset(dataset)\n",
" # init generator\n",
" if generator is None:\n",
" # explicitly set the manual seed of the generator\n",
" generator = Generator().manual_seed(get_config().global_seed)\n",
" if not isinstance(generator, Generator):\n",
" generator = Generator(generator=generator)\n",
" \n",
" generator = generator.torch_generator()\n",
" # init batch sampler\n",
" generator = torch.Generator().manual_seed(get_config().global_seed)\n",
" if shuffle: \n",
" sampler = RandomSampler(dataset, generator=generator)\n",
" else: \n",
Expand Down
11 changes: 5 additions & 6 deletions nbs/utils.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -253,20 +253,19 @@
"\n",
" if generator is None:\n",
" self._seed = get_config().global_seed\n",
" elif isinstance(generator, jax.Array):\n",
" self._jax_generator = generator\n",
" elif isinstance(generator, torch.Generator):\n",
" self._torch_generator = generator\n",
" elif isinstance(generator, jax.Array):\n",
" self._jax_generator = generator\n",
" else:\n",
" raise ValueError(f\"generator=`{generator}` is invalid. Must be either a `jax.random.PRNGKey` or a `torch.Generator`.\")\n",
" \n",
" if self._seed is None and self._torch_generator is not None:\n",
" self._seed = self._torch_generator.initial_seed()\n",
"\n",
" def seed(self) -> int:\n",
" def seed(self) -> Optional[int]:\n",
" \"\"\"The initial seed of the generator\"\"\"\n",
" if self._seed is None:\n",
" raise ValueError(\"The seed is not specified. Please set the seed using `manual_seed()` or pass a generator.\")\n",
" # TODO: the seed might not be initizalized if the generator is a `jax.random.PRNGKey`\n",
Copy link

Copilot AI Apr 24, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The word 'initizalized' appears to be misspelled. Consider updating it to 'initialized'.

Suggested change
" # TODO: the seed might not be initizalized if the generator is a `jax.random.PRNGKey`\n",
" # TODO: the seed might not be initialized if the generator is a `jax.random.PRNGKey`\n",

Copilot uses AI. Check for mistakes.
" return self._seed\n",
" \n",
" def manual_seed(self, seed: int) -> Generator:\n",
Expand Down Expand Up @@ -310,6 +309,7 @@
"# Examples of using the generator when passing a `jax.random.PRNGKey` or `torch.Generator`\n",
"g_jax = Generator(generator=jax.random.PRNGKey(123))\n",
"assert jnp.array_equal(g_jax.jax_generator(), jax.random.PRNGKey(123))\n",
"assert g_jax.seed() is None\n",
"\n",
"g_torch = Generator(generator=torch.Generator().manual_seed(123))\n",
"assert g_torch.torch_generator().initial_seed() == 123\n",
Expand All @@ -324,7 +324,6 @@
"outputs": [],
"source": [
"#| hide\n",
"test_fail(g_jax.seed, contains='The seed is not specified')\n",
"test_fail(g_jax.torch_generator, contains='Neither pytorch generator or seed is specified')"
]
},
Expand Down
Loading