-
Notifications
You must be signed in to change notification settings - Fork 3
Closed
Description
Hi,
The document says that you can set per-dataloader seed generator using JAX keys. But with jax_dataloader=0.1.3 (which seems to be the latest) I can not actually do this -- please see example below.
>>> import jax.numpy as jnp
>>> import jax.random as jr
>>> import jax_dataloader as jdl
>>> xs = jnp.arange(100.)
>>> ds = jdl.ArrayDataset(xs, asnumpy=False)
>>> next(iter(jdl.DataLoader(ds, backend='jax', shuffle=True, generator=jr.key(0))))
(array([83.], dtype=float32),)
>>> next(iter(jdl.DataLoader(ds, backend='jax', shuffle=True, generator=jr.key(1))))
(array([83.], dtype=float32),)
>>> next(iter(jdl.DataLoader(ds, backend='jax', shuffle=True, generator=jr.key(2))))
(array([83.], dtype=float32),)
>>> next(iter(jdl.DataLoader(ds, backend='jax', shuffle=True, generator=jr.key(42))))
(array([83.], dtype=float32),)
>>> jdl.manual_seed(0)
>>> next(iter(jdl.DataLoader(ds, backend='jax', shuffle=True)))
(array([98.], dtype=float32),)
>>> jdl.manual_seed(1)
>>> next(iter(jdl.DataLoader(ds, backend='jax', shuffle=True)))
(array([94.], dtype=float32),)
>>> jdl.manual_seed(2)
>>> next(iter(jdl.DataLoader(ds, backend='jax', shuffle=True)))
(array([74.], dtype=float32),)
>>> jdl.manual_seed(42)
>>> next(iter(jdl.DataLoader(ds, backend='jax', shuffle=True)))
(array([83.], dtype=float32),)Looking into the source code, it seems at least the DataLoaderJAX backend simply ignore the generator argument in **kwargs. Would be happy to submit a PR if you intend to fix this :)
Reactions are currently unavailable
Metadata
Metadata
Assignees
Labels
No labels