diff --git a/environment.yml b/environment.yml index 2acbf791..27b43db1 100644 --- a/environment.yml +++ b/environment.yml @@ -11,3 +11,4 @@ dependencies: - matplotlib - pytorch-cpu - xbatcher + - pytest diff --git a/notebooks/functions.py b/notebooks/functions.py index b7f01c35..c5704fb4 100644 --- a/notebooks/functions.py +++ b/notebooks/functions.py @@ -14,8 +14,10 @@ def _get_resample_factor( resample_factor = {} for dim in resample_dim: r = output_tensor_dim[dim] / bgen.input_dims[dim] - assert r.is_integer() or (r ** -1).is_integer() - resample_factor[dim] = output_tensor_dim[dim] / bgen.input_dims[dim] + is_int = (r == int(r)) + is_inv_int = (1/r == int(1/r)) if r != 0 else False + assert is_int or is_inv_int, f"Resample ratio for dim '{dim}' must be an integer or its inverse." + resample_factor[dim] = r return resample_factor @@ -39,8 +41,8 @@ def _get_output_array_size( # determined by the source array if output_tensor_dim[key] != bgen.ds.sizes[key]: raise ValueError( - f"Axis {key} is a core dim, but the tensor size " - f"({output_tensor_dim[key]}) does not equal the" + f"Axis {key} is a core dim, but the tensor size" + f"({output_tensor_dim[key]}) does not equal the " f"source data array size ({bgen.ds.sizes[key]})." ) output_size[key] = bgen.ds.sizes[key] @@ -48,10 +50,10 @@ def _get_output_array_size( # This is a resampled axis, determine the new size # by the resample factor. temp_output_size = bgen.ds.sizes[key] * resample_factor[key] - assert temp_output_size.is_integer() + assert temp_output_size.is_integer(), f"Resampling for dim '{key}' results in non-integer size." output_size[key] = int(temp_output_size) else: - raise ValueError(f"Axis {dim} must be specified in one of new_dim, core_dim, or resample_dim") + raise ValueError(f"Axis {key} must be specified in one of new_dim, core_dim, or resample_dim") return output_size diff --git a/notebooks/inference-testing.ipynb b/notebooks/inference-testing.ipynb index 3784d6e4..3e72ec22 100644 --- a/notebooks/inference-testing.ipynb +++ b/notebooks/inference-testing.ipynb @@ -23,794 +23,632 @@ }, { "cell_type": "code", - "execution_count": 1, + "execution_count": 11, "metadata": {}, "outputs": [], "source": [ "import torch\n", "import xbatcher\n", "import xarray as xr\n", - "import numpy as np" + "import numpy as np\n", + "import pytest\n", + "\n", + "from functions import _get_output_array_size, predict_on_array" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ - "## Toy data" + "## Testing the array size function" ] }, { "cell_type": "code", - "execution_count": 2, + "execution_count": 24, "metadata": {}, "outputs": [ { - "data": { - "text/html": [ - "
<xarray.DataArray (x: 100, y: 100, t: 10)> Size: 800kB\n", - "array([[[0.80102402, 0.31982866, 0.25465188, ..., 0.31021283,\n", - " 0.03268302, 0.7941252 ],\n", - " [0.4382327 , 0.60519493, 0.86055512, ..., 0.91560426,\n", - " 0.75151732, 0.045112 ],\n", - " [0.56324046, 0.16672896, 0.95892792, ..., 0.75294859,\n", - " 0.62661621, 0.9507573 ],\n", - " ...,\n", - " [0.8942051 , 0.67152092, 0.5936959 , ..., 0.88955481,\n", - " 0.00104593, 0.25334781],\n", - " [0.20207265, 0.89166716, 0.11930909, ..., 0.44454162,\n", - " 0.16769459, 0.13463857],\n", - " [0.52366905, 0.77434724, 0.18922243, ..., 0.60982095,\n", - " 0.81261007, 0.35591865]],\n", - "\n", - " [[0.36699177, 0.29720945, 0.47093257, ..., 0.08379803,\n", - " 0.52233317, 0.62779805],\n", - " [0.78408776, 0.29027363, 0.06707094, ..., 0.83199475,\n", - " 0.26807477, 0.93323152],\n", - " [0.70058356, 0.83639276, 0.82536327, ..., 0.94036962,\n", - " 0.79252184, 0.83713594],\n", - "...\n", - " [0.39766555, 0.35994714, 0.50271116, ..., 0.46796155,\n", - " 0.97140383, 0.51062791],\n", - " [0.70055487, 0.81950773, 0.76781762, ..., 0.05735074,\n", - " 0.73039909, 0.17101624],\n", - " [0.8547392 , 0.21515083, 0.40187337, ..., 0.4583594 ,\n", - " 0.53213994, 0.15382577]],\n", - "\n", - " [[0.6465284 , 0.88463142, 0.33543765, ..., 0.49702278,\n", - " 0.40990036, 0.42300881],\n", - " [0.20363819, 0.02142565, 0.66128113, ..., 0.56253861,\n", - " 0.33799959, 0.16641441],\n", - " [0.67934574, 0.38965227, 0.68431764, ..., 0.56820253,\n", - " 0.18640022, 0.111049 ],\n", - " ...,\n", - " [0.91296116, 0.6995538 , 0.50325961, ..., 0.40427649,\n", - " 0.14176453, 0.26158606],\n", - " [0.5348746 , 0.67937558, 0.22366613, ..., 0.48889003,\n", - " 0.50030053, 0.1855764 ],\n", - " [0.8714776 , 0.9552773 , 0.06885801, ..., 0.95366137,\n", - " 0.73931089, 0.90909682]]], shape=(100, 100, 10))\n", - "Dimensions without coordinates: x, y, t