diff --git a/nbs/examples/random.ipynb b/nbs/examples/random.ipynb new file mode 100644 index 0000000..d43188a --- /dev/null +++ b/nbs/examples/random.ipynb @@ -0,0 +1,280 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "id": "6c46413e", + "metadata": {}, + "source": [ + "# Controlling Randomness\n", + "\n", + "`jax-dataloader` provides flexible mechanisms to manage the pseudo-random number generation used during data loading, which is essential for reproducibility, especially when shuffling data. This tutorial outlines the two primary ways to control randomness: \n", + "\n", + "* Setting a global seed \n", + "* Assigning specific seed generators to individual dataloaders.\n" + ] + }, + { + "cell_type": "markdown", + "id": "585e8833", + "metadata": {}, + "source": [ + "## Prerequisites\n", + "\n", + "Let's set up the necessary imports and a simple dataset for our examples:" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "12f9b11a", + "metadata": {}, + "outputs": [], + "source": [ + "import jax_dataloader as jdl\n", + "import jax\n", + "import jax.numpy as jnp\n", + "import torch\n", + "\n", + "# Sample dataset\n", + "data = jnp.arange(20).reshape(10, 2)\n", + "labels = jnp.arange(10)\n", + "ds = jdl.ArrayDataset(data, labels)" + ] + }, + { + "cell_type": "markdown", + "id": "13e38853", + "metadata": {}, + "source": [ + "## Method 1: Setting the Global Seed\n", + "\n", + "The simplest way to control randomness across all `jax-dataloader` instances is by setting a global seed. This affects all dataloaders created after the seed is set, unless they have their own specific generator specified.\n", + "\n", + "Use the `jax_dataloader.manual_seed()` function:\n", + "\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "29258bdd", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "DataLoader 1 first batch: (array([[2, 3],\n", + " [4, 5]], dtype=int32), array([1, 2], dtype=int32))\n", + "DataLoader 2 first batch: (array([[2, 3],\n", + " [4, 5]], dtype=int32), array([1, 2], dtype=int32))\n" + ] + } + ], + "source": [ + "# Set the global seed for all subsequent dataloaders\n", + "jdl.manual_seed(1234)\n", + "\n", + "# Both dataloaders below will use the same underlying seed sequence\n", + "# resulting in identical shuffling order if other parameters are the same.\n", + "dl_1 = jdl.DataLoader(ds, backend='jax', batch_size=2, shuffle=True)\n", + "dl_2 = jdl.DataLoader(ds, backend='jax', batch_size=2, shuffle=True)\n", + "\n", + "# Iterate through dl_1 and dl_2 to observe the same order\n", + "print(\"DataLoader 1 first batch:\", next(iter(dl_1)))\n", + "print(\"DataLoader 2 first batch:\", next(iter(dl_2)))\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "1818de3b", + "metadata": {}, + "outputs": [], + "source": [ + "#| hide\n", + "for (x1, y1), (x2, y2) in zip(dl_1, dl_2):\n", + " assert jnp.array_equal(x1, x2)\n", + " assert jnp.array_equal(y1, y2)" + ] + }, + { + "cell_type": "markdown", + "id": "b6ff73f7", + "metadata": {}, + "source": [ + "## Method 2: Setting Per-Dataloader Seed Generators\n", + "\n", + "For more fine-grained control, assign a specific seed generator to individual DataLoader instances using the generator argument. This overrides any global seed for that specific dataloader.\n", + "\n", + "jax-dataloader supports generators from `jax-dataloader`, `jax.random.PRNGKey`, and `torch.Generator`.\n", + "\n" + ] + }, + { + "cell_type": "markdown", + "id": "a641cf61", + "metadata": {}, + "source": [ + "### 1. Using `jdl.Generator`\n", + "\n", + "Create and seed a `jdl.Generator` object and pass it to the `jdl.DataLoader`.\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "a44862d1", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "DataLoader with jdl.Generator first batch: (array([[ 6, 7],\n", + " [10, 11]], dtype=int32), array([3, 5], dtype=int32))\n" + ] + } + ], + "source": [ + "# Create a specific generator with its own seed\n", + "g1 = jdl.Generator().manual_seed(4321)\n", + "\n", + "# This dataloader will use g1, overriding any global seed\n", + "dl_jdl_gen = jdl.DataLoader(ds, backend='jax', batch_size=2, shuffle=True, generator=g1)\n", + "\n", + "print(\"DataLoader with jdl.Generator first batch:\", next(iter(dl_jdl_gen)))" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "c0e5849e", + "metadata": {}, + "outputs": [], + "source": [ + "#| hide\n", + "for (x1, y1), (x2, y2) in zip(dl_1, dl_jdl_gen):\n", + " assert not jnp.array_equal(x1, x2)\n", + " assert not jnp.array_equal(y1, y2)" + ] + }, + { + "cell_type": "markdown", + "id": "e1e101fc", + "metadata": {}, + "source": [ + "### 2. Using `jax.random.PRNGKey`\n", + "\n", + "Directly use a `jax.random.PRNGKey` as the generator." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "c9989302", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "DataLoader with JAX PRNGKey first batch: (array([[ 6, 7],\n", + " [10, 11]], dtype=int32), array([3, 5], dtype=int32))\n" + ] + } + ], + "source": [ + "# Create a JAX PRNGKey\n", + "key = jax.random.PRNGKey(4321)\n", + "\n", + "# This dataloader will use the JAX key, overriding any global seed\n", + "# jax-dataloader handles the key internally for reproducible iteration.\n", + "dl_jax_key = jdl.DataLoader(ds, backend='jax', batch_size=2, shuffle=True, generator=key)\n", + "\n", + "print(\"DataLoader with JAX PRNGKey first batch:\", next(iter(dl_jax_key)))" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "d9f11077", + "metadata": {}, + "outputs": [], + "source": [ + "# | hide\n", + "for (x1, y1), (x2, y2) in zip(dl_1, dl_jax_key):\n", + " assert not jnp.array_equal(x1, x2)\n", + " assert not jnp.array_equal(y1, y2)" + ] + }, + { + "cell_type": "markdown", + "id": "c8727c96", + "metadata": {}, + "source": [ + "### 3. Using torch.Generator\n", + "\n", + "When using the `'torch'` backend, you can use a `torch.Generator`.\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "14ce1fe2", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "DataLoader with torch.Generator first batch: [array([[ 0, 1],\n", + " [14, 15]], dtype=int32), array([0, 7], dtype=int32)]\n" + ] + } + ], + "source": [ + "# Create a PyTorch generator\n", + "g3 = torch.Generator().manual_seed(5678)\n", + "\n", + "# This dataloader uses the 'torch' backend and the PyTorch generator\n", + "dl_torch_gen = jdl.DataLoader(ds, backend='pytorch', batch_size=2, shuffle=True, generator=g3)\n", + "\n", + "print(\"DataLoader with torch.Generator first batch:\", next(iter(dl_torch_gen)))" + ] + }, + { + "cell_type": "markdown", + "id": "ca21eac0", + "metadata": {}, + "source": [ + "## Trade-offs: Global Seed vs. Per-Dataloader Generators\n", + "\n", + "\n", + "Consider these trade-offs when deciding how to manage randomness.\n", + "\n", + "### Global Seed (`jdl.manual_seed()`)\n", + "\n", + "* Simplicity: Very easy to implement with one line for basic reproducibility.\n", + "* Implicit Consistency: Automatically ensures dataloaders created subsequently (without their own generator) share the same base randomness, useful for simple synchronization.\n", + "\n", + "\n", + "### Per-Dataloader Generator (`generator=...`)\n", + "\n", + "* Fine-grained Control: Allows independent and precise randomness management for each dataloader.\n", + "* Isolation: Prevents randomness in one dataloader from affecting others.\n", + "* Integration: Works naturally with JAX keys or PyTorch generators.\n", + "* Modularity: Better suited for complex applications or libraries where components need self-contained randomness.\n" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "python3", + "language": "python", + "name": "python3" + } + }, + "nbformat": 4, + "nbformat_minor": 5 +} diff --git a/nbs/sidebar.yml b/nbs/sidebar.yml index 5dd98b3..0e29b47 100644 --- a/nbs/sidebar.yml +++ b/nbs/sidebar.yml @@ -2,6 +2,7 @@ website: sidebar: contents: - index.ipynb + - examples/random.ipynb - section: Examples contents: - examples/vit.torch.ipynb