Use Generator to control the randomness for each backend dataloader#45
Use Generator to control the randomness for each backend dataloader#45
Generator to control the randomness for each backend dataloader#45Conversation
|
Check out this pull request on See visual diffs & provide feedback on Jupyter Notebooks. Powered by ReviewNB |
There was a problem hiding this comment.
Pull Request Overview
This PR updates the randomness handling in each backend dataloader by introducing and standardizing the use of a Generator class. Key changes include:
- Updating the handling and initialization of random generators in notebooks and loader modules.
- Changing the return type of the seed() method to Optional[int] to account for generators that do not set a seed.
- Adding Generator import and handling in multiple files to streamline random seed control.
Reviewed Changes
Copilot reviewed 11 out of 11 changed files in this pull request and generated 2 comments.
Show a summary per file
| File | Description |
|---|---|
| nbs/utils.ipynb | Updated generator type checking and modified seed() return type. |
| nbs/loader.torch.ipynb | Adjusted generator initialization and ensured proper generator import. |
| nbs/loader.tf.ipynb | Introduced get_seed function and updated generator usage in Tf loader. |
| nbs/loader.jax.ipynb | Updated generator handling for JAX loader. |
| nbs/loader.base.ipynb | Added generator parameter import and handling. |
| jax_dataloader/utils.py | Updated seed() method return type and corrected generator logic. |
| jax_dataloader/loaders/torch.py | Modified generator initialization to rely on the Generator class. |
| jax_dataloader/loaders/tensorflow.py | Updated generator usage and added get_seed. |
| jax_dataloader/loaders/jax.py | Updated generator usage to set the JAX PRNGKey from the Generator. |
| jax_dataloader/loaders/base.py | Extended the signature to include the generator parameter. |
| jax_dataloader/_modidx.py | Documented the new get_seed function in the module index. |
nbs/utils.ipynb
Outdated
| " \"\"\"The initial seed of the generator\"\"\"\n", | ||
| " if self._seed is None:\n", | ||
| " raise ValueError(\"The seed is not specified. Please set the seed using `manual_seed()` or pass a generator.\")\n", | ||
| " # TODO: the seed might not be initizalized if the generator is a `jax.random.PRNGKey`\n", |
There was a problem hiding this comment.
The word 'initizalized' appears to be misspelled. Consider updating it to 'initialized'.
| " # TODO: the seed might not be initizalized if the generator is a `jax.random.PRNGKey`\n", | |
| " # TODO: the seed might not be initialized if the generator is a `jax.random.PRNGKey`\n", |
jax_dataloader/utils.py
Outdated
| """The initial seed of the generator""" | ||
| if self._seed is None: | ||
| raise ValueError("The seed is not specified. Please set the seed using `manual_seed()` or pass a generator.") | ||
| # TODO: the seed might not be initizalized if the generator is a `jax.random.PRNGKey` |
There was a problem hiding this comment.
The word 'initizalized' should be corrected to 'initialized' for clarity.
| # TODO: the seed might not be initizalized if the generator is a `jax.random.PRNGKey` | |
| # TODO: the seed might not be initialized if the generator is a `jax.random.PRNGKey` |
|
Maybe consider tf generator as well |
The |
This fixes #43
This fixes #44