diff --git a/environment.yml b/environment.yml index 2acbf791..27b43db1 100644 --- a/environment.yml +++ b/environment.yml @@ -11,3 +11,4 @@ dependencies: - matplotlib - pytorch-cpu - xbatcher + - pytest diff --git a/notebooks/inference-testing.ipynb b/notebooks/inference-testing.ipynb index 3784d6e4..2604509e 100644 --- a/notebooks/inference-testing.ipynb +++ b/notebooks/inference-testing.ipynb @@ -23,794 +23,220 @@ }, { "cell_type": "code", - "execution_count": 1, + "execution_count": 9, "metadata": {}, "outputs": [], "source": [ "import torch\n", "import xbatcher\n", "import xarray as xr\n", - "import numpy as np" + "import numpy as np\n", + "import pytest\n", + "\n", + "from functions import _get_output_array_size" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ - "## Toy data" + "## Testing the array size function" ] }, { "cell_type": "code", - "execution_count": 2, + "execution_count": 10, "metadata": {}, "outputs": [ { - "data": { - "text/html": [ - "
<xarray.DataArray (x: 100, y: 100, t: 10)> Size: 800kB\n", - "array([[[0.80102402, 0.31982866, 0.25465188, ..., 0.31021283,\n", - " 0.03268302, 0.7941252 ],\n", - " [0.4382327 , 0.60519493, 0.86055512, ..., 0.91560426,\n", - " 0.75151732, 0.045112 ],\n", - " [0.56324046, 0.16672896, 0.95892792, ..., 0.75294859,\n", - " 0.62661621, 0.9507573 ],\n", - " ...,\n", - " [0.8942051 , 0.67152092, 0.5936959 , ..., 0.88955481,\n", - " 0.00104593, 0.25334781],\n", - " [0.20207265, 0.89166716, 0.11930909, ..., 0.44454162,\n", - " 0.16769459, 0.13463857],\n", - " [0.52366905, 0.77434724, 0.18922243, ..., 0.60982095,\n", - " 0.81261007, 0.35591865]],\n", - "\n", - " [[0.36699177, 0.29720945, 0.47093257, ..., 0.08379803,\n", - " 0.52233317, 0.62779805],\n", - " [0.78408776, 0.29027363, 0.06707094, ..., 0.83199475,\n", - " 0.26807477, 0.93323152],\n", - " [0.70058356, 0.83639276, 0.82536327, ..., 0.94036962,\n", - " 0.79252184, 0.83713594],\n", - "...\n", - " [0.39766555, 0.35994714, 0.50271116, ..., 0.46796155,\n", - " 0.97140383, 0.51062791],\n", - " [0.70055487, 0.81950773, 0.76781762, ..., 0.05735074,\n", - " 0.73039909, 0.17101624],\n", - " [0.8547392 , 0.21515083, 0.40187337, ..., 0.4583594 ,\n", - " 0.53213994, 0.15382577]],\n", - "\n", - " [[0.6465284 , 0.88463142, 0.33543765, ..., 0.49702278,\n", - " 0.40990036, 0.42300881],\n", - " [0.20363819, 0.02142565, 0.66128113, ..., 0.56253861,\n", - " 0.33799959, 0.16641441],\n", - " [0.67934574, 0.38965227, 0.68431764, ..., 0.56820253,\n", - " 0.18640022, 0.111049 ],\n", - " ...,\n", - " [0.91296116, 0.6995538 , 0.50325961, ..., 0.40427649,\n", - " 0.14176453, 0.26158606],\n", - " [0.5348746 , 0.67937558, 0.22366613, ..., 0.48889003,\n", - " 0.50030053, 0.1855764 ],\n", - " [0.8714776 , 0.9552773 , 0.06885801, ..., 0.95366137,\n", - " 0.73931089, 0.90909682]]], shape=(100, 100, 10))\n", - "Dimensions without coordinates: x, y, t