Skip to content

Commit 4a514ab

Browse files
authored
Merge pull request #40 from BirkhoffG/generator
Implement Generator for better controlling the randomness
2 parents fedb395 + a043bdb commit 4a514ab

File tree

6 files changed

+208
-15
lines changed

6 files changed

+208
-15
lines changed

.github/workflows/deploy.yaml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@ jobs:
1111
- uses: actions/checkout@v3
1212
- uses: actions/setup-python@v4
1313
with:
14-
python-version: '3.9'
14+
python-version: '3.10'
1515
cache: "pip"
1616
cache-dependency-path: settings.ini
1717
- name: Install Dependencies

.github/workflows/nbdev.yaml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -101,7 +101,7 @@ jobs:
101101
runs-on: ubuntu-latest
102102
strategy:
103103
matrix:
104-
py: ['3.9', '3.10', '3.11', '3.12']
104+
py: ['3.10', '3.11', '3.12']
105105
steps:
106106
- name: Checkout Code
107107
uses: actions/checkout@v4
@@ -117,7 +117,7 @@ jobs:
117117
run: |
118118
pip install --upgrade pip
119119
pip install torch torchvision --index-url https://download.pytorch.org/whl/cpu
120-
pip install -e .[dev]
120+
pip install -e .[dev] -U
121121
122122
- name: Run Tests
123123
run: nbdev_test

jax_dataloader/_modidx.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -118,6 +118,16 @@
118118
'jax_dataloader/tests.py')},
119119
'jax_dataloader.utils': { 'jax_dataloader.utils.Config': ('utils.html#config', 'jax_dataloader/utils.py'),
120120
'jax_dataloader.utils.Config.default': ('utils.html#config.default', 'jax_dataloader/utils.py'),
121+
'jax_dataloader.utils.Generator': ('utils.html#generator', 'jax_dataloader/utils.py'),
122+
'jax_dataloader.utils.Generator.__init__': ( 'utils.html#generator.__init__',
123+
'jax_dataloader/utils.py'),
124+
'jax_dataloader.utils.Generator.jax_generator': ( 'utils.html#generator.jax_generator',
125+
'jax_dataloader/utils.py'),
126+
'jax_dataloader.utils.Generator.manual_seed': ( 'utils.html#generator.manual_seed',
127+
'jax_dataloader/utils.py'),
128+
'jax_dataloader.utils.Generator.seed': ('utils.html#generator.seed', 'jax_dataloader/utils.py'),
129+
'jax_dataloader.utils.Generator.torch_generator': ( 'utils.html#generator.torch_generator',
130+
'jax_dataloader/utils.py'),
121131
'jax_dataloader.utils.asnumpy': ('utils.html#asnumpy', 'jax_dataloader/utils.py'),
122132
'jax_dataloader.utils.check_hf_installed': ( 'utils.html#check_hf_installed',
123133
'jax_dataloader/utils.py'),

jax_dataloader/utils.py

Lines changed: 64 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -8,9 +8,9 @@
88

99
# %% auto 0
1010
__all__ = ['Config', 'get_config', 'manual_seed', 'check_pytorch_installed', 'has_pytorch_tensor', 'check_hf_installed',
11-
'check_tf_installed', 'asnumpy']
11+
'check_tf_installed', 'Generator', 'asnumpy']
1212

13-
# %% ../nbs/utils.ipynb 6
13+
# %% ../nbs/utils.ipynb 7
1414
@dataclass
1515
class Config:
1616
"""Global configuration for the library"""
@@ -21,27 +21,27 @@ class Config:
2121
def default(cls) -> Config:
2222
return cls(rng_reserve_size=1, global_seed=42)
2323

24-
# %% ../nbs/utils.ipynb 7
24+
# %% ../nbs/utils.ipynb 8
2525
main_config = Config.default()
2626

27-
# %% ../nbs/utils.ipynb 8
27+
# %% ../nbs/utils.ipynb 9
2828
def get_config() -> Config:
2929
return main_config
3030

31-
# %% ../nbs/utils.ipynb 9
31+
# %% ../nbs/utils.ipynb 10
3232
def manual_seed(seed: int):
3333
"""Set the seed for the library"""
3434
main_config.global_seed = seed
3535

36-
# %% ../nbs/utils.ipynb 12
36+
# %% ../nbs/utils.ipynb 13
3737
def check_pytorch_installed():
3838
if torch_data is None:
3939
raise ModuleNotFoundError("`pytorch` library needs to be installed. "
4040
"Try `pip install torch`. Please refer to pytorch documentation for details: "
4141
"https://pytorch.org/get-started/.")
4242

4343

44-
# %% ../nbs/utils.ipynb 14
44+
# %% ../nbs/utils.ipynb 15
4545
def has_pytorch_tensor(batch) -> bool:
4646
if isinstance(batch[0], torch.Tensor):
4747
return True
@@ -51,21 +51,75 @@ def has_pytorch_tensor(batch) -> bool:
5151
else:
5252
return False
5353

54-
# %% ../nbs/utils.ipynb 15
54+
# %% ../nbs/utils.ipynb 16
5555
def check_hf_installed():
5656
if hf_datasets is None:
5757
raise ModuleNotFoundError("`datasets` library needs to be installed. "
5858
"Try `pip install datasets`. Please refer to huggingface documentation for details: "
5959
"https://huggingface.co/docs/datasets/installation.html.")
6060

61-
# %% ../nbs/utils.ipynb 17
61+
# %% ../nbs/utils.ipynb 18
6262
def check_tf_installed():
6363
if tf is None:
6464
raise ModuleNotFoundError("`tensorflow` library needs to be installed. "
6565
"Try `pip install tensorflow`. Please refer to tensorflow documentation for details: "
6666
"https://www.tensorflow.org/install/pip.")
6767

68-
# %% ../nbs/utils.ipynb 20
68+
# %% ../nbs/utils.ipynb 21
69+
class Generator:
70+
def __init__(
71+
self,
72+
*,
73+
generator: jrand.Array | torch.Generator = None,
74+
):
75+
self._seed = None
76+
self._jax_generator = None
77+
self._torch_generator = None
78+
79+
if generator is None:
80+
self._seed = get_config().global_seed
81+
elif isinstance(generator, jax.Array):
82+
self._jax_generator = generator
83+
elif isinstance(generator, torch.Generator):
84+
self._torch_generator = generator
85+
else:
86+
raise ValueError(f"generator=`{generator}` is invalid. Must be either a `jax.random.PRNGKey` or a `torch.Generator`.")
87+
88+
if self._seed is None and self._torch_generator is not None:
89+
self._seed = self._torch_generator.initial_seed()
90+
91+
def seed(self) -> int:
92+
"""The initial seed of the generator"""
93+
if self._seed is None:
94+
raise ValueError("The seed is not specified. Please set the seed using `manual_seed()` or pass a generator.")
95+
return self._seed
96+
97+
def manual_seed(self, seed: int) -> Generator:
98+
"""Set the seed for the generator. This will override the initial seed and the generator."""
99+
100+
if self._jax_generator is not None:
101+
self._jax_generator = jrand.PRNGKey(seed)
102+
if self._torch_generator is not None:
103+
self._torch_generator = torch.Generator().manual_seed(seed)
104+
self._seed = seed
105+
return self
106+
107+
def jax_generator(self) -> jax.Array:
108+
"""The JAX generator"""
109+
if self._jax_generator is None:
110+
self._jax_generator = jrand.PRNGKey(self._seed)
111+
return self._jax_generator
112+
113+
def torch_generator(self) -> torch.Generator:
114+
"""The PyTorch generator"""
115+
check_pytorch_installed()
116+
if self._torch_generator is None and self._seed is not None:
117+
self._torch_generator = torch.Generator().manual_seed(self._seed)
118+
if self._torch_generator is None:
119+
raise ValueError("Neither pytorch generator or seed is specified.")
120+
return self._torch_generator
121+
122+
# %% ../nbs/utils.ipynb 26
69123
def asnumpy(x) -> np.ndarray:
70124
if isinstance(x, np.ndarray):
71125
return x

nbs/utils.ipynb

Lines changed: 129 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -53,6 +53,16 @@
5353
"import collections"
5454
]
5555
},
56+
{
57+
"cell_type": "code",
58+
"execution_count": null,
59+
"metadata": {},
60+
"outputs": [],
61+
"source": [
62+
"#| hide\n",
63+
"from fastcore.test import test_fail"
64+
]
65+
},
5666
{
5767
"cell_type": "markdown",
5868
"metadata": {},
@@ -217,6 +227,125 @@
217227
"check_tf_installed()"
218228
]
219229
},
230+
{
231+
"cell_type": "markdown",
232+
"metadata": {},
233+
"source": [
234+
"## Seed Generator"
235+
]
236+
},
237+
{
238+
"cell_type": "code",
239+
"execution_count": null,
240+
"metadata": {},
241+
"outputs": [],
242+
"source": [
243+
"#| export\n",
244+
"class Generator:\n",
245+
" def __init__(\n",
246+
" self, \n",
247+
" *, \n",
248+
" generator: jrand.Array | torch.Generator = None,\n",
249+
" ):\n",
250+
" self._seed = None\n",
251+
" self._jax_generator = None\n",
252+
" self._torch_generator = None\n",
253+
"\n",
254+
" if generator is None:\n",
255+
" self._seed = get_config().global_seed\n",
256+
" elif isinstance(generator, jax.Array):\n",
257+
" self._jax_generator = generator\n",
258+
" elif isinstance(generator, torch.Generator):\n",
259+
" self._torch_generator = generator\n",
260+
" else:\n",
261+
" raise ValueError(f\"generator=`{generator}` is invalid. Must be either a `jax.random.PRNGKey` or a `torch.Generator`.\")\n",
262+
" \n",
263+
" if self._seed is None and self._torch_generator is not None:\n",
264+
" self._seed = self._torch_generator.initial_seed()\n",
265+
"\n",
266+
" def seed(self) -> int:\n",
267+
" \"\"\"The initial seed of the generator\"\"\"\n",
268+
" if self._seed is None:\n",
269+
" raise ValueError(\"The seed is not specified. Please set the seed using `manual_seed()` or pass a generator.\")\n",
270+
" return self._seed\n",
271+
" \n",
272+
" def manual_seed(self, seed: int) -> Generator:\n",
273+
" \"\"\"Set the seed for the generator. This will override the initial seed and the generator.\"\"\"\n",
274+
" \n",
275+
" if self._jax_generator is not None:\n",
276+
" self._jax_generator = jrand.PRNGKey(seed)\n",
277+
" if self._torch_generator is not None:\n",
278+
" self._torch_generator = torch.Generator().manual_seed(seed)\n",
279+
" self._seed = seed\n",
280+
" return self\n",
281+
" \n",
282+
" def jax_generator(self) -> jax.Array:\n",
283+
" \"\"\"The JAX generator\"\"\"\n",
284+
" if self._jax_generator is None:\n",
285+
" self._jax_generator = jrand.PRNGKey(self._seed)\n",
286+
" return self._jax_generator\n",
287+
" \n",
288+
" def torch_generator(self) -> torch.Generator:\n",
289+
" \"\"\"The PyTorch generator\"\"\"\n",
290+
" check_pytorch_installed()\n",
291+
" if self._torch_generator is None and self._seed is not None:\n",
292+
" self._torch_generator = torch.Generator().manual_seed(self._seed)\n",
293+
" if self._torch_generator is None:\n",
294+
" raise ValueError(\"Neither pytorch generator or seed is specified.\")\n",
295+
" return self._torch_generator"
296+
]
297+
},
298+
{
299+
"cell_type": "code",
300+
"execution_count": null,
301+
"metadata": {},
302+
"outputs": [],
303+
"source": [
304+
"# Example of using the generator\n",
305+
"g = Generator()\n",
306+
"assert g.seed() == get_config().global_seed\n",
307+
"assert jnp.array_equal(g.jax_generator(), jax.random.PRNGKey(get_config().global_seed)) \n",
308+
"assert g.torch_generator().initial_seed() == get_config().global_seed\n",
309+
"\n",
310+
"# Examples of using the generator when passing a `jax.random.PRNGKey` or `torch.Generator`\n",
311+
"g_jax = Generator(generator=jax.random.PRNGKey(123))\n",
312+
"assert jnp.array_equal(g_jax.jax_generator(), jax.random.PRNGKey(123))\n",
313+
"\n",
314+
"g_torch = Generator(generator=torch.Generator().manual_seed(123))\n",
315+
"assert g_torch.torch_generator().initial_seed() == 123\n",
316+
"assert g_torch.seed() == 123\n",
317+
"assert jnp.array_equal(g_torch.jax_generator(), jax.random.PRNGKey(123))"
318+
]
319+
},
320+
{
321+
"cell_type": "code",
322+
"execution_count": null,
323+
"metadata": {},
324+
"outputs": [],
325+
"source": [
326+
"#| hide\n",
327+
"test_fail(g_jax.seed, contains='The seed is not specified')\n",
328+
"test_fail(g_jax.torch_generator, contains='Neither pytorch generator or seed is specified')"
329+
]
330+
},
331+
{
332+
"cell_type": "code",
333+
"execution_count": null,
334+
"metadata": {},
335+
"outputs": [],
336+
"source": [
337+
"# Example of using `manual_seed` to set the seed\n",
338+
"g_jax.manual_seed(456)\n",
339+
"assert g_jax.seed() == 456\n",
340+
"assert jnp.array_equal(g_jax.jax_generator(), jax.random.PRNGKey(456))\n",
341+
"assert g_jax.torch_generator().initial_seed() == 456\n",
342+
"\n",
343+
"g_torch.manual_seed(789)\n",
344+
"assert g_torch.seed() == 789\n",
345+
"assert g_torch.torch_generator().initial_seed() == 789\n",
346+
"assert jnp.array_equal(g_torch.jax_generator(), jax.random.PRNGKey(789))"
347+
]
348+
},
220349
{
221350
"cell_type": "markdown",
222351
"metadata": {},

settings.ini

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -25,8 +25,8 @@ keywords = python jax dataloader pytorch tensorflow datasets huggingface
2525
language = English
2626
status = 3
2727
user = birkhoffg
28-
requirements = jax[cpu] plum-dispatch
29-
dev_requirements = scikit-learn pandas nbdev jupyter dm-haiku optax nbdev-mkdocs flax
28+
requirements = jax plum-dispatch
29+
dev_requirements = scikit-learn pandas nbdev jupyter dm-haiku optax flax
3030
torch_requirements = torch torchvision
3131
tensorflow_requirements = tensorflow tensorflow-datasets
3232
huggingface_requirements = datasets numpy<2.0.0

0 commit comments

Comments
 (0)