diff --git a/notebooks/dummy_models.py b/notebooks/dummy_models.py index 9bb1533f..8a6da171 100644 --- a/notebooks/dummy_models.py +++ b/notebooks/dummy_models.py @@ -1,5 +1,9 @@ import torch +class Identity(torch.nn.Module): + def __init__(self): super().__init__() + def forward(self, x): return x + class MeanAlongDim(torch.nn.Module): def __init__(self, ax): super(MeanAlongDim, self).__init__() diff --git a/notebooks/functions.py b/notebooks/functions.py index c5704fb4..2b9a2101 100644 --- a/notebooks/functions.py +++ b/notebooks/functions.py @@ -74,7 +74,8 @@ def _resample_coordinate( offset = 0 if mode == "edges" else old_step / 2 new_step = old_step / factor coord = coord - offset - return np.arange(coord.min().item(), coord.max().item()+old_step, step=new_step) + offset + new_coord_end = coord.max().item() + old_step + return np.arange(coord.min().item(), new_coord_end, step=new_step) + offset def _get_output_array_coordinates( @@ -90,7 +91,7 @@ def _get_output_array_coordinates( output_coords[dim] = _resample_coordinate(src_da[dim], resample_factor[dim], resample_mode) elif dim in src_da.coords: # Source array has coordinate but it isn't changing size - output_coords[dim] = src_da[dim].copy() + output_coords[dim] = src_da[dim].copy(deep=True).data else: # Source array doesn't have a coordinate on this dim or # this is a new dim, ignore @@ -168,22 +169,28 @@ def predict_on_array( Overlaps are allowed, in which case the average of all output values is returned. ''' - # TODO input checking - # *_dim args cannot have common axes + s_new = set(new_dim) + s_core = set(core_dim) + s_resample = set(resample_dim) + + if s_new & s_core or s_new & s_resample or s_core & s_resample: + raise ValueError("new_dim, core_dim, and resample_dim must be disjoint sets.") + + bgen = dataset.X_generator # Get resample factors resample_factor = _get_resample_factor( - dataset.X_generator, + bgen, output_tensor_dim, resample_dim ) # Set up output array output_size = _get_output_array_size( - dataset.X_generator, - output_tensor_dim, - new_dim, - core_dim, + bgen, + output_tensor_dim, + new_dim, + core_dim, resample_dim ) @@ -198,12 +205,14 @@ def predict_on_array( # Iterate over each batch for i, batch in enumerate(loader): - out_batch = model(batch).detach().numpy() + input_tensor = batch[0] if isinstance(batch, (list, tuple)) else batch + out_batch = model(input_tensor).detach().numpy() - # Iterate over each example in the batch + # Iterate over each sample in the batch for ib in range(out_batch.shape[0]): - # Get the slice object associated with this example - old_indexer = dataset.X_generator._batch_selectors.selectors[(i*batch_size)+ib][0] + # Get the slice object associated with this sample + global_index = (i * batch_size) + ib + old_indexer = bgen._batch_selectors.selectors[global_index][0] # Only index into axes that are resampled, rescaling the bounds # Perhaps use xbatcher _gen_slices here? new_indexer = {} @@ -213,7 +222,7 @@ def predict_on_array( int(old_indexer[key].start * resample_factor[key]), int(old_indexer[key].stop * resample_factor[key]) ) - + output_da.loc[new_indexer] += out_batch[ib, ...] output_n.loc[new_indexer] += 1 diff --git a/notebooks/inference-testing.ipynb b/notebooks/inference-testing.ipynb index 3e72ec22..56000811 100644 --- a/notebooks/inference-testing.ipynb +++ b/notebooks/inference-testing.ipynb @@ -23,7 +23,7 @@ }, { "cell_type": "code", - "execution_count": 11, + "execution_count": 1, "metadata": {}, "outputs": [], "source": [ @@ -45,7 +45,7 @@ }, { "cell_type": "code", - "execution_count": 24, + "execution_count": 2, "metadata": {}, "outputs": [ { @@ -69,7 +69,7 @@ }, { "cell_type": "code", - "execution_count": 28, + "execution_count": 3, "metadata": {}, "outputs": [ { @@ -123,11 +123,11 @@ " ),\n", " (\n", " \"Mixed: Resample x, add new channel dimension and keep t as core\",\n", - " {'x': 30, 'channel': 12}, \n", + " {'x': 30, 'channel': 12, 't': 10}, \n", " ['channel'],\n", " ['t'],\n", " ['x'],\n", - " {'x': 300, 'channel': 12} \n", + " {'x': 300, 'channel': 12, 't': 10} \n", " ),\n", " (\n", " \"Identity resampling (ratio=1)\",\n", @@ -175,7 +175,7 @@ }, { "cell_type": "code", - "execution_count": 29, + "execution_count": 4, "metadata": {}, "outputs": [ { @@ -212,7 +212,7 @@ }, { "cell_type": "code", - "execution_count": 30, + "execution_count": 5, "metadata": {}, "outputs": [ { @@ -235,7 +235,7 @@ "test_get_array_size.py::test_get_output_array_size_raises_error_on_unspecified_dim \u001b[32mPASSED\u001b[0m\u001b[32m [ 87%]\u001b[0m\n", "test_get_array_size.py::test_get_resample_factor_raises_error_on_invalid_ratio \u001b[32mPASSED\u001b[0m\u001b[32m [100%]\u001b[0m\n", "\n", - "\u001b[32m============================== \u001b[32m\u001b[1m8 passed\u001b[0m\u001b[32m in 1.23s\u001b[0m\u001b[32m ===============================\u001b[0m\n" + "\u001b[32m============================== \u001b[32m\u001b[1m8 passed\u001b[0m\u001b[32m in 1.00s\u001b[0m\u001b[32m ===============================\u001b[0m\n" ] } ], @@ -252,7 +252,7 @@ }, { "cell_type": "code", - "execution_count": 19, + "execution_count": 6, "metadata": {}, "outputs": [ { @@ -272,13 +272,14 @@ "import pytest\n", "from xbatcher.loaders.torch import MapDataset\n", "\n", - "from functions import _get_output_array_size, predict_on_array\n", - "from dummy_models import *" + "from functions import _get_output_array_size, _resample_coordinate\n", + "from functions import predict_on_array, _get_resample_factor\n", + "from dummy_models import Identity, MeanAlongDim, SubsetAlongAxis, ExpandAlongAxis, AddAxis" ] }, { "cell_type": "code", - "execution_count": 20, + "execution_count": 7, "metadata": {}, "outputs": [], "source": [ @@ -289,13 +290,13 @@ "import pytest\n", "from xbatcher.loaders.torch import MapDataset\n", "\n", - "from functions import _get_output_array_size, predict_on_array\n", + "from functions import *\n", "from dummy_models import *" ] }, { "cell_type": "code", - "execution_count": 21, + "execution_count": 8, "metadata": {}, "outputs": [ { @@ -304,7 +305,7 @@ "tensor([0., 1., 2., 3., 4.])" ] }, - "execution_count": 21, + "execution_count": 8, "metadata": {}, "output_type": "execute_result" } @@ -316,32 +317,28 @@ }, { "cell_type": "code", - "execution_count": 22, + "execution_count": 9, "metadata": {}, "outputs": [ { "data": { "text/plain": [ - "tensor([[ 2., 7., 12., 17., 22.],\n", - " [ 27., 32., 37., 42., 47.],\n", - " [ 52., 57., 62., 67., 72.],\n", - " [ 77., 82., 87., 92., 97.],\n", - " [102., 107., 112., 117., 122.]])" + "torch.Size([5, 10, 5])" ] }, - "execution_count": 22, + "execution_count": 9, "metadata": {}, "output_type": "execute_result" } ], "source": [ - "model = MeanAlongDim(-1)\n", - "model(input_tensor)" + "model = ExpandAlongAxis(1, 2)\n", + "model(input_tensor).shape" ] }, { "cell_type": "code", - "execution_count": 23, + "execution_count": 10, "metadata": {}, "outputs": [ { @@ -357,49 +354,705 @@ "\n", "@pytest.fixture\n", "def map_dataset_fixture() -> MapDataset:\n", - " \"\"\"\n", - " Creates a MapDataset with a predictable BatchGenerator for testing.\n", - " - Data is an xarray DataArray with dimensions x=20, y=10\n", - " - Values are a simple np.arange sequence for easy verification.\n", - " - Batches are size x=10, y=5 with overlap x=2, y=2\n", - " \"\"\"\n", - " # Using a smaller, more manageable dataset for testing\n", " data = xr.DataArray(\n", - " data=np.arange(20 * 10).reshape(20, 10),\n", + " data=np.arange(20 * 10).reshape(20, 10).astype(np.float32),\n", " dims=(\"x\", \"y\"),\n", - " coords={\"x\": np.arange(20), \"y\": np.arange(10)}\n", - " ).astype(float)\n", - " \n", - " bgen = xbatcher.BatchGenerator(\n", - " data,\n", - " input_dims=dict(x=10, y=5),\n", - " input_overlap=dict(x=2, y=2),\n", + " coords={\"x\": np.arange(20, dtype=float), \"y\": np.arange(10, dtype=float)},\n", " )\n", - " return MapDataset(bgen)" + " bgen = xbatcher.BatchGenerator(data, input_dims=dict(x=10, y=5), input_overlap=dict(x=2, y=2))\n", + " return MapDataset(bgen)\n" ] }, { "cell_type": "code", - "execution_count": null, + "execution_count": 11, "metadata": {}, "outputs": [], "source": [ - " data = xr.DataArray(\n", - " data=np.arange(20 * 10).reshape(20, 10),\n", - " dims=(\"x\", \"y\"),\n", - " coords={\"x\": np.arange(20), \"y\": np.arange(10)}\n", - " ).astype(float)\n", - " \n", - " bgen = xbatcher.BatchGenerator(\n", - " data,\n", - " input_dims=dict(x=10, y=5),\n", - " input_overlap=dict(x=2, y=2),\n", - " )" + "data = xr.DataArray(\n", + " data=np.arange(20 * 10).reshape(20, 10),\n", + " dims=(\"x\", \"y\"),\n", + " coords={\"x\": np.arange(20), \"y\": np.arange(10)}\n", + ").astype(float)\n", + "\n", + "bgen = xbatcher.BatchGenerator(\n", + " data,\n", + " input_dims=dict(x=10, y=5),\n", + " input_overlap=dict(x=2, y=2)\n", + ")" + ] + }, + { + "cell_type": "code", + "execution_count": 12, + "metadata": {}, + "outputs": [], + "source": [ + "ds = MapDataset(bgen)" ] }, { "cell_type": "code", - "execution_count": 24, + "execution_count": 13, + "metadata": {}, + "outputs": [ + { + "data": { + "text/html": [ + "
\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "
<xarray.DataArray (x: 20, y: 10)> Size: 2kB\n",
+       "array([[  0.,   1.,   2.,   3.,   4.,   5.,   6.,   7.,   8.,   9.],\n",
+       "       [ 10.,  11.,  12.,  13.,  14.,  15.,  16.,  17.,  18.,  19.],\n",
+       "       [ 20.,  21.,  22.,  23.,  24.,  25.,  26.,  27.,  28.,  29.],\n",
+       "       [ 30.,  31.,  32.,  33.,  34.,  35.,  36.,  37.,  38.,  39.],\n",
+       "       [ 40.,  41.,  42.,  43.,  44.,  45.,  46.,  47.,  48.,  49.],\n",
+       "       [ 50.,  51.,  52.,  53.,  54.,  55.,  56.,  57.,  58.,  59.],\n",
+       "       [ 60.,  61.,  62.,  63.,  64.,  65.,  66.,  67.,  68.,  69.],\n",
+       "       [ 70.,  71.,  72.,  73.,  74.,  75.,  76.,  77.,  78.,  79.],\n",
+       "       [ 80.,  81.,  82.,  83.,  84.,  85.,  86.,  87.,  88.,  89.],\n",
+       "       [ 90.,  91.,  92.,  93.,  94.,  95.,  96.,  97.,  98.,  99.],\n",
+       "       [100., 101., 102., 103., 104., 105., 106., 107., 108., 109.],\n",
+       "       [110., 111., 112., 113., 114., 115., 116., 117., 118., 119.],\n",
+       "       [120., 121., 122., 123., 124., 125., 126., 127., 128., 129.],\n",
+       "       [130., 131., 132., 133., 134., 135., 136., 137., 138., 139.],\n",
+       "       [140., 141., 142., 143., 144., 145., 146., 147., 148., 149.],\n",
+       "       [150., 151., 152., 153., 154., 155., 156., 157., 158., 159.],\n",
+       "       [160., 161., 162., 163., 164., 165., 166., 167., 168., 169.],\n",
+       "       [170., 171., 172., 173., 174., 175., 176., 177., 178., 179.],\n",
+       "       [180., 181., 182., 183., 184., 185., 186., 187., 188., 189.],\n",
+       "       [190., 191., 192., 193., 194., 195., 196., 197., 198., 199.]])\n",
+       "Coordinates:\n",
+       "  * x        (x) int64 160B 0 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19\n",
+       "  * y        (y) int64 80B 0 1 2 3 4 5 6 7 8 9
" + ], + "text/plain": [ + " Size: 2kB\n", + "array([[ 0., 1., 2., 3., 4., 5., 6., 7., 8., 9.],\n", + " [ 10., 11., 12., 13., 14., 15., 16., 17., 18., 19.],\n", + " [ 20., 21., 22., 23., 24., 25., 26., 27., 28., 29.],\n", + " [ 30., 31., 32., 33., 34., 35., 36., 37., 38., 39.],\n", + " [ 40., 41., 42., 43., 44., 45., 46., 47., 48., 49.],\n", + " [ 50., 51., 52., 53., 54., 55., 56., 57., 58., 59.],\n", + " [ 60., 61., 62., 63., 64., 65., 66., 67., 68., 69.],\n", + " [ 70., 71., 72., 73., 74., 75., 76., 77., 78., 79.],\n", + " [ 80., 81., 82., 83., 84., 85., 86., 87., 88., 89.],\n", + " [ 90., 91., 92., 93., 94., 95., 96., 97., 98., 99.],\n", + " [100., 101., 102., 103., 104., 105., 106., 107., 108., 109.],\n", + " [110., 111., 112., 113., 114., 115., 116., 117., 118., 119.],\n", + " [120., 121., 122., 123., 124., 125., 126., 127., 128., 129.],\n", + " [130., 131., 132., 133., 134., 135., 136., 137., 138., 139.],\n", + " [140., 141., 142., 143., 144., 145., 146., 147., 148., 149.],\n", + " [150., 151., 152., 153., 154., 155., 156., 157., 158., 159.],\n", + " [160., 161., 162., 163., 164., 165., 166., 167., 168., 169.],\n", + " [170., 171., 172., 173., 174., 175., 176., 177., 178., 179.],\n", + " [180., 181., 182., 183., 184., 185., 186., 187., 188., 189.],\n", + " [190., 191., 192., 193., 194., 195., 196., 197., 198., 199.]])\n", + "Coordinates:\n", + " * x (x) int64 160B 0 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19\n", + " * y (y) int64 80B 0 1 2 3 4 5 6 7 8 9" + ] + }, + "execution_count": 13, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "data" + ] + }, + { + "cell_type": "code", + "execution_count": 14, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "tensor([[ 3., 4., 5., 6., 7.],\n", + " [13., 14., 15., 16., 17.],\n", + " [23., 24., 25., 26., 27.],\n", + " [33., 34., 35., 36., 37.],\n", + " [43., 44., 45., 46., 47.],\n", + " [53., 54., 55., 56., 57.],\n", + " [63., 64., 65., 66., 67.],\n", + " [73., 74., 75., 76., 77.],\n", + " [83., 84., 85., 86., 87.],\n", + " [93., 94., 95., 96., 97.]], dtype=torch.float64)" + ] + }, + "execution_count": 14, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "ds[1]" + ] + }, + { + "cell_type": "code", + "execution_count": 15, + "metadata": {}, + "outputs": [], + "source": [ + "output_tensor_dim = {'x': 20, 'y': 5}\n", + "resample_dim = ['x', 'y']\n", + "core_dim = []\n", + "new_dim = []" + ] + }, + { + "cell_type": "code", + "execution_count": 16, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "torch.Size([10, 5])" + ] + }, + "execution_count": 16, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "ds[0].shape" + ] + }, + { + "cell_type": "code", + "execution_count": 17, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "torch.Size([10, 10])" + ] + }, + "execution_count": 17, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "model(ds[0]).shape" + ] + }, + { + "cell_type": "code", + "execution_count": 18, + "metadata": {}, + "outputs": [], + "source": [ + "import functions\n", + "from importlib import reload\n", + "reload(functions)\n", + "result = functions.predict_on_array(\n", + " ds,\n", + " model,\n", + " output_tensor_dim=output_tensor_dim,\n", + " new_dim=new_dim,\n", + " core_dim=core_dim,\n", + " resample_dim=resample_dim,\n", + " batch_size=4\n", + ")" + ] + }, + { + "cell_type": "code", + "execution_count": 19, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Appending to test_predict_on_array.py\n" + ] + } + ], + "source": [ + "%%writefile -a test_predict_on_array.py\n", + "\n", + "@pytest.mark.parametrize(\"factor, mode, expected\", [\n", + " (2.0, \"edges\", np.arange(0, 10, 0.5)),\n", + " (0.5, \"edges\", np.arange(0, 10, 2.0)),\n", + "])\n", + "def test_resample_coordinate(factor, mode, expected):\n", + " coord = xr.DataArray(np.arange(10, dtype=float), dims=\"x\")\n", + " resampled = _resample_coordinate(coord, factor, mode)\n", + " np.testing.assert_allclose(resampled, expected)" + ] + }, + { + "cell_type": "code", + "execution_count": 20, "metadata": {}, "outputs": [ { @@ -414,86 +1067,91 @@ "%%writefile -a test_predict_on_array.py\n", "\n", "@pytest.mark.parametrize(\n", - " \"model, output_tensor_dim, new_dim, resample_dim, expected_transform\",\n", + " \"model, output_tensor_dim, new_dim, core_dim, resample_dim, manual_transform\",\n", " [\n", - " # Case 1: Resampling - Downsampling with a subset model\n", + " # Case 1: Identity - No change\n", " (\n", - " SubsetAlongAxis(ax=1, n=5), # Corresponds to 'x' dim in batch\n", + " Identity(),\n", + " {'x': 10, 'y': 5},\n", + " [], [], ['x', 'y'],\n", + " lambda da: da.data\n", + " ),\n", + " # Case 2: ExpandAlongAxis - Upsampling\n", + " (\n", + " ExpandAlongAxis(ax=1, n_repeats=2), # ax=1 is 'x'\n", + " {'x': 20, 'y': 5},\n", + " [], [], ['x', 'y'],\n", + " lambda da: da.data.repeat(2, axis=0) # axis=0 in the 2D numpy array\n", + " ),\n", + " # Case 3: SubsetAlongAxis - Coarsening\n", + " (\n", + " SubsetAlongAxis(ax=1, n=5), # ax=1 is 'x'\n", " {'x': 5, 'y': 5},\n", - " [],\n", - " ['x'],\n", - " lambda da: da.isel(x=slice(0, 5)) # Expected: take first 5 elements of original 'x'\n", + " [], [], ['x', 'y'],\n", + " lambda da: da.isel(x=slice(0, 5)).data\n", " ),\n", - " # Case 2: Dimension reduction with a mean model\n", + " # Case 4: MeanAlongDim - Dimension reduction\n", " (\n", - " MeanAlongDim(ax=2), # Corresponds to 'y' dim in batch\n", + " MeanAlongDim(ax=2), # ax=2 is 'y'\n", " {'x': 10},\n", - " [],\n", - " ['x'],\n", - " lambda da: da.mean(dim='y') # Expected: mean along original 'y'\n", + " [], [], ['x'],\n", + " lambda da: da.mean(dim='y').data\n", + " ),\n", + " # Case 5: AddAxis - Add a new dimension\n", + " (\n", + " AddAxis(ax=1), # Add new dim at axis 1\n", + " {'channel': 1, 'x': 10, 'y': 5},\n", + " ['channel'], [], ['x', 'y'],\n", + " lambda da: np.expand_dims(da.data, axis=0)\n", " ),\n", " ]\n", ")\n", - "def test_predict_on_array_reassembly(\n", - " map_dataset_fixture,\n", - " model,\n", - " output_tensor_dim,\n", - " new_dim,\n", - " resample_dim,\n", - " expected_transform\n", + "def test_predict_on_array_all_models(\n", + " map_dataset_fixture, model, output_tensor_dim, new_dim, core_dim, resample_dim, manual_transform\n", "):\n", " \"\"\"\n", - " Tests that predict_on_array correctly reassembles batches from different models.\n", + " Tests reassembly, averaging, and coordinate assignment using a variety of models.\n", " \"\"\"\n", + " dataset = map_dataset_fixture\n", + " bgen = dataset.X_generator\n", + " resample_factor = _get_resample_factor(bgen, output_tensor_dim, resample_dim)\n", + "\n", " # --- Run the function under test ---\n", - " # Using a small batch_size to ensure multiple iterations\n", - " predicted_da, predicted_n = predict_on_array(\n", - " dataset=map_dataset_fixture,\n", - " model=model,\n", - " output_tensor_dim=output_tensor_dim,\n", - " new_dim=new_dim,\n", - " resample_dim=resample_dim,\n", - " batch_size=4 \n", + " result_da = predict_on_array(\n", + " dataset=dataset, model=model, output_tensor_dim=output_tensor_dim,\n", + " new_dim=new_dim, core_dim=core_dim, resample_dim=resample_dim, batch_size=4\n", " )\n", "\n", " # --- Manually calculate the expected result ---\n", - " bgen = map_dataset_fixture.generator\n", - " # 1. Create the expected output array structure\n", - " expected_size = _get_output_array_size(bgen, output_tensor_dim, new_dim, resample_dim)\n", - " expected_da = xr.DataArray(np.zeros(list(expected_size.values())), dims=list(expected_size.keys()))\n", - " expected_n = xr.full_like(expected_da, 0)\n", + " expected_size = _get_output_array_size(bgen, output_tensor_dim, new_dim, core_dim, resample_dim)\n", + " expected_sum = xr.DataArray(np.zeros(list(expected_size.values())), dims=list(expected_size.keys()))\n", + " expected_count = xr.full_like(expected_sum, 0, dtype=int)\n", "\n", - " # 2. Manually iterate through batches and apply the same logic as the function\n", - " for i in range(len(map_dataset_fixture)):\n", + " for i in range(len(dataset)):\n", " batch_da = bgen[i]\n", - " \n", - " # Apply the same transformation the model would\n", - " transformed_batch = expected_transform(batch_da)\n", - " \n", - " # Get the rescaled indexer\n", - " old_indexer = bgen.batch_selectors[i]\n", + " old_indexer = bgen._batch_selectors.selectors[i][0]\n", " new_indexer = {}\n", " for key in old_indexer:\n", " if key in resample_dim:\n", - " resample_ratio = output_tensor_dim[key] / bgen.input_dims[key]\n", - " new_indexer[key] = slice(\n", - " int(old_indexer[key].start * resample_ratio),\n", - " int(old_indexer[key].stop * resample_ratio)\n", - " )\n", - " \n", - " # Add the result to our manually calculated array\n", - " expected_da.loc[new_indexer] += transformed_batch.values\n", - " expected_n.loc[new_indexer] += 1\n", + " new_indexer[key] = slice(int(old_indexer[key].start * resample_factor.get(key, 1)), int(old_indexer[key].stop * resample_factor.get(key, 1)))\n", + " elif key in core_dim:\n", + " new_indexer[key] = old_indexer[key]\n", "\n", - " # --- Assert that the results are identical ---\n", - " # We test the raw summed output and the overlap counter array\n", - " xr.testing.assert_allclose(predicted_da, expected_da)\n", - " xr.testing.assert_allclose(predicted_n, expected_n)" + " model_output_on_batch = manual_transform(batch_da)\n", + " print(f\"Batch {i}: {new_indexer} -> {model_output_on_batch.shape}\")\n", + " print(f\"Expected sum shape: {expected_sum.loc[new_indexer].shape}\")\n", + " expected_sum.loc[new_indexer] += model_output_on_batch\n", + " expected_count.loc[new_indexer] += 1\n", + " \n", + " expected_avg_data = expected_sum.data / expected_count.data\n", + " \n", + " # --- Assert correctness ---\n", + " np.testing.assert_allclose(result_da.values, expected_avg_data, equal_nan=True)" ] }, { "cell_type": "code", - "execution_count": 25, + "execution_count": 21, "metadata": {}, "outputs": [ { @@ -501,149 +1159,31 @@ "output_type": "stream", "text": [ "\u001b[1m============================= test session starts ==============================\u001b[0m\n", - "platform linux -- Python 3.10.16, pytest-8.4.1, pluggy-1.6.0 -- /srv/conda/envs/notebook/bin/python3.10\n", + "platform darwin -- Python 3.13.5, pytest-8.4.1, pluggy-1.6.0 -- /Users/nkalauni/miniconda3/envs/cookbook-dev/bin/python3.13\n", "cachedir: .pytest_cache\n", - "rootdir: /home/jovyan/xbatcher-deep-learning/notebooks\n", - "plugins: anyio-4.8.0\n", - "collected 2 items \u001b[0m\u001b[1m\n", - "\n", - "test_predict_on_array.py::test_predict_on_array_reassembly[model0-output_tensor_dim0-new_dim0-resample_dim0-] \u001b[31mFAILED\u001b[0m\u001b[31m [ 50%]\u001b[0m\n", - "\u001b[31mFAILED\u001b[0m\u001b[31m [100%]\u001b[0mpredict_on_array_reassembly[model1-output_tensor_dim1-new_dim1-resample_dim1-] \n", - "\n", - "=================================== FAILURES ===================================\n", - "\u001b[31m\u001b[1m_ test_predict_on_array_reassembly[model0-output_tensor_dim0-new_dim0-resample_dim0-] _\u001b[0m\n", - "\n", - "map_dataset_fixture = \n", - "model = SubsetAlongAxis(), output_tensor_dim = {'x': 5, 'y': 5}, new_dim = []\n", - "resample_dim = ['x'], expected_transform = at 0x7f4d4a136cb0>\n", - "\n", - " \u001b[0m\u001b[37m@pytest\u001b[39;49;00m.mark.parametrize(\u001b[90m\u001b[39;49;00m\n", - " \u001b[33m\"\u001b[39;49;00m\u001b[33mmodel, output_tensor_dim, new_dim, resample_dim, expected_transform\u001b[39;49;00m\u001b[33m\"\u001b[39;49;00m,\u001b[90m\u001b[39;49;00m\n", - " [\u001b[90m\u001b[39;49;00m\n", - " \u001b[90m# Case 1: Resampling - Downsampling with a subset model\u001b[39;49;00m\u001b[90m\u001b[39;49;00m\n", - " (\u001b[90m\u001b[39;49;00m\n", - " SubsetAlongAxis(ax=\u001b[94m1\u001b[39;49;00m, n=\u001b[94m5\u001b[39;49;00m), \u001b[90m# Corresponds to 'x' dim in batch\u001b[39;49;00m\u001b[90m\u001b[39;49;00m\n", - " {\u001b[33m'\u001b[39;49;00m\u001b[33mx\u001b[39;49;00m\u001b[33m'\u001b[39;49;00m: \u001b[94m5\u001b[39;49;00m, \u001b[33m'\u001b[39;49;00m\u001b[33my\u001b[39;49;00m\u001b[33m'\u001b[39;49;00m: \u001b[94m5\u001b[39;49;00m},\u001b[90m\u001b[39;49;00m\n", - " [],\u001b[90m\u001b[39;49;00m\n", - " [\u001b[33m'\u001b[39;49;00m\u001b[33mx\u001b[39;49;00m\u001b[33m'\u001b[39;49;00m],\u001b[90m\u001b[39;49;00m\n", - " \u001b[94mlambda\u001b[39;49;00m da: da.isel(x=\u001b[96mslice\u001b[39;49;00m(\u001b[94m0\u001b[39;49;00m, \u001b[94m5\u001b[39;49;00m)) \u001b[90m# Expected: take first 5 elements of original 'x'\u001b[39;49;00m\u001b[90m\u001b[39;49;00m\n", - " ),\u001b[90m\u001b[39;49;00m\n", - " \u001b[90m# Case 2: Dimension reduction with a mean model\u001b[39;49;00m\u001b[90m\u001b[39;49;00m\n", - " (\u001b[90m\u001b[39;49;00m\n", - " MeanAlongDim(ax=\u001b[94m2\u001b[39;49;00m), \u001b[90m# Corresponds to 'y' dim in batch\u001b[39;49;00m\u001b[90m\u001b[39;49;00m\n", - " {\u001b[33m'\u001b[39;49;00m\u001b[33mx\u001b[39;49;00m\u001b[33m'\u001b[39;49;00m: \u001b[94m10\u001b[39;49;00m},\u001b[90m\u001b[39;49;00m\n", - " [],\u001b[90m\u001b[39;49;00m\n", - " [\u001b[33m'\u001b[39;49;00m\u001b[33mx\u001b[39;49;00m\u001b[33m'\u001b[39;49;00m],\u001b[90m\u001b[39;49;00m\n", - " \u001b[94mlambda\u001b[39;49;00m da: da.mean(dim=\u001b[33m'\u001b[39;49;00m\u001b[33my\u001b[39;49;00m\u001b[33m'\u001b[39;49;00m) \u001b[90m# Expected: mean along original 'y'\u001b[39;49;00m\u001b[90m\u001b[39;49;00m\n", - " ),\u001b[90m\u001b[39;49;00m\n", - " ]\u001b[90m\u001b[39;49;00m\n", - " )\u001b[90m\u001b[39;49;00m\n", - " \u001b[94mdef\u001b[39;49;00m\u001b[90m \u001b[39;49;00m\u001b[92mtest_predict_on_array_reassembly\u001b[39;49;00m(\u001b[90m\u001b[39;49;00m\n", - " map_dataset_fixture,\u001b[90m\u001b[39;49;00m\n", - " model,\u001b[90m\u001b[39;49;00m\n", - " output_tensor_dim,\u001b[90m\u001b[39;49;00m\n", - " new_dim,\u001b[90m\u001b[39;49;00m\n", - " resample_dim,\u001b[90m\u001b[39;49;00m\n", - " expected_transform\u001b[90m\u001b[39;49;00m\n", - " ):\u001b[90m\u001b[39;49;00m\n", - " \u001b[90m \u001b[39;49;00m\u001b[33m\"\"\"\u001b[39;49;00m\n", - " \u001b[33m Tests that predict_on_array correctly reassembles batches from different models.\u001b[39;49;00m\n", - " \u001b[33m \"\"\"\u001b[39;49;00m\u001b[90m\u001b[39;49;00m\n", - " \u001b[90m# --- Run the function under test ---\u001b[39;49;00m\u001b[90m\u001b[39;49;00m\n", - " \u001b[90m# Using a small batch_size to ensure multiple iterations\u001b[39;49;00m\u001b[90m\u001b[39;49;00m\n", - "> predicted_da, predicted_n = predict_on_array(\u001b[90m\u001b[39;49;00m\n", - " dataset=map_dataset_fixture,\u001b[90m\u001b[39;49;00m\n", - " model=model,\u001b[90m\u001b[39;49;00m\n", - " output_tensor_dim=output_tensor_dim,\u001b[90m\u001b[39;49;00m\n", - " new_dim=new_dim,\u001b[90m\u001b[39;49;00m\n", - " resample_dim=resample_dim,\u001b[90m\u001b[39;49;00m\n", - " batch_size=\u001b[94m4\u001b[39;49;00m\u001b[90m\u001b[39;49;00m\n", - " )\u001b[90m\u001b[39;49;00m\n", - "\n", - "\u001b[1m\u001b[31mtest_predict_on_array.py\u001b[0m:67: \n", - "_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ \n", - "\u001b[1m\u001b[31mfunctions.py\u001b[0m:55: in predict_on_array\n", - " \u001b[0moutput_size = _get_output_array_size(dataset.X_generator, output_tensor_dim, new_dim, resample_dim)\u001b[90m\u001b[39;49;00m\n", - "_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ \n", - "\n", - "bgen = \n", - "output_tensor_dim = {'x': 5, 'y': 5}, new_dim = [], resample_dim = ['x']\n", - "\n", - " \u001b[0m\u001b[94mdef\u001b[39;49;00m\u001b[90m \u001b[39;49;00m\u001b[92m_get_output_array_size\u001b[39;49;00m(\u001b[90m\u001b[39;49;00m\n", - " bgen: xbatcher.BatchGenerator,\u001b[90m\u001b[39;49;00m\n", - " output_tensor_dim: \u001b[96mdict\u001b[39;49;00m[\u001b[96mstr\u001b[39;49;00m, \u001b[96mint\u001b[39;49;00m],\u001b[90m\u001b[39;49;00m\n", - " new_dim: \u001b[96mlist\u001b[39;49;00m[\u001b[96mstr\u001b[39;49;00m],\u001b[90m\u001b[39;49;00m\n", - " resample_dim: \u001b[96mlist\u001b[39;49;00m[\u001b[96mstr\u001b[39;49;00m]\u001b[90m\u001b[39;49;00m\n", - " ):\u001b[90m\u001b[39;49;00m\n", - " resample_factor = _get_resample_factor(bgen, output_tensor_dim, resample_dim)\u001b[90m\u001b[39;49;00m\n", - " output_size = {}\u001b[90m\u001b[39;49;00m\n", - " \u001b[94mfor\u001b[39;49;00m key, size \u001b[95min\u001b[39;49;00m output_tensor_dim.items():\u001b[90m\u001b[39;49;00m\n", - " \u001b[94mif\u001b[39;49;00m key \u001b[95min\u001b[39;49;00m new_dim:\u001b[90m\u001b[39;49;00m\n", - " \u001b[90m# This is a new axis, size is determined\u001b[39;49;00m\u001b[90m\u001b[39;49;00m\n", - " \u001b[90m# by the tensor size.\u001b[39;49;00m\u001b[90m\u001b[39;49;00m\n", - " output_size[key] = output_tensor_dim[key]\u001b[90m\u001b[39;49;00m\n", - " \u001b[94melse\u001b[39;49;00m:\u001b[90m\u001b[39;49;00m\n", - " \u001b[90m# This is a resampled axis, determine the new size\u001b[39;49;00m\u001b[90m\u001b[39;49;00m\n", - " \u001b[90m# by the ratio of the batchgen window to the tensor size.\u001b[39;49;00m\u001b[90m\u001b[39;49;00m\n", - "> temp_output_size = bgen.ds.sizes[key] * resample_factor[key]\u001b[90m\u001b[39;49;00m\n", - "\u001b[1m\u001b[31mE KeyError: 'y'\u001b[0m\n", - "\n", - "\u001b[1m\u001b[31mfunctions.py\u001b[0m:36: KeyError\n", - "\u001b[31m\u001b[1m_ test_predict_on_array_reassembly[model1-output_tensor_dim1-new_dim1-resample_dim1-] _\u001b[0m\n", + "rootdir: /Users/nkalauni/Documents/Cline/xbatcher-deep-learning/notebooks\n", + "plugins: anyio-4.10.0\n", + "collected 7 items \u001b[0m\u001b[1m\n", "\n", - "map_dataset_fixture = \n", - "model = MeanAlongDim(), output_tensor_dim = {'x': 10}, new_dim = []\n", - "resample_dim = ['x'], expected_transform = at 0x7f4d4a1371c0>\n", + "test_predict_on_array.py::test_resample_coordinate[2.0-edges-expected0] \u001b[32mPASSED\u001b[0m\u001b[32m [ 14%]\u001b[0m\n", + "test_predict_on_array.py::test_resample_coordinate[0.5-edges-expected1] \u001b[32mPASSED\u001b[0m\u001b[32m [ 28%]\u001b[0m\n", + "test_predict_on_array.py::test_predict_on_array_all_models[model0-output_tensor_dim0-new_dim0-core_dim0-resample_dim0-] \u001b[32mPASSED\u001b[0m\u001b[32m [ 42%]\u001b[0m\n", + "test_predict_on_array.py::test_predict_on_array_all_models[model1-output_tensor_dim1-new_dim1-core_dim1-resample_dim1-] \u001b[32mPASSED\u001b[0m\u001b[33m [ 57%]\u001b[0m\n", + "test_predict_on_array.py::test_predict_on_array_all_models[model2-output_tensor_dim2-new_dim2-core_dim2-resample_dim2-] \u001b[32mPASSED\u001b[0m\u001b[33m [ 71%]\u001b[0m\n", + "test_predict_on_array.py::test_predict_on_array_all_models[model3-output_tensor_dim3-new_dim3-core_dim3-resample_dim3-] \u001b[32mPASSED\u001b[0m\u001b[33m [ 85%]\u001b[0m\n", + "test_predict_on_array.py::test_predict_on_array_all_models[model4-output_tensor_dim4-new_dim4-core_dim4-resample_dim4-] \u001b[32mPASSED\u001b[0m\u001b[33m [100%]\u001b[0m\n", "\n", - " \u001b[0m\u001b[37m@pytest\u001b[39;49;00m.mark.parametrize(\u001b[90m\u001b[39;49;00m\n", - " \u001b[33m\"\u001b[39;49;00m\u001b[33mmodel, output_tensor_dim, new_dim, resample_dim, expected_transform\u001b[39;49;00m\u001b[33m\"\u001b[39;49;00m,\u001b[90m\u001b[39;49;00m\n", - " [\u001b[90m\u001b[39;49;00m\n", - " \u001b[90m# Case 1: Resampling - Downsampling with a subset model\u001b[39;49;00m\u001b[90m\u001b[39;49;00m\n", - " (\u001b[90m\u001b[39;49;00m\n", - " SubsetAlongAxis(ax=\u001b[94m1\u001b[39;49;00m, n=\u001b[94m5\u001b[39;49;00m), \u001b[90m# Corresponds to 'x' dim in batch\u001b[39;49;00m\u001b[90m\u001b[39;49;00m\n", - " {\u001b[33m'\u001b[39;49;00m\u001b[33mx\u001b[39;49;00m\u001b[33m'\u001b[39;49;00m: \u001b[94m5\u001b[39;49;00m, \u001b[33m'\u001b[39;49;00m\u001b[33my\u001b[39;49;00m\u001b[33m'\u001b[39;49;00m: \u001b[94m5\u001b[39;49;00m},\u001b[90m\u001b[39;49;00m\n", - " [],\u001b[90m\u001b[39;49;00m\n", - " [\u001b[33m'\u001b[39;49;00m\u001b[33mx\u001b[39;49;00m\u001b[33m'\u001b[39;49;00m],\u001b[90m\u001b[39;49;00m\n", - " \u001b[94mlambda\u001b[39;49;00m da: da.isel(x=\u001b[96mslice\u001b[39;49;00m(\u001b[94m0\u001b[39;49;00m, \u001b[94m5\u001b[39;49;00m)) \u001b[90m# Expected: take first 5 elements of original 'x'\u001b[39;49;00m\u001b[90m\u001b[39;49;00m\n", - " ),\u001b[90m\u001b[39;49;00m\n", - " \u001b[90m# Case 2: Dimension reduction with a mean model\u001b[39;49;00m\u001b[90m\u001b[39;49;00m\n", - " (\u001b[90m\u001b[39;49;00m\n", - " MeanAlongDim(ax=\u001b[94m2\u001b[39;49;00m), \u001b[90m# Corresponds to 'y' dim in batch\u001b[39;49;00m\u001b[90m\u001b[39;49;00m\n", - " {\u001b[33m'\u001b[39;49;00m\u001b[33mx\u001b[39;49;00m\u001b[33m'\u001b[39;49;00m: \u001b[94m10\u001b[39;49;00m},\u001b[90m\u001b[39;49;00m\n", - " [],\u001b[90m\u001b[39;49;00m\n", - " [\u001b[33m'\u001b[39;49;00m\u001b[33mx\u001b[39;49;00m\u001b[33m'\u001b[39;49;00m],\u001b[90m\u001b[39;49;00m\n", - " \u001b[94mlambda\u001b[39;49;00m da: da.mean(dim=\u001b[33m'\u001b[39;49;00m\u001b[33my\u001b[39;49;00m\u001b[33m'\u001b[39;49;00m) \u001b[90m# Expected: mean along original 'y'\u001b[39;49;00m\u001b[90m\u001b[39;49;00m\n", - " ),\u001b[90m\u001b[39;49;00m\n", - " ]\u001b[90m\u001b[39;49;00m\n", - " )\u001b[90m\u001b[39;49;00m\n", - " \u001b[94mdef\u001b[39;49;00m\u001b[90m \u001b[39;49;00m\u001b[92mtest_predict_on_array_reassembly\u001b[39;49;00m(\u001b[90m\u001b[39;49;00m\n", - " map_dataset_fixture,\u001b[90m\u001b[39;49;00m\n", - " model,\u001b[90m\u001b[39;49;00m\n", - " output_tensor_dim,\u001b[90m\u001b[39;49;00m\n", - " new_dim,\u001b[90m\u001b[39;49;00m\n", - " resample_dim,\u001b[90m\u001b[39;49;00m\n", - " expected_transform\u001b[90m\u001b[39;49;00m\n", - " ):\u001b[90m\u001b[39;49;00m\n", - " \u001b[90m \u001b[39;49;00m\u001b[33m\"\"\"\u001b[39;49;00m\n", - " \u001b[33m Tests that predict_on_array correctly reassembles batches from different models.\u001b[39;49;00m\n", - " \u001b[33m \"\"\"\u001b[39;49;00m\u001b[90m\u001b[39;49;00m\n", - " \u001b[90m# --- Run the function under test ---\u001b[39;49;00m\u001b[90m\u001b[39;49;00m\n", - " \u001b[90m# Using a small batch_size to ensure multiple iterations\u001b[39;49;00m\u001b[90m\u001b[39;49;00m\n", - "> predicted_da, predicted_n = predict_on_array(\u001b[90m\u001b[39;49;00m\n", - " dataset=map_dataset_fixture,\u001b[90m\u001b[39;49;00m\n", - " model=model,\u001b[90m\u001b[39;49;00m\n", - " output_tensor_dim=output_tensor_dim,\u001b[90m\u001b[39;49;00m\n", - " new_dim=new_dim,\u001b[90m\u001b[39;49;00m\n", - " resample_dim=resample_dim,\u001b[90m\u001b[39;49;00m\n", - " batch_size=\u001b[94m4\u001b[39;49;00m\u001b[90m\u001b[39;49;00m\n", - " )\u001b[90m\u001b[39;49;00m\n", - "\u001b[1m\u001b[31mE ValueError: too many values to unpack (expected 2)\u001b[0m\n", + "\u001b[33m=============================== warnings summary ===============================\u001b[0m\n", + "test_predict_on_array.py::test_predict_on_array_all_models[model0-output_tensor_dim0-new_dim0-core_dim0-resample_dim0-]\n", + "test_predict_on_array.py::test_predict_on_array_all_models[model1-output_tensor_dim1-new_dim1-core_dim1-resample_dim1-]\n", + "test_predict_on_array.py::test_predict_on_array_all_models[model2-output_tensor_dim2-new_dim2-core_dim2-resample_dim2-]\n", + "test_predict_on_array.py::test_predict_on_array_all_models[model3-output_tensor_dim3-new_dim3-core_dim3-resample_dim3-]\n", + "test_predict_on_array.py::test_predict_on_array_all_models[model4-output_tensor_dim4-new_dim4-core_dim4-resample_dim4-]\n", + " /Users/nkalauni/Documents/Cline/xbatcher-deep-learning/notebooks/test_predict_on_array.py:108: RuntimeWarning: invalid value encountered in divide\n", + " expected_avg_data = expected_sum.data / expected_count.data\n", "\n", - "\u001b[1m\u001b[31mtest_predict_on_array.py\u001b[0m:67: ValueError\n", - "\u001b[36m\u001b[1m=========================== short test summary info ============================\u001b[0m\n", - "\u001b[31mFAILED\u001b[0m test_predict_on_array.py::\u001b[1mtest_predict_on_array_reassembly[model0-output_tensor_dim0-new_dim0-resample_dim0-]\u001b[0m - KeyError: 'y'\n", - "\u001b[31mFAILED\u001b[0m test_predict_on_array.py::\u001b[1mtest_predict_on_array_reassembly[model1-output_tensor_dim1-new_dim1-resample_dim1-]\u001b[0m - ValueError: too many values to unpack (expected 2)\n", - "\u001b[31m============================== \u001b[31m\u001b[1m2 failed\u001b[0m\u001b[31m in 1.98s\u001b[0m\u001b[31m ===============================\u001b[0m\n" + "-- Docs: https://docs.pytest.org/en/stable/how-to/capture-warnings.html\n", + "\u001b[33m======================== \u001b[32m7 passed\u001b[0m, \u001b[33m\u001b[1m5 warnings\u001b[0m\u001b[33m in 1.25s\u001b[0m\u001b[33m =========================\u001b[0m\n" ] } ], diff --git a/notebooks/test_get_array_size.py b/notebooks/test_get_array_size.py index d54aea67..868eade7 100644 --- a/notebooks/test_get_array_size.py +++ b/notebooks/test_get_array_size.py @@ -46,120 +46,11 @@ def bgen_fixture() -> xbatcher.BatchGenerator: ), ( "Mixed: Resample x, add new channel dimension and keep t as core", - {'x': 30, 'channel': 12}, + {'x': 30, 'channel': 12, 't': 10}, ['channel'], ['t'], ['x'], - {'x': 300, 'channel': 12} - ), - ( - "Identity resampling (ratio=1)", - {'x': 10, 'y': 10}, - [], - [], - ['x', 'y'], - {'x': 100, 'y': 100} - ), - ( - "Core dims only: 't' is a core dim", - {'t': 10}, - [], - ['t'], - [], - {'t': 10} - ), - ] -) -def test_get_output_array_size_scenarios( - bgen_fixture, # The fixture is passed as an argument - case_description, - output_tensor_dim, - new_dim, - core_dim, - resample_dim, - expected_output -): - """ - Tests various valid scenarios for calculating the output array size. - The `case_description` parameter is not used in the code but helps make - test results more readable. - """ - # The `bgen_fixture` argument is the BatchGenerator instance created by our fixture - result = _get_output_array_size( - bgen=bgen_fixture, - output_tensor_dim=output_tensor_dim, - new_dim=new_dim, - core_dim=core_dim, - resample_dim=resample_dim - ) - - assert result == expected_output, f"Failed on case: {case_description}" - -def test_get_output_array_size_raises_error_on_mismatched_core_dim(bgen_fixture): - """Tests ValueError when a core_dim size doesn't match the source.""" - with pytest.raises(ValueError, match="does not equal the source data array size"): - _get_output_array_size( - bgen_fixture, output_tensor_dim={'t': 99}, new_dim=[], core_dim=['t'], resample_dim=[] - ) - -def test_get_output_array_size_raises_error_on_unspecified_dim(bgen_fixture): - """Tests ValueError when a dimension is not specified in any category.""" - with pytest.raises(ValueError, match="must be specified in one of"): - _get_output_array_size( - bgen_fixture, output_tensor_dim={'x': 10}, new_dim=[], core_dim=[], resample_dim=[] - ) - -def test_get_resample_factor_raises_error_on_invalid_ratio(bgen_fixture): - """Tests AssertionError when the resample ratio is not an integer or its inverse.""" - with pytest.raises(AssertionError, match="must be an integer or its inverse"): - # 15 / 10 = 1.5, which is not a valid ratio - _get_resample_factor(bgen_fixture, output_tensor_dim={'x': 15}, resample_dim=['x']) - -@pytest.fixture -def bgen_fixture() -> xbatcher.BatchGenerator: - data = xr.DataArray( - data=np.random.rand(100, 100, 10), - dims=("x", "y", "t"), - coords={ - "x": np.arange(100), - "y": np.arange(100), - "t": np.arange(10), - } - ) - - bgen = xbatcher.BatchGenerator( - data, - input_dims=dict(x=10, y=10), - input_overlap=dict(x=5, y=5), - ) - return bgen - -@pytest.mark.parametrize( - "case_description, output_tensor_dim, new_dim, core_dim, resample_dim, expected_output", - [ - ( - "Resampling only: Downsample x, Upsample y", - {'x': 5, 'y': 20}, - [], - [], - ['x', 'y'], - {'x': 50, 'y': 200} - ), - ( - "New dimensions only: Add a 'channel' dimension", - {'channel': 3}, - ['channel'], - [], - [], - {'channel': 3} - ), - ( - "Mixed: Resample x, add new channel dimension and keep t as core", - {'x': 30, 'channel': 12}, - ['channel'], - ['t'], - ['x'], - {'x': 300, 'channel': 12} + {'x': 300, 'channel': 12, 't': 10} ), ( "Identity resampling (ratio=1)", diff --git a/notebooks/test_predict_on_array.py b/notebooks/test_predict_on_array.py index 4a41e18f..8f45aa4f 100644 --- a/notebooks/test_predict_on_array.py +++ b/notebooks/test_predict_on_array.py @@ -5,27 +5,107 @@ import pytest from xbatcher.loaders.torch import MapDataset -from functions import _get_output_array_size, predict_on_array -from dummy_models import * +from functions import _get_output_array_size, _resample_coordinate +from functions import predict_on_array, _get_resample_factor +from dummy_models import Identity, MeanAlongDim, SubsetAlongAxis, ExpandAlongAxis, AddAxis @pytest.fixture def map_dataset_fixture() -> MapDataset: - """ - Creates a MapDataset with a predictable BatchGenerator for testing. - - Data is an xarray DataArray with dimensions x=20, y=10 - - Values are a simple np.arange sequence for easy verification. - - Batches are size x=10, y=5 with overlap x=2, y=2 - """ - # Using a smaller, more manageable dataset for testing data = xr.DataArray( - data=np.arange(20 * 10).reshape(20, 10), + data=np.arange(20 * 10).reshape(20, 10).astype(np.float32), dims=("x", "y"), - coords={"x": np.arange(20), "y": np.arange(10)} - ).astype(float) - - bgen = xbatcher.BatchGenerator( - data, - input_dims=dict(x=10, y=5), - input_overlap=dict(x=2, y=2), + coords={"x": np.arange(20, dtype=float), "y": np.arange(10, dtype=float)}, ) + bgen = xbatcher.BatchGenerator(data, input_dims=dict(x=10, y=5), input_overlap=dict(x=2, y=2)) return MapDataset(bgen) + +@pytest.mark.parametrize("factor, mode, expected", [ + (2.0, "edges", np.arange(0, 10, 0.5)), + (0.5, "edges", np.arange(0, 10, 2.0)), +]) +def test_resample_coordinate(factor, mode, expected): + coord = xr.DataArray(np.arange(10, dtype=float), dims="x") + resampled = _resample_coordinate(coord, factor, mode) + np.testing.assert_allclose(resampled, expected) + +@pytest.mark.parametrize( + "model, output_tensor_dim, new_dim, core_dim, resample_dim, manual_transform", + [ + # Case 1: Identity - No change + ( + Identity(), + {'x': 10, 'y': 5}, + [], [], ['x', 'y'], + lambda da: da.data + ), + # Case 2: ExpandAlongAxis - Upsampling + ( + ExpandAlongAxis(ax=1, n_repeats=2), # ax=1 is 'x' + {'x': 20, 'y': 5}, + [], [], ['x', 'y'], + lambda da: da.data.repeat(2, axis=0) # axis=0 in the 2D numpy array + ), + # Case 3: SubsetAlongAxis - Coarsening + ( + SubsetAlongAxis(ax=1, n=5), # ax=1 is 'x' + {'x': 5, 'y': 5}, + [], [], ['x', 'y'], + lambda da: da.isel(x=slice(0, 5)).data + ), + # Case 4: MeanAlongDim - Dimension reduction + ( + MeanAlongDim(ax=2), # ax=2 is 'y' + {'x': 10}, + [], [], ['x'], + lambda da: da.mean(dim='y').data + ), + # Case 5: AddAxis - Add a new dimension + ( + AddAxis(ax=1), # Add new dim at axis 1 + {'channel': 1, 'x': 10, 'y': 5}, + ['channel'], [], ['x', 'y'], + lambda da: np.expand_dims(da.data, axis=0) + ), + ] +) +def test_predict_on_array_all_models( + map_dataset_fixture, model, output_tensor_dim, new_dim, core_dim, resample_dim, manual_transform +): + """ + Tests reassembly, averaging, and coordinate assignment using a variety of models. + """ + dataset = map_dataset_fixture + bgen = dataset.X_generator + resample_factor = _get_resample_factor(bgen, output_tensor_dim, resample_dim) + + # --- Run the function under test --- + result_da = predict_on_array( + dataset=dataset, model=model, output_tensor_dim=output_tensor_dim, + new_dim=new_dim, core_dim=core_dim, resample_dim=resample_dim, batch_size=4 + ) + + # --- Manually calculate the expected result --- + expected_size = _get_output_array_size(bgen, output_tensor_dim, new_dim, core_dim, resample_dim) + expected_sum = xr.DataArray(np.zeros(list(expected_size.values())), dims=list(expected_size.keys())) + expected_count = xr.full_like(expected_sum, 0, dtype=int) + + for i in range(len(dataset)): + batch_da = bgen[i] + old_indexer = bgen._batch_selectors.selectors[i][0] + new_indexer = {} + for key in old_indexer: + if key in resample_dim: + new_indexer[key] = slice(int(old_indexer[key].start * resample_factor.get(key, 1)), int(old_indexer[key].stop * resample_factor.get(key, 1))) + elif key in core_dim: + new_indexer[key] = old_indexer[key] + + model_output_on_batch = manual_transform(batch_da) + print(f"Batch {i}: {new_indexer} -> {model_output_on_batch.shape}") + print(f"Expected sum shape: {expected_sum.loc[new_indexer].shape}") + expected_sum.loc[new_indexer] += model_output_on_batch + expected_count.loc[new_indexer] += 1 + + expected_avg_data = expected_sum.data / expected_count.data + + # --- Assert correctness --- + np.testing.assert_allclose(result_da.values, expected_avg_data, equal_nan=True)