Skip to content

Can not set per-dataloader seed generator #50

@chnyutao

Description

@chnyutao

Hi,

The document says that you can set per-dataloader seed generator using JAX keys. But with jax_dataloader=0.1.3 (which seems to be the latest) I can not actually do this -- please see example below.

>>> import jax.numpy as jnp
>>> import jax.random as jr
>>> import jax_dataloader as jdl
>>> xs = jnp.arange(100.)
>>> ds = jdl.ArrayDataset(xs, asnumpy=False)

>>> next(iter(jdl.DataLoader(ds, backend='jax', shuffle=True, generator=jr.key(0))))
(array([83.], dtype=float32),)

>>> next(iter(jdl.DataLoader(ds, backend='jax', shuffle=True, generator=jr.key(1))))
(array([83.], dtype=float32),)

>>> next(iter(jdl.DataLoader(ds, backend='jax', shuffle=True, generator=jr.key(2))))
(array([83.], dtype=float32),)

>>> next(iter(jdl.DataLoader(ds, backend='jax', shuffle=True, generator=jr.key(42))))
(array([83.], dtype=float32),)

>>> jdl.manual_seed(0)
>>> next(iter(jdl.DataLoader(ds, backend='jax', shuffle=True)))
(array([98.], dtype=float32),)

>>> jdl.manual_seed(1)
>>> next(iter(jdl.DataLoader(ds, backend='jax', shuffle=True)))
(array([94.], dtype=float32),)

>>> jdl.manual_seed(2)
>>> next(iter(jdl.DataLoader(ds, backend='jax', shuffle=True)))
(array([74.], dtype=float32),)

>>> jdl.manual_seed(42)
>>> next(iter(jdl.DataLoader(ds, backend='jax', shuffle=True)))
(array([83.], dtype=float32),)

Looking into the source code, it seems at least the DataLoaderJAX backend simply ignore the generator argument in **kwargs. Would be happy to submit a PR if you intend to fix this :)

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions