Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
280 changes: 280 additions & 0 deletions nbs/examples/random.ipynb
Original file line number Diff line number Diff line change
@@ -0,0 +1,280 @@
{
"cells": [
{
"cell_type": "markdown",
"id": "6c46413e",
"metadata": {},
"source": [
"# Controlling Randomness\n",
"\n",
"`jax-dataloader` provides flexible mechanisms to manage the pseudo-random number generation used during data loading, which is essential for reproducibility, especially when shuffling data. This tutorial outlines the two primary ways to control randomness: \n",
"\n",
"* Setting a global seed \n",
"* Assigning specific seed generators to individual dataloaders.\n"
]
},
{
"cell_type": "markdown",
"id": "585e8833",
"metadata": {},
"source": [
"## Prerequisites\n",
"\n",
"Let's set up the necessary imports and a simple dataset for our examples:"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "12f9b11a",
"metadata": {},
"outputs": [],
"source": [
"import jax_dataloader as jdl\n",
"import jax\n",
"import jax.numpy as jnp\n",
"import torch\n",
"\n",
"# Sample dataset\n",
"data = jnp.arange(20).reshape(10, 2)\n",
"labels = jnp.arange(10)\n",
"ds = jdl.ArrayDataset(data, labels)"
]
},
{
"cell_type": "markdown",
"id": "13e38853",
"metadata": {},
"source": [
"## Method 1: Setting the Global Seed\n",
"\n",
"The simplest way to control randomness across all `jax-dataloader` instances is by setting a global seed. This affects all dataloaders created after the seed is set, unless they have their own specific generator specified.\n",
"\n",
"Use the `jax_dataloader.manual_seed()` function:\n",
"\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "29258bdd",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"DataLoader 1 first batch: (array([[2, 3],\n",
" [4, 5]], dtype=int32), array([1, 2], dtype=int32))\n",
"DataLoader 2 first batch: (array([[2, 3],\n",
" [4, 5]], dtype=int32), array([1, 2], dtype=int32))\n"
]
}
],
"source": [
"# Set the global seed for all subsequent dataloaders\n",
"jdl.manual_seed(1234)\n",
"\n",
"# Both dataloaders below will use the same underlying seed sequence\n",
"# resulting in identical shuffling order if other parameters are the same.\n",
"dl_1 = jdl.DataLoader(ds, backend='jax', batch_size=2, shuffle=True)\n",
"dl_2 = jdl.DataLoader(ds, backend='jax', batch_size=2, shuffle=True)\n",
"\n",
"# Iterate through dl_1 and dl_2 to observe the same order\n",
"print(\"DataLoader 1 first batch:\", next(iter(dl_1)))\n",
"print(\"DataLoader 2 first batch:\", next(iter(dl_2)))\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "1818de3b",
"metadata": {},
"outputs": [],
"source": [
"#| hide\n",
"for (x1, y1), (x2, y2) in zip(dl_1, dl_2):\n",
" assert jnp.array_equal(x1, x2)\n",
" assert jnp.array_equal(y1, y2)"
]
},
{
"cell_type": "markdown",
"id": "b6ff73f7",
"metadata": {},
"source": [
"## Method 2: Setting Per-Dataloader Seed Generators\n",
"\n",
"For more fine-grained control, assign a specific seed generator to individual DataLoader instances using the generator argument. This overrides any global seed for that specific dataloader.\n",
"\n",
"jax-dataloader supports generators from `jax-dataloader`, `jax.random.PRNGKey`, and `torch.Generator`.\n",
"\n"
]
},
{
"cell_type": "markdown",
"id": "a641cf61",
"metadata": {},
"source": [
"### 1. Using `jdl.Generator`\n",
"\n",
"Create and seed a `jdl.Generator` object and pass it to the `jdl.DataLoader`.\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "a44862d1",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"DataLoader with jdl.Generator first batch: (array([[ 6, 7],\n",
" [10, 11]], dtype=int32), array([3, 5], dtype=int32))\n"
]
}
],
"source": [
"# Create a specific generator with its own seed\n",
"g1 = jdl.Generator().manual_seed(4321)\n",
"\n",
"# This dataloader will use g1, overriding any global seed\n",
"dl_jdl_gen = jdl.DataLoader(ds, backend='jax', batch_size=2, shuffle=True, generator=g1)\n",
"\n",
"print(\"DataLoader with jdl.Generator first batch:\", next(iter(dl_jdl_gen)))"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "c0e5849e",
"metadata": {},
"outputs": [],
"source": [
"#| hide\n",
"for (x1, y1), (x2, y2) in zip(dl_1, dl_jdl_gen):\n",
" assert not jnp.array_equal(x1, x2)\n",
" assert not jnp.array_equal(y1, y2)"
]
},
{
"cell_type": "markdown",
"id": "e1e101fc",
"metadata": {},
"source": [
"### 2. Using `jax.random.PRNGKey`\n",
"\n",
"Directly use a `jax.random.PRNGKey` as the generator."
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "c9989302",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"DataLoader with JAX PRNGKey first batch: (array([[ 6, 7],\n",
" [10, 11]], dtype=int32), array([3, 5], dtype=int32))\n"
]
}
],
"source": [
"# Create a JAX PRNGKey\n",
"key = jax.random.PRNGKey(4321)\n",
"\n",
"# This dataloader will use the JAX key, overriding any global seed\n",
"# jax-dataloader handles the key internally for reproducible iteration.\n",
"dl_jax_key = jdl.DataLoader(ds, backend='jax', batch_size=2, shuffle=True, generator=key)\n",
"\n",
"print(\"DataLoader with JAX PRNGKey first batch:\", next(iter(dl_jax_key)))"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "d9f11077",
"metadata": {},
"outputs": [],
"source": [
"# | hide\n",
"for (x1, y1), (x2, y2) in zip(dl_1, dl_jax_key):\n",
" assert not jnp.array_equal(x1, x2)\n",
" assert not jnp.array_equal(y1, y2)"
]
},
{
"cell_type": "markdown",
"id": "c8727c96",
"metadata": {},
"source": [
"### 3. Using torch.Generator\n",
"\n",
"When using the `'torch'` backend, you can use a `torch.Generator`.\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "14ce1fe2",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"DataLoader with torch.Generator first batch: [array([[ 0, 1],\n",
" [14, 15]], dtype=int32), array([0, 7], dtype=int32)]\n"
]
}
],
"source": [
"# Create a PyTorch generator\n",
"g3 = torch.Generator().manual_seed(5678)\n",
"\n",
"# This dataloader uses the 'torch' backend and the PyTorch generator\n",
"dl_torch_gen = jdl.DataLoader(ds, backend='pytorch', batch_size=2, shuffle=True, generator=g3)\n",
"\n",
"print(\"DataLoader with torch.Generator first batch:\", next(iter(dl_torch_gen)))"
]
},
{
"cell_type": "markdown",
"id": "ca21eac0",
"metadata": {},
"source": [
"## Trade-offs: Global Seed vs. Per-Dataloader Generators\n",
"\n",
"\n",
"Consider these trade-offs when deciding how to manage randomness.\n",
"\n",
"### Global Seed (`jdl.manual_seed()`)\n",
"\n",
"* Simplicity: Very easy to implement with one line for basic reproducibility.\n",
"* Implicit Consistency: Automatically ensures dataloaders created subsequently (without their own generator) share the same base randomness, useful for simple synchronization.\n",
"\n",
"\n",
"### Per-Dataloader Generator (`generator=...`)\n",
"\n",
"* Fine-grained Control: Allows independent and precise randomness management for each dataloader.\n",
"* Isolation: Prevents randomness in one dataloader from affecting others.\n",
"* Integration: Works naturally with JAX keys or PyTorch generators.\n",
"* Modularity: Better suited for complex applications or libraries where components need self-contained randomness.\n"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "python3",
"language": "python",
"name": "python3"
}
},
"nbformat": 4,
"nbformat_minor": 5
}
1 change: 1 addition & 0 deletions nbs/sidebar.yml
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ website:
sidebar:
contents:
- index.ipynb
- examples/random.ipynb
- section: Examples
contents:
- examples/vit.torch.ipynb
Expand Down