diff --git a/docs/tutorials/performance_debugging.ipynb b/docs/tutorials/performance_debugging.ipynb new file mode 100644 index 000000000..51682d0d1 --- /dev/null +++ b/docs/tutorials/performance_debugging.ipynb @@ -0,0 +1,1529 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "metadata": { + "id": "wpQU648ycXYj" + }, + "source": [ + "# Set up input data" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/" + }, + "id": "xd7C3J5YlGwr", + "outputId": "e1cdb916-925f-41e4-8ec7-56044ae09687" + }, + "outputs": [ + { + "output_type": "stream", + "name": "stdout", + "text": [ + "Requirement already satisfied: grain in /usr/local/lib/python3.12/dist-packages (0.2.14)\n", + "Requirement already satisfied: transformers in /usr/local/lib/python3.12/dist-packages (4.57.1)\n", + "Requirement already satisfied: absl-py in /usr/local/lib/python3.12/dist-packages (from grain) (1.4.0)\n", + "Requirement already satisfied: array-record>=0.8.1 in /usr/local/lib/python3.12/dist-packages (from grain) (0.8.2)\n", + "Requirement already satisfied: cloudpickle in /usr/local/lib/python3.12/dist-packages (from grain) (3.1.2)\n", + "Requirement already satisfied: etils[epath,epy] in /usr/local/lib/python3.12/dist-packages (from grain) (1.13.0)\n", + "Requirement already satisfied: numpy in /usr/local/lib/python3.12/dist-packages (from grain) (2.0.2)\n", + "Requirement already satisfied: protobuf>=5.28.3 in /usr/local/lib/python3.12/dist-packages (from grain) (6.33.0)\n", + "Requirement already satisfied: filelock in /usr/local/lib/python3.12/dist-packages (from transformers) (3.20.0)\n", + "Requirement already satisfied: huggingface-hub<1.0,>=0.34.0 in /usr/local/lib/python3.12/dist-packages (from transformers) (0.36.0)\n", + "Requirement already satisfied: packaging>=20.0 in /usr/local/lib/python3.12/dist-packages (from transformers) (25.0)\n", + "Requirement already satisfied: pyyaml>=5.1 in /usr/local/lib/python3.12/dist-packages (from transformers) (6.0.3)\n", + "Requirement already satisfied: regex!=2019.12.17 in /usr/local/lib/python3.12/dist-packages (from transformers) (2025.11.3)\n", + "Requirement already satisfied: requests in /usr/local/lib/python3.12/dist-packages (from transformers) (2.32.4)\n", + "Requirement already satisfied: tokenizers<=0.23.0,>=0.22.0 in /usr/local/lib/python3.12/dist-packages (from transformers) (0.22.1)\n", + "Requirement already satisfied: safetensors>=0.4.3 in /usr/local/lib/python3.12/dist-packages (from transformers) (0.6.2)\n", + "Requirement already satisfied: tqdm>=4.27 in /usr/local/lib/python3.12/dist-packages (from transformers) (4.67.1)\n", + "Requirement already satisfied: fsspec>=2023.5.0 in /usr/local/lib/python3.12/dist-packages (from huggingface-hub<1.0,>=0.34.0->transformers) (2025.10.0)\n", + "Requirement already satisfied: typing-extensions>=3.7.4.3 in /usr/local/lib/python3.12/dist-packages (from huggingface-hub<1.0,>=0.34.0->transformers) (4.15.0)\n", + "Requirement already satisfied: hf-xet<2.0.0,>=1.1.3 in /usr/local/lib/python3.12/dist-packages (from huggingface-hub<1.0,>=0.34.0->transformers) (1.2.0)\n", + "Requirement already satisfied: importlib_resources in /usr/local/lib/python3.12/dist-packages (from etils[epath,epy]->grain) (6.5.2)\n", + "Requirement already satisfied: zipp in /usr/local/lib/python3.12/dist-packages (from etils[epath,epy]->grain) (3.23.0)\n", + "Requirement already satisfied: charset_normalizer<4,>=2 in /usr/local/lib/python3.12/dist-packages (from requests->transformers) (3.4.4)\n", + "Requirement already satisfied: idna<4,>=2.5 in /usr/local/lib/python3.12/dist-packages (from requests->transformers) (3.11)\n", + "Requirement already satisfied: urllib3<3,>=1.21.1 in /usr/local/lib/python3.12/dist-packages (from requests->transformers) (2.5.0)\n", + "Requirement already satisfied: certifi>=2017.4.17 in /usr/local/lib/python3.12/dist-packages (from requests->transformers) (2025.10.5)\n" + ] + } + ], + "source": [ + "# Data transformations.\n", + "!pip install grain transformers" + ] + }, + { + "cell_type": "code", + "source": [ + "# We'll use data from TFDS for simplicity.\n", + "!pip install tensorflow-datasets[tf-nightly]" + ], + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/" + }, + "id": "ZTgJW67SyxeD", + "outputId": "e2cf9db9-928d-4b26-b0bf-3bd8b95b4dc3" + }, + "execution_count": null, + "outputs": [ + { + "output_type": "stream", + "name": "stdout", + "text": [ + "Requirement already satisfied: tensorflow-datasets[tf-nightly] in /usr/local/lib/python3.12/dist-packages (4.9.9)\n", + "Requirement already satisfied: absl-py in /usr/local/lib/python3.12/dist-packages (from tensorflow-datasets[tf-nightly]) (1.4.0)\n", + "Requirement already satisfied: array_record>=0.5.0 in /usr/local/lib/python3.12/dist-packages (from tensorflow-datasets[tf-nightly]) (0.8.2)\n", + "Requirement already satisfied: dm-tree in /usr/local/lib/python3.12/dist-packages (from tensorflow-datasets[tf-nightly]) (0.1.9)\n", + "Requirement already satisfied: etils>=1.9.1 in /usr/local/lib/python3.12/dist-packages (from etils[edc,enp,epath,epy,etree]>=1.9.1; python_version >= \"3.11\"->tensorflow-datasets[tf-nightly]) (1.13.0)\n", + "Requirement already satisfied: immutabledict in /usr/local/lib/python3.12/dist-packages (from tensorflow-datasets[tf-nightly]) (4.2.2)\n", + "Requirement already satisfied: numpy in /usr/local/lib/python3.12/dist-packages (from tensorflow-datasets[tf-nightly]) (2.0.2)\n", + "Requirement already satisfied: promise in /usr/local/lib/python3.12/dist-packages (from tensorflow-datasets[tf-nightly]) (2.3)\n", + "Requirement already satisfied: protobuf>=3.20 in /usr/local/lib/python3.12/dist-packages (from tensorflow-datasets[tf-nightly]) (6.33.0)\n", + "Requirement already satisfied: psutil in /usr/local/lib/python3.12/dist-packages (from tensorflow-datasets[tf-nightly]) (5.9.5)\n", + "Requirement already satisfied: pyarrow in /usr/local/lib/python3.12/dist-packages (from tensorflow-datasets[tf-nightly]) (22.0.0)\n", + "Requirement already satisfied: requests>=2.19.0 in /usr/local/lib/python3.12/dist-packages (from tensorflow-datasets[tf-nightly]) (2.32.4)\n", + "Requirement already satisfied: simple_parsing in /usr/local/lib/python3.12/dist-packages (from tensorflow-datasets[tf-nightly]) (0.1.7)\n", + "Requirement already satisfied: tensorflow-metadata in /usr/local/lib/python3.12/dist-packages (from tensorflow-datasets[tf-nightly]) (1.17.2)\n", + "Requirement already satisfied: termcolor in /usr/local/lib/python3.12/dist-packages (from tensorflow-datasets[tf-nightly]) (3.2.0)\n", + "Requirement already satisfied: toml in /usr/local/lib/python3.12/dist-packages (from tensorflow-datasets[tf-nightly]) (0.10.2)\n", + "Requirement already satisfied: tqdm in /usr/local/lib/python3.12/dist-packages (from tensorflow-datasets[tf-nightly]) (4.67.1)\n", + "Requirement already satisfied: wrapt in /usr/local/lib/python3.12/dist-packages (from tensorflow-datasets[tf-nightly]) (2.0.1)\n", + "Requirement already satisfied: tf-nightly in /usr/local/lib/python3.12/dist-packages (from tensorflow-datasets[tf-nightly]) (2.21.0.dev20251117)\n", + "Requirement already satisfied: einops in /usr/local/lib/python3.12/dist-packages (from etils[edc,enp,epath,epy,etree]>=1.9.1; python_version >= \"3.11\"->tensorflow-datasets[tf-nightly]) (0.8.1)\n", + "Requirement already satisfied: fsspec in /usr/local/lib/python3.12/dist-packages (from etils[edc,enp,epath,epy,etree]>=1.9.1; python_version >= \"3.11\"->tensorflow-datasets[tf-nightly]) (2025.10.0)\n", + "Requirement already satisfied: importlib_resources in /usr/local/lib/python3.12/dist-packages (from etils[edc,enp,epath,epy,etree]>=1.9.1; python_version >= \"3.11\"->tensorflow-datasets[tf-nightly]) (6.5.2)\n", + "Requirement already satisfied: typing_extensions in /usr/local/lib/python3.12/dist-packages (from etils[edc,enp,epath,epy,etree]>=1.9.1; python_version >= \"3.11\"->tensorflow-datasets[tf-nightly]) (4.15.0)\n", + "Requirement already satisfied: zipp in /usr/local/lib/python3.12/dist-packages (from etils[edc,enp,epath,epy,etree]>=1.9.1; python_version >= \"3.11\"->tensorflow-datasets[tf-nightly]) (3.23.0)\n", + "Requirement already satisfied: charset_normalizer<4,>=2 in /usr/local/lib/python3.12/dist-packages (from requests>=2.19.0->tensorflow-datasets[tf-nightly]) (3.4.4)\n", + "Requirement already satisfied: idna<4,>=2.5 in /usr/local/lib/python3.12/dist-packages (from requests>=2.19.0->tensorflow-datasets[tf-nightly]) (3.11)\n", + "Requirement already satisfied: urllib3<3,>=1.21.1 in /usr/local/lib/python3.12/dist-packages (from requests>=2.19.0->tensorflow-datasets[tf-nightly]) (2.5.0)\n", + "Requirement already satisfied: certifi>=2017.4.17 in /usr/local/lib/python3.12/dist-packages (from requests>=2.19.0->tensorflow-datasets[tf-nightly]) (2025.10.5)\n", + "Requirement already satisfied: attrs>=18.2.0 in /usr/local/lib/python3.12/dist-packages (from dm-tree->tensorflow-datasets[tf-nightly]) (25.4.0)\n", + "Requirement already satisfied: six in /usr/local/lib/python3.12/dist-packages (from promise->tensorflow-datasets[tf-nightly]) (1.17.0)\n", + "Requirement already satisfied: docstring-parser<1.0,>=0.15 in /usr/local/lib/python3.12/dist-packages (from simple_parsing->tensorflow-datasets[tf-nightly]) (0.17.0)\n", + "Requirement already satisfied: googleapis-common-protos<2,>=1.56.4 in /usr/local/lib/python3.12/dist-packages (from tensorflow-metadata->tensorflow-datasets[tf-nightly]) (1.72.0)\n", + "Requirement already satisfied: astunparse>=1.6.0 in /usr/local/lib/python3.12/dist-packages (from tf-nightly->tensorflow-datasets[tf-nightly]) (1.6.3)\n", + "Requirement already satisfied: flatbuffers>=25.9.23 in /usr/local/lib/python3.12/dist-packages (from tf-nightly->tensorflow-datasets[tf-nightly]) (25.9.23)\n", + "Requirement already satisfied: gast!=0.5.0,!=0.5.1,!=0.5.2,>=0.2.1 in /usr/local/lib/python3.12/dist-packages (from tf-nightly->tensorflow-datasets[tf-nightly]) (0.6.0)\n", + "Requirement already satisfied: google_pasta>=0.1.1 in /usr/local/lib/python3.12/dist-packages (from tf-nightly->tensorflow-datasets[tf-nightly]) (0.2.0)\n", + "Requirement already satisfied: libclang>=13.0.0 in /usr/local/lib/python3.12/dist-packages (from tf-nightly->tensorflow-datasets[tf-nightly]) (18.1.1)\n", + "Requirement already satisfied: opt_einsum>=2.3.2 in /usr/local/lib/python3.12/dist-packages (from tf-nightly->tensorflow-datasets[tf-nightly]) (3.4.0)\n", + "Requirement already satisfied: packaging in /usr/local/lib/python3.12/dist-packages (from tf-nightly->tensorflow-datasets[tf-nightly]) (25.0)\n", + "Requirement already satisfied: setuptools in /usr/local/lib/python3.12/dist-packages (from tf-nightly->tensorflow-datasets[tf-nightly]) (75.2.0)\n", + "Requirement already satisfied: grpcio<2.0,>=1.24.3 in /usr/local/lib/python3.12/dist-packages (from tf-nightly->tensorflow-datasets[tf-nightly]) (1.76.0)\n", + "Requirement already satisfied: tb-nightly~=2.20.0.a in /usr/local/lib/python3.12/dist-packages (from tf-nightly->tensorflow-datasets[tf-nightly]) (2.20.0a20250717)\n", + "Requirement already satisfied: keras-nightly>=3.10.0.dev in /usr/local/lib/python3.12/dist-packages (from tf-nightly->tensorflow-datasets[tf-nightly]) (3.12.0.dev2025100703)\n", + "Requirement already satisfied: h5py<3.15.0,>=3.11.0 in /usr/local/lib/python3.12/dist-packages (from tf-nightly->tensorflow-datasets[tf-nightly]) (3.14.0)\n", + "Requirement already satisfied: ml_dtypes<1.0.0,>=0.5.1 in /usr/local/lib/python3.12/dist-packages (from tf-nightly->tensorflow-datasets[tf-nightly]) (0.5.3)\n", + "Requirement already satisfied: wheel<1.0,>=0.23.0 in /usr/local/lib/python3.12/dist-packages (from astunparse>=1.6.0->tf-nightly->tensorflow-datasets[tf-nightly]) (0.45.1)\n", + "Requirement already satisfied: rich in /usr/local/lib/python3.12/dist-packages (from keras-nightly>=3.10.0.dev->tf-nightly->tensorflow-datasets[tf-nightly]) (14.2.0)\n", + "Requirement already satisfied: namex in /usr/local/lib/python3.12/dist-packages (from keras-nightly>=3.10.0.dev->tf-nightly->tensorflow-datasets[tf-nightly]) (0.1.0)\n", + "Requirement already satisfied: optree in /usr/local/lib/python3.12/dist-packages (from keras-nightly>=3.10.0.dev->tf-nightly->tensorflow-datasets[tf-nightly]) (0.17.0)\n", + "Requirement already satisfied: markdown>=2.6.8 in /usr/lib/python3/dist-packages (from tb-nightly~=2.20.0.a->tf-nightly->tensorflow-datasets[tf-nightly]) (3.3.6)\n", + "Requirement already satisfied: pillow in /usr/local/lib/python3.12/dist-packages (from tb-nightly~=2.20.0.a->tf-nightly->tensorflow-datasets[tf-nightly]) (12.0.0)\n", + "Requirement already satisfied: tensorboard-data-server<0.8.0,>=0.7.0 in /usr/local/lib/python3.12/dist-packages (from tb-nightly~=2.20.0.a->tf-nightly->tensorflow-datasets[tf-nightly]) (0.7.2)\n", + "Requirement already satisfied: werkzeug>=1.0.1 in /usr/local/lib/python3.12/dist-packages (from tb-nightly~=2.20.0.a->tf-nightly->tensorflow-datasets[tf-nightly]) (3.1.3)\n", + "Requirement already satisfied: MarkupSafe>=2.1.1 in /usr/local/lib/python3.12/dist-packages (from werkzeug>=1.0.1->tb-nightly~=2.20.0.a->tf-nightly->tensorflow-datasets[tf-nightly]) (3.0.3)\n", + "Requirement already satisfied: markdown-it-py>=2.2.0 in /usr/local/lib/python3.12/dist-packages (from rich->keras-nightly>=3.10.0.dev->tf-nightly->tensorflow-datasets[tf-nightly]) (4.0.0)\n", + "Requirement already satisfied: pygments<3.0.0,>=2.13.0 in /usr/local/lib/python3.12/dist-packages (from rich->keras-nightly>=3.10.0.dev->tf-nightly->tensorflow-datasets[tf-nightly]) (2.19.2)\n", + "Requirement already satisfied: mdurl~=0.1 in /usr/local/lib/python3.12/dist-packages (from markdown-it-py>=2.2.0->rich->keras-nightly>=3.10.0.dev->tf-nightly->tensorflow-datasets[tf-nightly]) (0.1.2)\n" + ] + } + ] + }, + { + "cell_type": "code", + "source": [ + "# Remove dir if exists.\n", + "!rm -rf /tmp/arrayrecord/ag_news_subset" + ], + "metadata": { + "id": "LS6T-Jc5kxq9" + }, + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/" + }, + "id": "B1wIk8wLazuG", + "outputId": "ff2e96ca-feed-425b-d469-1536a41d2b68" + }, + "outputs": [ + { + "output_type": "stream", + "name": "stderr", + "text": [ + "/usr/local/lib/python3.12/dist-packages/jax/_src/cloud_tpu_init.py:86: UserWarning: Transparent hugepages are not enabled. TPU runtime startup and shutdown time should be significantly improved on TPU v5e and newer. If not already set, you may need to enable transparent hugepages in your VM image (sudo sh -c \"echo always > /sys/kernel/mm/transparent_hugepage/enabled\")\n", + " warnings.warn(\n" + ] + }, + { + "output_type": "stream", + "name": "stdout", + "text": [ + "Written 0 examples\n", + "Written 10000 examples\n", + "Written 20000 examples\n", + "Written 30000 examples\n", + "Written 40000 examples\n", + "Written 50000 examples\n", + "Written 60000 examples\n", + "Written 70000 examples\n", + "Written 80000 examples\n", + "Written 90000 examples\n", + "Written 100000 examples\n", + "Written 110000 examples\n", + "Finished\n" + ] + } + ], + "source": [ + "import json\n", + "import os\n", + "from array_record.python.array_record_module import ArrayRecordWriter\n", + "import grain\n", + "import tensorflow_datasets as tfds\n", + "\n", + "# Copy to local dir and convert to JSON-serialized.\n", + "data_path = \"/tmp/arrayrecord/ag_news_subset\"\n", + "os.makedirs(data_path)\n", + "writer = ArrayRecordWriter(\n", + " os.path.join(data_path, \"train.array-record\"), \"group_size:1\"\n", + ")\n", + "\n", + "source = grain.MapDataset.source(\n", + " tfds.data_source(\"ag_news_subset\", split=\"train\")\n", + ")\n", + "\n", + "for idx, e in enumerate(source.to_iter_dataset()):\n", + " if idx % 10000 == 0:\n", + " print(f\"Written {idx} examples\")\n", + " new_e = {}\n", + " for k, v in e.items():\n", + " new_e[k] = v.decode(\"utf-8\") if isinstance(v, bytes) else v\n", + " writer.write(json.dumps(new_e).encode(\"utf-8\"))\n", + "\n", + "writer.close()\n", + "print(f\"Finished\")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/" + }, + "id": "W_eYR60JTClI", + "outputId": "d902217f-cd93-4793-88bb-31f4472e4871" + }, + "outputs": [ + { + "output_type": "stream", + "name": "stdout", + "text": [ + "total 31M\n", + "-rw-r--r-- 1 root root 31M Nov 17 22:20 train.array-record\n" + ] + } + ], + "source": [ + "!ls -lh /tmp/arrayrecord/ag_news_subset" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "DsHxUdq1deEa" + }, + "source": [ + "# Create source and examine the data\n", + "\n", + "We have locally stored data serialized as JSON in an ArrayRecord file." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/" + }, + "id": "yCIxdwV7ci3c", + "outputId": "114d29f8-3227-4792-9027-4783bbf6698d" + }, + "outputs": [ + { + "output_type": "stream", + "name": "stdout", + "text": [ + "120000 examples\n" + ] + } + ], + "source": [ + "import grain\n", + "from pprint import pprint\n", + "\n", + "data_path = \"/tmp/arrayrecord/ag_news_subset/train.array-record\"\n", + "source = grain.sources.ArrayRecordDataSource(data_path)\n", + "print(f\"{len(source)} examples\")" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "SWu4a9HTeT2z" + }, + "source": [ + "# Process the data" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "OfDWIGIveKPf" + }, + "source": [ + "## Parse\n", + "\n", + "Let's read, parse and inspect the data. `MapDataset` object acts as a lazily initialized sequence. It will only process an element at the given index." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/" + }, + "id": "u7S1THl9ehDG", + "outputId": "0d152cbf-12d9-4f33-ad22-71f6928053db" + }, + "outputs": [ + { + "output_type": "stream", + "name": "stdout", + "text": [ + "{'description': 'AMD #39;s new dual-core Opteron chip is designed mainly for '\n", + " 'corporate computing applications, including databases, Web '\n", + " 'services, and financial transactions.',\n", + " 'label': 3,\n", + " 'title': 'AMD Debuts Dual-Core Opteron Processor'}\n" + ] + } + ], + "source": [ + "import json\n", + "\n", + "\n", + "parsed_ds = grain.MapDataset.source(source).map(json.loads)\n", + "\n", + "pprint(parsed_ds[0])" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "0xzoTs2hg-6E" + }, + "source": [ + "## Tokenize\n", + "\n", + "One of the central transformations in text data processing is tokenization -- splitting text into tokens and mapping them into a vocabulary entry indices for ML training-friendly representation. In this particular demo we will use HuggingFace tokenizer APIs. Any other would work as well." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/" + }, + "id": "dgULhmHthML4", + "outputId": "c662047b-9cbb-4e56-9d8a-a3983dea5716" + }, + "outputs": [ + { + "output_type": "stream", + "name": "stderr", + "text": [ + "/usr/local/lib/python3.12/dist-packages/torch_xla/__init__.py:258: UserWarning: `tensorflow` can conflict with `torch-xla`. Prefer `tensorflow-cpu` when using PyTorch/XLA. To silence this warning, `pip uninstall -y tensorflow && pip install tensorflow-cpu`. If you are in a notebook environment such as Colab or Kaggle, restart your notebook runtime afterwards.\n", + " warnings.warn(\n", + "/usr/local/lib/python3.12/dist-packages/huggingface_hub/utils/_auth.py:94: UserWarning: \n", + "The secret `HF_TOKEN` does not exist in your Colab secrets.\n", + "To authenticate with the Hugging Face Hub, create a token in your settings tab (https://huggingface.co/settings/tokens), set it as secret in your Google Colab and restart your session.\n", + "You will be able to reuse this secret in all of your notebooks.\n", + "Please note that authentication is recommended but still optional to access public models or datasets.\n", + " warnings.warn(\n" + ] + }, + { + "output_type": "stream", + "name": "stdout", + "text": [ + "{'description': array([ 101, 2572, 2094, 1001, 4464, 1025, 1055, 2047, 7037,\n", + " 1011, 4563, 23569, 26534, 9090, 2003, 2881, 3701, 2005,\n", + " 5971, 9798, 5097, 1010, 2164, 17881, 1010, 4773, 2578,\n", + " 1010, 1998, 3361, 11817, 1012, 102])}\n" + ] + } + ], + "source": [ + "from transformers import AutoTokenizer\n", + "import numpy as np\n", + "\n", + "tokenizer = AutoTokenizer.from_pretrained(\"bert-base-uncased\")\n", + "\n", + "def tokenize(element):\n", + " tokenized = tokenizer(element[\"description\"])\n", + " return {\"description\": np.asarray(tokenized[\"input_ids\"])}\n", + "\n", + "tokenized_ds = parsed_ds.map(tokenize)\n", + "\n", + "pprint(tokenized_ds[0])" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "tPL3Rktficwr" + }, + "source": [ + "## Shuffle, repeat & shard\n", + "\n", + "In order to prevent order bias we globally shuffle all records. We then repeat the dataset multiple times to make it generalize better (different epochs are shuffled differently).\n", + "\n", + "In case of distributed training, we split the dataset into the # of hosts parts." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/" + }, + "id": "Sqma7Dk0iqgF", + "outputId": "db77a408-c7b7-4d2e-85b5-749a2af37e45" + }, + "outputs": [ + { + "output_type": "stream", + "name": "stdout", + "text": [ + "{'description': array([ 101, 3688, 3517, 2574, 2008, 2097, 7532, 3081, 1037,\n", + " 2047, 2188, 14048, 12827, 6153, 2011, 2070, 1997, 1996,\n", + " 2088, 1005, 1055, 2922, 7325, 8139, 1998, 3274, 3316,\n", + " 1012, 102])}\n" + ] + } + ], + "source": [ + "# Global shuffle.\n", + "shuffled_ds = tokenized_ds.shuffle(seed=42)\n", + "\n", + "# Repeat dataset 10 times, each epoch is shuffled differently.\n", + "repeated_ds = shuffled_ds.repeat(num_epochs=100)\n", + "\n", + "# Shard for distributed training.\n", + "shard_index = 1 # this will typically be jax.process_index()\n", + "shard_count = 16 # this will typically be jax.process_count()\n", + "sharded_ds = repeated_ds[shard_index::shard_count]\n", + "\n", + "pprint(sharded_ds[0])" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "xUGzlDfRjmyt" + }, + "source": [ + "## Pack\n", + "\n", + "Text and multimedia data have naturally varying sizes. In order to enable batched processing of such data by ML training, Grain provides several bin packing algorithms. They allow to fit varying size data into fixed size length and minimize the necessary padding.\n", + "\n", + "Since packing needs to fetch a varying number of elements to fit the fixed size bins, it can no longer preserve indexing in the original dataset. It therefor requires confersion to a `grain.IterDataset` which is a Python `Iterable` producing a `grain.DatasetIterator` for fetching elements that supports checkpointing." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/" + }, + "id": "Ek79b2CRjuX_", + "outputId": "44229be3-c815-4e57-eb7f-0e70afc59914" + }, + "outputs": [ + { + "output_type": "stream", + "name": "stdout", + "text": [ + "{'description': array([ 101, 26665, 1011, 5522, 1005, 1055, 23205, 29501, 3062,\n", + " 1015, 1012, 5764, 3867, 2011, 1032, 22878, 2006, 6928,\n", + " 1010, 8402, 6409, 2046, 1037, 2353, 2154, 2004, 2178,\n", + " 1032, 12058, 1999, 3514, 7597, 20183, 15508, 2055, 1996,\n", + " 3795, 3171, 1032, 4254, 1998, 6573, 2091, 9167, 2545,\n", + " 2107, 2004, 11742, 5013, 13058, 1012, 102, 101, 1037,\n", + " 6816, 4457, 2010, 3954, 1998, 1037, 10563, 2012, 1037,\n", + " 2221, 4189, 1999, 2358, 1012, 14060, 1010, 2021, 4445,\n", + " 4265, 2350, 6441, 1012, 20099, 6610, 8833, 17922, 24598,\n", + " 1010, 3954, 1997, 1996, 6816, 1010, 4265, 26136, 14890,\n", + " 102, 101, 15335, 2176, 7767, 8046, 1996, 12592, 2231,\n", + " 2006, 6928, 2000, 6186, 1996, 6019, 1997, 1037, 9042,\n", + " 1011, 4427, 6543, 7450, 3832, 2000, 2562, 2343, 17127,\n", + " 2474, 6806, 6784, 1999, 2373, 2005, 2178, 2093, 2086,\n", + " 1012, 102]),\n", + " 'description_positions': array([ 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16,\n", + " 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33,\n", + " 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49, 50,\n", + " 51, 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15,\n", + " 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32,\n", + " 33, 34, 35, 36, 37, 38, 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10,\n", + " 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27,\n", + " 28, 29, 30, 31, 32, 33, 34, 35, 36], dtype=int32),\n", + " 'description_segment_ids': array([1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,\n", + " 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,\n", + " 1, 1, 1, 1, 1, 1, 1, 1, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2,\n", + " 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2,\n", + " 2, 2, 2, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3,\n", + " 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3], dtype=int32)}\n" + ] + } + ], + "source": [ + "sequence_length = 128\n", + "\n", + "def trim_values(element):\n", + " return {\"description\": element[\"description\"][:sequence_length]}\n", + "\n", + "trimmed_ds = sharded_ds.map(trim_values).to_iter_dataset(grain.ReadOptions(num_threads=0))\n", + "packed_ds = grain.experimental.FirstFitPackIterDataset(\n", + " trimmed_ds,\n", + " length_struct={\"description\": sequence_length},\n", + " num_packing_bins=30\n", + " )\n", + "\n", + "pprint(next(iter(packed_ds)))" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "FJwL_9xxlsjy" + }, + "source": [ + "## Batch\n", + "\n", + "Now that the example sizes are fixed, we can batch the data for training!" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/" + }, + "id": "Y3K-GvDAl2t7", + "outputId": "1c49cef6-7de0-4740-efa4-c78b36b614ec" + }, + "outputs": [ + { + "output_type": "stream", + "name": "stdout", + "text": [ + "{'description': array([[ 101, 26665, 1011, ..., 2086, 1012, 102],\n", + " [ 101, 10884, 1006, ..., 0, 0, 0],\n", + " [ 101, 9706, 1011, ..., 0, 0, 0],\n", + " ...,\n", + " [ 101, 14497, 1010, ..., 0, 0, 0],\n", + " [ 101, 9838, 1001, ..., 0, 0, 0],\n", + " [ 101, 1996, 7794, ..., 0, 0, 0]]),\n", + " 'description_positions': array([[ 0, 1, 2, ..., 34, 35, 36],\n", + " [ 0, 1, 2, ..., 0, 0, 0],\n", + " [ 0, 1, 2, ..., 0, 0, 0],\n", + " ...,\n", + " [ 0, 1, 2, ..., 0, 0, 0],\n", + " [ 0, 1, 2, ..., 0, 0, 0],\n", + " [ 0, 1, 2, ..., 0, 0, 0]], dtype=int32),\n", + " 'description_segment_ids': array([[1, 1, 1, ..., 3, 3, 3],\n", + " [1, 1, 1, ..., 0, 0, 0],\n", + " [1, 1, 1, ..., 0, 0, 0],\n", + " ...,\n", + " [1, 1, 1, ..., 0, 0, 0],\n", + " [1, 1, 1, ..., 0, 0, 0],\n", + " [1, 1, 1, ..., 0, 0, 0]], dtype=int32)}\n" + ] + } + ], + "source": [ + "batch_size = 512\n", + "\n", + "batched_ds = packed_ds.batch(batch_size, drop_remainder=True)\n", + "\n", + "pprint(next(iter(batched_ds)))" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "iWWsUBUNe8mN" + }, + "source": [ + "## Enable visualization mode\n", + "\n", + "In order to understand the sequence of transformations and their outputs better, Grain offers a pipeline visualization mode." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/" + }, + "id": "xPGakC3UfDLF", + "outputId": "1768549f-b552-4724-c52b-5f16dbde9513" + }, + "outputs": [ + { + "output_type": "stream", + "name": "stdout", + "text": [ + "Grain Dataset graph:\n", + "\n", + "SourceMapDataset(source=ArrayRecordDataSource)\n", + " ││\n", + " ││ \n", + " ││\n", + " ╲╱\n", + "'bytes[]'\n", + "\n", + " ││\n", + " ││ MapMapDataset(transform=loads @ .../python3.12/json/__init__.py:299)\n", + " ││\n", + " ╲╱\n", + "{'description': 'str[]', 'label': 'int[]', 'title': 'str[]'}\n", + "\n", + " ││\n", + " ││ MapMapDataset(transform=tokenize @ ...//tmp/ipython-input-1246250811.py:6)\n", + " ││\n", + " ╲╱\n", + "{'description': 'int64[29]'}\n", + "\n", + " ││\n", + " ││ ShuffleMapDataset\n", + " ││\n", + " ╲╱\n", + "{'description': 'int64[29]'}\n", + "\n", + " ││\n", + " ││ RepeatMapDataset(num_epochs=100)\n", + " ││\n", + " ╲╱\n", + "{'description': 'int64[29]'}\n", + "\n", + " ││\n", + " ││ SliceMapDataset[1:12000000:16]\n", + " ││\n", + " ╲╱\n", + "{'description': 'int64[29]'}\n", + "\n", + " ││\n", + " ││ MapMapDataset(transform=trim_values @ ...//tmp/ipython-input-4160003364.py:3)\n", + " ││\n", + " ╲╱\n", + "{'description': 'int64[29]'}\n", + "\n", + " ││\n", + " ││ PrefetchDatasetIterator(read_options=ReadOptions(num_threads=0, prefetch_buffer_size=500), allow_nones=False)\n", + " ││\n", + " ╲╱\n", + "{'description': 'int64[29]'}\n", + "\n", + "Grain Dataset graph:\n", + "\n", + "SourceMapDataset(source=ArrayRecordDataSource)\n", + " ││\n", + " ││ \n", + " ││\n", + " ╲╱\n", + "'bytes[]'\n", + "\n", + " ││\n", + " ││ MapMapDataset(transform=loads @ .../python3.12/json/__init__.py:299)\n", + " ││\n", + " ╲╱\n", + "{'description': 'str[]', 'label': 'int[]', 'title': 'str[]'}\n", + "\n", + " ││\n", + " ││ MapMapDataset(transform=tokenize @ ...//tmp/ipython-input-1246250811.py:6)\n", + " ││\n", + " ╲╱\n", + "{'description': 'int64[29]'}\n", + "\n", + " ││\n", + " ││ ShuffleMapDataset\n", + " ││\n", + " ╲╱\n", + "{'description': 'int64[29]'}\n", + "\n", + " ││\n", + " ││ RepeatMapDataset(num_epochs=100)\n", + " ││\n", + " ╲╱\n", + "{'description': 'int64[29]'}\n", + "\n", + " ││\n", + " ││ SliceMapDataset[1:12000000:16]\n", + " ││\n", + " ╲╱\n", + "{'description': 'int64[29]'}\n", + "\n", + " ││\n", + " ││ MapMapDataset(transform=trim_values @ ...//tmp/ipython-input-4160003364.py:3)\n", + " ││\n", + " ╲╱\n", + "{'description': 'int64[29]'}\n", + "\n", + " ││\n", + " ││ PrefetchDatasetIterator(read_options=ReadOptions(num_threads=0, prefetch_buffer_size=500), allow_nones=False)\n", + " ││\n", + " ╲╱\n", + "{'description': 'int64[29]'}\n", + "\n", + " ││\n", + " ││ \n", + " ││\n", + " ╲╱\n", + "{'description': 'int64[128]',\n", + " 'description_positions': 'int32[128]',\n", + " 'description_segment_ids': 'int32[128]'}\n", + "\n", + "Grain Dataset graph:\n", + "\n", + "SourceMapDataset(source=ArrayRecordDataSource)\n", + " ││\n", + " ││ \n", + " ││\n", + " ╲╱\n", + "'bytes[]'\n", + "\n", + " ││\n", + " ││ MapMapDataset(transform=loads @ .../python3.12/json/__init__.py:299)\n", + " ││\n", + " ╲╱\n", + "{'description': 'str[]', 'label': 'int[]', 'title': 'str[]'}\n", + "\n", + " ││\n", + " ││ MapMapDataset(transform=tokenize @ ...//tmp/ipython-input-1246250811.py:6)\n", + " ││\n", + " ╲╱\n", + "{'description': 'int64[29]'}\n", + "\n", + " ││\n", + " ││ ShuffleMapDataset\n", + " ││\n", + " ╲╱\n", + "{'description': 'int64[29]'}\n", + "\n", + " ││\n", + " ││ RepeatMapDataset(num_epochs=100)\n", + " ││\n", + " ╲╱\n", + "{'description': 'int64[29]'}\n", + "\n", + " ││\n", + " ││ SliceMapDataset[1:12000000:16]\n", + " ││\n", + " ╲╱\n", + "{'description': 'int64[29]'}\n", + "\n", + " ││\n", + " ││ MapMapDataset(transform=trim_values @ ...//tmp/ipython-input-4160003364.py:3)\n", + " ││\n", + " ╲╱\n", + "{'description': 'int64[29]'}\n", + "\n", + " ││\n", + " ││ PrefetchDatasetIterator(read_options=ReadOptions(num_threads=0, prefetch_buffer_size=500), allow_nones=False)\n", + " ││\n", + " ╲╱\n", + "{'description': 'int64[29]'}\n", + "\n", + " ││\n", + " ││ \n", + " ││\n", + " ╲╱\n", + "{'description': 'int64[128]',\n", + " 'description_positions': 'int32[128]',\n", + " 'description_segment_ids': 'int32[128]'}\n", + "\n", + " ││\n", + " ││ BatchDatasetIterator(batch_size=512, drop_remainder=True)\n", + " ││\n", + " ╲╱\n", + "{'description': 'int64[512, 128]',\n", + " 'description_positions': 'int32[512, 128]',\n", + " 'description_segment_ids': 'int32[512, 128]'}\n", + "\n" + ] + } + ], + "source": [ + "from absl import flags\n", + "\n", + "# Enable visualization.\n", + "flags.FLAGS.mark_as_parsed()\n", + "grain.config.update(\"py_dataset_visualization_output_dir\", \"\")\n", + "\n", + "next(iter(batched_ds))\n", + "\n", + "# Disable visualization -- we don't need it for following sections.\n", + "grain.config.update(\"py_dataset_visualization_output_dir\", None)" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "-GQgpeOamjto" + }, + "source": [ + "# Measure throughput\n", + "\n", + "We have the necessary transformations, let's make sure that we're utilizing the training accelerator efficiently!" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "pIUpr7eTI5Ga" + }, + "outputs": [], + "source": [ + "import time\n", + "\n", + "num_batches = 200\n", + "\n", + "\n", + "def time_iterator(it) -> None:\n", + " start_time = time.perf_counter()\n", + " next(it)\n", + " time_to_first_batch = time.perf_counter() - start_time\n", + " print(f\"Time to get the first batch: {time_to_first_batch:.02f} sec.\")\n", + " start_time = time.perf_counter()\n", + " for _ in range(num_batches):\n", + " next(it)\n", + " total_time = time.perf_counter() - start_time\n", + " print(\n", + " \"Iterator throughput:\"\n", + " f\" {(num_batches * batch_size / total_time):.02f} examples/sec; \"\n", + " f\"{(num_batches / total_time):.02f} batches/sec.\"\n", + " )" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/" + }, + "id": "G5_1dar_mxMi", + "outputId": "6137ed74-cb03-4360-ed48-8349b05dd7fb" + }, + "outputs": [ + { + "output_type": "stream", + "name": "stdout", + "text": [ + "Time to get the first batch: 0.81 sec.\n", + "Iterator throughput: 656.87 examples/sec; 1.28 batches/sec.\n" + ] + } + ], + "source": [ + "time_iterator(iter(batched_ds))" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "6pt09OkIYE8q" + }, + "source": [ + "Once you've measured your pipeline's throughput, there's two possible scenarios: it is either faster or slower than your training step." + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "7GpbuY7UXZjX" + }, + "source": [ + "## Data loading is slower than training" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "dck6aSeYnyGo" + }, + "source": [ + "### Enable performance debug mode\n", + "\n", + "Grain offers a debug mode in which each transformation execution time is tracked and periodically logged into a table." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/" + }, + "id": "d_tXwhK4n4nX", + "outputId": "5a3af2d3-7144-4a64-b674-a070dc213595" + }, + "outputs": [ + { + "output_type": "stream", + "name": "stdout", + "text": [ + "Time to get the first batch: 0.83 sec.\n", + "Grain Dataset Execution Summary:\n", + "\n", + "NOTE: Before analyzing the `total_processing_time` for a node, please check the `percent wait time` column to ensure that the node is indicated as bottleneck. The `MapDataset` nodes are executed in multiple threads and thus, should not be compared to the `total_processing_time` of `DatasetIterator` nodes.\n", + "\n", + "|---------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------|\n", + "| id | name | inputs | percent wait time | total processing time | min processing time | max processing time | avg processing time | num produced elements | memory usage |\n", + "|---------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------|\n", + "| 7 | SourceMapDataset(source=ArrayR | [] | 58.36% | 23.53ms | 259.06us | 450.77us | 276.86us | 85 | 0 bytes |\n", + "| | ecordDataSource) | | | | | | | | |\n", + "|---------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------|\n", + "| 5 | MapMapDataset(transform=tokeni | [6] | 35.37% | 14.26ms | 106.08us | 272.46us | 167.81us | 85 | 0 bytes |\n", + "| | ze @ ...//tmp/ipython-input-12 | | | | | | | | |\n", + "| | 46250811.py:6) | | | | | | | | |\n", + "|---------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------|\n", + "| 4 | ShuffleMapDataset | [5] | 3.33% | 1.34ms | 13.77us | 37.11us | 15.79us | 85 | 0 bytes |\n", + "|---------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------|\n", + "| 6 | MapMapDataset(transform=loads | [7] | 2.15% | 868.73us | 9.03us | 22.82us | 10.22us | 85 | 0 bytes |\n", + "| | @ .../python3.12/json/__init__ | | | | | | | | |\n", + "| | .py:299) | | | | | | | | |\n", + "|---------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------|\n", + "| 2 | SliceMapDataset[1:12000000:16] | [3] | 0.39% | 155.65us | 1.62us | 2.95us | 1.83us | 85 | 0 bytes |\n", + "| | | | | | | | | | |\n", + "|---------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------|\n", + "| 1 | MapMapDataset(transform=trim_v | [2] | 0.39% | 158.96us | 1.68us | 3.93us | 1.87us | 85 | 0 bytes |\n", + "| | alues @ ...//tmp/ipython-input | | | | | | | | |\n", + "| | -4160003364.py:3) | | | | | | | | |\n", + "|---------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------|\n", + "| 3 | RepeatMapDataset(num_epochs=10 | [4] | 0.00% | N/A | N/A | N/A | N/A | 0 | 0 bytes |\n", + "| | 0) | | | | | | | | |\n", + "|---------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------|\n", + "| 0 | PrefetchDatasetIterator(read_o | [1] | N/A | 43.69ms | 436.19us | 1.15ms | 513.97us | 85 | 0 bytes |\n", + "| | ptions=ReadOptions(num_threads | | | | | | | | |\n", + "| | =0, prefetch_buffer_size=500), | | | | | | | | |\n", + "| | allow_nones=False) | | | | | | | | |\n", + "|---------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------|\n", + "\n", + "WARNING: Your source is likely the bottleneck. Please ensure if you have enough spindle quota or if your data is co-located with the computation. \n", + "Grain Dataset Execution Summary:\n", + "\n", + "NOTE: Before analyzing the `total_processing_time` for a node, please check the `percent wait time` column to ensure that the node is indicated as bottleneck. The `MapDataset` nodes are executed in multiple threads and thus, should not be compared to the `total_processing_time` of `DatasetIterator` nodes.\n", + "\n", + "|---------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------|\n", + "| id | name | inputs | percent wait time | total processing time | min processing time | max processing time | avg processing time | num produced elements | memory usage |\n", + "|---------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------|\n", + "| 8 | SourceMapDataset(source=ArrayR | [] | 58.65% | 367.58ms | 236.70us | 523.65us | 268.89us | 1367 | 0 bytes |\n", + "| | ecordDataSource) | | | | | | | | |\n", + "|---------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------|\n", + "| 6 | MapMapDataset(transform=tokeni | [7] | 34.79% | 218.04ms | 91.24us | 381.58us | 159.50us | 1367 | 0 bytes |\n", + "| | ze @ ...//tmp/ipython-input-12 | | | | | | | | |\n", + "| | 46250811.py:6) | | | | | | | | |\n", + "|---------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------|\n", + "| 5 | ShuffleMapDataset | [6] | 3.27% | 20.51ms | 12.38us | 39.89us | 15.00us | 1367 | 0 bytes |\n", + "|---------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------|\n", + "| 7 | MapMapDataset(transform=loads | [8] | 2.18% | 13.66ms | 8.14us | 31.67us | 9.99us | 1367 | 0 bytes |\n", + "| | @ .../python3.12/json/__init__ | | | | | | | | |\n", + "| | .py:299) | | | | | | | | |\n", + "|---------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------|\n", + "| 3 | SliceMapDataset[1:12000000:16] | [4] | 0.39% | 2.45ms | 1.40us | 23.93us | 1.79us | 1367 | 0 bytes |\n", + "| | | | | | | | | | |\n", + "|---------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------|\n", + "| 2 | MapMapDataset(transform=trim_v | [3] | 0.39% | 2.45ms | 1.43us | 14.95us | 1.79us | 1367 | 0 bytes |\n", + "| | alues @ ...//tmp/ipython-input | | | | | | | | |\n", + "| | -4160003364.py:3) | | | | | | | | |\n", + "|---------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------|\n", + "| 0 | | | | | | | | | |\n", + "|---------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------|\n", + "| 4 | RepeatMapDataset(num_epochs=10 | [5] | 0.00% | N/A | N/A | N/A | N/A | 0 | 0 bytes |\n", + "| | 0) | | | | | | | | |\n", + "|---------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------|\n", + "| 1 | PrefetchDatasetIterator(read_o | [2] | N/A | 671.95ms | 396.21us | 907.27us | 491.55us | 1367 | 0 bytes |\n", + "| | ptions=ReadOptions(num_threads | | | | | | | | |\n", + "| | =0, prefetch_buffer_size=500), | | | | | | | | |\n", + "| | allow_nones=False) | | | | | | | | |\n", + "|---------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------|\n", + "\n", + "WARNING: Your source is likely the bottleneck. Please ensure if you have enough spindle quota or if your data is co-located with the computation. \n", + "Grain Dataset Execution Summary:\n", + "\n", + "NOTE: Before analyzing the `total_processing_time` for a node, please check the `percent wait time` column to ensure that the node is indicated as bottleneck. The `MapDataset` nodes are executed in multiple threads and thus, should not be compared to the `total_processing_time` of `DatasetIterator` nodes.\n", + "\n", + "|---------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------|\n", + "| id | name | inputs | percent wait time | total processing time | min processing time | max processing time | avg processing time | num produced elements | memory usage |\n", + "|---------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------|\n", + "| 9 | SourceMapDataset(source=ArrayR | [] | 58.55% | 26.46s | 236.08us | 5.54ms | 270.03us | 98006 | 0 bytes |\n", + "| | ecordDataSource) | | | | | | | | |\n", + "|---------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------|\n", + "| 7 | MapMapDataset(transform=tokeni | [8] | 34.76% | 15.71s | 75.29us | 5.48ms | 160.33us | 98004 | 0 bytes |\n", + "| | ze @ ...//tmp/ipython-input-12 | | | | | | | | |\n", + "| | 46250811.py:6) | | | | | | | | |\n", + "|---------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------|\n", + "| 6 | ShuffleMapDataset | [7] | 3.25% | 1.47s | 12.31us | 55.97us | 14.97us | 98004 | 0 bytes |\n", + "|---------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------|\n", + "| 8 | MapMapDataset(transform=loads | [9] | 2.17% | 979.12ms | 7.90us | 85.28us | 9.99us | 98005 | 0 bytes |\n", + "| | @ .../python3.12/json/__init__ | | | | | | | | |\n", + "| | .py:299) | | | | | | | | |\n", + "|---------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------|\n", + "| 3 | MapMapDataset(transform=trim_v | [4] | 0.39% | 177.58ms | 1.36us | 45.03us | 1.81us | 98002 | 0 bytes |\n", + "| | alues @ ...//tmp/ipython-input | | | | | | | | |\n", + "| | -4160003364.py:3) | | | | | | | | |\n", + "|---------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------|\n", + "| 4 | SliceMapDataset[1:12000000:16] | [5] | 0.38% | 171.85ms | 1.34us | 42.37us | 1.75us | 98003 | 0 bytes |\n", + "| | | | | | | | | | |\n", + "|---------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------|\n", + "| 1 | | | | | | | | | |\n", + "|---------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------|\n", + "| 0 | BatchDatasetIterator(batch_siz | [1] | 0.19% | 94.05ms | 1.18ms | 1.79ms | 1.34ms | 70 | 0 bytes |\n", + "| | e=512, drop_remainder=True) | | | | | | | | |\n", + "|---------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------|\n", + "| 5 | RepeatMapDataset(num_epochs=10 | [6] | 0.00% | N/A | N/A | N/A | N/A | 0 | 0 bytes |\n", + "| | 0) | | | | | | | | |\n", + "|---------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------|\n", + "| 2 | PrefetchDatasetIterator(read_o | [3] | N/A | 48.23s | 377.56us | 11.05ms | 492.12us | 98001 | 0 bytes |\n", + "| | ptions=ReadOptions(num_threads | | | | | | | | |\n", + "| | =0, prefetch_buffer_size=500), | | | | | | | | |\n", + "| | allow_nones=False) | | | | | | | | |\n", + "|---------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------|\n", + "\n", + "WARNING: Your source is likely the bottleneck. Please ensure if you have enough spindle quota or if your data is co-located with the computation. \n", + "Grain Dataset Execution Summary:\n", + "\n", + "NOTE: Before analyzing the `total_processing_time` for a node, please check the `percent wait time` column to ensure that the node is indicated as bottleneck. The `MapDataset` nodes are executed in multiple threads and thus, should not be compared to the `total_processing_time` of `DatasetIterator` nodes.\n", + "\n", + "|---------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------|\n", + "| id | name | inputs | percent wait time | total processing time | min processing time | max processing time | avg processing time | num produced elements | memory usage |\n", + "|---------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------|\n", + "| 9 | SourceMapDataset(source=ArrayR | [] | 58.55% | 55.34s | 236.08us | 5.95ms | 270.53us | 204556 | 0 bytes |\n", + "| | ecordDataSource) | | | | | | | | |\n", + "|---------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------|\n", + "| 7 | MapMapDataset(transform=tokeni | [8] | 34.76% | 32.86s | 75.15us | 5.53ms | 160.62us | 204554 | 0 bytes |\n", + "| | ze @ ...//tmp/ipython-input-12 | | | | | | | | |\n", + "| | 46250811.py:6) | | | | | | | | |\n", + "|---------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------|\n", + "| 6 | ShuffleMapDataset | [7] | 3.24% | 3.07s | 12.29us | 102.74us | 14.99us | 204554 | 0 bytes |\n", + "|---------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------|\n", + "| 8 | MapMapDataset(transform=loads | [9] | 2.16% | 2.04s | 7.90us | 91.96us | 9.99us | 204555 | 0 bytes |\n", + "| | @ .../python3.12/json/__init__ | | | | | | | | |\n", + "| | .py:299) | | | | | | | | |\n", + "|---------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------|\n", + "| 3 | MapMapDataset(transform=trim_v | [4] | 0.39% | 371.03ms | 1.35us | 45.03us | 1.81us | 204552 | 0 bytes |\n", + "| | alues @ ...//tmp/ipython-input | | | | | | | | |\n", + "| | -4160003364.py:3) | | | | | | | | |\n", + "|---------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------|\n", + "| 4 | SliceMapDataset[1:12000000:16] | [5] | 0.38% | 359.91ms | 1.29us | 42.37us | 1.76us | 204553 | 0 bytes |\n", + "| | | | | | | | | | |\n", + "|---------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------|\n", + "| 1 | | | | | | | | | |\n", + "|---------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------|\n", + "| 0 | BatchDatasetIterator(batch_siz | [1] | 0.19% | 195.66ms | 1.18ms | 1.79ms | 1.34ms | 146 | 0 bytes |\n", + "| | e=512, drop_remainder=True) | | | | | | | | |\n", + "|---------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------|\n", + "| 5 | RepeatMapDataset(num_epochs=10 | [6] | 0.00% | N/A | N/A | N/A | N/A | 0 | 0 bytes |\n", + "| | 0) | | | | | | | | |\n", + "|---------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------|\n", + "| 2 | PrefetchDatasetIterator(read_o | [3] | N/A | 100.88s | 374.09us | 11.09ms | 493.18us | 204551 | 0 bytes |\n", + "| | ptions=ReadOptions(num_threads | | | | | | | | |\n", + "| | =0, prefetch_buffer_size=500), | | | | | | | | |\n", + "| | allow_nones=False) | | | | | | | | |\n", + "|---------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------|\n", + "\n", + "WARNING: Your source is likely the bottleneck. Please ensure if you have enough spindle quota or if your data is co-located with the computation. \n", + "Iterator throughput: 639.62 examples/sec; 1.25 batches/sec.\n" + ] + } + ], + "source": [ + "grain.config.update(\"py_debug_mode\", True)\n", + "\n", + "time_iterator(iter(batched_ds))" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "I_h0KSLJrSxt" + }, + "source": [ + "### GIL-free bottleneck\n", + "\n", + "A subclass of bottlenecks that are executed without Python's GIL can be dealt with relatively easy.\n", + "\n", + "Some examples of such transformations: IO, numpy, JAX, PIL, C/C++ extension modules.\n", + "\n", + "Increase # of threads! But keep it lower than the number of available cores." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/" + }, + "id": "KgGZqT6MrYrt", + "outputId": "9a95ce58-c10d-4092-89d1-94e99bc75f29" + }, + "outputs": [ + { + "output_type": "stream", + "name": "stdout", + "text": [ + "Time to get the first batch: 0.52 sec.\n", + "Grain Dataset Execution Summary:\n", + "\n", + "NOTE: Before analyzing the `total_processing_time` for a node, please check the `percent wait time` column to ensure that the node is indicated as bottleneck. The `MapDataset` nodes are executed in multiple threads and thus, should not be compared to the `total_processing_time` of `DatasetIterator` nodes.\n", + "\n", + "|---------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------|\n", + "| id | name | inputs | percent wait time | total processing time | min processing time | max processing time | avg processing time | num produced elements | memory usage |\n", + "|---------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------|\n", + "| 7 | SourceMapDataset(source=ArrayR | [] | 56.78% | 213.35ms | 281.13us | 6.46ms | 1.81ms | 118 | 0 bytes |\n", + "| | ecordDataSource) | | | | | | | | |\n", + "|---------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------|\n", + "| 5 | MapMapDataset(transform=tokeni | [6] | 42.05% | 157.98ms | 124.42us | 6.42ms | 1.49ms | 106 | 0 bytes |\n", + "| | ze @ ...//tmp/ipython-input-12 | | | | | | | | |\n", + "| | 46250811.py:6) | | | | | | | | |\n", + "|---------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------|\n", + "| 4 | ShuffleMapDataset | [5] | 0.61% | 2.29ms | 13.92us | 54.84us | 19.37us | 118 | 0 bytes |\n", + "|---------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------|\n", + "| 6 | MapMapDataset(transform=loads | [7] | 0.40% | 1.52ms | 10.11us | 41.70us | 14.34us | 106 | 0 bytes |\n", + "| | @ .../python3.12/json/__init__ | | | | | | | | |\n", + "| | .py:299) | | | | | | | | |\n", + "|---------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------|\n", + "| 2 | SliceMapDataset[1:12000000:16] | [3] | 0.08% | 305.54us | 1.67us | 8.18us | 2.59us | 118 | 0 bytes |\n", + "| | | | | | | | | | |\n", + "|---------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------|\n", + "| 1 | MapMapDataset(transform=trim_v | [2] | 0.07% | 274.52us | 1.54us | 15.30us | 2.69us | 102 | 0 bytes |\n", + "| | alues @ ...//tmp/ipython-input | | | | | | | | |\n", + "| | -4160003364.py:3) | | | | | | | | |\n", + "|---------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------|\n", + "| 3 | RepeatMapDataset(num_epochs=10 | [4] | 0.00% | N/A | N/A | N/A | N/A | 0 | 0 bytes |\n", + "| | 0) | | | | | | | | |\n", + "|---------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------|\n", + "| 0 | PrefetchDatasetIterator(read_o | [1] | N/A | 33.47ms | 17.22us | 21.27ms | 393.81us | 85 | 173.78 KiB |\n", + "| | ptions=ReadOptions(num_threads | | | | | | | | |\n", + "| | =24, prefetch_buffer_size=500) | | | | | | | | |\n", + "| | , allow_nones=False) | | | | | | | | |\n", + "|---------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------|\n", + "\n", + "WARNING: Your source is likely the bottleneck. Please ensure if you have enough spindle quota or if your data is co-located with the computation. \n", + "Grain Dataset Execution Summary:\n", + "\n", + "NOTE: Before analyzing the `total_processing_time` for a node, please check the `percent wait time` column to ensure that the node is indicated as bottleneck. The `MapDataset` nodes are executed in multiple threads and thus, should not be compared to the `total_processing_time` of `DatasetIterator` nodes.\n", + "\n", + "|---------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------|\n", + "| id | name | inputs | percent wait time | total processing time | min processing time | max processing time | avg processing time | num produced elements | memory usage |\n", + "|---------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------|\n", + "| 8 | SourceMapDataset(source=ArrayR | [] | 49.79% | 5.67s | 263.72us | 21.56ms | 3.86ms | 1468 | 0 bytes |\n", + "| | ecordDataSource) | | | | | | | | |\n", + "|---------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------|\n", + "| 6 | MapMapDataset(transform=tokeni | [7] | 48.06% | 5.47s | 135.48us | 18.85ms | 3.71ms | 1474 | 0 bytes |\n", + "| | ze @ ...//tmp/ipython-input-12 | | | | | | | | |\n", + "| | 46250811.py:6) | | | | | | | | |\n", + "|---------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------|\n", + "| 0 | | | | | | | | | |\n", + "|---------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------|\n", + "| 5 | ShuffleMapDataset | [6] | 0.22% | 25.08ms | 12.23us | 47.86us | 17.09us | 1468 | 0 bytes |\n", + "|---------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------|\n", + "| 7 | MapMapDataset(transform=loads | [8] | 0.18% | 20.17ms | 8.75us | 58.03us | 13.68us | 1474 | 0 bytes |\n", + "| | @ .../python3.12/json/__init__ | | | | | | | | |\n", + "| | .py:299) | | | | | | | | |\n", + "|---------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------|\n", + "| 3 | SliceMapDataset[1:12000000:16] | [4] | 0.03% | 3.33ms | 1.43us | 15.69us | 2.27us | 1468 | 0 bytes |\n", + "| | | | | | | | | | |\n", + "|---------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------|\n", + "| 2 | MapMapDataset(transform=trim_v | [3] | 0.03% | 3.65ms | 1.74us | 20.64us | 2.50us | 1460 | 0 bytes |\n", + "| | alues @ ...//tmp/ipython-input | | | | | | | | |\n", + "| | -4160003364.py:3) | | | | | | | | |\n", + "|---------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------|\n", + "| 4 | RepeatMapDataset(num_epochs=10 | [5] | 0.00% | N/A | N/A | N/A | N/A | 0 | 0 bytes |\n", + "| | 0) | | | | | | | | |\n", + "|---------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------|\n", + "| 1 | PrefetchDatasetIterator(read_o | [2] | N/A | 140.84ms | 15.33us | 12.83ms | 103.03us | 1367 | 0 bytes |\n", + "| | ptions=ReadOptions(num_threads | | | | | | | | |\n", + "| | =24, prefetch_buffer_size=500) | | | | | | | | |\n", + "| | , allow_nones=False) | | | | | | | | |\n", + "|---------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------|\n", + "\n", + "WARNING: Your source is likely the bottleneck. Please ensure if you have enough spindle quota or if your data is co-located with the computation. \n", + "Grain Dataset Execution Summary:\n", + "\n", + "NOTE: Before analyzing the `total_processing_time` for a node, please check the `percent wait time` column to ensure that the node is indicated as bottleneck. The `MapDataset` nodes are executed in multiple threads and thus, should not be compared to the `total_processing_time` of `DatasetIterator` nodes.\n", + "\n", + "|---------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------|\n", + "| id | name | inputs | percent wait time | total processing time | min processing time | max processing time | avg processing time | num produced elements | memory usage |\n", + "|---------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------|\n", + "| 9 | SourceMapDataset(source=ArrayR | [] | 48.61% | 693.30s | 254.94us | 53.30ms | 4.22ms | 164135 | 0 bytes |\n", + "| | ecordDataSource) | | | | | | | | |\n", + "|---------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------|\n", + "| 7 | MapMapDataset(transform=tokeni | [8] | 46.82% | 667.76s | 119.02us | 47.30ms | 4.07ms | 164049 | 0 bytes |\n", + "| | ze @ ...//tmp/ipython-input-12 | | | | | | | | |\n", + "| | 46250811.py:6) | | | | | | | | |\n", + "|---------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------|\n", + "| 0 | BatchDatasetIterator(batch_siz | [1] | 2.09% | 292.44ms | 1.67ms | 3.39ms | 2.50ms | 117 | 0 bytes |\n", + "| | e=512, drop_remainder=True) | | | | | | | | |\n", + "|---------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------|\n", + "| 1 | | | | | | | | | |\n", + "|---------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------|\n", + "| 6 | ShuffleMapDataset | [7] | 0.21% | 2.93s | 11.65us | 189.04us | 17.83us | 164040 | 0 bytes |\n", + "|---------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------|\n", + "| 8 | MapMapDataset(transform=loads | [9] | 0.16% | 2.35s | 8.09us | 112.77us | 14.30us | 164085 | 0 bytes |\n", + "| | @ .../python3.12/json/__init__ | | | | | | | | |\n", + "| | .py:299) | | | | | | | | |\n", + "|---------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------|\n", + "| 4 | SliceMapDataset[1:12000000:16] | [5] | 0.03% | 402.22ms | 1.44us | 47.27us | 2.45us | 163984 | 0 bytes |\n", + "| | | | | | | | | | |\n", + "|---------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------|\n", + "| 3 | MapMapDataset(transform=trim_v | [4] | 0.03% | 433.33ms | 1.59us | 82.99us | 2.64us | 163966 | 0 bytes |\n", + "| | alues @ ...//tmp/ipython-input | | | | | | | | |\n", + "| | -4160003364.py:3) | | | | | | | | |\n", + "|---------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------|\n", + "| 5 | RepeatMapDataset(num_epochs=10 | [6] | 0.00% | N/A | N/A | N/A | N/A | 0 | 0 bytes |\n", + "| | 0) | | | | | | | | |\n", + "|---------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------|\n", + "| 2 | PrefetchDatasetIterator(read_o | [3] | N/A | 13.40s | 15.56us | 23.76ms | 81.80us | 163770 | 0 bytes |\n", + "| | ptions=ReadOptions(num_threads | | | | | | | | |\n", + "| | =24, prefetch_buffer_size=500) | | | | | | | | |\n", + "| | , allow_nones=False) | | | | | | | | |\n", + "|---------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------|\n", + "\n", + "WARNING: Your source is likely the bottleneck. Please ensure if you have enough spindle quota or if your data is co-located with the computation. \n", + "Iterator throughput: 1041.19 examples/sec; 2.03 batches/sec.\n" + ] + } + ], + "source": [ + "read_options = grain.ReadOptions(num_threads=24)\n", + "\n", + "ds = grain.MapDataset.source(source).map(json.loads).map(tokenize)\n", + "ds = ds.shuffle(seed=42).repeat(num_epochs=100)[shard_index::shard_count]\n", + "ds = ds.map(trim_values).to_iter_dataset(read_options)\n", + "ds = grain.experimental.FirstFitPackIterDataset(\n", + " ds,\n", + " length_struct={\"description\": sequence_length},\n", + " num_packing_bins=30\n", + " )\n", + "ds = ds.batch(batch_size, drop_remainder=True)\n", + "\n", + "time_iterator(iter(ds))" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "DZMQDk9XsNJb" + }, + "source": [ + "### Bottleneck with GIL\n", + "\n", + "These bottlenecks do not allow to take advantage of multithreading in Python and therefore require multiprocessing." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "QHmtJi1UsWsP", + "colab": { + "base_uri": "https://localhost:8080/" + }, + "outputId": "4c4d1bcb-e679-4500-ddef-887e00d9e8e4" + }, + "outputs": [ + { + "output_type": "stream", + "name": "stdout", + "text": [ + "Time to get the first batch: 49.06 sec.\n" + ] + }, + { + "output_type": "stream", + "name": "stderr", + "text": [ + "WARNING:absl:Couldn't get execution summary from the child process 0\n" + ] + }, + { + "output_type": "stream", + "name": "stdout", + "text": [ + "Grain Dataset Execution Summary:\n", + "\n", + "NOTE: Before analyzing the `total_processing_time` for a node, please check the `percent wait time` column to ensure that the node is indicated as bottleneck. The `MapDataset` nodes are executed in multiple threads and thus, should not be compared to the `total_processing_time` of `DatasetIterator` nodes.\n", + "\n", + "|---------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------|\n", + "| id | name | inputs | percent wait time | total processing time | min processing time | max processing time | avg processing time | num produced elements | memory usage |\n", + "|---------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------|\n", + "| 0 | MultiprocessPrefetchDatasetIte | [] | N/A | 668.47us | 1.75us | 22.36us | 7.96us | 84 | 0 bytes |\n", + "| | rator(multiprocessing_options= | | | | | | | | |\n", + "| | MultiprocessingOptions(num_wor | | | | | | | | |\n", + "| | kers=24, per_worker_buffer_siz | | | | | | | | |\n", + "| | e=1, enable_profiling=False)) | | | | | | | | |\n", + "|---------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------|\n", + "\n" + ] + } + ], + "source": [ + "ds = grain.MapDataset.source(source).map(json.loads).map(tokenize)\n", + "ds = ds.shuffle(seed=42).repeat(num_epochs=100)[shard_index::shard_count]\n", + "ds = ds.map(trim_values).to_iter_dataset(grain.ReadOptions(num_threads=2))\n", + "ds = grain.experimental.FirstFitPackIterDataset(\n", + " ds, length_struct={\"description\": sequence_length}, num_packing_bins=30\n", + ")\n", + "ds = ds.batch(batch_size, drop_remainder=True)\n", + "ds = ds.mp_prefetch(\n", + " grain.multiprocessing.MultiprocessingOptions(num_workers=16)\n", + ")\n", + "\n", + "time_iterator(iter(ds))" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "-xOjTwk7Tbzj" + }, + "source": [ + "## Data loading is faster than training\n", + "\n", + "Don't spend time optimizing your pipeline, just hide its latency behind the training!\n", + "\n", + "In this example data fetching and training are synchronous:" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "LjDV4xedThOt" + }, + "outputs": [], + "source": [ + "import time\n", + "\n", + "train_step_latency_s = 1\n", + "num_steps = 100\n", + "\n", + "\n", + "def train_step(data_batch):\n", + " del data_batch\n", + " time.sleep(train_step_latency_s)\n", + "\n", + "\n", + "def train(dataset):\n", + " it = iter(dataset)\n", + " start_time = time.perf_counter()\n", + " for _ in range(num_steps):\n", + " data_batch = next(it)\n", + " train_step(data_batch)\n", + "\n", + " training_time = time.perf_counter() - start_time\n", + " idle_ratio = (\n", + " training_time - num_steps * train_step_latency_s\n", + " ) / training_time\n", + " print(f\"Spent {(idle_ratio * 100):.2f}% of time waiting for data\")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/" + }, + "id": "7CuYBoPkV5DD", + "outputId": "252f8b9d-ba63-44bb-d344-8b4bac978c9e" + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Spent 46.24% of time waiting for data\n" + ] + } + ], + "source": [ + "train(batched_ds)" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "BgmuptqxVGnC" + }, + "source": [ + "### Add background prefetching\n", + "\n", + "Background thread prefetching allows to asynchronously process data before it is requested and thus hides majority of the data processing latency." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/" + }, + "id": "RpQQVhcnVD_s", + "outputId": "f01d2b01-e5ef-4ea0-8ee8-16911114deee" + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Spent 1.60% of time waiting for data\n" + ] + } + ], + "source": [ + "prefetched_ds = grain.experimental.ThreadPrefetchIterDataset(\n", + " batched_ds, prefetch_buffer_size=3\n", + ")\n", + "\n", + "train(prefetched_ds)" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "InoprIqgWGBL" + }, + "source": [ + "### Hide first batch processing behind checkpoint recovery\n", + "\n", + "Another (often ovelooked) optimization is to overlap first batch processing with the model checkpoint recovery." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "omBeviSJWSeC" + }, + "outputs": [], + "source": [ + "model_restore_latency_s = 5\n", + "\n", + "\n", + "def restore_model():\n", + " time.sleep(model_restore_latency_s)\n", + "\n", + "\n", + "def train_from_checkpoint(dataset):\n", + " it = iter(dataset)\n", + " it.start_prefetch()\n", + " start_time = time.perf_counter()\n", + " restore_model()\n", + " for _ in range(num_steps):\n", + " data_batch = next(it)\n", + " train_step(data_batch)\n", + "\n", + " training_time = time.perf_counter() - start_time\n", + " idle_ratio = (\n", + " training_time - num_steps * train_step_latency_s - model_restore_latency_s\n", + " ) / training_time\n", + " print(f\"Spent {(idle_ratio * 100):.2f}% of time waiting for data\")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/" + }, + "id": "bzEA6H_8rk0j", + "outputId": "da4d4a81-8cf0-459e-eaeb-369d514cd797" + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Spent 0.48% of time waiting for data\n" + ] + } + ], + "source": [ + "train_from_checkpoint(prefetched_ds)" + ] + } + ], + "metadata": { + "colab": { + "provenance": [], + "gpuType": "V5E1", + "collapsed_sections": [ + "wpQU648ycXYj" + ] + }, + "kernelspec": { + "display_name": "Python 3", + "name": "python3" + }, + "language_info": { + "name": "python" + }, + "accelerator": "TPU" + }, + "nbformat": 4, + "nbformat_minor": 0 +} \ No newline at end of file