diff --git a/environment.yml b/environment.yml index 2acbf791..27b43db1 100644 --- a/environment.yml +++ b/environment.yml @@ -11,3 +11,4 @@ dependencies: - matplotlib - pytorch-cpu - xbatcher + - pytest diff --git a/notebooks/functions.py b/notebooks/functions.py index b7f01c35..c5704fb4 100644 --- a/notebooks/functions.py +++ b/notebooks/functions.py @@ -14,8 +14,10 @@ def _get_resample_factor( resample_factor = {} for dim in resample_dim: r = output_tensor_dim[dim] / bgen.input_dims[dim] - assert r.is_integer() or (r ** -1).is_integer() - resample_factor[dim] = output_tensor_dim[dim] / bgen.input_dims[dim] + is_int = (r == int(r)) + is_inv_int = (1/r == int(1/r)) if r != 0 else False + assert is_int or is_inv_int, f"Resample ratio for dim '{dim}' must be an integer or its inverse." + resample_factor[dim] = r return resample_factor @@ -39,8 +41,8 @@ def _get_output_array_size( # determined by the source array if output_tensor_dim[key] != bgen.ds.sizes[key]: raise ValueError( - f"Axis {key} is a core dim, but the tensor size " - f"({output_tensor_dim[key]}) does not equal the" + f"Axis {key} is a core dim, but the tensor size" + f"({output_tensor_dim[key]}) does not equal the " f"source data array size ({bgen.ds.sizes[key]})." ) output_size[key] = bgen.ds.sizes[key] @@ -48,10 +50,10 @@ def _get_output_array_size( # This is a resampled axis, determine the new size # by the resample factor. temp_output_size = bgen.ds.sizes[key] * resample_factor[key] - assert temp_output_size.is_integer() + assert temp_output_size.is_integer(), f"Resampling for dim '{key}' results in non-integer size." output_size[key] = int(temp_output_size) else: - raise ValueError(f"Axis {dim} must be specified in one of new_dim, core_dim, or resample_dim") + raise ValueError(f"Axis {key} must be specified in one of new_dim, core_dim, or resample_dim") return output_size diff --git a/notebooks/inference-testing.ipynb b/notebooks/inference-testing.ipynb index 3784d6e4..3e72ec22 100644 --- a/notebooks/inference-testing.ipynb +++ b/notebooks/inference-testing.ipynb @@ -23,794 +23,632 @@ }, { "cell_type": "code", - "execution_count": 1, + "execution_count": 11, "metadata": {}, "outputs": [], "source": [ "import torch\n", "import xbatcher\n", "import xarray as xr\n", - "import numpy as np" + "import numpy as np\n", + "import pytest\n", + "\n", + "from functions import _get_output_array_size, predict_on_array" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ - "## Toy data" + "## Testing the array size function" ] }, { "cell_type": "code", - "execution_count": 2, + "execution_count": 24, "metadata": {}, "outputs": [ { - "data": { - "text/html": [ - "
\n", - "\n", - "\n", - "\n", - "\n", - "\n", - "\n", - "\n", - "\n", - "\n", - "\n", - "\n", - "\n", - "\n", - "\n", - "
<xarray.DataArray (x: 100, y: 100, t: 10)> Size: 800kB\n",
-       "array([[[0.80102402, 0.31982866, 0.25465188, ..., 0.31021283,\n",
-       "         0.03268302, 0.7941252 ],\n",
-       "        [0.4382327 , 0.60519493, 0.86055512, ..., 0.91560426,\n",
-       "         0.75151732, 0.045112  ],\n",
-       "        [0.56324046, 0.16672896, 0.95892792, ..., 0.75294859,\n",
-       "         0.62661621, 0.9507573 ],\n",
-       "        ...,\n",
-       "        [0.8942051 , 0.67152092, 0.5936959 , ..., 0.88955481,\n",
-       "         0.00104593, 0.25334781],\n",
-       "        [0.20207265, 0.89166716, 0.11930909, ..., 0.44454162,\n",
-       "         0.16769459, 0.13463857],\n",
-       "        [0.52366905, 0.77434724, 0.18922243, ..., 0.60982095,\n",
-       "         0.81261007, 0.35591865]],\n",
-       "\n",
-       "       [[0.36699177, 0.29720945, 0.47093257, ..., 0.08379803,\n",
-       "         0.52233317, 0.62779805],\n",
-       "        [0.78408776, 0.29027363, 0.06707094, ..., 0.83199475,\n",
-       "         0.26807477, 0.93323152],\n",
-       "        [0.70058356, 0.83639276, 0.82536327, ..., 0.94036962,\n",
-       "         0.79252184, 0.83713594],\n",
-       "...\n",
-       "        [0.39766555, 0.35994714, 0.50271116, ..., 0.46796155,\n",
-       "         0.97140383, 0.51062791],\n",
-       "        [0.70055487, 0.81950773, 0.76781762, ..., 0.05735074,\n",
-       "         0.73039909, 0.17101624],\n",
-       "        [0.8547392 , 0.21515083, 0.40187337, ..., 0.4583594 ,\n",
-       "         0.53213994, 0.15382577]],\n",
-       "\n",
-       "       [[0.6465284 , 0.88463142, 0.33543765, ..., 0.49702278,\n",
-       "         0.40990036, 0.42300881],\n",
-       "        [0.20363819, 0.02142565, 0.66128113, ..., 0.56253861,\n",
-       "         0.33799959, 0.16641441],\n",
-       "        [0.67934574, 0.38965227, 0.68431764, ..., 0.56820253,\n",
-       "         0.18640022, 0.111049  ],\n",
-       "        ...,\n",
-       "        [0.91296116, 0.6995538 , 0.50325961, ..., 0.40427649,\n",
-       "         0.14176453, 0.26158606],\n",
-       "        [0.5348746 , 0.67937558, 0.22366613, ..., 0.48889003,\n",
-       "         0.50030053, 0.1855764 ],\n",
-       "        [0.8714776 , 0.9552773 , 0.06885801, ..., 0.95366137,\n",
-       "         0.73931089, 0.90909682]]], shape=(100, 100, 10))\n",
-       "Dimensions without coordinates: x, y, t
" - ], - "text/plain": [ - " Size: 800kB\n", - "array([[[0.80102402, 0.31982866, 0.25465188, ..., 0.31021283,\n", - " 0.03268302, 0.7941252 ],\n", - " [0.4382327 , 0.60519493, 0.86055512, ..., 0.91560426,\n", - " 0.75151732, 0.045112 ],\n", - " [0.56324046, 0.16672896, 0.95892792, ..., 0.75294859,\n", - " 0.62661621, 0.9507573 ],\n", - " ...,\n", - " [0.8942051 , 0.67152092, 0.5936959 , ..., 0.88955481,\n", - " 0.00104593, 0.25334781],\n", - " [0.20207265, 0.89166716, 0.11930909, ..., 0.44454162,\n", - " 0.16769459, 0.13463857],\n", - " [0.52366905, 0.77434724, 0.18922243, ..., 0.60982095,\n", - " 0.81261007, 0.35591865]],\n", - "\n", - " [[0.36699177, 0.29720945, 0.47093257, ..., 0.08379803,\n", - " 0.52233317, 0.62779805],\n", - " [0.78408776, 0.29027363, 0.06707094, ..., 0.83199475,\n", - " 0.26807477, 0.93323152],\n", - " [0.70058356, 0.83639276, 0.82536327, ..., 0.94036962,\n", - " 0.79252184, 0.83713594],\n", - "...\n", - " [0.39766555, 0.35994714, 0.50271116, ..., 0.46796155,\n", - " 0.97140383, 0.51062791],\n", - " [0.70055487, 0.81950773, 0.76781762, ..., 0.05735074,\n", - " 0.73039909, 0.17101624],\n", - " [0.8547392 , 0.21515083, 0.40187337, ..., 0.4583594 ,\n", - " 0.53213994, 0.15382577]],\n", - "\n", - " [[0.6465284 , 0.88463142, 0.33543765, ..., 0.49702278,\n", - " 0.40990036, 0.42300881],\n", - " [0.20363819, 0.02142565, 0.66128113, ..., 0.56253861,\n", - " 0.33799959, 0.16641441],\n", - " [0.67934574, 0.38965227, 0.68431764, ..., 0.56820253,\n", - " 0.18640022, 0.111049 ],\n", - " ...,\n", - " [0.91296116, 0.6995538 , 0.50325961, ..., 0.40427649,\n", - " 0.14176453, 0.26158606],\n", - " [0.5348746 , 0.67937558, 0.22366613, ..., 0.48889003,\n", - " 0.50030053, 0.1855764 ],\n", - " [0.8714776 , 0.9552773 , 0.06885801, ..., 0.95366137,\n", - " 0.73931089, 0.90909682]]], shape=(100, 100, 10))\n", - "Dimensions without coordinates: x, y, t" - ] - }, - "execution_count": 2, - "metadata": {}, - "output_type": "execute_result" + "name": "stdout", + "output_type": "stream", + "text": [ + "Overwriting test_get_array_size.py\n" + ] } ], "source": [ - "data = xr.DataArray(\n", - " data=np.random.rand(100, 100, 10),\n", - " dims=(\"x\", \"y\", \"t\")\n", - ")\n", - "data" + "%%writefile test_get_array_size.py\n", + "import torch\n", + "import xbatcher\n", + "import xarray as xr\n", + "import numpy as np\n", + "import pytest\n", + "\n", + "from functions import _get_output_array_size, _get_resample_factor" ] }, { - "cell_type": "markdown", + "cell_type": "code", + "execution_count": 28, "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Appending to test_get_array_size.py\n" + ] + } + ], "source": [ - "## Simple model" + "%%writefile -a test_get_array_size.py\n", + "\n", + "@pytest.fixture\n", + "def bgen_fixture() -> xbatcher.BatchGenerator:\n", + " data = xr.DataArray(\n", + " data=np.random.rand(100, 100, 10),\n", + " dims=(\"x\", \"y\", \"t\"),\n", + " coords={\n", + " \"x\": np.arange(100),\n", + " \"y\": np.arange(100),\n", + " \"t\": np.arange(10),\n", + " }\n", + " )\n", + " \n", + " bgen = xbatcher.BatchGenerator(\n", + " data,\n", + " input_dims=dict(x=10, y=10),\n", + " input_overlap=dict(x=5, y=5),\n", + " )\n", + " return bgen\n", + "\n", + "@pytest.mark.parametrize(\n", + " \"case_description, output_tensor_dim, new_dim, core_dim, resample_dim, expected_output\",\n", + " [\n", + " (\n", + " \"Resampling only: Downsample x, Upsample y\",\n", + " {'x': 5, 'y': 20}, \n", + " [],\n", + " [],\n", + " ['x', 'y'],\n", + " {'x': 50, 'y': 200} \n", + " ),\n", + " (\n", + " \"New dimensions only: Add a 'channel' dimension\",\n", + " {'channel': 3},\n", + " ['channel'],\n", + " [],\n", + " [],\n", + " {'channel': 3}\n", + " ),\n", + " (\n", + " \"Mixed: Resample x, add new channel dimension and keep t as core\",\n", + " {'x': 30, 'channel': 12}, \n", + " ['channel'],\n", + " ['t'],\n", + " ['x'],\n", + " {'x': 300, 'channel': 12} \n", + " ),\n", + " (\n", + " \"Identity resampling (ratio=1)\",\n", + " {'x': 10, 'y': 10},\n", + " [],\n", + " [],\n", + " ['x', 'y'],\n", + " {'x': 100, 'y': 100} \n", + " ),\n", + " (\n", + " \"Core dims only: 't' is a core dim\",\n", + " {'t': 10},\n", + " [], \n", + " ['t'], \n", + " [],\n", + " {'t': 10}\n", + " ),\n", + " ]\n", + ")\n", + "def test_get_output_array_size_scenarios(\n", + " bgen_fixture, # The fixture is passed as an argument\n", + " case_description,\n", + " output_tensor_dim,\n", + " new_dim,\n", + " core_dim,\n", + " resample_dim,\n", + " expected_output\n", + "):\n", + " \"\"\"\n", + " Tests various valid scenarios for calculating the output array size.\n", + " The `case_description` parameter is not used in the code but helps make\n", + " test results more readable.\n", + " \"\"\"\n", + " # The `bgen_fixture` argument is the BatchGenerator instance created by our fixture\n", + " result = _get_output_array_size(\n", + " bgen=bgen_fixture,\n", + " output_tensor_dim=output_tensor_dim,\n", + " new_dim=new_dim,\n", + " core_dim=core_dim,\n", + " resample_dim=resample_dim\n", + " )\n", + " \n", + " assert result == expected_output, f\"Failed on case: {case_description}\"" ] }, { "cell_type": "code", - "execution_count": 3, + "execution_count": 29, "metadata": {}, - "outputs": [], + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Appending to test_get_array_size.py\n" + ] + } + ], "source": [ - "class MeanAlongDim(torch.nn.Module):\n", - " def __init__(self, ax):\n", - " super(MeanAlongDim, self).__init__()\n", - " self.ax = ax\n", + "%%writefile -a test_get_array_size.py\n", + "\n", + "def test_get_output_array_size_raises_error_on_mismatched_core_dim(bgen_fixture):\n", + " \"\"\"Tests ValueError when a core_dim size doesn't match the source.\"\"\"\n", + " with pytest.raises(ValueError, match=\"does not equal the source data array size\"):\n", + " _get_output_array_size(\n", + " bgen_fixture, output_tensor_dim={'t': 99}, new_dim=[], core_dim=['t'], resample_dim=[]\n", + " )\n", "\n", - " def forward(self, x):\n", - " return torch.mean(x, self.ax)" + "def test_get_output_array_size_raises_error_on_unspecified_dim(bgen_fixture):\n", + " \"\"\"Tests ValueError when a dimension is not specified in any category.\"\"\"\n", + " with pytest.raises(ValueError, match=\"must be specified in one of\"):\n", + " _get_output_array_size(\n", + " bgen_fixture, output_tensor_dim={'x': 10}, new_dim=[], core_dim=[], resample_dim=[]\n", + " )\n", + "\n", + "def test_get_resample_factor_raises_error_on_invalid_ratio(bgen_fixture):\n", + " \"\"\"Tests AssertionError when the resample ratio is not an integer or its inverse.\"\"\"\n", + " with pytest.raises(AssertionError, match=\"must be an integer or its inverse\"):\n", + " # 15 / 10 = 1.5, which is not a valid ratio\n", + " _get_resample_factor(bgen_fixture, output_tensor_dim={'x': 15}, resample_dim=['x'])" + ] + }, + { + "cell_type": "code", + "execution_count": 30, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "\u001b[1m============================= test session starts ==============================\u001b[0m\n", + "platform darwin -- Python 3.13.5, pytest-8.4.1, pluggy-1.6.0 -- /Users/nkalauni/miniconda3/envs/cookbook-dev/bin/python3.13\n", + "cachedir: .pytest_cache\n", + "rootdir: /Users/nkalauni/Documents/Cline/xbatcher-deep-learning/notebooks\n", + "plugins: anyio-4.10.0\n", + "collected 8 items \u001b[0m\u001b[1m\n", + "\n", + "test_get_array_size.py::test_get_output_array_size_scenarios[Resampling only: Downsample x, Upsample y-output_tensor_dim0-new_dim0-core_dim0-resample_dim0-expected_output0] \u001b[32mPASSED\u001b[0m\u001b[32m [ 12%]\u001b[0m\n", + "test_get_array_size.py::test_get_output_array_size_scenarios[New dimensions only: Add a 'channel' dimension-output_tensor_dim1-new_dim1-core_dim1-resample_dim1-expected_output1] \u001b[32mPASSED\u001b[0m\u001b[32m [ 25%]\u001b[0m\n", + "test_get_array_size.py::test_get_output_array_size_scenarios[Mixed: Resample x, add new channel dimension and keep t as core-output_tensor_dim2-new_dim2-core_dim2-resample_dim2-expected_output2] \u001b[32mPASSED\u001b[0m\u001b[32m [ 37%]\u001b[0m\n", + "test_get_array_size.py::test_get_output_array_size_scenarios[Identity resampling (ratio=1)-output_tensor_dim3-new_dim3-core_dim3-resample_dim3-expected_output3] \u001b[32mPASSED\u001b[0m\u001b[32m [ 50%]\u001b[0m\n", + "test_get_array_size.py::test_get_output_array_size_scenarios[Core dims only: 't' is a core dim-output_tensor_dim4-new_dim4-core_dim4-resample_dim4-expected_output4] \u001b[32mPASSED\u001b[0m\u001b[32m [ 62%]\u001b[0m\n", + "test_get_array_size.py::test_get_output_array_size_raises_error_on_mismatched_core_dim \u001b[32mPASSED\u001b[0m\u001b[32m [ 75%]\u001b[0m\n", + "test_get_array_size.py::test_get_output_array_size_raises_error_on_unspecified_dim \u001b[32mPASSED\u001b[0m\u001b[32m [ 87%]\u001b[0m\n", + "test_get_array_size.py::test_get_resample_factor_raises_error_on_invalid_ratio \u001b[32mPASSED\u001b[0m\u001b[32m [100%]\u001b[0m\n", + "\n", + "\u001b[32m============================== \u001b[32m\u001b[1m8 passed\u001b[0m\u001b[32m in 1.23s\u001b[0m\u001b[32m ===============================\u001b[0m\n" + ] + } + ], + "source": [ + "!pytest -v test_get_array_size.py" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ - "## Batch generator, dataset" + "## Testing the predict_on_array function" ] }, { "cell_type": "code", - "execution_count": 4, + "execution_count": 19, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ - "Input shape: torch.Size([10, 10, 10])\n", - "Output shape: torch.Size([10, 10])\n" + "Overwriting test_predict_on_array.py\n" ] } ], "source": [ + "%%writefile test_predict_on_array.py\n", + "import xarray as xr\n", + "import numpy as np\n", + "import torch\n", + "import xbatcher\n", + "import pytest\n", "from xbatcher.loaders.torch import MapDataset\n", "\n", - "bgen = xbatcher.BatchGenerator(\n", - " data,\n", - " input_dims=dict(x=10, y=10),\n", - " input_overlap=dict(x=5, y=5),\n", - ")\n", - "\n", - "ds = MapDataset(bgen)\n", - "\n", - "inp = next(iter(ds))\n", - "\n", - "# Check the input/output size of the first example\n", - "print(\"Input shape:\", inp.shape)\n", - "\n", - "mad = MeanAlongDim(-1)\n", - "print(\"Output shape:\", mad(inp).shape)" + "from functions import _get_output_array_size, predict_on_array\n", + "from dummy_models import *" ] }, { "cell_type": "code", - "execution_count": 5, + "execution_count": 20, "metadata": {}, "outputs": [], "source": [ - "assert torch.allclose(mad(inp), torch.mean(inp, -1))" + "import xarray as xr\n", + "import numpy as np\n", + "import torch\n", + "import xbatcher\n", + "import pytest\n", + "from xbatcher.loaders.torch import MapDataset\n", + "\n", + "from functions import _get_output_array_size, predict_on_array\n", + "from dummy_models import *" ] }, { - "cell_type": "markdown", + "cell_type": "code", + "execution_count": 21, "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "tensor([0., 1., 2., 3., 4.])" + ] + }, + "execution_count": 21, + "metadata": {}, + "output_type": "execute_result" + } + ], "source": [ - "## Inference function" + "input_tensor = torch.arange(125).reshape((5, 5, 5)).to(torch.float32)\n", + "input_tensor[0,0,:]" ] }, { "cell_type": "code", - "execution_count": 6, + "execution_count": 22, "metadata": {}, - "outputs": [], + "outputs": [ + { + "data": { + "text/plain": [ + "tensor([[ 2., 7., 12., 17., 22.],\n", + " [ 27., 32., 37., 42., 47.],\n", + " [ 52., 57., 62., 67., 72.],\n", + " [ 77., 82., 87., 92., 97.],\n", + " [102., 107., 112., 117., 122.]])" + ] + }, + "execution_count": 22, + "metadata": {}, + "output_type": "execute_result" + } + ], "source": [ - "%run ./functions.ipynb" + "model = MeanAlongDim(-1)\n", + "model(input_tensor)" ] }, { "cell_type": "code", - "execution_count": 17, + "execution_count": 23, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ - "{'y': 100, 'x': 100, 't': 5}\n" + "Appending to test_predict_on_array.py\n" ] } ], "source": [ - "out_size_dict = _get_output_array_size(\n", - " bgen = ds.X_generator,\n", - " output_tensor_dim = dict(y=10, x=10, t=5),\n", - " new_dim = [\"t\"],\n", - " resample_dim = [\"y\", \"x\", \"t\"]\n", - ")\n", - "print(out_size_dict)" - ] - }, - { - "cell_type": "code", - "execution_count": 11, - "metadata": {}, - "outputs": [], - "source": [ - "from functools import partial" + "%%writefile -a test_predict_on_array.py\n", + "\n", + "@pytest.fixture\n", + "def map_dataset_fixture() -> MapDataset:\n", + " \"\"\"\n", + " Creates a MapDataset with a predictable BatchGenerator for testing.\n", + " - Data is an xarray DataArray with dimensions x=20, y=10\n", + " - Values are a simple np.arange sequence for easy verification.\n", + " - Batches are size x=10, y=5 with overlap x=2, y=2\n", + " \"\"\"\n", + " # Using a smaller, more manageable dataset for testing\n", + " data = xr.DataArray(\n", + " data=np.arange(20 * 10).reshape(20, 10),\n", + " dims=(\"x\", \"y\"),\n", + " coords={\"x\": np.arange(20), \"y\": np.arange(10)}\n", + " ).astype(float)\n", + " \n", + " bgen = xbatcher.BatchGenerator(\n", + " data,\n", + " input_dims=dict(x=10, y=5),\n", + " input_overlap=dict(x=2, y=2),\n", + " )\n", + " return MapDataset(bgen)" ] }, { "cell_type": "code", - "execution_count": 21, + "execution_count": null, "metadata": {}, "outputs": [], "source": [ - "get_array_size_partial = partial(_get_output_array_size, bgen=ds.X_generator)" + " data = xr.DataArray(\n", + " data=np.arange(20 * 10).reshape(20, 10),\n", + " dims=(\"x\", \"y\"),\n", + " coords={\"x\": np.arange(20), \"y\": np.arange(10)}\n", + " ).astype(float)\n", + " \n", + " bgen = xbatcher.BatchGenerator(\n", + " data,\n", + " input_dims=dict(x=10, y=5),\n", + " input_overlap=dict(x=2, y=2),\n", + " )" ] }, { "cell_type": "code", - "execution_count": 18, + "execution_count": 24, "metadata": {}, - "outputs": [], + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Appending to test_predict_on_array.py\n" + ] + } + ], "source": [ - "test_cases = [\n", - " {\n", - " \"name\": \"Same dims and same input sizes 1\",\n", - " \"function_inputs\": {\n", - " \"output_tensor_dim\": dict(y=10, x=10),\n", - " \"new_dim\": [],\n", - " \"resample_dim\": [\"x\", \"y\"]\n", - " },\n", - " \"expected_output\": dict(y=100, x=100)\n", - " },\n", - " {\n", - " \"name\": \"New dim added\",\n", - " \"function_inputs\": {\n", - " \"output_tensor_dim\": dict(y=10, x=10, t=5),\n", - " \"new_dim\": [\"t\"],\n", - " \"resample_dim\": [\"y\", \"x\"]\n", - " },\n", - " \"expected_output\": dict(y=100, x=100, t=5)\n", - " }\n", - "] " + "%%writefile -a test_predict_on_array.py\n", + "\n", + "@pytest.mark.parametrize(\n", + " \"model, output_tensor_dim, new_dim, resample_dim, expected_transform\",\n", + " [\n", + " # Case 1: Resampling - Downsampling with a subset model\n", + " (\n", + " SubsetAlongAxis(ax=1, n=5), # Corresponds to 'x' dim in batch\n", + " {'x': 5, 'y': 5},\n", + " [],\n", + " ['x'],\n", + " lambda da: da.isel(x=slice(0, 5)) # Expected: take first 5 elements of original 'x'\n", + " ),\n", + " # Case 2: Dimension reduction with a mean model\n", + " (\n", + " MeanAlongDim(ax=2), # Corresponds to 'y' dim in batch\n", + " {'x': 10},\n", + " [],\n", + " ['x'],\n", + " lambda da: da.mean(dim='y') # Expected: mean along original 'y'\n", + " ),\n", + " ]\n", + ")\n", + "def test_predict_on_array_reassembly(\n", + " map_dataset_fixture,\n", + " model,\n", + " output_tensor_dim,\n", + " new_dim,\n", + " resample_dim,\n", + " expected_transform\n", + "):\n", + " \"\"\"\n", + " Tests that predict_on_array correctly reassembles batches from different models.\n", + " \"\"\"\n", + " # --- Run the function under test ---\n", + " # Using a small batch_size to ensure multiple iterations\n", + " predicted_da, predicted_n = predict_on_array(\n", + " dataset=map_dataset_fixture,\n", + " model=model,\n", + " output_tensor_dim=output_tensor_dim,\n", + " new_dim=new_dim,\n", + " resample_dim=resample_dim,\n", + " batch_size=4 \n", + " )\n", + "\n", + " # --- Manually calculate the expected result ---\n", + " bgen = map_dataset_fixture.generator\n", + " # 1. Create the expected output array structure\n", + " expected_size = _get_output_array_size(bgen, output_tensor_dim, new_dim, resample_dim)\n", + " expected_da = xr.DataArray(np.zeros(list(expected_size.values())), dims=list(expected_size.keys()))\n", + " expected_n = xr.full_like(expected_da, 0)\n", + "\n", + " # 2. Manually iterate through batches and apply the same logic as the function\n", + " for i in range(len(map_dataset_fixture)):\n", + " batch_da = bgen[i]\n", + " \n", + " # Apply the same transformation the model would\n", + " transformed_batch = expected_transform(batch_da)\n", + " \n", + " # Get the rescaled indexer\n", + " old_indexer = bgen.batch_selectors[i]\n", + " new_indexer = {}\n", + " for key in old_indexer:\n", + " if key in resample_dim:\n", + " resample_ratio = output_tensor_dim[key] / bgen.input_dims[key]\n", + " new_indexer[key] = slice(\n", + " int(old_indexer[key].start * resample_ratio),\n", + " int(old_indexer[key].stop * resample_ratio)\n", + " )\n", + " \n", + " # Add the result to our manually calculated array\n", + " expected_da.loc[new_indexer] += transformed_batch.values\n", + " expected_n.loc[new_indexer] += 1\n", + "\n", + " # --- Assert that the results are identical ---\n", + " # We test the raw summed output and the overlap counter array\n", + " xr.testing.assert_allclose(predicted_da, expected_da)\n", + " xr.testing.assert_allclose(predicted_n, expected_n)" ] }, { "cell_type": "code", - "execution_count": 22, + "execution_count": 25, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ - "Test case 0 passed\n", - "Test case 1 passed\n" + "\u001b[1m============================= test session starts ==============================\u001b[0m\n", + "platform linux -- Python 3.10.16, pytest-8.4.1, pluggy-1.6.0 -- /srv/conda/envs/notebook/bin/python3.10\n", + "cachedir: .pytest_cache\n", + "rootdir: /home/jovyan/xbatcher-deep-learning/notebooks\n", + "plugins: anyio-4.8.0\n", + "collected 2 items \u001b[0m\u001b[1m\n", + "\n", + "test_predict_on_array.py::test_predict_on_array_reassembly[model0-output_tensor_dim0-new_dim0-resample_dim0-] \u001b[31mFAILED\u001b[0m\u001b[31m [ 50%]\u001b[0m\n", + "\u001b[31mFAILED\u001b[0m\u001b[31m [100%]\u001b[0mpredict_on_array_reassembly[model1-output_tensor_dim1-new_dim1-resample_dim1-] \n", + "\n", + "=================================== FAILURES ===================================\n", + "\u001b[31m\u001b[1m_ test_predict_on_array_reassembly[model0-output_tensor_dim0-new_dim0-resample_dim0-] _\u001b[0m\n", + "\n", + "map_dataset_fixture = \n", + "model = SubsetAlongAxis(), output_tensor_dim = {'x': 5, 'y': 5}, new_dim = []\n", + "resample_dim = ['x'], expected_transform = at 0x7f4d4a136cb0>\n", + "\n", + " \u001b[0m\u001b[37m@pytest\u001b[39;49;00m.mark.parametrize(\u001b[90m\u001b[39;49;00m\n", + " \u001b[33m\"\u001b[39;49;00m\u001b[33mmodel, output_tensor_dim, new_dim, resample_dim, expected_transform\u001b[39;49;00m\u001b[33m\"\u001b[39;49;00m,\u001b[90m\u001b[39;49;00m\n", + " [\u001b[90m\u001b[39;49;00m\n", + " \u001b[90m# Case 1: Resampling - Downsampling with a subset model\u001b[39;49;00m\u001b[90m\u001b[39;49;00m\n", + " (\u001b[90m\u001b[39;49;00m\n", + " SubsetAlongAxis(ax=\u001b[94m1\u001b[39;49;00m, n=\u001b[94m5\u001b[39;49;00m), \u001b[90m# Corresponds to 'x' dim in batch\u001b[39;49;00m\u001b[90m\u001b[39;49;00m\n", + " {\u001b[33m'\u001b[39;49;00m\u001b[33mx\u001b[39;49;00m\u001b[33m'\u001b[39;49;00m: \u001b[94m5\u001b[39;49;00m, \u001b[33m'\u001b[39;49;00m\u001b[33my\u001b[39;49;00m\u001b[33m'\u001b[39;49;00m: \u001b[94m5\u001b[39;49;00m},\u001b[90m\u001b[39;49;00m\n", + " [],\u001b[90m\u001b[39;49;00m\n", + " [\u001b[33m'\u001b[39;49;00m\u001b[33mx\u001b[39;49;00m\u001b[33m'\u001b[39;49;00m],\u001b[90m\u001b[39;49;00m\n", + " \u001b[94mlambda\u001b[39;49;00m da: da.isel(x=\u001b[96mslice\u001b[39;49;00m(\u001b[94m0\u001b[39;49;00m, \u001b[94m5\u001b[39;49;00m)) \u001b[90m# Expected: take first 5 elements of original 'x'\u001b[39;49;00m\u001b[90m\u001b[39;49;00m\n", + " ),\u001b[90m\u001b[39;49;00m\n", + " \u001b[90m# Case 2: Dimension reduction with a mean model\u001b[39;49;00m\u001b[90m\u001b[39;49;00m\n", + " (\u001b[90m\u001b[39;49;00m\n", + " MeanAlongDim(ax=\u001b[94m2\u001b[39;49;00m), \u001b[90m# Corresponds to 'y' dim in batch\u001b[39;49;00m\u001b[90m\u001b[39;49;00m\n", + " {\u001b[33m'\u001b[39;49;00m\u001b[33mx\u001b[39;49;00m\u001b[33m'\u001b[39;49;00m: \u001b[94m10\u001b[39;49;00m},\u001b[90m\u001b[39;49;00m\n", + " [],\u001b[90m\u001b[39;49;00m\n", + " [\u001b[33m'\u001b[39;49;00m\u001b[33mx\u001b[39;49;00m\u001b[33m'\u001b[39;49;00m],\u001b[90m\u001b[39;49;00m\n", + " \u001b[94mlambda\u001b[39;49;00m da: da.mean(dim=\u001b[33m'\u001b[39;49;00m\u001b[33my\u001b[39;49;00m\u001b[33m'\u001b[39;49;00m) \u001b[90m# Expected: mean along original 'y'\u001b[39;49;00m\u001b[90m\u001b[39;49;00m\n", + " ),\u001b[90m\u001b[39;49;00m\n", + " ]\u001b[90m\u001b[39;49;00m\n", + " )\u001b[90m\u001b[39;49;00m\n", + " \u001b[94mdef\u001b[39;49;00m\u001b[90m \u001b[39;49;00m\u001b[92mtest_predict_on_array_reassembly\u001b[39;49;00m(\u001b[90m\u001b[39;49;00m\n", + " map_dataset_fixture,\u001b[90m\u001b[39;49;00m\n", + " model,\u001b[90m\u001b[39;49;00m\n", + " output_tensor_dim,\u001b[90m\u001b[39;49;00m\n", + " new_dim,\u001b[90m\u001b[39;49;00m\n", + " resample_dim,\u001b[90m\u001b[39;49;00m\n", + " expected_transform\u001b[90m\u001b[39;49;00m\n", + " ):\u001b[90m\u001b[39;49;00m\n", + " \u001b[90m \u001b[39;49;00m\u001b[33m\"\"\"\u001b[39;49;00m\n", + " \u001b[33m Tests that predict_on_array correctly reassembles batches from different models.\u001b[39;49;00m\n", + " \u001b[33m \"\"\"\u001b[39;49;00m\u001b[90m\u001b[39;49;00m\n", + " \u001b[90m# --- Run the function under test ---\u001b[39;49;00m\u001b[90m\u001b[39;49;00m\n", + " \u001b[90m# Using a small batch_size to ensure multiple iterations\u001b[39;49;00m\u001b[90m\u001b[39;49;00m\n", + "> predicted_da, predicted_n = predict_on_array(\u001b[90m\u001b[39;49;00m\n", + " dataset=map_dataset_fixture,\u001b[90m\u001b[39;49;00m\n", + " model=model,\u001b[90m\u001b[39;49;00m\n", + " output_tensor_dim=output_tensor_dim,\u001b[90m\u001b[39;49;00m\n", + " new_dim=new_dim,\u001b[90m\u001b[39;49;00m\n", + " resample_dim=resample_dim,\u001b[90m\u001b[39;49;00m\n", + " batch_size=\u001b[94m4\u001b[39;49;00m\u001b[90m\u001b[39;49;00m\n", + " )\u001b[90m\u001b[39;49;00m\n", + "\n", + "\u001b[1m\u001b[31mtest_predict_on_array.py\u001b[0m:67: \n", + "_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ \n", + "\u001b[1m\u001b[31mfunctions.py\u001b[0m:55: in predict_on_array\n", + " \u001b[0moutput_size = _get_output_array_size(dataset.X_generator, output_tensor_dim, new_dim, resample_dim)\u001b[90m\u001b[39;49;00m\n", + "_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ \n", + "\n", + "bgen = \n", + "output_tensor_dim = {'x': 5, 'y': 5}, new_dim = [], resample_dim = ['x']\n", + "\n", + " \u001b[0m\u001b[94mdef\u001b[39;49;00m\u001b[90m \u001b[39;49;00m\u001b[92m_get_output_array_size\u001b[39;49;00m(\u001b[90m\u001b[39;49;00m\n", + " bgen: xbatcher.BatchGenerator,\u001b[90m\u001b[39;49;00m\n", + " output_tensor_dim: \u001b[96mdict\u001b[39;49;00m[\u001b[96mstr\u001b[39;49;00m, \u001b[96mint\u001b[39;49;00m],\u001b[90m\u001b[39;49;00m\n", + " new_dim: \u001b[96mlist\u001b[39;49;00m[\u001b[96mstr\u001b[39;49;00m],\u001b[90m\u001b[39;49;00m\n", + " resample_dim: \u001b[96mlist\u001b[39;49;00m[\u001b[96mstr\u001b[39;49;00m]\u001b[90m\u001b[39;49;00m\n", + " ):\u001b[90m\u001b[39;49;00m\n", + " resample_factor = _get_resample_factor(bgen, output_tensor_dim, resample_dim)\u001b[90m\u001b[39;49;00m\n", + " output_size = {}\u001b[90m\u001b[39;49;00m\n", + " \u001b[94mfor\u001b[39;49;00m key, size \u001b[95min\u001b[39;49;00m output_tensor_dim.items():\u001b[90m\u001b[39;49;00m\n", + " \u001b[94mif\u001b[39;49;00m key \u001b[95min\u001b[39;49;00m new_dim:\u001b[90m\u001b[39;49;00m\n", + " \u001b[90m# This is a new axis, size is determined\u001b[39;49;00m\u001b[90m\u001b[39;49;00m\n", + " \u001b[90m# by the tensor size.\u001b[39;49;00m\u001b[90m\u001b[39;49;00m\n", + " output_size[key] = output_tensor_dim[key]\u001b[90m\u001b[39;49;00m\n", + " \u001b[94melse\u001b[39;49;00m:\u001b[90m\u001b[39;49;00m\n", + " \u001b[90m# This is a resampled axis, determine the new size\u001b[39;49;00m\u001b[90m\u001b[39;49;00m\n", + " \u001b[90m# by the ratio of the batchgen window to the tensor size.\u001b[39;49;00m\u001b[90m\u001b[39;49;00m\n", + "> temp_output_size = bgen.ds.sizes[key] * resample_factor[key]\u001b[90m\u001b[39;49;00m\n", + "\u001b[1m\u001b[31mE KeyError: 'y'\u001b[0m\n", + "\n", + "\u001b[1m\u001b[31mfunctions.py\u001b[0m:36: KeyError\n", + "\u001b[31m\u001b[1m_ test_predict_on_array_reassembly[model1-output_tensor_dim1-new_dim1-resample_dim1-] _\u001b[0m\n", + "\n", + "map_dataset_fixture = \n", + "model = MeanAlongDim(), output_tensor_dim = {'x': 10}, new_dim = []\n", + "resample_dim = ['x'], expected_transform = at 0x7f4d4a1371c0>\n", + "\n", + " \u001b[0m\u001b[37m@pytest\u001b[39;49;00m.mark.parametrize(\u001b[90m\u001b[39;49;00m\n", + " \u001b[33m\"\u001b[39;49;00m\u001b[33mmodel, output_tensor_dim, new_dim, resample_dim, expected_transform\u001b[39;49;00m\u001b[33m\"\u001b[39;49;00m,\u001b[90m\u001b[39;49;00m\n", + " [\u001b[90m\u001b[39;49;00m\n", + " \u001b[90m# Case 1: Resampling - Downsampling with a subset model\u001b[39;49;00m\u001b[90m\u001b[39;49;00m\n", + " (\u001b[90m\u001b[39;49;00m\n", + " SubsetAlongAxis(ax=\u001b[94m1\u001b[39;49;00m, n=\u001b[94m5\u001b[39;49;00m), \u001b[90m# Corresponds to 'x' dim in batch\u001b[39;49;00m\u001b[90m\u001b[39;49;00m\n", + " {\u001b[33m'\u001b[39;49;00m\u001b[33mx\u001b[39;49;00m\u001b[33m'\u001b[39;49;00m: \u001b[94m5\u001b[39;49;00m, \u001b[33m'\u001b[39;49;00m\u001b[33my\u001b[39;49;00m\u001b[33m'\u001b[39;49;00m: \u001b[94m5\u001b[39;49;00m},\u001b[90m\u001b[39;49;00m\n", + " [],\u001b[90m\u001b[39;49;00m\n", + " [\u001b[33m'\u001b[39;49;00m\u001b[33mx\u001b[39;49;00m\u001b[33m'\u001b[39;49;00m],\u001b[90m\u001b[39;49;00m\n", + " \u001b[94mlambda\u001b[39;49;00m da: da.isel(x=\u001b[96mslice\u001b[39;49;00m(\u001b[94m0\u001b[39;49;00m, \u001b[94m5\u001b[39;49;00m)) \u001b[90m# Expected: take first 5 elements of original 'x'\u001b[39;49;00m\u001b[90m\u001b[39;49;00m\n", + " ),\u001b[90m\u001b[39;49;00m\n", + " \u001b[90m# Case 2: Dimension reduction with a mean model\u001b[39;49;00m\u001b[90m\u001b[39;49;00m\n", + " (\u001b[90m\u001b[39;49;00m\n", + " MeanAlongDim(ax=\u001b[94m2\u001b[39;49;00m), \u001b[90m# Corresponds to 'y' dim in batch\u001b[39;49;00m\u001b[90m\u001b[39;49;00m\n", + " {\u001b[33m'\u001b[39;49;00m\u001b[33mx\u001b[39;49;00m\u001b[33m'\u001b[39;49;00m: \u001b[94m10\u001b[39;49;00m},\u001b[90m\u001b[39;49;00m\n", + " [],\u001b[90m\u001b[39;49;00m\n", + " [\u001b[33m'\u001b[39;49;00m\u001b[33mx\u001b[39;49;00m\u001b[33m'\u001b[39;49;00m],\u001b[90m\u001b[39;49;00m\n", + " \u001b[94mlambda\u001b[39;49;00m da: da.mean(dim=\u001b[33m'\u001b[39;49;00m\u001b[33my\u001b[39;49;00m\u001b[33m'\u001b[39;49;00m) \u001b[90m# Expected: mean along original 'y'\u001b[39;49;00m\u001b[90m\u001b[39;49;00m\n", + " ),\u001b[90m\u001b[39;49;00m\n", + " ]\u001b[90m\u001b[39;49;00m\n", + " )\u001b[90m\u001b[39;49;00m\n", + " \u001b[94mdef\u001b[39;49;00m\u001b[90m \u001b[39;49;00m\u001b[92mtest_predict_on_array_reassembly\u001b[39;49;00m(\u001b[90m\u001b[39;49;00m\n", + " map_dataset_fixture,\u001b[90m\u001b[39;49;00m\n", + " model,\u001b[90m\u001b[39;49;00m\n", + " output_tensor_dim,\u001b[90m\u001b[39;49;00m\n", + " new_dim,\u001b[90m\u001b[39;49;00m\n", + " resample_dim,\u001b[90m\u001b[39;49;00m\n", + " expected_transform\u001b[90m\u001b[39;49;00m\n", + " ):\u001b[90m\u001b[39;49;00m\n", + " \u001b[90m \u001b[39;49;00m\u001b[33m\"\"\"\u001b[39;49;00m\n", + " \u001b[33m Tests that predict_on_array correctly reassembles batches from different models.\u001b[39;49;00m\n", + " \u001b[33m \"\"\"\u001b[39;49;00m\u001b[90m\u001b[39;49;00m\n", + " \u001b[90m# --- Run the function under test ---\u001b[39;49;00m\u001b[90m\u001b[39;49;00m\n", + " \u001b[90m# Using a small batch_size to ensure multiple iterations\u001b[39;49;00m\u001b[90m\u001b[39;49;00m\n", + "> predicted_da, predicted_n = predict_on_array(\u001b[90m\u001b[39;49;00m\n", + " dataset=map_dataset_fixture,\u001b[90m\u001b[39;49;00m\n", + " model=model,\u001b[90m\u001b[39;49;00m\n", + " output_tensor_dim=output_tensor_dim,\u001b[90m\u001b[39;49;00m\n", + " new_dim=new_dim,\u001b[90m\u001b[39;49;00m\n", + " resample_dim=resample_dim,\u001b[90m\u001b[39;49;00m\n", + " batch_size=\u001b[94m4\u001b[39;49;00m\u001b[90m\u001b[39;49;00m\n", + " )\u001b[90m\u001b[39;49;00m\n", + "\u001b[1m\u001b[31mE ValueError: too many values to unpack (expected 2)\u001b[0m\n", + "\n", + "\u001b[1m\u001b[31mtest_predict_on_array.py\u001b[0m:67: ValueError\n", + "\u001b[36m\u001b[1m=========================== short test summary info ============================\u001b[0m\n", + "\u001b[31mFAILED\u001b[0m test_predict_on_array.py::\u001b[1mtest_predict_on_array_reassembly[model0-output_tensor_dim0-new_dim0-resample_dim0-]\u001b[0m - KeyError: 'y'\n", + "\u001b[31mFAILED\u001b[0m test_predict_on_array.py::\u001b[1mtest_predict_on_array_reassembly[model1-output_tensor_dim1-new_dim1-resample_dim1-]\u001b[0m - ValueError: too many values to unpack (expected 2)\n", + "\u001b[31m============================== \u001b[31m\u001b[1m2 failed\u001b[0m\u001b[31m in 1.98s\u001b[0m\u001b[31m ===============================\u001b[0m\n" ] } ], "source": [ - "for i, test_case in enumerate(test_cases):\n", - " true_output = get_array_size_partial(**test_case[\"function_inputs\"])\n", - " success = true_output == test_case[\"expected_output\"]\n", - " message = \"passed\" if success else \"failed\"\n", - " print(f\"Test case {i} {message}\")" + "!pytest -v test_predict_on_array.py" ] }, { @@ -823,7 +661,7 @@ ], "metadata": { "kernelspec": { - "display_name": "Python 3 (ipykernel)", + "display_name": "cookbook-dev", "language": "python", "name": "python3" }, @@ -837,7 +675,7 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.12.11" + "version": "3.13.5" }, "nbdime-conflicts": { "local_diff": [ diff --git a/notebooks/test_get_array_size.py b/notebooks/test_get_array_size.py new file mode 100644 index 00000000..d54aea67 --- /dev/null +++ b/notebooks/test_get_array_size.py @@ -0,0 +1,225 @@ +import torch +import xbatcher +import xarray as xr +import numpy as np +import pytest + +from functions import _get_output_array_size, _get_resample_factor + +@pytest.fixture +def bgen_fixture() -> xbatcher.BatchGenerator: + data = xr.DataArray( + data=np.random.rand(100, 100, 10), + dims=("x", "y", "t"), + coords={ + "x": np.arange(100), + "y": np.arange(100), + "t": np.arange(10), + } + ) + + bgen = xbatcher.BatchGenerator( + data, + input_dims=dict(x=10, y=10), + input_overlap=dict(x=5, y=5), + ) + return bgen + +@pytest.mark.parametrize( + "case_description, output_tensor_dim, new_dim, core_dim, resample_dim, expected_output", + [ + ( + "Resampling only: Downsample x, Upsample y", + {'x': 5, 'y': 20}, + [], + [], + ['x', 'y'], + {'x': 50, 'y': 200} + ), + ( + "New dimensions only: Add a 'channel' dimension", + {'channel': 3}, + ['channel'], + [], + [], + {'channel': 3} + ), + ( + "Mixed: Resample x, add new channel dimension and keep t as core", + {'x': 30, 'channel': 12}, + ['channel'], + ['t'], + ['x'], + {'x': 300, 'channel': 12} + ), + ( + "Identity resampling (ratio=1)", + {'x': 10, 'y': 10}, + [], + [], + ['x', 'y'], + {'x': 100, 'y': 100} + ), + ( + "Core dims only: 't' is a core dim", + {'t': 10}, + [], + ['t'], + [], + {'t': 10} + ), + ] +) +def test_get_output_array_size_scenarios( + bgen_fixture, # The fixture is passed as an argument + case_description, + output_tensor_dim, + new_dim, + core_dim, + resample_dim, + expected_output +): + """ + Tests various valid scenarios for calculating the output array size. + The `case_description` parameter is not used in the code but helps make + test results more readable. + """ + # The `bgen_fixture` argument is the BatchGenerator instance created by our fixture + result = _get_output_array_size( + bgen=bgen_fixture, + output_tensor_dim=output_tensor_dim, + new_dim=new_dim, + core_dim=core_dim, + resample_dim=resample_dim + ) + + assert result == expected_output, f"Failed on case: {case_description}" + +def test_get_output_array_size_raises_error_on_mismatched_core_dim(bgen_fixture): + """Tests ValueError when a core_dim size doesn't match the source.""" + with pytest.raises(ValueError, match="does not equal the source data array size"): + _get_output_array_size( + bgen_fixture, output_tensor_dim={'t': 99}, new_dim=[], core_dim=['t'], resample_dim=[] + ) + +def test_get_output_array_size_raises_error_on_unspecified_dim(bgen_fixture): + """Tests ValueError when a dimension is not specified in any category.""" + with pytest.raises(ValueError, match="must be specified in one of"): + _get_output_array_size( + bgen_fixture, output_tensor_dim={'x': 10}, new_dim=[], core_dim=[], resample_dim=[] + ) + +def test_get_resample_factor_raises_error_on_invalid_ratio(bgen_fixture): + """Tests AssertionError when the resample ratio is not an integer or its inverse.""" + with pytest.raises(AssertionError, match="must be an integer or its inverse"): + # 15 / 10 = 1.5, which is not a valid ratio + _get_resample_factor(bgen_fixture, output_tensor_dim={'x': 15}, resample_dim=['x']) + +@pytest.fixture +def bgen_fixture() -> xbatcher.BatchGenerator: + data = xr.DataArray( + data=np.random.rand(100, 100, 10), + dims=("x", "y", "t"), + coords={ + "x": np.arange(100), + "y": np.arange(100), + "t": np.arange(10), + } + ) + + bgen = xbatcher.BatchGenerator( + data, + input_dims=dict(x=10, y=10), + input_overlap=dict(x=5, y=5), + ) + return bgen + +@pytest.mark.parametrize( + "case_description, output_tensor_dim, new_dim, core_dim, resample_dim, expected_output", + [ + ( + "Resampling only: Downsample x, Upsample y", + {'x': 5, 'y': 20}, + [], + [], + ['x', 'y'], + {'x': 50, 'y': 200} + ), + ( + "New dimensions only: Add a 'channel' dimension", + {'channel': 3}, + ['channel'], + [], + [], + {'channel': 3} + ), + ( + "Mixed: Resample x, add new channel dimension and keep t as core", + {'x': 30, 'channel': 12}, + ['channel'], + ['t'], + ['x'], + {'x': 300, 'channel': 12} + ), + ( + "Identity resampling (ratio=1)", + {'x': 10, 'y': 10}, + [], + [], + ['x', 'y'], + {'x': 100, 'y': 100} + ), + ( + "Core dims only: 't' is a core dim", + {'t': 10}, + [], + ['t'], + [], + {'t': 10} + ), + ] +) +def test_get_output_array_size_scenarios( + bgen_fixture, # The fixture is passed as an argument + case_description, + output_tensor_dim, + new_dim, + core_dim, + resample_dim, + expected_output +): + """ + Tests various valid scenarios for calculating the output array size. + The `case_description` parameter is not used in the code but helps make + test results more readable. + """ + # The `bgen_fixture` argument is the BatchGenerator instance created by our fixture + result = _get_output_array_size( + bgen=bgen_fixture, + output_tensor_dim=output_tensor_dim, + new_dim=new_dim, + core_dim=core_dim, + resample_dim=resample_dim + ) + + assert result == expected_output, f"Failed on case: {case_description}" + +def test_get_output_array_size_raises_error_on_mismatched_core_dim(bgen_fixture): + """Tests ValueError when a core_dim size doesn't match the source.""" + with pytest.raises(ValueError, match="does not equal the source data array size"): + _get_output_array_size( + bgen_fixture, output_tensor_dim={'t': 99}, new_dim=[], core_dim=['t'], resample_dim=[] + ) + +def test_get_output_array_size_raises_error_on_unspecified_dim(bgen_fixture): + """Tests ValueError when a dimension is not specified in any category.""" + with pytest.raises(ValueError, match="must be specified in one of"): + _get_output_array_size( + bgen_fixture, output_tensor_dim={'x': 10}, new_dim=[], core_dim=[], resample_dim=[] + ) + +def test_get_resample_factor_raises_error_on_invalid_ratio(bgen_fixture): + """Tests AssertionError when the resample ratio is not an integer or its inverse.""" + with pytest.raises(AssertionError, match="must be an integer or its inverse"): + # 15 / 10 = 1.5, which is not a valid ratio + _get_resample_factor(bgen_fixture, output_tensor_dim={'x': 15}, resample_dim=['x']) diff --git a/notebooks/test_predict_on_array.py b/notebooks/test_predict_on_array.py new file mode 100644 index 00000000..4a41e18f --- /dev/null +++ b/notebooks/test_predict_on_array.py @@ -0,0 +1,31 @@ +import xarray as xr +import numpy as np +import torch +import xbatcher +import pytest +from xbatcher.loaders.torch import MapDataset + +from functions import _get_output_array_size, predict_on_array +from dummy_models import * + +@pytest.fixture +def map_dataset_fixture() -> MapDataset: + """ + Creates a MapDataset with a predictable BatchGenerator for testing. + - Data is an xarray DataArray with dimensions x=20, y=10 + - Values are a simple np.arange sequence for easy verification. + - Batches are size x=10, y=5 with overlap x=2, y=2 + """ + # Using a smaller, more manageable dataset for testing + data = xr.DataArray( + data=np.arange(20 * 10).reshape(20, 10), + dims=("x", "y"), + coords={"x": np.arange(20), "y": np.arange(10)} + ).astype(float) + + bgen = xbatcher.BatchGenerator( + data, + input_dims=dict(x=10, y=5), + input_overlap=dict(x=2, y=2), + ) + return MapDataset(bgen)