diff --git a/environment.yml b/environment.yml index 2acbf791..27b43db1 100644 --- a/environment.yml +++ b/environment.yml @@ -11,3 +11,4 @@ dependencies: - matplotlib - pytorch-cpu - xbatcher + - pytest diff --git a/notebooks/inference-testing.ipynb b/notebooks/inference-testing.ipynb index 3784d6e4..2604509e 100644 --- a/notebooks/inference-testing.ipynb +++ b/notebooks/inference-testing.ipynb @@ -23,794 +23,220 @@ }, { "cell_type": "code", - "execution_count": 1, + "execution_count": 9, "metadata": {}, "outputs": [], "source": [ "import torch\n", "import xbatcher\n", "import xarray as xr\n", - "import numpy as np" + "import numpy as np\n", + "import pytest\n", + "\n", + "from functions import _get_output_array_size" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ - "## Toy data" + "## Testing the array size function" ] }, { "cell_type": "code", - "execution_count": 2, + "execution_count": 10, "metadata": {}, "outputs": [ { - "data": { - "text/html": [ - "
\n", - "\n", - "\n", - "\n", - "\n", - "\n", - "\n", - "\n", - "\n", - "\n", - "\n", - "\n", - "\n", - "\n", - "\n", - "
<xarray.DataArray (x: 100, y: 100, t: 10)> Size: 800kB\n",
-       "array([[[0.80102402, 0.31982866, 0.25465188, ..., 0.31021283,\n",
-       "         0.03268302, 0.7941252 ],\n",
-       "        [0.4382327 , 0.60519493, 0.86055512, ..., 0.91560426,\n",
-       "         0.75151732, 0.045112  ],\n",
-       "        [0.56324046, 0.16672896, 0.95892792, ..., 0.75294859,\n",
-       "         0.62661621, 0.9507573 ],\n",
-       "        ...,\n",
-       "        [0.8942051 , 0.67152092, 0.5936959 , ..., 0.88955481,\n",
-       "         0.00104593, 0.25334781],\n",
-       "        [0.20207265, 0.89166716, 0.11930909, ..., 0.44454162,\n",
-       "         0.16769459, 0.13463857],\n",
-       "        [0.52366905, 0.77434724, 0.18922243, ..., 0.60982095,\n",
-       "         0.81261007, 0.35591865]],\n",
-       "\n",
-       "       [[0.36699177, 0.29720945, 0.47093257, ..., 0.08379803,\n",
-       "         0.52233317, 0.62779805],\n",
-       "        [0.78408776, 0.29027363, 0.06707094, ..., 0.83199475,\n",
-       "         0.26807477, 0.93323152],\n",
-       "        [0.70058356, 0.83639276, 0.82536327, ..., 0.94036962,\n",
-       "         0.79252184, 0.83713594],\n",
-       "...\n",
-       "        [0.39766555, 0.35994714, 0.50271116, ..., 0.46796155,\n",
-       "         0.97140383, 0.51062791],\n",
-       "        [0.70055487, 0.81950773, 0.76781762, ..., 0.05735074,\n",
-       "         0.73039909, 0.17101624],\n",
-       "        [0.8547392 , 0.21515083, 0.40187337, ..., 0.4583594 ,\n",
-       "         0.53213994, 0.15382577]],\n",
-       "\n",
-       "       [[0.6465284 , 0.88463142, 0.33543765, ..., 0.49702278,\n",
-       "         0.40990036, 0.42300881],\n",
-       "        [0.20363819, 0.02142565, 0.66128113, ..., 0.56253861,\n",
-       "         0.33799959, 0.16641441],\n",
-       "        [0.67934574, 0.38965227, 0.68431764, ..., 0.56820253,\n",
-       "         0.18640022, 0.111049  ],\n",
-       "        ...,\n",
-       "        [0.91296116, 0.6995538 , 0.50325961, ..., 0.40427649,\n",
-       "         0.14176453, 0.26158606],\n",
-       "        [0.5348746 , 0.67937558, 0.22366613, ..., 0.48889003,\n",
-       "         0.50030053, 0.1855764 ],\n",
-       "        [0.8714776 , 0.9552773 , 0.06885801, ..., 0.95366137,\n",
-       "         0.73931089, 0.90909682]]], shape=(100, 100, 10))\n",
-       "Dimensions without coordinates: x, y, t
" - ], - "text/plain": [ - " Size: 800kB\n", - "array([[[0.80102402, 0.31982866, 0.25465188, ..., 0.31021283,\n", - " 0.03268302, 0.7941252 ],\n", - " [0.4382327 , 0.60519493, 0.86055512, ..., 0.91560426,\n", - " 0.75151732, 0.045112 ],\n", - " [0.56324046, 0.16672896, 0.95892792, ..., 0.75294859,\n", - " 0.62661621, 0.9507573 ],\n", - " ...,\n", - " [0.8942051 , 0.67152092, 0.5936959 , ..., 0.88955481,\n", - " 0.00104593, 0.25334781],\n", - " [0.20207265, 0.89166716, 0.11930909, ..., 0.44454162,\n", - " 0.16769459, 0.13463857],\n", - " [0.52366905, 0.77434724, 0.18922243, ..., 0.60982095,\n", - " 0.81261007, 0.35591865]],\n", - "\n", - " [[0.36699177, 0.29720945, 0.47093257, ..., 0.08379803,\n", - " 0.52233317, 0.62779805],\n", - " [0.78408776, 0.29027363, 0.06707094, ..., 0.83199475,\n", - " 0.26807477, 0.93323152],\n", - " [0.70058356, 0.83639276, 0.82536327, ..., 0.94036962,\n", - " 0.79252184, 0.83713594],\n", - "...\n", - " [0.39766555, 0.35994714, 0.50271116, ..., 0.46796155,\n", - " 0.97140383, 0.51062791],\n", - " [0.70055487, 0.81950773, 0.76781762, ..., 0.05735074,\n", - " 0.73039909, 0.17101624],\n", - " [0.8547392 , 0.21515083, 0.40187337, ..., 0.4583594 ,\n", - " 0.53213994, 0.15382577]],\n", - "\n", - " [[0.6465284 , 0.88463142, 0.33543765, ..., 0.49702278,\n", - " 0.40990036, 0.42300881],\n", - " [0.20363819, 0.02142565, 0.66128113, ..., 0.56253861,\n", - " 0.33799959, 0.16641441],\n", - " [0.67934574, 0.38965227, 0.68431764, ..., 0.56820253,\n", - " 0.18640022, 0.111049 ],\n", - " ...,\n", - " [0.91296116, 0.6995538 , 0.50325961, ..., 0.40427649,\n", - " 0.14176453, 0.26158606],\n", - " [0.5348746 , 0.67937558, 0.22366613, ..., 0.48889003,\n", - " 0.50030053, 0.1855764 ],\n", - " [0.8714776 , 0.9552773 , 0.06885801, ..., 0.95366137,\n", - " 0.73931089, 0.90909682]]], shape=(100, 100, 10))\n", - "Dimensions without coordinates: x, y, t" - ] - }, - "execution_count": 2, - "metadata": {}, - "output_type": "execute_result" + "name": "stdout", + "output_type": "stream", + "text": [ + "Overwriting test_get_array_size.py\n" + ] } ], "source": [ - "data = xr.DataArray(\n", - " data=np.random.rand(100, 100, 10),\n", - " dims=(\"x\", \"y\", \"t\")\n", - ")\n", - "data" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "## Simple model" - ] - }, - { - "cell_type": "code", - "execution_count": 3, - "metadata": {}, - "outputs": [], - "source": [ - "class MeanAlongDim(torch.nn.Module):\n", - " def __init__(self, ax):\n", - " super(MeanAlongDim, self).__init__()\n", - " self.ax = ax\n", + "%%writefile test_get_array_size.py\n", + "import torch\n", + "import xbatcher\n", + "import xarray as xr\n", + "import numpy as np\n", + "import pytest\n", "\n", - " def forward(self, x):\n", - " return torch.mean(x, self.ax)" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "## Batch generator, dataset" + "from functions import _get_output_array_size" ] }, { "cell_type": "code", - "execution_count": 4, + "execution_count": 11, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ - "Input shape: torch.Size([10, 10, 10])\n", - "Output shape: torch.Size([10, 10])\n" + "Appending to test_get_array_size.py\n" ] } ], "source": [ - "from xbatcher.loaders.torch import MapDataset\n", + "%%writefile -a test_get_array_size.py\n", "\n", - "bgen = xbatcher.BatchGenerator(\n", - " data,\n", - " input_dims=dict(x=10, y=10),\n", - " input_overlap=dict(x=5, y=5),\n", - ")\n", + "@pytest.fixture\n", + "def bgen_fixture() -> xbatcher.BatchGenerator:\n", + " data = xr.DataArray(\n", + " data=np.random.rand(100, 100, 10),\n", + " dims=(\"x\", \"y\", \"t\"),\n", + " coords={\n", + " \"x\": np.arange(100),\n", + " \"y\": np.arange(100),\n", + " \"t\": np.arange(10),\n", + " }\n", + " )\n", + " \n", + " bgen = xbatcher.BatchGenerator(\n", + " data,\n", + " input_dims=dict(x=10, y=10),\n", + " input_overlap=dict(x=5, y=5),\n", + " )\n", + " return bgen\n", "\n", - "ds = MapDataset(bgen)\n", - "\n", - "inp = next(iter(ds))\n", - "\n", - "# Check the input/output size of the first example\n", - "print(\"Input shape:\", inp.shape)\n", - "\n", - "mad = MeanAlongDim(-1)\n", - "print(\"Output shape:\", mad(inp).shape)" - ] - }, - { - "cell_type": "code", - "execution_count": 5, - "metadata": {}, - "outputs": [], - "source": [ - "assert torch.allclose(mad(inp), torch.mean(inp, -1))" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "## Inference function" - ] - }, - { - "cell_type": "code", - "execution_count": 6, - "metadata": {}, - "outputs": [], - "source": [ - "%run ./functions.ipynb" + "@pytest.mark.parametrize(\n", + " \"case_description, output_tensor_dim, new_dim, resample_dim, expected_output\",\n", + " [\n", + " (\n", + " \"Resampling only: Downsample x, Upsample y\",\n", + " {'x': 5, 'y': 20}, \n", + " [],\n", + " ['x', 'y'],\n", + " {'x': 50, 'y': 200} \n", + " ),\n", + " (\n", + " \"New dimensions only: Add a 'channel' dimension\",\n", + " {'channel': 3},\n", + " ['channel'],\n", + " [],\n", + " {'channel': 3}\n", + " ),\n", + " (\n", + " \"Mixed: Resample x and add new channel dimension\",\n", + " {'x': 30, 'channel': 12}, \n", + " ['channel'],\n", + " ['x'],\n", + " {'x': 300, 'channel': 12} \n", + " ),\n", + " (\n", + " \"Identity resampling (ratio=1)\",\n", + " {'x': 10, 'y': 10},\n", + " [],\n", + " ['x', 'y'],\n", + " {'x': 100, 'y': 100} \n", + " ),\n", + " (\n", + " \"Dimension not in batcher is treated as new\",\n", + " {'t': 5},\n", + " ['t'],\n", + " [],\n", + " {'t': 5}\n", + " )\n", + " \n", + " ]\n", + ")\n", + "def test_get_output_array_size_scenarios(\n", + " bgen_fixture, # The fixture is passed as an argument\n", + " case_description,\n", + " output_tensor_dim,\n", + " new_dim,\n", + " resample_dim,\n", + " expected_output\n", + "):\n", + " \"\"\"\n", + " Tests various valid scenarios for calculating the output array size.\n", + " The `case_description` parameter is not used in the code but helps make\n", + " test results more readable.\n", + " \"\"\"\n", + " # The `bgen_fixture` argument is the BatchGenerator instance created by our fixture\n", + " result = _get_output_array_size(\n", + " bgen=bgen_fixture,\n", + " output_tensor_dim=output_tensor_dim,\n", + " new_dim=new_dim,\n", + " resample_dim=resample_dim\n", + " )\n", + " \n", + " assert result == expected_output, f\"Failed on case: {case_description}\"" ] }, { "cell_type": "code", - "execution_count": 17, + "execution_count": 12, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ - "{'y': 100, 'x': 100, 't': 5}\n" + "Appending to test_get_array_size.py\n" ] } ], "source": [ - "out_size_dict = _get_output_array_size(\n", - " bgen = ds.X_generator,\n", - " output_tensor_dim = dict(y=10, x=10, t=5),\n", - " new_dim = [\"t\"],\n", - " resample_dim = [\"y\", \"x\", \"t\"]\n", - ")\n", - "print(out_size_dict)" - ] - }, - { - "cell_type": "code", - "execution_count": 11, - "metadata": {}, - "outputs": [], - "source": [ - "from functools import partial" - ] - }, - { - "cell_type": "code", - "execution_count": 21, - "metadata": {}, - "outputs": [], - "source": [ - "get_array_size_partial = partial(_get_output_array_size, bgen=ds.X_generator)" - ] - }, - { - "cell_type": "code", - "execution_count": 18, - "metadata": {}, - "outputs": [], - "source": [ - "test_cases = [\n", - " {\n", - " \"name\": \"Same dims and same input sizes 1\",\n", - " \"function_inputs\": {\n", - " \"output_tensor_dim\": dict(y=10, x=10),\n", - " \"new_dim\": [],\n", - " \"resample_dim\": [\"x\", \"y\"]\n", - " },\n", - " \"expected_output\": dict(y=100, x=100)\n", - " },\n", - " {\n", - " \"name\": \"New dim added\",\n", - " \"function_inputs\": {\n", - " \"output_tensor_dim\": dict(y=10, x=10, t=5),\n", - " \"new_dim\": [\"t\"],\n", - " \"resample_dim\": [\"y\", \"x\"]\n", - " },\n", - " \"expected_output\": dict(y=100, x=100, t=5)\n", - " }\n", - "] " + "%%writefile -a test_get_array_size.py\n", + "\n", + "def test_get_output_array_size_raises_assertion_error_on_non_integer_size():\n", + " \"\"\"\n", + " Tests that the function raises an AssertionError when the resampling\n", + " calculation results in a non-integer output dimension size.\n", + " \"\"\"\n", + " # DataArray size for 'x' is 101.\n", + " data_for_error = xr.DataArray(\n", + " data=np.random.rand(101, 100, 10),\n", + " dims=(\"x\", \"y\", \"t\")\n", + " )\n", + " \n", + " bgen = xbatcher.BatchGenerator(data_for_error, input_dims={'x': 10})\n", + " \n", + " # The resampling logic will be: 101 * (5 / 10) = 50.5, which is not an integer.\n", + " output_tensor_dim = {'x': 5}\n", + " \n", + " with pytest.raises(AssertionError):\n", + " _get_output_array_size(\n", + " bgen=bgen,\n", + " output_tensor_dim=output_tensor_dim,\n", + " new_dim=[],\n", + " resample_dim=['x']\n", + " )" ] }, { "cell_type": "code", - "execution_count": 22, + "execution_count": 13, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ - "Test case 0 passed\n", - "Test case 1 passed\n" + "\u001b[1m============================= test session starts ==============================\u001b[0m\n", + "platform linux -- Python 3.12.11, pytest-8.4.1, pluggy-1.6.0 -- /srv/conda/envs/notebook/bin/python3.12\n", + "cachedir: .pytest_cache\n", + "rootdir: /home/jovyan/xbatcher-deep-learning/notebooks\n", + "plugins: anyio-4.9.0, hydra-core-1.3.2, jaxtyping-0.3.2\n", + "collected 6 items \u001b[0m\u001b[1m\n", + "\n", + "test_get_array_size.py::test_get_output_array_size_scenarios[Resampling only: Downsample x, Upsample y-output_tensor_dim0-new_dim0-resample_dim0-expected_output0] \u001b[32mPASSED\u001b[0m\u001b[32m [ 16%]\u001b[0m\n", + "test_get_array_size.py::test_get_output_array_size_scenarios[New dimensions only: Add a 'channel' dimension-output_tensor_dim1-new_dim1-resample_dim1-expected_output1] \u001b[32mPASSED\u001b[0m\u001b[32m [ 33%]\u001b[0m\n", + "test_get_array_size.py::test_get_output_array_size_scenarios[Mixed: Resample x and add new channel dimension-output_tensor_dim2-new_dim2-resample_dim2-expected_output2] \u001b[32mPASSED\u001b[0m\u001b[32m [ 50%]\u001b[0m\n", + "test_get_array_size.py::test_get_output_array_size_scenarios[Identity resampling (ratio=1)-output_tensor_dim3-new_dim3-resample_dim3-expected_output3] \u001b[32mPASSED\u001b[0m\u001b[32m [ 66%]\u001b[0m\n", + "test_get_array_size.py::test_get_output_array_size_scenarios[Dimension not in batcher is treated as new-output_tensor_dim4-new_dim4-resample_dim4-expected_output4] \u001b[32mPASSED\u001b[0m\u001b[32m [ 83%]\u001b[0m\n", + "test_get_array_size.py::test_get_output_array_size_raises_assertion_error_on_non_integer_size \u001b[32mPASSED\u001b[0m\u001b[32m [100%]\u001b[0m\n", + "\n", + "\u001b[32m============================== \u001b[32m\u001b[1m6 passed\u001b[0m\u001b[32m in 2.52s\u001b[0m\u001b[32m ===============================\u001b[0m\n" ] } ], "source": [ - "for i, test_case in enumerate(test_cases):\n", - " true_output = get_array_size_partial(**test_case[\"function_inputs\"])\n", - " success = true_output == test_case[\"expected_output\"]\n", - " message = \"passed\" if success else \"failed\"\n", - " print(f\"Test case {i} {message}\")" + "!pytest -v" ] }, { @@ -837,7 +263,7 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.12.11" + "version": "3.10.16" }, "nbdime-conflicts": { "local_diff": [ diff --git a/notebooks/test_get_array_size.py b/notebooks/test_get_array_size.py new file mode 100644 index 00000000..e3e42aa9 --- /dev/null +++ b/notebooks/test_get_array_size.py @@ -0,0 +1,114 @@ +import torch +import xbatcher +import xarray as xr +import numpy as np +import pytest + +from functions import _get_output_array_size + +@pytest.fixture +def bgen_fixture() -> xbatcher.BatchGenerator: + data = xr.DataArray( + data=np.random.rand(100, 100, 10), + dims=("x", "y", "t"), + coords={ + "x": np.arange(100), + "y": np.arange(100), + "t": np.arange(10), + } + ) + + bgen = xbatcher.BatchGenerator( + data, + input_dims=dict(x=10, y=10), + input_overlap=dict(x=5, y=5), + ) + return bgen + +@pytest.mark.parametrize( + "case_description, output_tensor_dim, new_dim, resample_dim, expected_output", + [ + ( + "Resampling only: Downsample x, Upsample y", + {'x': 5, 'y': 20}, + [], + ['x', 'y'], + {'x': 50, 'y': 200} + ), + ( + "New dimensions only: Add a 'channel' dimension", + {'channel': 3}, + ['channel'], + [], + {'channel': 3} + ), + ( + "Mixed: Resample x and add new channel dimension", + {'x': 30, 'channel': 12}, + ['channel'], + ['x'], + {'x': 300, 'channel': 12} + ), + ( + "Identity resampling (ratio=1)", + {'x': 10, 'y': 10}, + [], + ['x', 'y'], + {'x': 100, 'y': 100} + ), + ( + "Dimension not in batcher is treated as new", + {'t': 5}, + ['t'], + [], + {'t': 5} + ) + + ] +) +def test_get_output_array_size_scenarios( + bgen_fixture, # The fixture is passed as an argument + case_description, + output_tensor_dim, + new_dim, + resample_dim, + expected_output +): + """ + Tests various valid scenarios for calculating the output array size. + The `case_description` parameter is not used in the code but helps make + test results more readable. + """ + # The `bgen_fixture` argument is the BatchGenerator instance created by our fixture + result = _get_output_array_size( + bgen=bgen_fixture, + output_tensor_dim=output_tensor_dim, + new_dim=new_dim, + resample_dim=resample_dim + ) + + assert result == expected_output, f"Failed on case: {case_description}" + +def test_get_output_array_size_raises_assertion_error_on_non_integer_size(): + """ + Tests that the function raises an AssertionError when the resampling + calculation results in a non-integer output dimension size. + """ + # DataArray size for 'x' is 101. + data_for_error = xr.DataArray( + data=np.random.rand(101, 100, 10), + dims=("x", "y", "t") + ) + + bgen = xbatcher.BatchGenerator(data_for_error, input_dims={'x': 10}) + + # The resampling logic will be: 101 * (5 / 10) = 50.5, which is not an integer. + output_tensor_dim = {'x': 5} + + with pytest.raises(AssertionError): + _get_output_array_size( + bgen=bgen, + output_tensor_dim=output_tensor_dim, + new_dim=[], + resample_dim=['x'] + )