diff --git a/docs/notebooks/host-offloading.ipynb b/docs/notebooks/host-offloading.ipynb index 756845be3f03..5620b507a281 100644 --- a/docs/notebooks/host-offloading.ipynb +++ b/docs/notebooks/host-offloading.ipynb @@ -1,888 +1,888 @@ { - "cells": [ - { - "cell_type": "markdown", - "metadata": { - "id": "bQbS50fIdHw1" - }, - "source": [ - "(host-offloading)=\n", - "# JAX Memories and Host Offloading\n", - "\n", - "\n", - "\n", - "This tutorial provides a practical introduction to host offloading techniques in JAX, focusing on:\n", - "\n", - "- Activation offloading\n", - "- Parameter offloading\n", - "- Optimizer state offloading\n", - "\n", - "By applying offloading strategies, developers can better manage memory resources and reduce memory pressure on devices. To implement these strategies effectively, understanding JAX's core mechanisms for data placement and movement is essential.\n", - "\n", - "## Building Blocks for Offloading\n", - "\n", - "JAX provides several key components for controlling where and how data are stored and moved between the host and the device memory. The following sections explore:\n", - "\n", - "- How to specify data distribution with sharding\n", - "- How to control memory placement between host and device\n", - "- How to manage data movement in jitted functions\n", - "\n", - "### NamedSharding and Memory Kinds\n", - "\n", - "{class}`~jax.sharding.NamedSharding` defines how data are distributed across devices. It includes:\n", - "\n", - "- Basic data distribution configuration\n", - "- `memory_kind` parameter for specifying memory type (`device` or `pinned_host`)\n", - "- By default, `memory_kind` is set to `device` memory\n", - "- `with_memory_kind` method for creating new sharding with modified memory type" - ] - }, - { - "cell_type": "code", - "execution_count": 1, - "metadata": { - "colab": { - "base_uri": "https://localhost:8080/" + "cells": [ + { + "cell_type": "markdown", + "metadata": { + "id": "bQbS50fIdHw1" + }, + "source": [ + "(host-offloading)=\n", + "# JAX Memories and Host Offloading\n", + "\n", + "\n", + "\n", + "This tutorial provides a practical introduction to host offloading techniques in JAX, focusing on:\n", + "\n", + "- Activation offloading\n", + "- Parameter offloading\n", + "- Optimizer state offloading\n", + "\n", + "By applying offloading strategies, developers can better manage memory resources and reduce memory pressure on devices. To implement these strategies effectively, understanding JAX's core mechanisms for data placement and movement is essential.\n", + "\n", + "## Building Blocks for Offloading\n", + "\n", + "JAX provides several key components for controlling where and how data are stored and moved between the host and the device memory. The following sections explore:\n", + "\n", + "- How to specify data distribution with sharding\n", + "- How to control memory placement between host and device\n", + "- How to manage data movement in jitted functions\n", + "\n", + "### NamedSharding and Memory Kinds\n", + "\n", + "{class}`~jax.sharding.NamedSharding` defines how data are distributed across devices. It includes:\n", + "\n", + "- Basic data distribution configuration\n", + "- `memory_kind` parameter for specifying memory type (`device` or `pinned_host`)\n", + "- By default, `memory_kind` is set to `device` memory\n", + "- `with_memory_kind` method for creating new sharding with modified memory type" + ] }, - "id": "f-6sxUlqrlBn", - "outputId": "691a3df2-8341-44a9-a4a0-5521c2d891e3" - }, - "outputs": [ { - "name": "stdout", - "output_type": "stream", - "text": [ - "NamedSharding(mesh=Mesh('x': 1, 'y': 1), spec=PartitionSpec('x', 'y'), memory_kind=device)\n", - "NamedSharding(mesh=Mesh('x': 1, 'y': 1), spec=PartitionSpec('x', 'y'), memory_kind=pinned_host)\n" - ] - } - ], - "source": [ - "import jax\n", - "import jax.numpy as jnp\n", - "from jax.sharding import Mesh, NamedSharding, PartitionSpec as P\n", - "import numpy as np\n", - "\n", - "# Create mesh\n", - "# 1x1 mesh represents a single device with two named dimensions (x and y)\n", - "mesh = Mesh(np.array(jax.devices()[0]).reshape(1, 1), ('x', 'y'))\n", - "\n", - "# Device sharding - partitions data along x and y dimensions\n", - "s_dev = NamedSharding(mesh, P('x', 'y'), memory_kind=\"device\")\n", - "\n", - "# Host sharding - same partitioning but in pinned host memory\n", - "s_host = s_dev.with_memory_kind('pinned_host')\n", - "\n", - "print(s_dev) # Shows device memory sharding\n", - "print(s_host) # Shows pinned host memory sharding" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "R_pB9465VoMP" - }, - "source": [ - "### Data Placement with device_put\n", - "\n", - "{func}`jax.device_put` is a function that explicitly transfers arrays to a specified memory location according to a sharding specification." - ] - }, - { - "cell_type": "code", - "execution_count": 2, - "metadata": { - "colab": { - "base_uri": "https://localhost:8080/" + "cell_type": "code", + "execution_count": 1, + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/" + }, + "id": "f-6sxUlqrlBn", + "outputId": "691a3df2-8341-44a9-a4a0-5521c2d891e3" + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "NamedSharding(mesh=Mesh('x': 1, 'y': 1), spec=PartitionSpec('x', 'y'), memory_kind=device)\n", + "NamedSharding(mesh=Mesh('x': 1, 'y': 1), spec=PartitionSpec('x', 'y'), memory_kind=pinned_host)\n" + ] + } + ], + "source": [ + "import jax\n", + "import jax.numpy as jnp\n", + "from jax.sharding import Mesh, NamedSharding, PartitionSpec as P\n", + "import numpy as np\n", + "\n", + "# Create mesh\n", + "# 1x1 mesh represents a single device with two named dimensions (x and y)\n", + "mesh = Mesh(np.array(jax.devices()[0]).reshape(1, 1), ('x', 'y'))\n", + "\n", + "# Device sharding - partitions data along x and y dimensions\n", + "s_dev = NamedSharding(mesh, P('x', 'y'), memory_kind=\"device\")\n", + "\n", + "# Host sharding - same partitioning but in pinned host memory\n", + "s_host = s_dev.with_memory_kind('pinned_host')\n", + "\n", + "print(s_dev) # Shows device memory sharding\n", + "print(s_host) # Shows pinned host memory sharding" + ] }, - "id": "OJFnf7FGp6Lj", - "outputId": "c762e1df-2453-4ed9-9d53-0defb6a05ce2" - }, - "outputs": [ { - "name": "stdout", - "output_type": "stream", - "text": [ - "pinned_host\n", - "device\n" - ] - } - ], - "source": [ - "# Create a 2x4 array\n", - "arr = jnp.arange(8.0).reshape(2, 4)\n", - "\n", - "# Move arrays to different memory locations based on sharding objects\n", - "arr_host = jax.device_put(arr, s_host) # Places in pinned host memory\n", - "arr_dev = jax.device_put(arr, s_dev) # Places in device memory\n", - "\n", - "# Verify memory locations\n", - "print(arr_host.sharding.memory_kind) # Output: pinned_host\n", - "print(arr_dev.sharding.memory_kind) # Output: device" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "HHXvBpQKTMCR" - }, - "source": [ - "### Output Sharding Controls\n", - "\n", - "Shardings determine how data is split across devices. JAX provides `out_shardings` to control how output arrays are partitioned when leaving a jitted function.\n", - "\n", - "Key Features:\n", - " - Can differ from input sharding\n", - " - Allows different memory kinds for outputs\n", - "\n", - "Examples:\n", - "\n", - "#### Device Output Sharding" - ] - }, - { - "cell_type": "code", - "execution_count": 3, - "metadata": { - "colab": { - "base_uri": "https://localhost:8080/" + "cell_type": "markdown", + "metadata": { + "id": "R_pB9465VoMP" + }, + "source": [ + "### Data Placement with device_put\n", + "\n", + "{func}`jax.device_put` is a function that explicitly transfers arrays to a specified memory location according to a sharding specification." + ] }, - "id": "ZXNj9NUeaIdX", - "outputId": "399321ef-082a-4a77-c33a-9de3421f429b" - }, - "outputs": [ { - "name": "stdout", - "output_type": "stream", - "text": [ - "Result value of H2D: \n", - " [[0. 1. 2. 3.]\n", - " [4. 5. 6. 7.]]\n" - ] - } - ], - "source": [ - "f = jax.jit(lambda x:x, out_shardings=s_dev)\n", - "out_dev = f(arr_host)\n", - "print(\"Result value of H2D: \\n\", out_dev)" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "iYXC5ix384XP" - }, - "source": [ - "Moving data from host to device memory when needed for computation is the essence of host offloading. Use {func}`jax.device_put` to perform this transfer in this example to optimize performance." - ] - }, - { - "cell_type": "code", - "execution_count": 7, - "metadata": { - "colab": { - "base_uri": "https://localhost:8080/" + "cell_type": "code", + "execution_count": 2, + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/" + }, + "id": "OJFnf7FGp6Lj", + "outputId": "c762e1df-2453-4ed9-9d53-0defb6a05ce2" + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "pinned_host\n", + "device\n" + ] + } + ], + "source": [ + "# Create a 2x4 array\n", + "arr = jnp.arange(8.0).reshape(2, 4)\n", + "\n", + "# Move arrays to different memory locations based on sharding objects\n", + "arr_host = jax.device_put(arr, s_host) # Places in pinned host memory\n", + "arr_dev = jax.device_put(arr, s_dev) # Places in device memory\n", + "\n", + "# Verify memory locations\n", + "print(arr_host.sharding.memory_kind) # Output: pinned_host\n", + "print(arr_dev.sharding.memory_kind) # Output: device" + ] }, - "id": "cmM6tJTS84XQ", - "outputId": "40c353a1-fb55-44bc-bac9-dffc09852f49" - }, - "outputs": [ { - "name": "stdout", - "output_type": "stream", - "text": [ - "Result value of H2D and add 1 in device memory: \n", - " [[1. 2. 3. 4.]\n", - " [5. 6. 7. 8.]]\n" - ] - } - ], - "source": [ - "# Instead of the lambda function, add_func can be defined explicitly\n", - "# move data to device before computation\n", - "def add_func(x): # Move data to device and add one\n", - " x = jax.device_put(x, s_dev)\n", - " return x + 1\n", - "\n", - "f = jax.jit(add_func, out_shardings=s_dev)\n", - "out_dev = f(arr_host)\n", - "print(\"Result value of H2D and add 1 in device memory: \\n\", out_dev)" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "EbE-eBrJTBuS" - }, - "source": [ - "#### Host Output Sharding" - ] - }, - { - "cell_type": "code", - "execution_count": 4, - "metadata": { - "colab": { - "base_uri": "https://localhost:8080/" + "cell_type": "markdown", + "metadata": { + "id": "HHXvBpQKTMCR" + }, + "source": [ + "### Output Sharding Controls\n", + "\n", + "Shardings determine how data is split across devices. JAX provides `out_shardings` to control how output arrays are partitioned when leaving a jitted function.\n", + "\n", + "Key Features:\n", + " - Can differ from input sharding\n", + " - Allows different memory kinds for outputs\n", + "\n", + "Examples:\n", + "\n", + "#### Device Output Sharding" + ] }, - "id": "FjZzkxI8ky4r", - "outputId": "2a1b6e7a-1c29-4347-c020-7b47c27a5cc3" - }, - "outputs": [ { - "name": "stdout", - "output_type": "stream", - "text": [ - "Result value of D2H: \n", - " [[0. 1. 2. 3.]\n", - " [4. 5. 6. 7.]]\n" - ] - } - ], - "source": [ - "f = jax.jit(lambda x: x, out_shardings=s_host)\n", - "out_host = f(arr_dev) # Input arrays in the device memory while output arrays in the host memory\n", - "print(\"Result value of D2H: \\n\", out_host)" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "UhLVvRO2p6Lj" - }, - "source": [ - "## Activation Offloading\n", - "\n", - "Before diving into activation offloading, let's first take a look at the baseline code.\n", - "\n", - "This code implements a simple neural network with 10 layers, each consisting of two linear transformations. The code demonstrates basic memory usage patterns and provides a foundation for comparing offloading optimization techniques.\n", - "\n", - "Key components:\n", - "- Each layer consists of two sequential linear operations:\n", - " 1. First multiplication: `x @ w1`\n", - " 2. Second multiplication: `y @ w2`\n", - "- 10-layer network using JAX's scan operation\n", - "- Memory usage analysis\n", - "- Gradient computation with JIT compilation\n", - "\n", - "To analyze memory usage in JAX, the {func}`jax.stages.Compiled.memory_analysis` method can be used on a compiled function. This provides detailed statistics about memory consumption during computation. The key metrics include temporary memory size, argument size, output size, and alias size. To calculate the total memory usage, sum the temporary, argument, and output sizes, then subtract the alias size to avoid double-counting the same memory multiple times. This provides a summarized view of how the device memory is utilized across different aspects of the computation." - ] - }, - { - "cell_type": "code", - "execution_count": 6, - "metadata": { - "colab": { - "base_uri": "https://localhost:8080/" + "cell_type": "code", + "execution_count": 3, + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/" + }, + "id": "ZXNj9NUeaIdX", + "outputId": "399321ef-082a-4a77-c33a-9de3421f429b" + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Result value of H2D: \n", + " [[0. 1. 2. 3.]\n", + " [4. 5. 6. 7.]]\n" + ] + } + ], + "source": [ + "f = jax.jit(lambda x:x, out_shardings=s_dev)\n", + "out_dev = f(arr_host)\n", + "print(\"Result value of H2D: \\n\", out_dev)" + ] }, - "id": "UEt0dtxukkaz", - "outputId": "22bb32b7-8491-4100-f212-e56c50f44cfa" - }, - "outputs": [ { - "name": "stdout", - "output_type": "stream", - "text": [ - "Temp size: 17.25 MB\n", - "Argument size: 20.25 MB\n", - "Total size: 57.50 MB\n", - "Sample of results: [3.8312336e-07 3.8312336e-07 3.8312336e-07 3.8312336e-07 3.8312336e-07]\n" - ] - } - ], - "source": [ - "# Initialize input and weights with small values (0.0001)\n", - "input = jnp.ones((256, 256), dtype=jnp.float32) * 0.001 # Input matrix: 256 x 256\n", - "w1 = jnp.ones((10, 256, 1024), dtype=jnp.float32) * 0.001 # 10 layers of 256 x 1024 matrices\n", - "w2 = jnp.ones((10, 1024, 256), dtype=jnp.float32) * 0.001 # 10 layers of 1024 x 256 matrices\n", - "\n", - "def two_layers(x, w):\n", - " # Simple two-layer linear transformation\n", - " w1, w2 = w\n", - " y = x @ w1\n", - " return y @ w2, None\n", - "\n", - "def scanned(w, x):\n", - " # Applies the layer function 10 times using JAX's scan operation\n", - " # Input: w (tuple of weight matrices), x (input matrix)\n", - " # Output: sum of the final layer's output\n", - " result = jax.lax.scan(two_layers, x, w)[0]\n", - " return jnp.sum(result)\n", - "\n", - "# Compile and compute gradients of the scanned function\n", - "f = jax.jit(jax.grad(scanned)) # Apply JIT compilation to gradient computation\n", - "\n", - "# Analyze memory usage\n", - "compiled_step = f.lower((w1, w2), input).compile()\n", - "compiled_stats = compiled_step.memory_analysis()\n", - "\n", - "if compiled_stats is not None:\n", - " # Calculate total memory usage including temporary storage, arguments, and outputs\n", - " # Subtract alias size to avoid double-counting memory shared between different components\n", - " total = compiled_stats.temp_size_in_bytes + compiled_stats.argument_size_in_bytes \\\n", - " + compiled_stats.output_size_in_bytes - compiled_stats.alias_size_in_bytes\n", - " print(f\"Temp size: {compiled_stats.temp_size_in_bytes / (1024**2):.2f} MB\")\n", - " print(f\"Argument size: {compiled_stats.argument_size_in_bytes / (1024**2):.2f} MB\")\n", - " print(f\"Total size: {total/(1024**2):.2f} MB\")\n", - "\n", - "# Execute the function and print sample results\n", - "result = f((w1, w2), input) # Execute the function with weights and input\n", - "print(\"Sample of results: \", result[0][0, 0, :5])" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "DnFyRt2nkkaz" - }, - "source": [ - "The detailed coverage of activation offloading can be found in the {ref}`gradient-checkpointing` tutorial. Activation offloading helps manage memory by moving intermediate activations to host memory after the forward pass, and bringing them back to device memory during the backward pass when needed for gradient computation.\n", - "\n", - "To implement activation offloading effectively, it is important to understand checkpoint names and policies. Here's how they work in a simple example:\n", - "\n", - "### Checkpoint Names\n", - "\n", - "The {func}`checkpoint_name` function allows labeling activations for memory management during computation. Here's a simple example that a checkpoint name `x` is specified." - ] - }, - { - "cell_type": "code", - "execution_count": 5, - "metadata": { - "id": "sLO9ceS6p6Lj" - }, - "outputs": [], - "source": [ - "from jax.ad_checkpoint import checkpoint_name\n", - "\n", - "def layer_name(x, w):\n", - " w1, w2 = w\n", - " x = checkpoint_name(x, \"x\")\n", - " y = x @ w1\n", - " return y @ w2, None" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "-_T92oCOp6Lk" - }, - "source": [ - "The checkpoint name helps the system decide whether to:\n", - "* Keep the activation in device memory or\n", - "* Offload it to host memory during computation\n", - "\n", - "This pattern is common in neural networks, where multiple transformations are applied sequentially to input data.\n", - "\n", - "### Checkpoint Policies\n", - "\n", - "This checkpoint policy implements a memory management strategy that optimizes memory usage during computation. It manages memory by handling intermediate values through three strategies:\n", - "1. Recomputing during backward pass (default behavior)\n", - "2. Storing on device\n", - "3. Offloading to host memory after forward pass and loading back during backward pass" - ] - }, - { - "cell_type": "code", - "execution_count": 6, - "metadata": { - "id": "W8Usw_wOp6Lk" - }, - "outputs": [], - "source": [ - "from jax import checkpoint_policies as cp\n", - "\n", - "policy = cp.save_and_offload_only_these_names(\n", - " names_which_can_be_saved=[], # No values stored on device\n", - " names_which_can_be_offloaded=[\"x\"], # Offload activations labeled \"x\"\n", - " offload_src=\"device\", # Move from device memory\n", - " offload_dst=\"pinned_host\" # To pinned host memory\n", - ")" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "iuDRCXu7ky4r" - }, - "source": [ - "{func}`jax.lax.scan` is commonly used in JAX for handling sequential operations (like RNNs or transformers). It can be integrated with JAX's rematerialization to process sequential data.\n", - "\n", - "Key components:\n", - "* {func}`jax.remat` creates a rematerialized version of the layer function using {func}`jax.remat` and applies the checkpoint policy to the layer function\n", - "* `prevent_cse=False` enables XLA's common subexpression elimination for better performance\n", - "* {func}`jax.lax.scan` iterates the rematerialized layer along an axis" - ] - }, - { - "cell_type": "code", - "execution_count": 7, - "metadata": { - "colab": { - "base_uri": "https://localhost:8080/" + "cell_type": "markdown", + "metadata": { + "id": "iYXC5ix384XP" + }, + "source": [ + "Moving data from host to device memory when needed for computation is the essence of host offloading. Use {func}`jax.device_put` to perform this transfer in this example to optimize performance." + ] }, - "id": "xCrxjTx_p6Lk", - "outputId": "13d46584-9b25-4622-b3c3-f50c1dac02c2" - }, - "outputs": [ { - "name": "stdout", - "output_type": "stream", - "text": [ - "Temp size: 6.50 MB\n", - "Argument size: 20.25 MB\n", - "Total size: 46.75 MB\n", - "Results match within tolerance: True\n", - "Sample of results: [3.8312336e-07 3.8312336e-07 3.8312336e-07 3.8312336e-07 3.8312336e-07]\n" - ] - } - ], - "source": [ - "def scanned(w, x):\n", - " remat_layer = jax.remat(layer_name,\n", - " policy=policy, # Use our offloading policy\n", - " prevent_cse=False) # Allow CSE optimizations\n", - " result = jax.lax.scan(remat_layer, x, w)[0]\n", - " return jnp.sum(result)\n", - "\n", - "# Initialize input and weights with small values (0.0001)\n", - "input = jnp.ones((256, 256), dtype=jnp.float32) * 0.001 # Input matrix: 256 x 256\n", - "w1 = jnp.ones((10, 256, 1024), dtype=jnp.float32) * 0.001 # 10 layers of 256 x 1024 matrices\n", - "w2 = jnp.ones((10, 1024, 256), dtype=jnp.float32) * 0.001 # 10 layers of 1024 x 256 matrices\n", - "\n", - "# Compile and compute gradients of the scanned function\n", - "f = jax.jit(jax.grad(scanned)) # Apply JIT compilation to gradient computation\n", - "\n", - "# Analyze memory usage\n", - "compiled_step = f.lower((w1, w2), input).compile()\n", - "compiled_stats = compiled_step.memory_analysis()\n", - "\n", - "if compiled_stats is not None:\n", - " total = compiled_stats.temp_size_in_bytes + compiled_stats.argument_size_in_bytes \\\n", - " + compiled_stats.output_size_in_bytes - compiled_stats.alias_size_in_bytes\n", - " print(f\"Temp size: {compiled_stats.temp_size_in_bytes / (1024**2):.2f} MB\")\n", - " print(f\"Argument size: {compiled_stats.argument_size_in_bytes / (1024**2):.2f} MB\")\n", - " print(f\"Total size: {total/(1024**2):.2f} MB\")\n", - "\n", - "result_activation = f((w1, w2), input) # Execute the function with weights and input\n", - "# Verify numerical correctness\n", - "are_close = jnp.allclose(\n", - " result_activation[0], # Result from activation offloading only\n", - " result[0], # Result from both activation and parameter offloading\n", - " rtol=1e-5,\n", - " atol=1e-5\n", - ")\n", - "print(f\"Results match within tolerance: {are_close}\")\n", - "print(\"Sample of results: \", result_activation[0][0, 0, :5])" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "0tx7aara42pY" - }, - "source": [ - "Activation offloading reduces temporary memory usage from 17.25 MB to 6.5 MB while input and output argument sizes remain the same. Totally 10.75 MB is saved. It is achieved by offloading activation `x` to host memory after the forward pass and loading it back to device memory before the backward pass.\n", - "\n", - "### Summary of Activation Offloading\n", - "\n", - "Activation offloading provides a powerful way to manage memory in large computations by:\n", - "\n", - "* Using checkpoint names to mark specific activations\n", - "* Applying policies to control where and how activations are stored\n", - "* Supporting common JAX patterns like scan operations\n", - "* Moving selected activations to host memory when device memory is under budget\n", - "\n", - "This approach is particularly useful when working with large models that would otherwise exceed device memory capacity.\n", - "\n", - "## Parameter Offloading\n", - "\n", - "Model parameters (also known as weights) can be offloaded to the host memory to optimize device memory usage during initialization. This is achieved by using {func}`jax.jit` with a sharding strategy that specifies host memory kind.\n", - "\n", - "While parameter offloading and activation offloading are distinct memory optimization techniques, the following example demonstrates parameter offloading built upon the activation offloading implementation shown earlier.\n", - "\n", - "### Parameter Placement for Computation\n", - "\n", - "Different from the earlier `layer` function, {func}`jax.device_put` is applied to move parameter `w1` and `w2` to the device before the matrix multiplications. This ensures the parameters are available on the device for both forward and backward passes.\n", - "\n", - "Note that the activation offloading implementation remains unchanged, using the same:\n", - "* Checkpoint name `\"x\"`\n", - "* Checkpoint policy\n", - "* `scanned` function combining {func}`jax.remat` and {func}`jax.lax.scan`\n", - "\n", - "### Parameter Initialization with Host Offloading\n", - "\n", - "During the initialization, parameter `w1` and `w2` are placed on host memory before being passed to the {func}`jax.jit` function `f`, while keeping the `input` variable on the device." - ] - }, - { - "cell_type": "code", - "execution_count": 8, - "metadata": { - "colab": { - "base_uri": "https://localhost:8080/" + "cell_type": "code", + "execution_count": 7, + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/" + }, + "id": "cmM6tJTS84XQ", + "outputId": "40c353a1-fb55-44bc-bac9-dffc09852f49" + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Result value of H2D and add 1 in device memory: \n", + " [[1. 2. 3. 4.]\n", + " [5. 6. 7. 8.]]\n" + ] + } + ], + "source": [ + "# Instead of the lambda function, add_func can be defined explicitly\n", + "# move data to device before computation\n", + "def add_func(x): # Move data to device and add one\n", + " x = jax.device_put(x, s_dev)\n", + " return x + 1\n", + "\n", + "f = jax.jit(add_func, out_shardings=s_dev)\n", + "out_dev = f(arr_host)\n", + "print(\"Result value of H2D and add 1 in device memory: \\n\", out_dev)" + ] }, - "id": "1qGN2hBQdheo", - "outputId": "48c09658-f8b6-4be3-ef0e-02e0e2566e10" - }, - "outputs": [ { - "name": "stdout", - "output_type": "stream", - "text": [ - "Temp size: 4.75 MB\n", - "Argument size: 0.25 MB\n", - "Total size: 25.00 MB\n", - "Results match within tolerance: True\n" - ] - } - ], - "source": [ - "# Hybrid version: Both activation and parameter offloading\n", - "def hybrid_layer(x, w):\n", - " # Move model parameters w1 and w2 to device memory via device_put\n", - " w1, w2 = jax.tree.map(lambda x: jax.device_put(x, s_dev), w)\n", - " x = checkpoint_name(x, \"x\") # Offload activation x to host memory\n", - " y = x @ w1\n", - " return y @ w2, None\n", - "\n", - "def hybrid_scanned(w, x):\n", - " remat_layer = jax.remat(hybrid_layer, # Use hybrid_layer instead of layer\n", - " policy=policy, # Use offloading policy\n", - " prevent_cse=False) # Allow CSE optimizations\n", - " result = jax.lax.scan(remat_layer, x, w)[0]\n", - " return jnp.sum(result)\n", - "\n", - "# Move model parameters w1 and w2 to the host via device_put\n", - "# Initialize input and weights with small values (0.0001)\n", - "wh1 = jax.device_put(w1, s_host)\n", - "wh2 = jax.device_put(w2, s_host)\n", - "\n", - "# Compile and compute gradients of the scanned function\n", - "f = jax.jit(jax.grad(hybrid_scanned)) # Apply JIT compilation to gradient computation\n", - "\n", - "# Analyze memory usage\n", - "compiled_step = f.lower((wh1, wh2), input).compile()\n", - "compiled_stats = compiled_step.memory_analysis()\n", - "\n", - "if compiled_stats is not None:\n", - " total = compiled_stats.temp_size_in_bytes + compiled_stats.argument_size_in_bytes \\\n", - " + compiled_stats.output_size_in_bytes - compiled_stats.alias_size_in_bytes\n", - " print(f\"Temp size: {compiled_stats.temp_size_in_bytes / (1024**2):.2f} MB\")\n", - " print(f\"Argument size: {compiled_stats.argument_size_in_bytes / (1024**2):.2f} MB\")\n", - " print(f\"Total size: {total / (1024**2):.2f} MB\")\n", - "\n", - "result_both = f((wh1, wh2), input) # Execute with both activation and parameter offloading\n", - "\n", - "# Verify numerical correctness\n", - "are_close = jnp.allclose(\n", - " result_activation[0], # Result from activation offloading only\n", - " result_both[0], # Result from both activation and parameter offloading\n", - " rtol=1e-5,\n", - " atol=1e-5\n", - ")\n", - "print(f\"Results match within tolerance: {are_close}\")" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "SVpozzwHflQk" - }, - "source": [ - "This implementation demonstrates how offloading model parameters together with activation offloading to host memory can significantly reduce device memory usage.\n", - "\n", - "### Memory Analysis\n", - "\n", - "**Baseline Memory Usage:**\n", - "- Input tensor: 0.25 MB (256 × 256 × 4 bytes)\n", - "- Model parameters (w1, w2): 10 MB each (256 × 1024 × 4 bytes ≈ 1 MB per layer × 10 layers)\n", - "\n", - "**Memory Usage Comparison:**\n", - "- Argument size without parameter offloading: 20.25 MB (0.25 + 10 + 10)\n", - "- Argument size with parameter offloading: 0.25 MB (only input remains)\n", - "- Temporary memory without activation offloading: 17.25 MB\n", - "- Temporary memory with activation offloading: 6.50 MB\n", - "- Temporary memory with activation and parameter offloading: 4.75 MB\n", - "\n", - "#### Key Optimizations\n", - "\n", - "1. **Parameter Offloading**: Moving parameters (w1, w2) to host memory reduces argument size by 20 MB (from 20.25 MB to 0.25 MB).\n", - "\n", - "2. **Activation Offloading**: Moving activations to host memory reduces temporary memory usage by 10.75 MB (from 17.25 to 6.50 MB).\n", - "\n", - "3. **Hybrid Strategy**: The rematerialization of activation offloading helps avoid keeping weights on the device and reduce temporary memory usage by 1.75 MB (from 6.50 MB to 4.75 MB). Without it, JAX would be eager to keep the on-device copies of the weights alive for the backward pass.\n", - "\n", - "#### Results\n", - "\n", - "**Total Memory Savings**: 33.5 MB (20 MB + 10.75 MB + 1.75 MB)\n", - "\n", - "This hybrid approach demonstrates that parameter and activation offloading work synergistically to achieve significant memory reductions while maintaining computational correctness. \n", - "\n", - "### Limitations of Parameter Offloading\n", - "\n", - "{func}`jax.lax.scan` is crucial for effective parameter management. Using an explicit for loop would cause parameters to continuously occupy device memory, resulting in the same memory usage as without parameter offloading. While {func}`jax.lax.scan` allows specifying the scan axis, parameter offloading currently works only when scanning over axis 0. Scanning over other axes generates a `transpose` operation during compilation before returning parameters to the device, which is expensive and not supported on all platforms.\n", - "\n", - "The offloading performance can vary for different device types. It may degrade performance due to memory transfers between host and device, so it's important to consider this trade-off when designing your optimization strategy.\n", - "\n", - "# Optimizer State Offloading\n", - "\n", - "Optimizer state offloading is a memory management technique that stores optimizer states in host memory instead of device memory. This approach is particularly useful when optimizer states are large, as it reduces device memory usage.\n", - "\n", - "A basic JAX implementation using the Adam optimizer can serve as a starting point, where all tensors are stored on the device. This will serve as a reference implementation before introducing optimizer state offloading.\n", - "\n", - "### Basic Implementation\n", - "\n", - "This section, let's implement a simple model with the Adam optimizer. This implementation helps establish the baseline behavior before exploring optimizer state offloading. It is particularly useful for understanding memory patterns in large-scale neural network training.\n", - "\n", - "In the code example below, a neural network training loop is included to use JAX and Optax's Adam optimizer. The network consists of four linear layers with GELU activation functions, processing large matrices of size 7168x7168. The training process involves:\n", - "- Forward pass: The input flows through four layers, each applying a linear transformation followed by GELU activation\n", - "- Loss computation: Calculates mean squared error between output and input, plus L2 regularization\n", - "- Backward pass: Computes gradients using automatic differentiation\n", - "- Optimization step: Updates parameters using Adam optimizer with gradient clipping\n", - "\n", - "The code uses JIT compilation to optimize performance and includes memory usage analysis to monitor the computational resources required during training. The memory analysis provides insights into temporary memory usage, argument sizes, and total memory consumption during the optimization step." - ] - }, - { - "cell_type": "code", - "execution_count": 12, - "metadata": { - "colab": { - "base_uri": "https://localhost:8080/" + "cell_type": "markdown", + "metadata": { + "id": "EbE-eBrJTBuS" + }, + "source": [ + "#### Host Output Sharding" + ] }, - "id": "ujvC0YJ2VOyV", - "outputId": "d237ca0a-89ae-4e14-edd3-36cc38890349" - }, - "outputs": [ { - "name": "stdout", - "output_type": "stream", - "text": [ - "Temp size: 2.11 GB\n", - "Argument size: 2.49 GB\n", - "Total size: 4.59 GB\n" - ] + "cell_type": "code", + "execution_count": 4, + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/" + }, + "id": "FjZzkxI8ky4r", + "outputId": "2a1b6e7a-1c29-4347-c020-7b47c27a5cc3" + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Result value of D2H: \n", + " [[0. 1. 2. 3.]\n", + " [4. 5. 6. 7.]]\n" + ] + } + ], + "source": [ + "f = jax.jit(lambda x: x, out_shardings=s_host)\n", + "out_host = f(arr_dev) # Input arrays in the device memory while output arrays in the host memory\n", + "print(\"Result value of D2H: \\n\", out_host)" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "UhLVvRO2p6Lj" + }, + "source": [ + "## Activation Offloading\n", + "\n", + "Before diving into activation offloading, let's first take a look at the baseline code.\n", + "\n", + "This code implements a simple neural network with 10 layers, each consisting of two linear transformations. The code demonstrates basic memory usage patterns and provides a foundation for comparing offloading optimization techniques.\n", + "\n", + "Key components:\n", + "- Each layer consists of two sequential linear operations:\n", + " 1. First multiplication: `x @ w1`\n", + " 2. Second multiplication: `y @ w2`\n", + "- 10-layer network using JAX's scan operation\n", + "- Memory usage analysis\n", + "- Gradient computation with JIT compilation\n", + "\n", + "To analyze memory usage in JAX, the {func}`jax.stages.Compiled.memory_analysis` method can be used on a compiled function. This provides detailed statistics about memory consumption during computation. The key metrics include temporary memory size, argument size, output size, and alias size. To calculate the total memory usage, sum the temporary, argument, and output sizes, then subtract the alias size to avoid double-counting the same memory multiple times. This provides a summarized view of how the device memory is utilized across different aspects of the computation." + ] + }, + { + "cell_type": "code", + "execution_count": 6, + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/" + }, + "id": "UEt0dtxukkaz", + "outputId": "22bb32b7-8491-4100-f212-e56c50f44cfa" + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Temp size: 17.25 MB\n", + "Argument size: 20.25 MB\n", + "Total size: 57.50 MB\n", + "Sample of results: [3.8312336e-07 3.8312336e-07 3.8312336e-07 3.8312336e-07 3.8312336e-07]\n" + ] + } + ], + "source": [ + "# Initialize input and weights with small values (0.0001)\n", + "input = jnp.ones((256, 256), dtype=jnp.float32) * 0.001 # Input matrix: 256 x 256\n", + "w1 = jnp.ones((10, 256, 1024), dtype=jnp.float32) * 0.001 # 10 layers of 256 x 1024 matrices\n", + "w2 = jnp.ones((10, 1024, 256), dtype=jnp.float32) * 0.001 # 10 layers of 1024 x 256 matrices\n", + "\n", + "def two_layers(x, w):\n", + " # Simple two-layer linear transformation\n", + " w1, w2 = w\n", + " y = x @ w1\n", + " return y @ w2, None\n", + "\n", + "def scanned(w, x):\n", + " # Applies the layer function 10 times using JAX's scan operation\n", + " # Input: w (tuple of weight matrices), x (input matrix)\n", + " # Output: sum of the final layer's output\n", + " result = jax.lax.scan(two_layers, x, w)[0]\n", + " return jnp.sum(result)\n", + "\n", + "# Compile and compute gradients of the scanned function\n", + "f = jax.jit(jax.grad(scanned)) # Apply JIT compilation to gradient computation\n", + "\n", + "# Analyze memory usage\n", + "compiled_step = f.lower((w1, w2), input).compile()\n", + "compiled_stats = compiled_step.memory_analysis()\n", + "\n", + "if compiled_stats is not None:\n", + " # Calculate total memory usage including temporary storage, arguments, and outputs\n", + " # Subtract alias size to avoid double-counting memory shared between different components\n", + " total = compiled_stats.temp_size_in_bytes + compiled_stats.argument_size_in_bytes \\\n", + " + compiled_stats.output_size_in_bytes - compiled_stats.alias_size_in_bytes\n", + " print(f\"Temp size: {compiled_stats.temp_size_in_bytes / (1024**2):.2f} MB\")\n", + " print(f\"Argument size: {compiled_stats.argument_size_in_bytes / (1024**2):.2f} MB\")\n", + " print(f\"Total size: {total/(1024**2):.2f} MB\")\n", + "\n", + "# Execute the function and print sample results\n", + "result = f((w1, w2), input) # Execute the function with weights and input\n", + "print(\"Sample of results: \", result[0][0, 0, :5])" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "DnFyRt2nkkaz" + }, + "source": [ + "The detailed coverage of activation offloading can be found in the {ref}`gradient-checkpointing` tutorial. Activation offloading helps manage memory by moving intermediate activations to host memory after the forward pass, and bringing them back to device memory during the backward pass when needed for gradient computation.\n", + "\n", + "To implement activation offloading effectively, it is important to understand checkpoint names and policies. Here's how they work in a simple example:\n", + "\n", + "### Checkpoint Names\n", + "\n", + "The {func}`checkpoint_name` function allows labeling activations for memory management during computation. Here's a simple example that a checkpoint name `x` is specified." + ] + }, + { + "cell_type": "code", + "execution_count": 5, + "metadata": { + "id": "sLO9ceS6p6Lj" + }, + "outputs": [], + "source": [ + "from jax.ad_checkpoint import checkpoint_name\n", + "\n", + "def layer_name(x, w):\n", + " w1, w2 = w\n", + " x = checkpoint_name(x, \"x\")\n", + " y = x @ w1\n", + " return y @ w2, None" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "-_T92oCOp6Lk" + }, + "source": [ + "The checkpoint name helps the system decide whether to:\n", + "* Keep the activation in device memory or\n", + "* Offload it to host memory during computation\n", + "\n", + "This pattern is common in neural networks, where multiple transformations are applied sequentially to input data.\n", + "\n", + "### Checkpoint Policies\n", + "\n", + "This checkpoint policy implements a memory management strategy that optimizes memory usage during computation. It manages memory by handling intermediate values through three strategies:\n", + "1. Recomputing during backward pass (default behavior)\n", + "2. Storing on device\n", + "3. Offloading to host memory after forward pass and loading back during backward pass" + ] + }, + { + "cell_type": "code", + "execution_count": 6, + "metadata": { + "id": "W8Usw_wOp6Lk" + }, + "outputs": [], + "source": [ + "from jax import checkpoint_policies as cp\n", + "\n", + "policy = cp.save_and_offload_only_these_names(\n", + " names_which_can_be_saved=[], # No values stored on device\n", + " names_which_can_be_offloaded=[\"x\"], # Offload activations labeled \"x\"\n", + " offload_src=\"device\", # Move from device memory\n", + " offload_dst=\"pinned_host\" # To pinned host memory\n", + ")" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "iuDRCXu7ky4r" + }, + "source": [ + "{func}`jax.lax.scan` is commonly used in JAX for handling sequential operations (like RNNs or transformers). It can be integrated with JAX's rematerialization to process sequential data.\n", + "\n", + "Key components:\n", + "* {func}`jax.remat` creates a rematerialized version of the layer function using {func}`jax.remat` and applies the checkpoint policy to the layer function\n", + "* `prevent_cse=False` enables XLA's common subexpression elimination for better performance\n", + "* {func}`jax.lax.scan` iterates the rematerialized layer along an axis" + ] + }, + { + "cell_type": "code", + "execution_count": 7, + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/" + }, + "id": "xCrxjTx_p6Lk", + "outputId": "13d46584-9b25-4622-b3c3-f50c1dac02c2" + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Temp size: 6.50 MB\n", + "Argument size: 20.25 MB\n", + "Total size: 46.75 MB\n", + "Results match within tolerance: True\n", + "Sample of results: [3.8312336e-07 3.8312336e-07 3.8312336e-07 3.8312336e-07 3.8312336e-07]\n" + ] + } + ], + "source": [ + "def scanned(w, x):\n", + " remat_layer = jax.remat(layer_name,\n", + " policy=policy, # Use our offloading policy\n", + " prevent_cse=False) # Allow CSE optimizations\n", + " result = jax.lax.scan(remat_layer, x, w)[0]\n", + " return jnp.sum(result)\n", + "\n", + "# Initialize input and weights with small values (0.0001)\n", + "input = jnp.ones((256, 256), dtype=jnp.float32) * 0.001 # Input matrix: 256 x 256\n", + "w1 = jnp.ones((10, 256, 1024), dtype=jnp.float32) * 0.001 # 10 layers of 256 x 1024 matrices\n", + "w2 = jnp.ones((10, 1024, 256), dtype=jnp.float32) * 0.001 # 10 layers of 1024 x 256 matrices\n", + "\n", + "# Compile and compute gradients of the scanned function\n", + "f = jax.jit(jax.grad(scanned)) # Apply JIT compilation to gradient computation\n", + "\n", + "# Analyze memory usage\n", + "compiled_step = f.lower((w1, w2), input).compile()\n", + "compiled_stats = compiled_step.memory_analysis()\n", + "\n", + "if compiled_stats is not None:\n", + " total = compiled_stats.temp_size_in_bytes + compiled_stats.argument_size_in_bytes \\\n", + " + compiled_stats.output_size_in_bytes - compiled_stats.alias_size_in_bytes\n", + " print(f\"Temp size: {compiled_stats.temp_size_in_bytes / (1024**2):.2f} MB\")\n", + " print(f\"Argument size: {compiled_stats.argument_size_in_bytes / (1024**2):.2f} MB\")\n", + " print(f\"Total size: {total/(1024**2):.2f} MB\")\n", + "\n", + "result_activation = f((w1, w2), input) # Execute the function with weights and input\n", + "# Verify numerical correctness\n", + "are_close = jnp.allclose(\n", + " result_activation[0], # Result from activation offloading only\n", + " result[0], # Result from both activation and parameter offloading\n", + " rtol=1e-5,\n", + " atol=1e-5\n", + ")\n", + "print(f\"Results match within tolerance: {are_close}\")\n", + "print(\"Sample of results: \", result_activation[0][0, 0, :5])" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "0tx7aara42pY" + }, + "source": [ + "Activation offloading reduces temporary memory usage from 17.25 MB to 6.5 MB while input and output argument sizes remain the same. Totally 10.75 MB is saved. It is achieved by offloading activation `x` to host memory after the forward pass and loading it back to device memory before the backward pass.\n", + "\n", + "### Summary of Activation Offloading\n", + "\n", + "Activation offloading provides a powerful way to manage memory in large computations by:\n", + "\n", + "* Using checkpoint names to mark specific activations\n", + "* Applying policies to control where and how activations are stored\n", + "* Supporting common JAX patterns like scan operations\n", + "* Moving selected activations to host memory when device memory is under budget\n", + "\n", + "This approach is particularly useful when working with large models that would otherwise exceed device memory capacity.\n", + "\n", + "## Parameter Offloading\n", + "\n", + "Model parameters (also known as weights) can be offloaded to the host memory to optimize device memory usage during initialization. This is achieved by using {func}`jax.jit` with a sharding strategy that specifies host memory kind.\n", + "\n", + "While parameter offloading and activation offloading are distinct memory optimization techniques, the following example demonstrates parameter offloading built upon the activation offloading implementation shown earlier.\n", + "\n", + "### Parameter Placement for Computation\n", + "\n", + "Different from the earlier `layer` function, {func}`jax.device_put` is applied to move parameter `w1` and `w2` to the device before the matrix multiplications. This ensures the parameters are available on the device for both forward and backward passes.\n", + "\n", + "Note that the activation offloading implementation remains unchanged, using the same:\n", + "* Checkpoint name `\"x\"`\n", + "* Checkpoint policy\n", + "* `scanned` function combining {func}`jax.remat` and {func}`jax.lax.scan`\n", + "\n", + "### Parameter Initialization with Host Offloading\n", + "\n", + "During the initialization, parameter `w1` and `w2` are placed on host memory before being passed to the {func}`jax.jit` function `f`, while keeping the `input` variable on the device." + ] + }, + { + "cell_type": "code", + "execution_count": 8, + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/" + }, + "id": "1qGN2hBQdheo", + "outputId": "48c09658-f8b6-4be3-ef0e-02e0e2566e10" + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Temp size: 4.75 MB\n", + "Argument size: 0.25 MB\n", + "Total size: 25.00 MB\n", + "Results match within tolerance: True\n" + ] + } + ], + "source": [ + "# Hybrid version: Both activation and parameter offloading\n", + "def hybrid_layer(x, w):\n", + " # Move model parameters w1 and w2 to device memory via device_put\n", + " w1, w2 = jax.tree.map(lambda x: jax.device_put(x, s_dev), w)\n", + " x = checkpoint_name(x, \"x\") # Offload activation x to host memory\n", + " y = x @ w1\n", + " return y @ w2, None\n", + "\n", + "def hybrid_scanned(w, x):\n", + " remat_layer = jax.remat(hybrid_layer, # Use hybrid_layer instead of layer\n", + " policy=policy, # Use offloading policy\n", + " prevent_cse=False) # Allow CSE optimizations\n", + " result = jax.lax.scan(remat_layer, x, w)[0]\n", + " return jnp.sum(result)\n", + "\n", + "# Move model parameters w1 and w2 to the host via device_put\n", + "# Initialize input and weights with small values (0.0001)\n", + "wh1 = jax.device_put(w1, s_host)\n", + "wh2 = jax.device_put(w2, s_host)\n", + "\n", + "# Compile and compute gradients of the scanned function\n", + "f = jax.jit(jax.grad(hybrid_scanned)) # Apply JIT compilation to gradient computation\n", + "\n", + "# Analyze memory usage\n", + "compiled_step = f.lower((wh1, wh2), input).compile()\n", + "compiled_stats = compiled_step.memory_analysis()\n", + "\n", + "if compiled_stats is not None:\n", + " total = compiled_stats.temp_size_in_bytes + compiled_stats.argument_size_in_bytes \\\n", + " + compiled_stats.output_size_in_bytes - compiled_stats.alias_size_in_bytes\n", + " print(f\"Temp size: {compiled_stats.temp_size_in_bytes / (1024**2):.2f} MB\")\n", + " print(f\"Argument size: {compiled_stats.argument_size_in_bytes / (1024**2):.2f} MB\")\n", + " print(f\"Total size: {total / (1024**2):.2f} MB\")\n", + "\n", + "result_both = f((wh1, wh2), input) # Execute with both activation and parameter offloading\n", + "\n", + "# Verify numerical correctness\n", + "are_close = jnp.allclose(\n", + " result_activation[0], # Result from activation offloading only\n", + " result_both[0], # Result from both activation and parameter offloading\n", + " rtol=1e-5,\n", + " atol=1e-5\n", + ")\n", + "print(f\"Results match within tolerance: {are_close}\")" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "SVpozzwHflQk" + }, + "source": [ + "This implementation demonstrates how offloading model parameters together with activation offloading to host memory can significantly reduce device memory usage.\n", + "\n", + "### Memory Analysis\n", + "\n", + "**Baseline Memory Usage:**\n", + "- Input tensor: 0.25 MB (256 × 256 × 4 bytes)\n", + "- Model parameters (w1, w2): 10 MB each (256 × 1024 × 4 bytes ≈ 1 MB per layer × 10 layers)\n", + "\n", + "**Memory Usage Comparison:**\n", + "- Argument size without parameter offloading: 20.25 MB (0.25 + 10 + 10)\n", + "- Argument size with parameter offloading: 0.25 MB (only input remains)\n", + "- Temporary memory without activation offloading: 17.25 MB\n", + "- Temporary memory with activation offloading: 6.50 MB\n", + "- Temporary memory with activation and parameter offloading: 4.75 MB\n", + "\n", + "#### Key Optimizations\n", + "\n", + "1. **Parameter Offloading**: Moving parameters (w1, w2) to host memory reduces argument size by 20 MB (from 20.25 MB to 0.25 MB).\n", + "\n", + "2. **Activation Offloading**: Moving activations to host memory reduces temporary memory usage by 10.75 MB (from 17.25 to 6.50 MB).\n", + "\n", + "3. **Hybrid Strategy**: The rematerialization of activation offloading helps avoid keeping weights on the device and reduce temporary memory usage by 1.75 MB (from 6.50 MB to 4.75 MB). Without it, JAX would be eager to keep the on-device copies of the weights alive for the backward pass.\n", + "\n", + "#### Results\n", + "\n", + "**Total Memory Savings**: 33.5 MB (20 MB + 10.75 MB + 1.75 MB)\n", + "\n", + "This hybrid approach demonstrates that parameter and activation offloading work synergistically to achieve significant memory reductions while maintaining computational correctness. \n", + "\n", + "### Limitations of Parameter Offloading\n", + "\n", + "{func}`jax.lax.scan` is crucial for effective parameter management. Using an explicit for loop would cause parameters to continuously occupy device memory, resulting in the same memory usage as without parameter offloading. While {func}`jax.lax.scan` allows specifying the scan axis, parameter offloading currently works only when scanning over axis 0. Scanning over other axes generates a `transpose` operation during compilation before returning parameters to the device, which is expensive and not supported on all platforms.\n", + "\n", + "The offloading performance can vary for different device types. It may degrade performance due to memory transfers between host and device, so it's important to consider this trade-off when designing your optimization strategy.\n", + "\n", + "# Optimizer State Offloading\n", + "\n", + "Optimizer state offloading is a memory management technique that stores optimizer states in host memory instead of device memory. This approach is particularly useful when optimizer states are large, as it reduces device memory usage.\n", + "\n", + "A basic JAX implementation using the Adam optimizer can serve as a starting point, where all tensors are stored on the device. This will serve as a reference implementation before introducing optimizer state offloading.\n", + "\n", + "### Basic Implementation\n", + "\n", + "This section, let's implement a simple model with the Adam optimizer. This implementation helps establish the baseline behavior before exploring optimizer state offloading. It is particularly useful for understanding memory patterns in large-scale neural network training.\n", + "\n", + "In the code example below, a neural network training loop is included to use JAX and Optax's Adam optimizer. The network consists of four linear layers with GELU activation functions, processing large matrices of size 7168x7168. The training process involves:\n", + "- Forward pass: The input flows through four layers, each applying a linear transformation followed by GELU activation\n", + "- Loss computation: Calculates mean squared error between output and input, plus L2 regularization\n", + "- Backward pass: Computes gradients using automatic differentiation\n", + "- Optimization step: Updates parameters using Adam optimizer with gradient clipping\n", + "\n", + "The code uses JIT compilation to optimize performance and includes memory usage analysis to monitor the computational resources required during training. The memory analysis provides insights into temporary memory usage, argument sizes, and total memory consumption during the optimization step." + ] + }, + { + "cell_type": "code", + "execution_count": 12, + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/" + }, + "id": "ujvC0YJ2VOyV", + "outputId": "d237ca0a-89ae-4e14-edd3-36cc38890349" + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Temp size: 2.11 GB\n", + "Argument size: 2.49 GB\n", + "Total size: 4.59 GB\n" + ] + } + ], + "source": [ + "import optax\n", + "\n", + "DIM = 7168\n", + "\n", + "# Initialize data and parameter w1, w2, w3 and w4\n", + "input = jnp.ones((DIM, DIM))\n", + "params = {f'w{i}': jnp.ones((DIM, DIM)) for i in range(1, 5)}\n", + "\n", + "# Initialize optimizer\n", + "optimizer = optax.chain(\n", + " optax.clip_by_global_norm(1.0),\n", + " optax.adam(learning_rate=0.1)\n", + ")\n", + "opt_state = optimizer.init(params)\n", + "\n", + "def gelu(x):\n", + " return 0.5 * x * (1 + jnp.tanh(jnp.sqrt(2 / jnp.pi) * (x + 0.044715 * x**3)))\n", + "\n", + "def single_layer(x, w):\n", + " return x @ w\n", + "\n", + "def forward(params, x):\n", + " for i in range(1, 5):\n", + " x = gelu(single_layer(x, params[f'w{i}']))\n", + " return x\n", + "\n", + "def compute_loss(params, inputs):\n", + " outputs = forward(params, inputs)\n", + " loss = jnp.mean((outputs - inputs) ** 2)\n", + " l2_reg = 0.001 * sum(jnp.sum(w ** 2) for w in jax.tree_util.tree_leaves(params))\n", + " return loss + l2_reg\n", + "\n", + "def step(params, opt_state, inputs):\n", + " grads = jax.grad(lambda p: compute_loss(p, inputs))(params)\n", + " updates, new_opt_state = optimizer.update(grads, opt_state, params)\n", + " return optax.apply_updates(params, updates), new_opt_state\n", + "\n", + "# JIT compile the step function with proper sharding\n", + "step = jax.jit(step, donate_argnums=(0, 1))\n", + "\n", + "# Run a optimization step\n", + "new_params, new_opt_state = step(params, opt_state, input)\n", + "\n", + "# Analyze memory usage\n", + "compiled_step = step.lower(params, opt_state, input).compile()\n", + "compiled_stats = compiled_step.memory_analysis()\n", + "\n", + "if compiled_stats is not None:\n", + " total = compiled_stats.temp_size_in_bytes + compiled_stats.argument_size_in_bytes \\\n", + " + compiled_stats.output_size_in_bytes - compiled_stats.alias_size_in_bytes\n", + " print(f\"Temp size: {compiled_stats.temp_size_in_bytes / (1024**3):.2f} GB\")\n", + " print(f\"Argument size: {compiled_stats.argument_size_in_bytes / (1024**3):.2f} GB\")\n", + " print(f\"Total size: {total / (1024**3):.2f} GB\")" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "oW4Qm6E5VOyV" + }, + "source": [ + "Optimizer state offloading can be implemented as follows.\n", + "\n", + "### Setting Up Sharding and Memory Kinds\n", + "\n", + "{func}`jax.sharding.SingleDeivceSharding` is adopted to simplify the shardings for both device and host memory kinds. During the model state initialization, move the optimizer state to the host using {func}`device_put`.\n", + "\n", + "### Model and Training Step Implementation\n", + "\n", + "Next, define the model architecture, loss function, and training step. The key addition here is moving the optimizer state to device memory via {func}`device_put` at the beginning of each training step, as it's needed for the parameter update on the device.\n", + "\n", + "### Running and Comparing Results\n", + "\n", + "After setting up the sharding, the optimizer state is moved to host memory and the step function is run with {func}`jax.jit`.\n", + "\n", + "The JIT compilation of the step function uses several important parameters:\n", + "- `donate_argnums=(0,)`: Indicates that the first argument (parameters) can be modified in-place, allowing JAX to reuse its memory\n", + "- `out_shardings`: Specifies how output tensors should be sharded across the mesh (devices and hosts)" + ] + }, + { + "cell_type": "code", + "execution_count": 13, + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/" + }, + "id": "fEDTasJZVOyW", + "outputId": "b36cedd6-cf30-4d36-f4fd-32b2fdfd7564" + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Temp size: 1.91 GB\n", + "Argument size: 0.96 MB\n", + "Total size: 2.87 GB\n" + ] + } + ], + "source": [ + "# Create sharding specifications for device and host memory\n", + "s_dev = jax.sharding.SingleDeviceSharding(jax.devices()[0], memory_kind=\"device\")\n", + "s_host = jax.sharding.SingleDeviceSharding(jax.devices()[0], memory_kind=\"pinned_host\")\n", + "\n", + "def step(params, opt_state, inputs):\n", + " grads = jax.grad(lambda p: compute_loss(p, inputs))(params)\n", + " opt_state = jax.device_put(opt_state, s_dev)\n", + " updates, new_opt_state = optimizer.update(grads, opt_state, params)\n", + " new_params = optax.apply_updates(params, updates)\n", + " return new_params, new_opt_state\n", + "\n", + "params = {f'w{i}': jnp.ones((DIM, DIM)) for i in range(1, 5)}\n", + "opt_state = optimizer.init(params)\n", + "\n", + "# Initialize optimizer\n", + "optimizer = optax.chain(\n", + " optax.clip_by_global_norm(1.0),\n", + " optax.adam(learning_rate=0.1)\n", + ")\n", + "\n", + "# Optimizer state is placed on the host during initialization\n", + "opt_state = jax.device_put(opt_state, s_host)\n", + "\n", + "# JIT compile the step function with proper sharding and memory optimization\n", + "step = jax.jit(\n", + " step,\n", + " donate_argnums=(0,),\n", + " out_shardings=(s_dev, s_host)\n", + ")\n", + "\n", + "# Run an optimization step\n", + "new_params, offload_opt_state = step(params, opt_state, input)\n", + "\n", + "# Analyze memory usage\n", + "compiled_step = step.lower(params, opt_state, input).compile()\n", + "compiled_stats = compiled_step.memory_analysis()\n", + "if compiled_stats is not None:\n", + " total = compiled_stats.temp_size_in_bytes + compiled_stats.argument_size_in_bytes \\\n", + " + compiled_stats.output_size_in_bytes - compiled_stats.alias_size_in_bytes\n", + " print(f\"Temp size: {compiled_stats.temp_size_in_bytes / (1024**3):.2f} GB\")\n", + " print(f\"Argument size: {compiled_stats.argument_size_in_bytes / (1024**3):.2f} MB\")\n", + " print(f\"Total size: {total / (1024**3):.2f} GB\")" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "vKo8qYnQVOyW" + }, + "source": [ + "This implementation demonstrates how to:\n", + "1. Set up sharding specifications for `device` and `pinned_host`\n", + "2. Move optimizer states between host and device memory via {func}`jax.device_put`\n", + "3. Use `out_shardings` to ensure proper memory placement\n", + "4. Show the memory usage\n", + "\n", + "This implementation demonstrates how offloading optimizer state to host memory can reduce device memory usage through a trade-off between argument size and temporary memory.\n", + "\n", + "Memory Analysis:\n", + "1. Argument Size Reduction:\n", + " - The optimizer states are arguments of the {func}`jax.jit` function\n", + " - By offloading these states to host memory, the argument size on device is reduced\n", + "\n", + "2. Temporary Memory Impact:\n", + " - Offloading increases temporary memory usage\n", + " - This is because outputs of optimizer states need memory buffers before being copied to host\n", + " - The memory live ranges for these temporary buffers are extended due to the host-device transfers\n", + "\n", + "3. Latency Hiding Scheduling:\n", + " - JAX uses XLA's latency hiding scheduling to overlap computation with host-device transfers\n", + " - The overlapping can cause tensors to have larger live ranges, which increases memory pressure on the device\n", + " - This adaptive behavior helps maintain stable memory usage while still providing some performance benefits\n", + "\n", + "4. Memory Trade-off:\n", + " - Total memory size with offloading: 2.87 GB\n", + " - Total memory size without offloading: 4.59 GB\n", + " - Net memory saving: 1.72 GB\n", + "\n", + "while offloading increases temporary memory usage, the reduction in argument size more than compensates for this increase, resulting in an overall reduction in device memory usage. \n", + "\n", + "Note: The optimizer states can be compared for numerical equivalence using `jax.tree_util.tree_map` and `jnp.allclose`, but this verification step is omitted here for brevity.\n", + "\n", + "## Tools for Host Offloading\n", + "\n", + "{func}`jax.stages.Compiled.memory_analysis` API is utilized above to get memory usage information. For device memory analysis, refer to {doc}`../device_memory_profiling`. The profiling tools described in {doc}`../profiling` can help measure memory savings and performance impact from host offloading." + ] } - ], - "source": [ - "import optax\n", - "\n", - "DIM = 7168\n", - "\n", - "# Initialize data and parameter w1, w2, w3 and w4\n", - "input = jnp.ones((DIM, DIM))\n", - "params = {f'w{i}': jnp.ones((DIM, DIM)) for i in range(1, 5)}\n", - "\n", - "# Initialize optimizer\n", - "optimizer = optax.chain(\n", - " optax.clip_by_global_norm(1.0),\n", - " optax.adam(learning_rate=0.1)\n", - ")\n", - "opt_state = optimizer.init(params)\n", - "\n", - "def gelu(x):\n", - " return 0.5 * x * (1 + jnp.tanh(jnp.sqrt(2 / jnp.pi) * (x + 0.044715 * x**3)))\n", - "\n", - "def single_layer(x, w):\n", - " return x @ w\n", - "\n", - "def forward(params, x):\n", - " for i in range(1, 5):\n", - " x = gelu(single_layer(x, params[f'w{i}']))\n", - " return x\n", - "\n", - "def compute_loss(params, inputs):\n", - " outputs = forward(params, inputs)\n", - " loss = jnp.mean((outputs - inputs) ** 2)\n", - " l2_reg = 0.001 * sum(jnp.sum(w ** 2) for w in jax.tree_util.tree_leaves(params))\n", - " return loss + l2_reg\n", - "\n", - "def step(params, opt_state, inputs):\n", - " grads = jax.grad(lambda p: compute_loss(p, inputs))(params)\n", - " updates, new_opt_state = optimizer.update(grads, opt_state, params)\n", - " return optax.apply_updates(params, updates), new_opt_state\n", - "\n", - "# JIT compile the step function with proper sharding\n", - "step = jax.jit(step, donate_argnums=(0, 1))\n", - "\n", - "# Run a optimization step\n", - "new_params, new_opt_state = step(params, opt_state, input)\n", - "\n", - "# Analyze memory usage\n", - "compiled_step = step.lower(params, opt_state, input).compile()\n", - "compiled_stats = compiled_step.memory_analysis()\n", - "\n", - "if compiled_stats is not None:\n", - " total = compiled_stats.temp_size_in_bytes + compiled_stats.argument_size_in_bytes \\\n", - " + compiled_stats.output_size_in_bytes - compiled_stats.alias_size_in_bytes\n", - " print(f\"Temp size: {compiled_stats.temp_size_in_bytes / (1024**3):.2f} GB\")\n", - " print(f\"Argument size: {compiled_stats.argument_size_in_bytes / (1024**3):.2f} GB\")\n", - " print(f\"Total size: {total / (1024**3):.2f} GB\")" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "oW4Qm6E5VOyV" - }, - "source": [ - "Optimizer state offloading can be implemented as follows.\n", - "\n", - "### Setting Up Sharding and Memory Kinds\n", - "\n", - "{func}`jax.sharding.SingleDeivceSharding` is adopted to simplify the shardings for both device and host memory kinds. During the model state initialization, move the optimizer state to the host using {func}`device_put`.\n", - "\n", - "### Model and Training Step Implementation\n", - "\n", - "Next, define the model architecture, loss function, and training step. The key addition here is moving the optimizer state to device memory via {func}`device_put` at the beginning of each training step, as it's needed for the parameter update on the device.\n", - "\n", - "### Running and Comparing Results\n", - "\n", - "After setting up the sharding, the optimizer state is moved to host memory and the step function is run with {func}`jax.jit`.\n", - "\n", - "The JIT compilation of the step function uses several important parameters:\n", - "- `donate_argnums=(0,)`: Indicates that the first argument (parameters) can be modified in-place, allowing JAX to reuse its memory\n", - "- `out_shardings`: Specifies how output tensors should be sharded across the mesh (devices and hosts)" - ] - }, - { - "cell_type": "code", - "execution_count": 13, - "metadata": { + ], + "metadata": { + "accelerator": "GPU", "colab": { - "base_uri": "https://localhost:8080/" + "gpuType": "T4", + "provenance": [], + "toc_visible": true }, - "id": "fEDTasJZVOyW", - "outputId": "b36cedd6-cf30-4d36-f4fd-32b2fdfd7564" - }, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "Temp size: 1.91 GB\n", - "Argument size: 0.96 MB\n", - "Total size: 2.87 GB\n" - ] + "jupytext": { + "formats": "ipynb,md:myst" + }, + "kernelspec": { + "display_name": "Python 3", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.12.3" } - ], - "source": [ - "# Create sharding specifications for device and host memory\n", - "s_dev = jax.sharding.SingleDeviceSharding(jax.devices()[0], memory_kind=\"device\")\n", - "s_host = jax.sharding.SingleDeviceSharding(jax.devices()[0], memory_kind=\"pinned_host\")\n", - "\n", - "def step(params, opt_state, inputs):\n", - " grads = jax.grad(lambda p: compute_loss(p, inputs))(params)\n", - " opt_state = jax.device_put(opt_state, s_dev)\n", - " updates, new_opt_state = optimizer.update(grads, opt_state, params)\n", - " new_params = optax.apply_updates(params, updates)\n", - " return new_params, new_opt_state\n", - "\n", - "params = {f'w{i}': jnp.ones((DIM, DIM)) for i in range(1, 5)}\n", - "opt_state = optimizer.init(params)\n", - "\n", - "# Initialize optimizer\n", - "optimizer = optax.chain(\n", - " optax.clip_by_global_norm(1.0),\n", - " optax.adam(learning_rate=0.1)\n", - ")\n", - "\n", - "# Optimizer state is placed on the host during initialization\n", - "opt_state = jax.device_put(opt_state, s_host)\n", - "\n", - "# JIT compile the step function with proper sharding and memory optimization\n", - "step = jax.jit(\n", - " step,\n", - " donate_argnums=(0,),\n", - " out_shardings=(s_dev, s_host)\n", - ")\n", - "\n", - "# Run an optimization step\n", - "new_params, offload_opt_state = step(params, opt_state, input)\n", - "\n", - "# Analyze memory usage\n", - "compiled_step = step.lower(params, opt_state, input).compile()\n", - "compiled_stats = compiled_step.memory_analysis()\n", - "if compiled_stats is not None:\n", - " total = compiled_stats.temp_size_in_bytes + compiled_stats.argument_size_in_bytes \\\n", - " + compiled_stats.output_size_in_bytes - compiled_stats.alias_size_in_bytes\n", - " print(f\"Temp size: {compiled_stats.temp_size_in_bytes / (1024**3):.2f} GB\")\n", - " print(f\"Argument size: {compiled_stats.argument_size_in_bytes / (1024**3):.2f} MB\")\n", - " print(f\"Total size: {total / (1024**3):.2f} GB\")" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "vKo8qYnQVOyW" - }, - "source": [ - "This implementation demonstrates how to:\n", - "1. Set up sharding specifications for `device` and `pinned_host`\n", - "2. Move optimizer states between host and device memory via {func}`jax.device_put`\n", - "3. Use `out_shardings` to ensure proper memory placement\n", - "4. Show the memory usage\n", - "\n", - "This implementation demonstrates how offloading optimizer state to host memory can reduce device memory usage through a trade-off between argument size and temporary memory.\n", - "\n", - "Memory Analysis:\n", - "1. Argument Size Reduction:\n", - " - The optimizer states are arguments of the {func}`jax.jit` function\n", - " - By offloading these states to host memory, the argument size on device is reduced\n", - "\n", - "2. Temporary Memory Impact:\n", - " - Offloading increases temporary memory usage\n", - " - This is because outputs of optimizer states need memory buffers before being copied to host\n", - " - The memory live ranges for these temporary buffers are extended due to the host-device transfers\n", - "\n", - "3. Latency Hiding Scheduling:\n", - " - JAX uses XLA's latency hiding scheduling to overlap computation with host-device transfers\n", - " - The overlapping can cause tensors to have larger live ranges, which increases memory pressure on the device\n", - " - This adaptive behavior helps maintain stable memory usage while still providing some performance benefits\n", - "\n", - "4. Memory Trade-off:\n", - " - Total memory size with offloading: 2.87 GB\n", - " - Total memory size without offloading: 4.59 GB\n", - " - Net memory saving: 1.72 GB\n", - "\n", - "while offloading increases temporary memory usage, the reduction in argument size more than compensates for this increase, resulting in an overall reduction in device memory usage. \n", - "\n", - "Note: The optimizer states can be compared for numerical equivalence using `jax.tree_util.tree_map` and `jnp.allclose`, but this verification step is omitted here for brevity.\n", - "\n", - "## Tools for Host Offloading\n", - "\n", - "{func}`jax.stages.Compiled.memory_analysis` API is utilized above to get memory usage information. For device memory analysis, refer to :doc:`device_memory_profiling`. The profiling tools described in {ref}`profiling` can help measure memory savings and performance impact from host offloading." - ] - } - ], - "metadata": { - "accelerator": "GPU", - "colab": { - "gpuType": "T4", - "provenance": [], - "toc_visible": true - }, - "jupytext": { - "formats": "ipynb,md:myst" - }, - "kernelspec": { - "display_name": "Python 3", - "name": "python3" }, - "language_info": { - "codemirror_mode": { - "name": "ipython", - "version": 3 - }, - "file_extension": ".py", - "mimetype": "text/x-python", - "name": "python", - "nbconvert_exporter": "python", - "pygments_lexer": "ipython3", - "version": "3.12.3" - } - }, - "nbformat": 4, - "nbformat_minor": 0 + "nbformat": 4, + "nbformat_minor": 0 } diff --git a/docs/notebooks/host-offloading.md b/docs/notebooks/host-offloading.md index f8855ecbb67f..db58c04b9014 100644 --- a/docs/notebooks/host-offloading.md +++ b/docs/notebooks/host-offloading.md @@ -637,4 +637,4 @@ Note: The optimizer states can be compared for numerical equivalence using `jax. ## Tools for Host Offloading -{func}`jax.stages.Compiled.memory_analysis` API is utilized above to get memory usage information. For device memory analysis, refer to :doc:`device_memory_profiling`. The profiling tools described in {ref}`profiling` can help measure memory savings and performance impact from host offloading. +{func}`jax.stages.Compiled.memory_analysis` API is utilized above to get memory usage information. For device memory analysis, refer to {doc}`../device_memory_profiling`. The profiling tools described in {doc}`../profiling` can help measure memory savings and performance impact from host offloading.