diff --git a/myst.yml b/myst.yml index b048b5ed..842a87e3 100644 --- a/myst.yml +++ b/myst.yml @@ -13,6 +13,10 @@ project: - title: Preamble children: - file: notebooks/how-to-cite.md + - title: Xbatcher fundamentals + children: + - file: notebooks/xbatcher_dataloading.ipynb + - file: notebooks/xbatcher_reconstruction.ipynb - title: Testing model inference children: - file: notebooks/inference-testing.ipynb diff --git a/notebooks/xbatcher_dataloading.ipynb b/notebooks/xbatcher_dataloading.ipynb new file mode 100644 index 00000000..3dfeee52 --- /dev/null +++ b/notebooks/xbatcher_dataloading.ipynb @@ -0,0 +1,2033 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# Dataloading from Xarray Datasets" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Working with large, multi-dimensional datasets, common in fields like climate science and oceanography, presents a significant challenge when preparing data for machine learning models. The `xbatcher` library is designed to simplify this crucial preprocessing step.\n", + "\n", + "`xbatcher` is a Python package that facilitates the generation of data batches from `xarray` objects for machine learning. It serves as a bridge between the labeled, multi-dimensional data structures of `xarray` and the tensor-based inputs required by deep learning frameworks such as PyTorch and TensorFlow.\n", + "\n", + "This guide provides an introduction to the fundamentals of `xbatcher`. We will cover how to create a `BatchGenerator`, customize it for specific needs, and prepare the resulting data for integration with a PyTorch model." + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Imports" + ] + }, + { + "cell_type": "code", + "execution_count": 14, + "metadata": {}, + "outputs": [], + "source": [ + "import xarray as xr\n", + "import numpy as np\n", + "import torch\n", + "import xbatcher\n", + "from xbatcher.loaders.torch import MapDataset, IterableDataset" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Creating a Sample Dataset\n", + "\n", + "To begin, we will create a sample `xarray.Dataset`. This allows us to focus on the mechanics of `xbatcher` without the overhead of a specific real-world dataset. This sample can be replaced by any `xarray.Dataset` loaded from a file (e.g., NetCDF, Zarr)." + ] + }, + { + "cell_type": "code", + "execution_count": 15, + "metadata": {}, + "outputs": [ + { + "data": { + "text/html": [ + "
\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "
<xarray.Dataset> Size: 8MB\n",
+       "Dimensions:        (x: 100, y: 100, time: 50)\n",
+       "Coordinates:\n",
+       "  * x              (x) int64 800B 0 1 2 3 4 5 6 7 8 ... 92 93 94 95 96 97 98 99\n",
+       "  * y              (y) int64 800B 0 1 2 3 4 5 6 7 8 ... 92 93 94 95 96 97 98 99\n",
+       "  * time           (time) int64 400B 0 1 2 3 4 5 6 7 ... 42 43 44 45 46 47 48 49\n",
+       "Data variables:\n",
+       "    temperature    (x, y, time) float64 4MB 0.6357 0.8989 ... 0.1376 0.1089\n",
+       "    precipitation  (x, y, time) float64 4MB 0.05915 0.2899 ... 0.0906 0.969
" + ], + "text/plain": [ + " Size: 8MB\n", + "Dimensions: (x: 100, y: 100, time: 50)\n", + "Coordinates:\n", + " * x (x) int64 800B 0 1 2 3 4 5 6 7 8 ... 92 93 94 95 96 97 98 99\n", + " * y (y) int64 800B 0 1 2 3 4 5 6 7 8 ... 92 93 94 95 96 97 98 99\n", + " * time (time) int64 400B 0 1 2 3 4 5 6 7 ... 42 43 44 45 46 47 48 49\n", + "Data variables:\n", + " temperature (x, y, time) float64 4MB 0.6357 0.8989 ... 0.1376 0.1089\n", + " precipitation (x, y, time) float64 4MB 0.05915 0.2899 ... 0.0906 0.969" + ] + }, + "execution_count": 15, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "ds = xr.Dataset(\n", + " {\n", + " \"temperature\": ((\"x\", \"y\", \"time\"), np.random.rand(100, 100, 50)),\n", + " \"precipitation\": ((\"x\", \"y\", \"time\"), np.random.rand(100, 100, 50)),\n", + " },\n", + " coords={\n", + " \"x\": np.arange(100),\n", + " \"y\": np.arange(100),\n", + " \"time\": np.arange(50),\n", + " },\n", + ")\n", + "ds" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "The dataset contains two variables, `temperature` and `precipitation`, and three dimensions: `x`, `y`, and `time`. We will now use `xbatcher` to generate batches from this dataset." + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## The `BatchGenerator`\n", + "\n", + "The `BatchGenerator` is the core component of `xbatcher`. It is a Python generator that yields batches of data from an `xarray` object." + ] + }, + { + "cell_type": "code", + "execution_count": 16, + "metadata": {}, + "outputs": [], + "source": [ + "bgen = xbatcher.BatchGenerator(ds, input_dims={\"x\": 10, \"y\": 10})" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "The `BatchGenerator` is initialized with the dataset and the `input_dims` parameter. `input_dims` specifies the size of the batches along each dimension. In this case, we are creating batches of size 10x10 along the `x` and `y` dimensions. The `time` dimension is not specified, so `xbatcher` will yield batches that include all time steps.\n", + "\n", + "Let's inspect the first batch generated." + ] + }, + { + "cell_type": "code", + "execution_count": 17, + "metadata": {}, + "outputs": [ + { + "data": { + "text/html": [ + "
\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "
<xarray.Dataset> Size: 81kB\n",
+       "Dimensions:        (x: 10, y: 10, time: 50)\n",
+       "Coordinates:\n",
+       "  * x              (x) int64 80B 0 1 2 3 4 5 6 7 8 9\n",
+       "  * y              (y) int64 80B 0 1 2 3 4 5 6 7 8 9\n",
+       "  * time           (time) int64 400B 0 1 2 3 4 5 6 7 ... 42 43 44 45 46 47 48 49\n",
+       "Data variables:\n",
+       "    temperature    (x, y, time) float64 40kB 0.6357 0.8989 ... 0.7347 0.4043\n",
+       "    precipitation  (x, y, time) float64 40kB 0.05915 0.2899 ... 0.1648 0.06016
" + ], + "text/plain": [ + " Size: 81kB\n", + "Dimensions: (x: 10, y: 10, time: 50)\n", + "Coordinates:\n", + " * x (x) int64 80B 0 1 2 3 4 5 6 7 8 9\n", + " * y (y) int64 80B 0 1 2 3 4 5 6 7 8 9\n", + " * time (time) int64 400B 0 1 2 3 4 5 6 7 ... 42 43 44 45 46 47 48 49\n", + "Data variables:\n", + " temperature (x, y, time) float64 40kB 0.6357 0.8989 ... 0.7347 0.4043\n", + " precipitation (x, y, time) float64 40kB 0.05915 0.2899 ... 0.1648 0.06016" + ] + }, + "execution_count": 17, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "first_batch = next(iter(bgen))\n", + "first_batch" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "The first batch has dimensions `x=10`, `y=10`, and `time=50`, as expected. The `BatchGenerator` will yield 100 batches in total (10 batches in the x-direction * 10 batches in the y-direction)." + ] + }, + { + "cell_type": "code", + "execution_count": 18, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "The BatchGenerator contains 100 batches.\n" + ] + } + ], + "source": [ + "print(f\"The BatchGenerator contains {len(bgen)} batches.\")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Overlapping Batches with `input_overlap`\n", + "\n", + "In many applications, it is useful to have overlapping batches to provide context from neighboring data points. The `input_overlap` parameter allows for this." + ] + }, + { + "cell_type": "code", + "execution_count": 19, + "metadata": {}, + "outputs": [ + { + "data": { + "text/html": [ + "
\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "
<xarray.Dataset> Size: 81kB\n",
+       "Dimensions:        (x: 10, y: 10, time: 50)\n",
+       "Coordinates:\n",
+       "  * x              (x) int64 80B 0 1 2 3 4 5 6 7 8 9\n",
+       "  * y              (y) int64 80B 0 1 2 3 4 5 6 7 8 9\n",
+       "  * time           (time) int64 400B 0 1 2 3 4 5 6 7 ... 42 43 44 45 46 47 48 49\n",
+       "Data variables:\n",
+       "    temperature    (x, y, time) float64 40kB 0.6357 0.8989 ... 0.7347 0.4043\n",
+       "    precipitation  (x, y, time) float64 40kB 0.05915 0.2899 ... 0.1648 0.06016
" + ], + "text/plain": [ + " Size: 81kB\n", + "Dimensions: (x: 10, y: 10, time: 50)\n", + "Coordinates:\n", + " * x (x) int64 80B 0 1 2 3 4 5 6 7 8 9\n", + " * y (y) int64 80B 0 1 2 3 4 5 6 7 8 9\n", + " * time (time) int64 400B 0 1 2 3 4 5 6 7 ... 42 43 44 45 46 47 48 49\n", + "Data variables:\n", + " temperature (x, y, time) float64 40kB 0.6357 0.8989 ... 0.7347 0.4043\n", + " precipitation (x, y, time) float64 40kB 0.05915 0.2899 ... 0.1648 0.06016" + ] + }, + "execution_count": 19, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "bgen_overlap = xbatcher.BatchGenerator(\n", + " ds, \n", + " input_dims={\"x\": 10, \"y\": 10}, \n", + " input_overlap={\"x\": 2, \"y\": 2}\n", + ")\n", + "first_batch_overlap = next(iter(bgen_overlap))\n", + "first_batch_overlap" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "The `input_overlap` parameter specifies the number of elements to overlap between consecutive batches. The size of the batches themselves does not change. Let's verify this by inspecting the coordinates of the first two batches." + ] + }, + { + "cell_type": "code", + "execution_count": 28, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Batch 1 y-coords: [0 1 2 3 4 5 6 7 8 9], Batch 1 x-coords: [0 1 2 3 4 5 6 7 8 9]\n", + "Batch 2 y-coords: [ 8 9 10 11 12 13 14 15 16 17], Batch 2 x-coords: [0 1 2 3 4 5 6 7 8 9]\n", + "Batch 3 y-coords: [16 17 18 19 20 21 22 23 24 25], Batch 3 x-coords: [0 1 2 3 4 5 6 7 8 9]\n", + "Batch 13 y-coords: [0 1 2 3 4 5 6 7 8 9], Batch 13 x-coords: [ 8 9 10 11 12 13 14 15 16 17]\n", + "Batch 14 y-coords: [ 8 9 10 11 12 13 14 15 16 17], Batch 11 x-coords: [ 8 9 10 11 12 13 14 15 16 17]\n" + ] + } + ], + "source": [ + "print(f\"Batch 1 y-coords: {bgen_overlap[0].y.values}, Batch 1 x-coords: {bgen_overlap[0].x.values}\")\n", + "print(f\"Batch 2 y-coords: {bgen_overlap[1].y.values}, Batch 2 x-coords: {bgen_overlap[1].x.values}\")\n", + "print(f\"Batch 3 y-coords: {bgen_overlap[2].y.values}, Batch 3 x-coords: {bgen_overlap[2].x.values}\")\n", + "print(f\"Batch 13 y-coords: {bgen_overlap[12].y.values}, Batch 13 x-coords: {bgen_overlap[12].x.values}\")\n", + "print(f\"Batch 14 y-coords: {bgen_overlap[13].y.values}, Batch 11 x-coords: {bgen_overlap[13].x.values}\")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "As you can see, the second batch starts at `y=8`, which is an overlap of 2 elements with the first batch, which ends at `y=9`." + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Integration with PyTorch\n", + "\n", + "`xbatcher` provides `MapDataset` and `IterableDataset` to wrap the `BatchGenerator` for use with PyTorch." + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### `MapDataset` vs. `IterableDataset`\n", + "\n", + "- `MapDataset`: Implements `__getitem__` and `__len__`, allowing for random access to data samples. This is the most common type of dataset in PyTorch.\n", + "- `IterableDataset`: Implements `__iter__`, and is suitable for very large datasets that may not fit into memory, as it streams data.\n", + "\n", + "We will use `MapDataset` for this example." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "(10, 10, 50)" + ] + }, + "execution_count": 47, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "bgen[0].temperature.shape\n", + "bgen[0].precipitation.shape" + ] + }, + { + "cell_type": "code", + "execution_count": 48, + "metadata": {}, + "outputs": [], + "source": [ + "def patch_to_tensor(patch):\n", + " temp_patch = torch.tensor(patch.temperature.data)\n", + " prcp_patch = torch.tensor(patch.precipitation.data)\n", + " stacked_patch = torch.stack((temp_patch, prcp_patch), dim=0)\n", + " patch = stacked_patch\n", + " patch = torch.nan_to_num(patch)\n", + " # patch = torch.unsqueeze(patch, 0)\n", + " patch = patch.float()\n", + " return patch\n", + "\n", + "map_ds = MapDataset(bgen, transform=patch_to_tensor)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "The `MapDataset` can then be used with a PyTorch `DataLoader`, which provides utilities for shuffling, batching, and multiprocessing." + ] + }, + { + "cell_type": "code", + "execution_count": 51, + "metadata": {}, + "outputs": [], + "source": [ + "dataloader = torch.utils.data.DataLoader(map_ds, batch_size=4)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Inspecting a batch from the `DataLoader` reveals a batch of PyTorch tensors." + ] + }, + { + "cell_type": "code", + "execution_count": 53, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "torch.Size([4, 2, 10, 10, 50])" + ] + }, + "execution_count": 53, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "torch_batch = next(iter(dataloader))\n", + "torch_batch.shape" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "The `DataLoader` has stacked 4 of the `xbatcher` batches, creating a new `batch` dimension of size 4. The data is now ready for use in a PyTorch model.\n", + "\n", + "In the next notebook, we will explore how to reconstruct an `xarray.Dataset` from a model's output." + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.9.7" + } + }, + "nbformat": 4, + "nbformat_minor": 4 +} diff --git a/notebooks/xbatcher_reconstruction.ipynb b/notebooks/xbatcher_reconstruction.ipynb new file mode 100644 index 00000000..11063f85 --- /dev/null +++ b/notebooks/xbatcher_reconstruction.ipynb @@ -0,0 +1,1457 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# Reconstructing Xarray Datasets from Model Outputs" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "This notebook addresses the process of reconstructing an `xarray.DataArray` from the output of a machine learning model. While the previous notebook focused on generating batches from `xarray` objects, this guide details the reverse process: assembling model outputs back into a coherent, labeled `xarray` object. This is a common requirement in scientific machine learning workflows, where the model output needs to be analyzed in its original spatial or temporal context.\n", + "\n", + "We will examine a function that reassembles model outputs, including a detailed look at how an internal API of `xbatcher` can be used to map batch outputs back to their original coordinates." + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Imports" + ] + }, + { + "cell_type": "code", + "execution_count": 1, + "metadata": {}, + "outputs": [], + "source": [ + "import xarray as xr\n", + "import numpy as np\n", + "import torch\n", + "import xbatcher\n", + "from xbatcher.loaders.torch import MapDataset\n", + "from typing import Literal\n", + "\n", + "from dummy_models import ExpandAlongAxis" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Setup: Data, Batches, and a Dummy Model\n", + "\n", + "We will begin by creating a sample `xarray.DataArray` and a `BatchGenerator`. We will also instantiate a dummy model that transforms the data, simulating a common machine learning scenario where the output dimensions differ from the input dimensions (e.g., super-resolution)." + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "metadata": {}, + "outputs": [ + { + "data": { + "text/html": [ + "
\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "
<xarray.DataArray (x: 50, y: 40)> Size: 8kB\n",
+       "array([[0.94426095, 0.7027894 , 0.02029528, ..., 0.16328041, 0.5883387 ,\n",
+       "        0.8879921 ],\n",
+       "       [0.6830533 , 0.8331848 , 0.44004276, ..., 0.6508039 , 0.8455495 ,\n",
+       "        0.66443324],\n",
+       "       [0.36509654, 0.9623709 , 0.44621307, ..., 0.66530186, 0.31605566,\n",
+       "        0.9226282 ],\n",
+       "       ...,\n",
+       "       [0.2908776 , 0.3381197 , 0.7494014 , ..., 0.19071114, 0.10994843,\n",
+       "        0.17150152],\n",
+       "       [0.6378889 , 0.95425236, 0.51718473, ..., 0.52702767, 0.9290716 ,\n",
+       "        0.819217  ],\n",
+       "       [0.59220934, 0.6537968 , 0.06189981, ..., 0.75576884, 0.0942427 ,\n",
+       "        0.36704108]], shape=(50, 40), dtype=float32)\n",
+       "Coordinates:\n",
+       "  * x        (x) int64 400B 0 1 2 3 4 5 6 7 8 9 ... 41 42 43 44 45 46 47 48 49\n",
+       "  * y        (y) int64 320B 0 1 2 3 4 5 6 7 8 9 ... 31 32 33 34 35 36 37 38 39
" + ], + "text/plain": [ + " Size: 8kB\n", + "array([[0.94426095, 0.7027894 , 0.02029528, ..., 0.16328041, 0.5883387 ,\n", + " 0.8879921 ],\n", + " [0.6830533 , 0.8331848 , 0.44004276, ..., 0.6508039 , 0.8455495 ,\n", + " 0.66443324],\n", + " [0.36509654, 0.9623709 , 0.44621307, ..., 0.66530186, 0.31605566,\n", + " 0.9226282 ],\n", + " ...,\n", + " [0.2908776 , 0.3381197 , 0.7494014 , ..., 0.19071114, 0.10994843,\n", + " 0.17150152],\n", + " [0.6378889 , 0.95425236, 0.51718473, ..., 0.52702767, 0.9290716 ,\n", + " 0.819217 ],\n", + " [0.59220934, 0.6537968 , 0.06189981, ..., 0.75576884, 0.0942427 ,\n", + " 0.36704108]], shape=(50, 40), dtype=float32)\n", + "Coordinates:\n", + " * x (x) int64 400B 0 1 2 3 4 5 6 7 8 9 ... 41 42 43 44 45 46 47 48 49\n", + " * y (y) int64 320B 0 1 2 3 4 5 6 7 8 9 ... 31 32 33 34 35 36 37 38 39" + ] + }, + "execution_count": 2, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "da = xr.DataArray(\n", + " data=np.random.rand(50, 40).astype(np.float32),\n", + " dims=(\"x\", \"y\"),\n", + " coords={\"x\": np.arange(50), \"y\": np.arange(40)},\n", + ")\n", + "da" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Next, we create the `BatchGenerator`." + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "metadata": {}, + "outputs": [], + "source": [ + "bgen = xbatcher.BatchGenerator(da, input_dims={\"x\": 10, \"y\": 10})" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "For the model, we will use `ExpandAlongAxis` from `dummy_models.py`. This model upsamples the input along a specified axis, changing the dimensions of the data." + ] + }, + { + "cell_type": "code", + "execution_count": 4, + "metadata": {}, + "outputs": [], + "source": [ + "# The model will expand the 'x' dimension by a factor of 2\n", + "model = ExpandAlongAxis(ax=1, n_repeats=2)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## The `predict_on_array` Function\n", + "\n", + "The `predict_on_array` function (from `functions.py`) is designed to take batches from a `BatchGenerator`, pass them through a model, and reassemble the outputs. The following sections will break down this function and its helpers." + ] + }, + { + "cell_type": "code", + "execution_count": 5, + "metadata": {}, + "outputs": [], + "source": [ + "def _get_resample_factor(\n", + " bgen: xbatcher.BatchGenerator,\n", + " output_tensor_dim: dict[str, int],\n", + " resample_dim: list[str]\n", + "):\n", + " resample_factor = {}\n", + " for dim in resample_dim:\n", + " r = output_tensor_dim[dim] / bgen.input_dims[dim]\n", + " is_int = (r == int(r))\n", + " is_inv_int = (1/r == int(1/r)) if r != 0 else False\n", + " assert is_int or is_inv_int, f\"Resample ratio for dim '{dim}' must be an integer or its inverse.\"\n", + " resample_factor[dim] = r\n", + "\n", + " return resample_factor" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### `_get_resample_factor`\n", + "\n", + "This helper function calculates the resampling factor for each dimension. For example, if input batches have `x=10` and the model outputs tensors with `x=20`, the resampling factor for `x` is 2. This is used to determine the dimensions of the final reconstructed array." + ] + }, + { + "cell_type": "code", + "execution_count": 6, + "metadata": {}, + "outputs": [], + "source": [ + "def _get_output_array_size(\n", + " bgen: xbatcher.BatchGenerator,\n", + " output_tensor_dim: dict[str, int],\n", + " new_dim: list[str],\n", + " core_dim: list[str],\n", + " resample_dim: list[str]\n", + "):\n", + " resample_factor = _get_resample_factor(bgen, output_tensor_dim, resample_dim)\n", + " output_size = {}\n", + " for key, size in output_tensor_dim.items():\n", + " if key in new_dim:\n", + " output_size[key] = output_tensor_dim[key]\n", + " elif key in core_dim:\n", + " if output_tensor_dim[key] != bgen.ds.sizes[key]:\n", + " raise ValueError(\n", + " f\"Axis {key} is a core dim, but the tensor size\"\n", + " f\"({output_tensor_dim[key]}) does not equal the \"\n", + " f\"source data array size ({bgen.ds.sizes[key]}).\"\n", + " )\n", + " output_size[key] = bgen.ds.sizes[key]\n", + " elif key in resample_dim:\n", + " temp_output_size = bgen.ds.sizes[key] * resample_factor[key]\n", + " assert temp_output_size.is_integer(), f\"Resampling for dim '{key}' results in non-integer size.\"\n", + " output_size[key] = int(temp_output_size)\n", + " else:\n", + " raise ValueError(f\"Axis {key} must be specified in one of new_dim, core_dim, or resample_dim\") \n", + " return output_size" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### `_get_output_array_size`\n", + "\n", + "This function determines the final size of the reconstructed array. It uses the resampling factor and also considers `new_dim` (dimensions that are new in the output) and `core_dim` (dimensions that are not batched over and remain unchanged)." + ] + }, + { + "cell_type": "code", + "execution_count": 7, + "metadata": {}, + "outputs": [], + "source": [ + "def _resample_coordinate(\n", + " coord: xr.DataArray,\n", + " factor: float,\n", + " mode: Literal[\"centers\", \"edges\"]=\"edges\"\n", + ") -> np.ndarray:\n", + " assert len(coord.shape) == 1 and coord.shape[0] > 1\n", + " assert (coord.shape[0] * factor).is_integer()\n", + " old_step = (coord.data[1] - coord.data[0])\n", + " offset = 0 if mode == \"edges\" else old_step / 2\n", + " new_step = old_step / factor\n", + " coord = coord - offset\n", + " new_coord_end = coord.max().item() + old_step\n", + " return np.arange(coord.min().item(), new_coord_end, step=new_step) + offset" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### `_resample_coordinate`\n", + "\n", + "If the size of a dimension is changed, its coordinates must also be updated. This function handles the resampling of coordinates." + ] + }, + { + "cell_type": "code", + "execution_count": 8, + "metadata": {}, + "outputs": [], + "source": [ + "def _get_output_array_coordinates(\n", + " src_da: xr.DataArray,\n", + " output_array_dim: list[str],\n", + " resample_factor: dict[str, int],\n", + " resample_mode: Literal[\"centers\", \"edges\"]=\"edges\"\n", + ") -> dict[str, np.ndarray]:\n", + " output_coords = {}\n", + " for dim in output_array_dim:\n", + " if dim in src_da.coords and dim in resample_factor:\n", + " output_coords[dim] = _resample_coordinate(src_da[dim], resample_factor[dim], resample_mode)\n", + " elif dim in src_da.coords:\n", + " output_coords[dim] = src_da[dim].copy(deep=True).data\n", + " else:\n", + " continue\n", + " return output_coords" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### `_get_output_array_coordinates`\n", + "\n", + "This function generates a dictionary of the new coordinates for the output array." + ] + }, + { + "cell_type": "code", + "execution_count": 9, + "metadata": {}, + "outputs": [], + "source": [ + "def predict_on_array(\n", + " dataset: MapDataset,\n", + " model: torch.nn.Module,\n", + " output_tensor_dim: dict[str, int],\n", + " new_dim: list[str],\n", + " core_dim: list[str],\n", + " resample_dim: list[str],\n", + " resample_mode: Literal[\"centers\", \"edges\"]=\"edges\",\n", + " batch_size: int=16\n", + ") -> xr.DataArray:\n", + " s_new = set(new_dim)\n", + " s_core = set(core_dim)\n", + " s_resample = set(resample_dim)\n", + "\n", + " if s_new & s_core or s_new & s_resample or s_core & s_resample:\n", + " raise ValueError(\"new_dim, core_dim, and resample_dim must be disjoint sets.\")\n", + "\n", + " bgen = dataset.X_generator\n", + "\n", + " resample_factor = _get_resample_factor(\n", + " bgen,\n", + " output_tensor_dim, \n", + " resample_dim\n", + " )\n", + " \n", + " output_size = _get_output_array_size(\n", + " bgen,\n", + " output_tensor_dim,\n", + " new_dim,\n", + " core_dim,\n", + " resample_dim\n", + " )\n", + " \n", + " output_da = xr.DataArray(\n", + " data=np.zeros(tuple(output_size.values())),\n", + " dims=tuple(output_size.keys()),\n", + " )\n", + " output_n = xr.full_like(output_da, 0)\n", + " \n", + " loader = torch.utils.data.DataLoader(dataset, batch_size=batch_size)\n", + "\n", + " for i, batch in enumerate(loader):\n", + " input_tensor = batch[0] if isinstance(batch, (list, tuple)) else batch\n", + " out_batch = model(input_tensor).detach().numpy()\n", + "\n", + " for ib in range(out_batch.shape[0]):\n", + " global_index = (i * batch_size) + ib\n", + " old_indexer = bgen._batch_selectors.selectors[global_index][0]\n", + " new_indexer = {}\n", + " for key in old_indexer:\n", + " if key in resample_dim:\n", + " new_indexer[key] = slice(\n", + " int(old_indexer[key].start * resample_factor[key]),\n", + " int(old_indexer[key].stop * resample_factor[key])\n", + " )\n", + "\n", + " output_da.loc[new_indexer] += out_batch[ib, ...]\n", + " output_n.loc[new_indexer] += 1\n", + "\n", + " output_da = output_da / output_n\n", + "\n", + " output_da = output_da.assign_coords(\n", + " _get_output_array_coordinates(\n", + " dataset.X_generator.ds, \n", + " list(output_tensor_dim.keys()), \n", + " resample_factor, \n", + " resample_mode\n", + " )\n", + " )\n", + "\n", + " return output_da" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### `predict_on_array` Internals\n", + "\n", + "The key steps of this function are as follows:\n", + "\n", + "1. **Initialization**: An empty `DataArray` (`output_da`) is created with the final dimensions, along with a corresponding `DataArray` (`output_n`) to track the number of predictions for each element (for averaging in case of overlaps).\n", + "2. **Iteration**: The function iterates through the `DataLoader`.\n", + "3. **The Internal API**: The core of the reconstruction is `bgen._batch_selectors.selectors[global_index]`. This internal attribute of the `BatchGenerator` stores the slice objects for each batch, providing a map from the batch to the original `DataArray`'s coordinate space.\n", + "4. **Disclaimer**: Accessing internal attributes such as `_batch_selectors` is not part of the public API and may change in future versions of `xbatcher`.\n", + "5. **Rescaling and Placing**: The resampling factor is used to scale the slices, and `.loc` is used to place the model's output into the correct location in `output_da`.\n", + "6. **Averaging and Coordinates**: Finally, the predictions are averaged (if there were overlaps) and the new coordinates are assigned." + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Reconstructing the Dataset\n", + "\n", + "We will now use the `predict_on_array` function to reconstruct the dataset." + ] + }, + { + "cell_type": "code", + "execution_count": 10, + "metadata": {}, + "outputs": [ + { + "data": { + "text/html": [ + "
\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "
<xarray.DataArray (x: 100, y: 40)> Size: 32kB\n",
+       "array([[0.94426095, 0.70278943, 0.02029528, ..., 0.16328041, 0.58833867,\n",
+       "        0.88799208],\n",
+       "       [0.94426095, 0.70278943, 0.02029528, ..., 0.16328041, 0.58833867,\n",
+       "        0.88799208],\n",
+       "       [0.68305331, 0.83318478, 0.44004276, ..., 0.65080392, 0.84554952,\n",
+       "        0.66443324],\n",
+       "       ...,\n",
+       "       [0.63788891, 0.95425236, 0.51718473, ..., 0.52702767, 0.92907161,\n",
+       "        0.81921703],\n",
+       "       [0.59220934, 0.65379679, 0.06189981, ..., 0.75576884, 0.0942427 ,\n",
+       "        0.36704108],\n",
+       "       [0.59220934, 0.65379679, 0.06189981, ..., 0.75576884, 0.0942427 ,\n",
+       "        0.36704108]], shape=(100, 40))\n",
+       "Coordinates:\n",
+       "  * x        (x) float64 800B 0.0 0.5 1.0 1.5 2.0 ... 47.5 48.0 48.5 49.0 49.5\n",
+       "  * y        (y) float64 320B 0.0 1.0 2.0 3.0 4.0 ... 35.0 36.0 37.0 38.0 39.0
" + ], + "text/plain": [ + " Size: 32kB\n", + "array([[0.94426095, 0.70278943, 0.02029528, ..., 0.16328041, 0.58833867,\n", + " 0.88799208],\n", + " [0.94426095, 0.70278943, 0.02029528, ..., 0.16328041, 0.58833867,\n", + " 0.88799208],\n", + " [0.68305331, 0.83318478, 0.44004276, ..., 0.65080392, 0.84554952,\n", + " 0.66443324],\n", + " ...,\n", + " [0.63788891, 0.95425236, 0.51718473, ..., 0.52702767, 0.92907161,\n", + " 0.81921703],\n", + " [0.59220934, 0.65379679, 0.06189981, ..., 0.75576884, 0.0942427 ,\n", + " 0.36704108],\n", + " [0.59220934, 0.65379679, 0.06189981, ..., 0.75576884, 0.0942427 ,\n", + " 0.36704108]], shape=(100, 40))\n", + "Coordinates:\n", + " * x (x) float64 800B 0.0 0.5 1.0 1.5 2.0 ... 47.5 48.0 48.5 49.0 49.5\n", + " * y (y) float64 320B 0.0 1.0 2.0 3.0 4.0 ... 35.0 36.0 37.0 38.0 39.0" + ] + }, + "execution_count": 10, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "map_dataset = MapDataset(bgen)\n", + "reconstructed_da = predict_on_array(\n", + " dataset=map_dataset,\n", + " model=model,\n", + " output_tensor_dim={\"x\": 20, \"y\": 10}, # The model doubles the x-dimension\n", + " new_dim=[],\n", + " core_dim=[],\n", + " resample_dim=[\"x\", \"y\"],\n", + " batch_size=4\n", + ")\n", + "reconstructed_da" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "The reconstructed `DataArray` has the upsampled `x` dimension. We can compare its shape to the original." + ] + }, + { + "cell_type": "code", + "execution_count": 11, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Original shape: (50, 40)\n", + "Reconstructed shape: (100, 40)\n" + ] + } + ], + "source": [ + "print(f\"Original shape: {da.shape}\")\n", + "print(f\"Reconstructed shape: {reconstructed_da.shape}\")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "The reconstructed array has twice the number of elements in the `x` dimension, as expected. This concludes the demonstration of reconstructing an `xarray.Dataset` from model outputs using `xbatcher`." + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "cookbook-dev", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.13.5" + } + }, + "nbformat": 4, + "nbformat_minor": 4 +}