diff --git a/notebooks/dummy_models.py b/notebooks/dummy_models.py
index 9bb1533f..8a6da171 100644
--- a/notebooks/dummy_models.py
+++ b/notebooks/dummy_models.py
@@ -1,5 +1,9 @@
import torch
+class Identity(torch.nn.Module):
+ def __init__(self): super().__init__()
+ def forward(self, x): return x
+
class MeanAlongDim(torch.nn.Module):
def __init__(self, ax):
super(MeanAlongDim, self).__init__()
diff --git a/notebooks/functions.py b/notebooks/functions.py
index c5704fb4..2b9a2101 100644
--- a/notebooks/functions.py
+++ b/notebooks/functions.py
@@ -74,7 +74,8 @@ def _resample_coordinate(
offset = 0 if mode == "edges" else old_step / 2
new_step = old_step / factor
coord = coord - offset
- return np.arange(coord.min().item(), coord.max().item()+old_step, step=new_step) + offset
+ new_coord_end = coord.max().item() + old_step
+ return np.arange(coord.min().item(), new_coord_end, step=new_step) + offset
def _get_output_array_coordinates(
@@ -90,7 +91,7 @@ def _get_output_array_coordinates(
output_coords[dim] = _resample_coordinate(src_da[dim], resample_factor[dim], resample_mode)
elif dim in src_da.coords:
# Source array has coordinate but it isn't changing size
- output_coords[dim] = src_da[dim].copy()
+ output_coords[dim] = src_da[dim].copy(deep=True).data
else:
# Source array doesn't have a coordinate on this dim or
# this is a new dim, ignore
@@ -168,22 +169,28 @@ def predict_on_array(
Overlaps are allowed, in which case the average of all output values is returned.
'''
- # TODO input checking
- # *_dim args cannot have common axes
+ s_new = set(new_dim)
+ s_core = set(core_dim)
+ s_resample = set(resample_dim)
+
+ if s_new & s_core or s_new & s_resample or s_core & s_resample:
+ raise ValueError("new_dim, core_dim, and resample_dim must be disjoint sets.")
+
+ bgen = dataset.X_generator
# Get resample factors
resample_factor = _get_resample_factor(
- dataset.X_generator,
+ bgen,
output_tensor_dim,
resample_dim
)
# Set up output array
output_size = _get_output_array_size(
- dataset.X_generator,
- output_tensor_dim,
- new_dim,
- core_dim,
+ bgen,
+ output_tensor_dim,
+ new_dim,
+ core_dim,
resample_dim
)
@@ -198,12 +205,14 @@ def predict_on_array(
# Iterate over each batch
for i, batch in enumerate(loader):
- out_batch = model(batch).detach().numpy()
+ input_tensor = batch[0] if isinstance(batch, (list, tuple)) else batch
+ out_batch = model(input_tensor).detach().numpy()
- # Iterate over each example in the batch
+ # Iterate over each sample in the batch
for ib in range(out_batch.shape[0]):
- # Get the slice object associated with this example
- old_indexer = dataset.X_generator._batch_selectors.selectors[(i*batch_size)+ib][0]
+ # Get the slice object associated with this sample
+ global_index = (i * batch_size) + ib
+ old_indexer = bgen._batch_selectors.selectors[global_index][0]
# Only index into axes that are resampled, rescaling the bounds
# Perhaps use xbatcher _gen_slices here?
new_indexer = {}
@@ -213,7 +222,7 @@ def predict_on_array(
int(old_indexer[key].start * resample_factor[key]),
int(old_indexer[key].stop * resample_factor[key])
)
-
+
output_da.loc[new_indexer] += out_batch[ib, ...]
output_n.loc[new_indexer] += 1
diff --git a/notebooks/inference-testing.ipynb b/notebooks/inference-testing.ipynb
index 3e72ec22..56000811 100644
--- a/notebooks/inference-testing.ipynb
+++ b/notebooks/inference-testing.ipynb
@@ -23,7 +23,7 @@
},
{
"cell_type": "code",
- "execution_count": 11,
+ "execution_count": 1,
"metadata": {},
"outputs": [],
"source": [
@@ -45,7 +45,7 @@
},
{
"cell_type": "code",
- "execution_count": 24,
+ "execution_count": 2,
"metadata": {},
"outputs": [
{
@@ -69,7 +69,7 @@
},
{
"cell_type": "code",
- "execution_count": 28,
+ "execution_count": 3,
"metadata": {},
"outputs": [
{
@@ -123,11 +123,11 @@
" ),\n",
" (\n",
" \"Mixed: Resample x, add new channel dimension and keep t as core\",\n",
- " {'x': 30, 'channel': 12}, \n",
+ " {'x': 30, 'channel': 12, 't': 10}, \n",
" ['channel'],\n",
" ['t'],\n",
" ['x'],\n",
- " {'x': 300, 'channel': 12} \n",
+ " {'x': 300, 'channel': 12, 't': 10} \n",
" ),\n",
" (\n",
" \"Identity resampling (ratio=1)\",\n",
@@ -175,7 +175,7 @@
},
{
"cell_type": "code",
- "execution_count": 29,
+ "execution_count": 4,
"metadata": {},
"outputs": [
{
@@ -212,7 +212,7 @@
},
{
"cell_type": "code",
- "execution_count": 30,
+ "execution_count": 5,
"metadata": {},
"outputs": [
{
@@ -235,7 +235,7 @@
"test_get_array_size.py::test_get_output_array_size_raises_error_on_unspecified_dim \u001b[32mPASSED\u001b[0m\u001b[32m [ 87%]\u001b[0m\n",
"test_get_array_size.py::test_get_resample_factor_raises_error_on_invalid_ratio \u001b[32mPASSED\u001b[0m\u001b[32m [100%]\u001b[0m\n",
"\n",
- "\u001b[32m============================== \u001b[32m\u001b[1m8 passed\u001b[0m\u001b[32m in 1.23s\u001b[0m\u001b[32m ===============================\u001b[0m\n"
+ "\u001b[32m============================== \u001b[32m\u001b[1m8 passed\u001b[0m\u001b[32m in 1.00s\u001b[0m\u001b[32m ===============================\u001b[0m\n"
]
}
],
@@ -252,7 +252,7 @@
},
{
"cell_type": "code",
- "execution_count": 19,
+ "execution_count": 6,
"metadata": {},
"outputs": [
{
@@ -272,13 +272,14 @@
"import pytest\n",
"from xbatcher.loaders.torch import MapDataset\n",
"\n",
- "from functions import _get_output_array_size, predict_on_array\n",
- "from dummy_models import *"
+ "from functions import _get_output_array_size, _resample_coordinate\n",
+ "from functions import predict_on_array, _get_resample_factor\n",
+ "from dummy_models import Identity, MeanAlongDim, SubsetAlongAxis, ExpandAlongAxis, AddAxis"
]
},
{
"cell_type": "code",
- "execution_count": 20,
+ "execution_count": 7,
"metadata": {},
"outputs": [],
"source": [
@@ -289,13 +290,13 @@
"import pytest\n",
"from xbatcher.loaders.torch import MapDataset\n",
"\n",
- "from functions import _get_output_array_size, predict_on_array\n",
+ "from functions import *\n",
"from dummy_models import *"
]
},
{
"cell_type": "code",
- "execution_count": 21,
+ "execution_count": 8,
"metadata": {},
"outputs": [
{
@@ -304,7 +305,7 @@
"tensor([0., 1., 2., 3., 4.])"
]
},
- "execution_count": 21,
+ "execution_count": 8,
"metadata": {},
"output_type": "execute_result"
}
@@ -316,32 +317,28 @@
},
{
"cell_type": "code",
- "execution_count": 22,
+ "execution_count": 9,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
- "tensor([[ 2., 7., 12., 17., 22.],\n",
- " [ 27., 32., 37., 42., 47.],\n",
- " [ 52., 57., 62., 67., 72.],\n",
- " [ 77., 82., 87., 92., 97.],\n",
- " [102., 107., 112., 117., 122.]])"
+ "torch.Size([5, 10, 5])"
]
},
- "execution_count": 22,
+ "execution_count": 9,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
- "model = MeanAlongDim(-1)\n",
- "model(input_tensor)"
+ "model = ExpandAlongAxis(1, 2)\n",
+ "model(input_tensor).shape"
]
},
{
"cell_type": "code",
- "execution_count": 23,
+ "execution_count": 10,
"metadata": {},
"outputs": [
{
@@ -357,49 +354,705 @@
"\n",
"@pytest.fixture\n",
"def map_dataset_fixture() -> MapDataset:\n",
- " \"\"\"\n",
- " Creates a MapDataset with a predictable BatchGenerator for testing.\n",
- " - Data is an xarray DataArray with dimensions x=20, y=10\n",
- " - Values are a simple np.arange sequence for easy verification.\n",
- " - Batches are size x=10, y=5 with overlap x=2, y=2\n",
- " \"\"\"\n",
- " # Using a smaller, more manageable dataset for testing\n",
" data = xr.DataArray(\n",
- " data=np.arange(20 * 10).reshape(20, 10),\n",
+ " data=np.arange(20 * 10).reshape(20, 10).astype(np.float32),\n",
" dims=(\"x\", \"y\"),\n",
- " coords={\"x\": np.arange(20), \"y\": np.arange(10)}\n",
- " ).astype(float)\n",
- " \n",
- " bgen = xbatcher.BatchGenerator(\n",
- " data,\n",
- " input_dims=dict(x=10, y=5),\n",
- " input_overlap=dict(x=2, y=2),\n",
+ " coords={\"x\": np.arange(20, dtype=float), \"y\": np.arange(10, dtype=float)},\n",
" )\n",
- " return MapDataset(bgen)"
+ " bgen = xbatcher.BatchGenerator(data, input_dims=dict(x=10, y=5), input_overlap=dict(x=2, y=2))\n",
+ " return MapDataset(bgen)\n"
]
},
{
"cell_type": "code",
- "execution_count": null,
+ "execution_count": 11,
"metadata": {},
"outputs": [],
"source": [
- " data = xr.DataArray(\n",
- " data=np.arange(20 * 10).reshape(20, 10),\n",
- " dims=(\"x\", \"y\"),\n",
- " coords={\"x\": np.arange(20), \"y\": np.arange(10)}\n",
- " ).astype(float)\n",
- " \n",
- " bgen = xbatcher.BatchGenerator(\n",
- " data,\n",
- " input_dims=dict(x=10, y=5),\n",
- " input_overlap=dict(x=2, y=2),\n",
- " )"
+ "data = xr.DataArray(\n",
+ " data=np.arange(20 * 10).reshape(20, 10),\n",
+ " dims=(\"x\", \"y\"),\n",
+ " coords={\"x\": np.arange(20), \"y\": np.arange(10)}\n",
+ ").astype(float)\n",
+ "\n",
+ "bgen = xbatcher.BatchGenerator(\n",
+ " data,\n",
+ " input_dims=dict(x=10, y=5),\n",
+ " input_overlap=dict(x=2, y=2)\n",
+ ")"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 12,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "ds = MapDataset(bgen)"
]
},
{
"cell_type": "code",
- "execution_count": 24,
+ "execution_count": 13,
+ "metadata": {},
+ "outputs": [
+ {
+ "data": {
+ "text/html": [
+ "
\n",
+ "
<xarray.DataArray (x: 20, y: 10)> Size: 2kB\n",
+ "array([[ 0., 1., 2., 3., 4., 5., 6., 7., 8., 9.],\n",
+ " [ 10., 11., 12., 13., 14., 15., 16., 17., 18., 19.],\n",
+ " [ 20., 21., 22., 23., 24., 25., 26., 27., 28., 29.],\n",
+ " [ 30., 31., 32., 33., 34., 35., 36., 37., 38., 39.],\n",
+ " [ 40., 41., 42., 43., 44., 45., 46., 47., 48., 49.],\n",
+ " [ 50., 51., 52., 53., 54., 55., 56., 57., 58., 59.],\n",
+ " [ 60., 61., 62., 63., 64., 65., 66., 67., 68., 69.],\n",
+ " [ 70., 71., 72., 73., 74., 75., 76., 77., 78., 79.],\n",
+ " [ 80., 81., 82., 83., 84., 85., 86., 87., 88., 89.],\n",
+ " [ 90., 91., 92., 93., 94., 95., 96., 97., 98., 99.],\n",
+ " [100., 101., 102., 103., 104., 105., 106., 107., 108., 109.],\n",
+ " [110., 111., 112., 113., 114., 115., 116., 117., 118., 119.],\n",
+ " [120., 121., 122., 123., 124., 125., 126., 127., 128., 129.],\n",
+ " [130., 131., 132., 133., 134., 135., 136., 137., 138., 139.],\n",
+ " [140., 141., 142., 143., 144., 145., 146., 147., 148., 149.],\n",
+ " [150., 151., 152., 153., 154., 155., 156., 157., 158., 159.],\n",
+ " [160., 161., 162., 163., 164., 165., 166., 167., 168., 169.],\n",
+ " [170., 171., 172., 173., 174., 175., 176., 177., 178., 179.],\n",
+ " [180., 181., 182., 183., 184., 185., 186., 187., 188., 189.],\n",
+ " [190., 191., 192., 193., 194., 195., 196., 197., 198., 199.]])\n",
+ "Coordinates:\n",
+ " * x (x) int64 160B 0 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19\n",
+ " * y (y) int64 80B 0 1 2 3 4 5 6 7 8 9
0.0 1.0 2.0 3.0 4.0 5.0 6.0 ... 194.0 195.0 196.0 197.0 198.0 199.0
array([[ 0., 1., 2., 3., 4., 5., 6., 7., 8., 9.],\n",
+ " [ 10., 11., 12., 13., 14., 15., 16., 17., 18., 19.],\n",
+ " [ 20., 21., 22., 23., 24., 25., 26., 27., 28., 29.],\n",
+ " [ 30., 31., 32., 33., 34., 35., 36., 37., 38., 39.],\n",
+ " [ 40., 41., 42., 43., 44., 45., 46., 47., 48., 49.],\n",
+ " [ 50., 51., 52., 53., 54., 55., 56., 57., 58., 59.],\n",
+ " [ 60., 61., 62., 63., 64., 65., 66., 67., 68., 69.],\n",
+ " [ 70., 71., 72., 73., 74., 75., 76., 77., 78., 79.],\n",
+ " [ 80., 81., 82., 83., 84., 85., 86., 87., 88., 89.],\n",
+ " [ 90., 91., 92., 93., 94., 95., 96., 97., 98., 99.],\n",
+ " [100., 101., 102., 103., 104., 105., 106., 107., 108., 109.],\n",
+ " [110., 111., 112., 113., 114., 115., 116., 117., 118., 119.],\n",
+ " [120., 121., 122., 123., 124., 125., 126., 127., 128., 129.],\n",
+ " [130., 131., 132., 133., 134., 135., 136., 137., 138., 139.],\n",
+ " [140., 141., 142., 143., 144., 145., 146., 147., 148., 149.],\n",
+ " [150., 151., 152., 153., 154., 155., 156., 157., 158., 159.],\n",
+ " [160., 161., 162., 163., 164., 165., 166., 167., 168., 169.],\n",
+ " [170., 171., 172., 173., 174., 175., 176., 177., 178., 179.],\n",
+ " [180., 181., 182., 183., 184., 185., 186., 187., 188., 189.],\n",
+ " [190., 191., 192., 193., 194., 195., 196., 197., 198., 199.]])
PandasIndex
PandasIndex(Index([0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19], dtype='int64', name='x'))
PandasIndex
PandasIndex(Index([0, 1, 2, 3, 4, 5, 6, 7, 8, 9], dtype='int64', name='y'))
"
+ ],
+ "text/plain": [
+ " Size: 2kB\n",
+ "array([[ 0., 1., 2., 3., 4., 5., 6., 7., 8., 9.],\n",
+ " [ 10., 11., 12., 13., 14., 15., 16., 17., 18., 19.],\n",
+ " [ 20., 21., 22., 23., 24., 25., 26., 27., 28., 29.],\n",
+ " [ 30., 31., 32., 33., 34., 35., 36., 37., 38., 39.],\n",
+ " [ 40., 41., 42., 43., 44., 45., 46., 47., 48., 49.],\n",
+ " [ 50., 51., 52., 53., 54., 55., 56., 57., 58., 59.],\n",
+ " [ 60., 61., 62., 63., 64., 65., 66., 67., 68., 69.],\n",
+ " [ 70., 71., 72., 73., 74., 75., 76., 77., 78., 79.],\n",
+ " [ 80., 81., 82., 83., 84., 85., 86., 87., 88., 89.],\n",
+ " [ 90., 91., 92., 93., 94., 95., 96., 97., 98., 99.],\n",
+ " [100., 101., 102., 103., 104., 105., 106., 107., 108., 109.],\n",
+ " [110., 111., 112., 113., 114., 115., 116., 117., 118., 119.],\n",
+ " [120., 121., 122., 123., 124., 125., 126., 127., 128., 129.],\n",
+ " [130., 131., 132., 133., 134., 135., 136., 137., 138., 139.],\n",
+ " [140., 141., 142., 143., 144., 145., 146., 147., 148., 149.],\n",
+ " [150., 151., 152., 153., 154., 155., 156., 157., 158., 159.],\n",
+ " [160., 161., 162., 163., 164., 165., 166., 167., 168., 169.],\n",
+ " [170., 171., 172., 173., 174., 175., 176., 177., 178., 179.],\n",
+ " [180., 181., 182., 183., 184., 185., 186., 187., 188., 189.],\n",
+ " [190., 191., 192., 193., 194., 195., 196., 197., 198., 199.]])\n",
+ "Coordinates:\n",
+ " * x (x) int64 160B 0 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19\n",
+ " * y (y) int64 80B 0 1 2 3 4 5 6 7 8 9"
+ ]
+ },
+ "execution_count": 13,
+ "metadata": {},
+ "output_type": "execute_result"
+ }
+ ],
+ "source": [
+ "data"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 14,
+ "metadata": {},
+ "outputs": [
+ {
+ "data": {
+ "text/plain": [
+ "tensor([[ 3., 4., 5., 6., 7.],\n",
+ " [13., 14., 15., 16., 17.],\n",
+ " [23., 24., 25., 26., 27.],\n",
+ " [33., 34., 35., 36., 37.],\n",
+ " [43., 44., 45., 46., 47.],\n",
+ " [53., 54., 55., 56., 57.],\n",
+ " [63., 64., 65., 66., 67.],\n",
+ " [73., 74., 75., 76., 77.],\n",
+ " [83., 84., 85., 86., 87.],\n",
+ " [93., 94., 95., 96., 97.]], dtype=torch.float64)"
+ ]
+ },
+ "execution_count": 14,
+ "metadata": {},
+ "output_type": "execute_result"
+ }
+ ],
+ "source": [
+ "ds[1]"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 15,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "output_tensor_dim = {'x': 20, 'y': 5}\n",
+ "resample_dim = ['x', 'y']\n",
+ "core_dim = []\n",
+ "new_dim = []"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 16,
+ "metadata": {},
+ "outputs": [
+ {
+ "data": {
+ "text/plain": [
+ "torch.Size([10, 5])"
+ ]
+ },
+ "execution_count": 16,
+ "metadata": {},
+ "output_type": "execute_result"
+ }
+ ],
+ "source": [
+ "ds[0].shape"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 17,
+ "metadata": {},
+ "outputs": [
+ {
+ "data": {
+ "text/plain": [
+ "torch.Size([10, 10])"
+ ]
+ },
+ "execution_count": 17,
+ "metadata": {},
+ "output_type": "execute_result"
+ }
+ ],
+ "source": [
+ "model(ds[0]).shape"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 18,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "import functions\n",
+ "from importlib import reload\n",
+ "reload(functions)\n",
+ "result = functions.predict_on_array(\n",
+ " ds,\n",
+ " model,\n",
+ " output_tensor_dim=output_tensor_dim,\n",
+ " new_dim=new_dim,\n",
+ " core_dim=core_dim,\n",
+ " resample_dim=resample_dim,\n",
+ " batch_size=4\n",
+ ")"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 19,
+ "metadata": {},
+ "outputs": [
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "Appending to test_predict_on_array.py\n"
+ ]
+ }
+ ],
+ "source": [
+ "%%writefile -a test_predict_on_array.py\n",
+ "\n",
+ "@pytest.mark.parametrize(\"factor, mode, expected\", [\n",
+ " (2.0, \"edges\", np.arange(0, 10, 0.5)),\n",
+ " (0.5, \"edges\", np.arange(0, 10, 2.0)),\n",
+ "])\n",
+ "def test_resample_coordinate(factor, mode, expected):\n",
+ " coord = xr.DataArray(np.arange(10, dtype=float), dims=\"x\")\n",
+ " resampled = _resample_coordinate(coord, factor, mode)\n",
+ " np.testing.assert_allclose(resampled, expected)"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 20,
"metadata": {},
"outputs": [
{
@@ -414,86 +1067,91 @@
"%%writefile -a test_predict_on_array.py\n",
"\n",
"@pytest.mark.parametrize(\n",
- " \"model, output_tensor_dim, new_dim, resample_dim, expected_transform\",\n",
+ " \"model, output_tensor_dim, new_dim, core_dim, resample_dim, manual_transform\",\n",
" [\n",
- " # Case 1: Resampling - Downsampling with a subset model\n",
+ " # Case 1: Identity - No change\n",
" (\n",
- " SubsetAlongAxis(ax=1, n=5), # Corresponds to 'x' dim in batch\n",
+ " Identity(),\n",
+ " {'x': 10, 'y': 5},\n",
+ " [], [], ['x', 'y'],\n",
+ " lambda da: da.data\n",
+ " ),\n",
+ " # Case 2: ExpandAlongAxis - Upsampling\n",
+ " (\n",
+ " ExpandAlongAxis(ax=1, n_repeats=2), # ax=1 is 'x'\n",
+ " {'x': 20, 'y': 5},\n",
+ " [], [], ['x', 'y'],\n",
+ " lambda da: da.data.repeat(2, axis=0) # axis=0 in the 2D numpy array\n",
+ " ),\n",
+ " # Case 3: SubsetAlongAxis - Coarsening\n",
+ " (\n",
+ " SubsetAlongAxis(ax=1, n=5), # ax=1 is 'x'\n",
" {'x': 5, 'y': 5},\n",
- " [],\n",
- " ['x'],\n",
- " lambda da: da.isel(x=slice(0, 5)) # Expected: take first 5 elements of original 'x'\n",
+ " [], [], ['x', 'y'],\n",
+ " lambda da: da.isel(x=slice(0, 5)).data\n",
" ),\n",
- " # Case 2: Dimension reduction with a mean model\n",
+ " # Case 4: MeanAlongDim - Dimension reduction\n",
" (\n",
- " MeanAlongDim(ax=2), # Corresponds to 'y' dim in batch\n",
+ " MeanAlongDim(ax=2), # ax=2 is 'y'\n",
" {'x': 10},\n",
- " [],\n",
- " ['x'],\n",
- " lambda da: da.mean(dim='y') # Expected: mean along original 'y'\n",
+ " [], [], ['x'],\n",
+ " lambda da: da.mean(dim='y').data\n",
+ " ),\n",
+ " # Case 5: AddAxis - Add a new dimension\n",
+ " (\n",
+ " AddAxis(ax=1), # Add new dim at axis 1\n",
+ " {'channel': 1, 'x': 10, 'y': 5},\n",
+ " ['channel'], [], ['x', 'y'],\n",
+ " lambda da: np.expand_dims(da.data, axis=0)\n",
" ),\n",
" ]\n",
")\n",
- "def test_predict_on_array_reassembly(\n",
- " map_dataset_fixture,\n",
- " model,\n",
- " output_tensor_dim,\n",
- " new_dim,\n",
- " resample_dim,\n",
- " expected_transform\n",
+ "def test_predict_on_array_all_models(\n",
+ " map_dataset_fixture, model, output_tensor_dim, new_dim, core_dim, resample_dim, manual_transform\n",
"):\n",
" \"\"\"\n",
- " Tests that predict_on_array correctly reassembles batches from different models.\n",
+ " Tests reassembly, averaging, and coordinate assignment using a variety of models.\n",
" \"\"\"\n",
+ " dataset = map_dataset_fixture\n",
+ " bgen = dataset.X_generator\n",
+ " resample_factor = _get_resample_factor(bgen, output_tensor_dim, resample_dim)\n",
+ "\n",
" # --- Run the function under test ---\n",
- " # Using a small batch_size to ensure multiple iterations\n",
- " predicted_da, predicted_n = predict_on_array(\n",
- " dataset=map_dataset_fixture,\n",
- " model=model,\n",
- " output_tensor_dim=output_tensor_dim,\n",
- " new_dim=new_dim,\n",
- " resample_dim=resample_dim,\n",
- " batch_size=4 \n",
+ " result_da = predict_on_array(\n",
+ " dataset=dataset, model=model, output_tensor_dim=output_tensor_dim,\n",
+ " new_dim=new_dim, core_dim=core_dim, resample_dim=resample_dim, batch_size=4\n",
" )\n",
"\n",
" # --- Manually calculate the expected result ---\n",
- " bgen = map_dataset_fixture.generator\n",
- " # 1. Create the expected output array structure\n",
- " expected_size = _get_output_array_size(bgen, output_tensor_dim, new_dim, resample_dim)\n",
- " expected_da = xr.DataArray(np.zeros(list(expected_size.values())), dims=list(expected_size.keys()))\n",
- " expected_n = xr.full_like(expected_da, 0)\n",
+ " expected_size = _get_output_array_size(bgen, output_tensor_dim, new_dim, core_dim, resample_dim)\n",
+ " expected_sum = xr.DataArray(np.zeros(list(expected_size.values())), dims=list(expected_size.keys()))\n",
+ " expected_count = xr.full_like(expected_sum, 0, dtype=int)\n",
"\n",
- " # 2. Manually iterate through batches and apply the same logic as the function\n",
- " for i in range(len(map_dataset_fixture)):\n",
+ " for i in range(len(dataset)):\n",
" batch_da = bgen[i]\n",
- " \n",
- " # Apply the same transformation the model would\n",
- " transformed_batch = expected_transform(batch_da)\n",
- " \n",
- " # Get the rescaled indexer\n",
- " old_indexer = bgen.batch_selectors[i]\n",
+ " old_indexer = bgen._batch_selectors.selectors[i][0]\n",
" new_indexer = {}\n",
" for key in old_indexer:\n",
" if key in resample_dim:\n",
- " resample_ratio = output_tensor_dim[key] / bgen.input_dims[key]\n",
- " new_indexer[key] = slice(\n",
- " int(old_indexer[key].start * resample_ratio),\n",
- " int(old_indexer[key].stop * resample_ratio)\n",
- " )\n",
- " \n",
- " # Add the result to our manually calculated array\n",
- " expected_da.loc[new_indexer] += transformed_batch.values\n",
- " expected_n.loc[new_indexer] += 1\n",
+ " new_indexer[key] = slice(int(old_indexer[key].start * resample_factor.get(key, 1)), int(old_indexer[key].stop * resample_factor.get(key, 1)))\n",
+ " elif key in core_dim:\n",
+ " new_indexer[key] = old_indexer[key]\n",
"\n",
- " # --- Assert that the results are identical ---\n",
- " # We test the raw summed output and the overlap counter array\n",
- " xr.testing.assert_allclose(predicted_da, expected_da)\n",
- " xr.testing.assert_allclose(predicted_n, expected_n)"
+ " model_output_on_batch = manual_transform(batch_da)\n",
+ " print(f\"Batch {i}: {new_indexer} -> {model_output_on_batch.shape}\")\n",
+ " print(f\"Expected sum shape: {expected_sum.loc[new_indexer].shape}\")\n",
+ " expected_sum.loc[new_indexer] += model_output_on_batch\n",
+ " expected_count.loc[new_indexer] += 1\n",
+ " \n",
+ " expected_avg_data = expected_sum.data / expected_count.data\n",
+ " \n",
+ " # --- Assert correctness ---\n",
+ " np.testing.assert_allclose(result_da.values, expected_avg_data, equal_nan=True)"
]
},
{
"cell_type": "code",
- "execution_count": 25,
+ "execution_count": 21,
"metadata": {},
"outputs": [
{
@@ -501,149 +1159,31 @@
"output_type": "stream",
"text": [
"\u001b[1m============================= test session starts ==============================\u001b[0m\n",
- "platform linux -- Python 3.10.16, pytest-8.4.1, pluggy-1.6.0 -- /srv/conda/envs/notebook/bin/python3.10\n",
+ "platform darwin -- Python 3.13.5, pytest-8.4.1, pluggy-1.6.0 -- /Users/nkalauni/miniconda3/envs/cookbook-dev/bin/python3.13\n",
"cachedir: .pytest_cache\n",
- "rootdir: /home/jovyan/xbatcher-deep-learning/notebooks\n",
- "plugins: anyio-4.8.0\n",
- "collected 2 items \u001b[0m\u001b[1m\n",
- "\n",
- "test_predict_on_array.py::test_predict_on_array_reassembly[model0-output_tensor_dim0-new_dim0-resample_dim0-] \u001b[31mFAILED\u001b[0m\u001b[31m [ 50%]\u001b[0m\n",
- "\u001b[31mFAILED\u001b[0m\u001b[31m [100%]\u001b[0mpredict_on_array_reassembly[model1-output_tensor_dim1-new_dim1-resample_dim1-] \n",
- "\n",
- "=================================== FAILURES ===================================\n",
- "\u001b[31m\u001b[1m_ test_predict_on_array_reassembly[model0-output_tensor_dim0-new_dim0-resample_dim0-] _\u001b[0m\n",
- "\n",
- "map_dataset_fixture = \n",
- "model = SubsetAlongAxis(), output_tensor_dim = {'x': 5, 'y': 5}, new_dim = []\n",
- "resample_dim = ['x'], expected_transform = at 0x7f4d4a136cb0>\n",
- "\n",
- " \u001b[0m\u001b[37m@pytest\u001b[39;49;00m.mark.parametrize(\u001b[90m\u001b[39;49;00m\n",
- " \u001b[33m\"\u001b[39;49;00m\u001b[33mmodel, output_tensor_dim, new_dim, resample_dim, expected_transform\u001b[39;49;00m\u001b[33m\"\u001b[39;49;00m,\u001b[90m\u001b[39;49;00m\n",
- " [\u001b[90m\u001b[39;49;00m\n",
- " \u001b[90m# Case 1: Resampling - Downsampling with a subset model\u001b[39;49;00m\u001b[90m\u001b[39;49;00m\n",
- " (\u001b[90m\u001b[39;49;00m\n",
- " SubsetAlongAxis(ax=\u001b[94m1\u001b[39;49;00m, n=\u001b[94m5\u001b[39;49;00m), \u001b[90m# Corresponds to 'x' dim in batch\u001b[39;49;00m\u001b[90m\u001b[39;49;00m\n",
- " {\u001b[33m'\u001b[39;49;00m\u001b[33mx\u001b[39;49;00m\u001b[33m'\u001b[39;49;00m: \u001b[94m5\u001b[39;49;00m, \u001b[33m'\u001b[39;49;00m\u001b[33my\u001b[39;49;00m\u001b[33m'\u001b[39;49;00m: \u001b[94m5\u001b[39;49;00m},\u001b[90m\u001b[39;49;00m\n",
- " [],\u001b[90m\u001b[39;49;00m\n",
- " [\u001b[33m'\u001b[39;49;00m\u001b[33mx\u001b[39;49;00m\u001b[33m'\u001b[39;49;00m],\u001b[90m\u001b[39;49;00m\n",
- " \u001b[94mlambda\u001b[39;49;00m da: da.isel(x=\u001b[96mslice\u001b[39;49;00m(\u001b[94m0\u001b[39;49;00m, \u001b[94m5\u001b[39;49;00m)) \u001b[90m# Expected: take first 5 elements of original 'x'\u001b[39;49;00m\u001b[90m\u001b[39;49;00m\n",
- " ),\u001b[90m\u001b[39;49;00m\n",
- " \u001b[90m# Case 2: Dimension reduction with a mean model\u001b[39;49;00m\u001b[90m\u001b[39;49;00m\n",
- " (\u001b[90m\u001b[39;49;00m\n",
- " MeanAlongDim(ax=\u001b[94m2\u001b[39;49;00m), \u001b[90m# Corresponds to 'y' dim in batch\u001b[39;49;00m\u001b[90m\u001b[39;49;00m\n",
- " {\u001b[33m'\u001b[39;49;00m\u001b[33mx\u001b[39;49;00m\u001b[33m'\u001b[39;49;00m: \u001b[94m10\u001b[39;49;00m},\u001b[90m\u001b[39;49;00m\n",
- " [],\u001b[90m\u001b[39;49;00m\n",
- " [\u001b[33m'\u001b[39;49;00m\u001b[33mx\u001b[39;49;00m\u001b[33m'\u001b[39;49;00m],\u001b[90m\u001b[39;49;00m\n",
- " \u001b[94mlambda\u001b[39;49;00m da: da.mean(dim=\u001b[33m'\u001b[39;49;00m\u001b[33my\u001b[39;49;00m\u001b[33m'\u001b[39;49;00m) \u001b[90m# Expected: mean along original 'y'\u001b[39;49;00m\u001b[90m\u001b[39;49;00m\n",
- " ),\u001b[90m\u001b[39;49;00m\n",
- " ]\u001b[90m\u001b[39;49;00m\n",
- " )\u001b[90m\u001b[39;49;00m\n",
- " \u001b[94mdef\u001b[39;49;00m\u001b[90m \u001b[39;49;00m\u001b[92mtest_predict_on_array_reassembly\u001b[39;49;00m(\u001b[90m\u001b[39;49;00m\n",
- " map_dataset_fixture,\u001b[90m\u001b[39;49;00m\n",
- " model,\u001b[90m\u001b[39;49;00m\n",
- " output_tensor_dim,\u001b[90m\u001b[39;49;00m\n",
- " new_dim,\u001b[90m\u001b[39;49;00m\n",
- " resample_dim,\u001b[90m\u001b[39;49;00m\n",
- " expected_transform\u001b[90m\u001b[39;49;00m\n",
- " ):\u001b[90m\u001b[39;49;00m\n",
- " \u001b[90m \u001b[39;49;00m\u001b[33m\"\"\"\u001b[39;49;00m\n",
- " \u001b[33m Tests that predict_on_array correctly reassembles batches from different models.\u001b[39;49;00m\n",
- " \u001b[33m \"\"\"\u001b[39;49;00m\u001b[90m\u001b[39;49;00m\n",
- " \u001b[90m# --- Run the function under test ---\u001b[39;49;00m\u001b[90m\u001b[39;49;00m\n",
- " \u001b[90m# Using a small batch_size to ensure multiple iterations\u001b[39;49;00m\u001b[90m\u001b[39;49;00m\n",
- "> predicted_da, predicted_n = predict_on_array(\u001b[90m\u001b[39;49;00m\n",
- " dataset=map_dataset_fixture,\u001b[90m\u001b[39;49;00m\n",
- " model=model,\u001b[90m\u001b[39;49;00m\n",
- " output_tensor_dim=output_tensor_dim,\u001b[90m\u001b[39;49;00m\n",
- " new_dim=new_dim,\u001b[90m\u001b[39;49;00m\n",
- " resample_dim=resample_dim,\u001b[90m\u001b[39;49;00m\n",
- " batch_size=\u001b[94m4\u001b[39;49;00m\u001b[90m\u001b[39;49;00m\n",
- " )\u001b[90m\u001b[39;49;00m\n",
- "\n",
- "\u001b[1m\u001b[31mtest_predict_on_array.py\u001b[0m:67: \n",
- "_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ \n",
- "\u001b[1m\u001b[31mfunctions.py\u001b[0m:55: in predict_on_array\n",
- " \u001b[0moutput_size = _get_output_array_size(dataset.X_generator, output_tensor_dim, new_dim, resample_dim)\u001b[90m\u001b[39;49;00m\n",
- "_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ \n",
- "\n",
- "bgen = \n",
- "output_tensor_dim = {'x': 5, 'y': 5}, new_dim = [], resample_dim = ['x']\n",
- "\n",
- " \u001b[0m\u001b[94mdef\u001b[39;49;00m\u001b[90m \u001b[39;49;00m\u001b[92m_get_output_array_size\u001b[39;49;00m(\u001b[90m\u001b[39;49;00m\n",
- " bgen: xbatcher.BatchGenerator,\u001b[90m\u001b[39;49;00m\n",
- " output_tensor_dim: \u001b[96mdict\u001b[39;49;00m[\u001b[96mstr\u001b[39;49;00m, \u001b[96mint\u001b[39;49;00m],\u001b[90m\u001b[39;49;00m\n",
- " new_dim: \u001b[96mlist\u001b[39;49;00m[\u001b[96mstr\u001b[39;49;00m],\u001b[90m\u001b[39;49;00m\n",
- " resample_dim: \u001b[96mlist\u001b[39;49;00m[\u001b[96mstr\u001b[39;49;00m]\u001b[90m\u001b[39;49;00m\n",
- " ):\u001b[90m\u001b[39;49;00m\n",
- " resample_factor = _get_resample_factor(bgen, output_tensor_dim, resample_dim)\u001b[90m\u001b[39;49;00m\n",
- " output_size = {}\u001b[90m\u001b[39;49;00m\n",
- " \u001b[94mfor\u001b[39;49;00m key, size \u001b[95min\u001b[39;49;00m output_tensor_dim.items():\u001b[90m\u001b[39;49;00m\n",
- " \u001b[94mif\u001b[39;49;00m key \u001b[95min\u001b[39;49;00m new_dim:\u001b[90m\u001b[39;49;00m\n",
- " \u001b[90m# This is a new axis, size is determined\u001b[39;49;00m\u001b[90m\u001b[39;49;00m\n",
- " \u001b[90m# by the tensor size.\u001b[39;49;00m\u001b[90m\u001b[39;49;00m\n",
- " output_size[key] = output_tensor_dim[key]\u001b[90m\u001b[39;49;00m\n",
- " \u001b[94melse\u001b[39;49;00m:\u001b[90m\u001b[39;49;00m\n",
- " \u001b[90m# This is a resampled axis, determine the new size\u001b[39;49;00m\u001b[90m\u001b[39;49;00m\n",
- " \u001b[90m# by the ratio of the batchgen window to the tensor size.\u001b[39;49;00m\u001b[90m\u001b[39;49;00m\n",
- "> temp_output_size = bgen.ds.sizes[key] * resample_factor[key]\u001b[90m\u001b[39;49;00m\n",
- "\u001b[1m\u001b[31mE KeyError: 'y'\u001b[0m\n",
- "\n",
- "\u001b[1m\u001b[31mfunctions.py\u001b[0m:36: KeyError\n",
- "\u001b[31m\u001b[1m_ test_predict_on_array_reassembly[model1-output_tensor_dim1-new_dim1-resample_dim1-] _\u001b[0m\n",
+ "rootdir: /Users/nkalauni/Documents/Cline/xbatcher-deep-learning/notebooks\n",
+ "plugins: anyio-4.10.0\n",
+ "collected 7 items \u001b[0m\u001b[1m\n",
"\n",
- "map_dataset_fixture = \n",
- "model = MeanAlongDim(), output_tensor_dim = {'x': 10}, new_dim = []\n",
- "resample_dim = ['x'], expected_transform = at 0x7f4d4a1371c0>\n",
+ "test_predict_on_array.py::test_resample_coordinate[2.0-edges-expected0] \u001b[32mPASSED\u001b[0m\u001b[32m [ 14%]\u001b[0m\n",
+ "test_predict_on_array.py::test_resample_coordinate[0.5-edges-expected1] \u001b[32mPASSED\u001b[0m\u001b[32m [ 28%]\u001b[0m\n",
+ "test_predict_on_array.py::test_predict_on_array_all_models[model0-output_tensor_dim0-new_dim0-core_dim0-resample_dim0-] \u001b[32mPASSED\u001b[0m\u001b[32m [ 42%]\u001b[0m\n",
+ "test_predict_on_array.py::test_predict_on_array_all_models[model1-output_tensor_dim1-new_dim1-core_dim1-resample_dim1-] \u001b[32mPASSED\u001b[0m\u001b[33m [ 57%]\u001b[0m\n",
+ "test_predict_on_array.py::test_predict_on_array_all_models[model2-output_tensor_dim2-new_dim2-core_dim2-resample_dim2-] \u001b[32mPASSED\u001b[0m\u001b[33m [ 71%]\u001b[0m\n",
+ "test_predict_on_array.py::test_predict_on_array_all_models[model3-output_tensor_dim3-new_dim3-core_dim3-resample_dim3-] \u001b[32mPASSED\u001b[0m\u001b[33m [ 85%]\u001b[0m\n",
+ "test_predict_on_array.py::test_predict_on_array_all_models[model4-output_tensor_dim4-new_dim4-core_dim4-resample_dim4-] \u001b[32mPASSED\u001b[0m\u001b[33m [100%]\u001b[0m\n",
"\n",
- " \u001b[0m\u001b[37m@pytest\u001b[39;49;00m.mark.parametrize(\u001b[90m\u001b[39;49;00m\n",
- " \u001b[33m\"\u001b[39;49;00m\u001b[33mmodel, output_tensor_dim, new_dim, resample_dim, expected_transform\u001b[39;49;00m\u001b[33m\"\u001b[39;49;00m,\u001b[90m\u001b[39;49;00m\n",
- " [\u001b[90m\u001b[39;49;00m\n",
- " \u001b[90m# Case 1: Resampling - Downsampling with a subset model\u001b[39;49;00m\u001b[90m\u001b[39;49;00m\n",
- " (\u001b[90m\u001b[39;49;00m\n",
- " SubsetAlongAxis(ax=\u001b[94m1\u001b[39;49;00m, n=\u001b[94m5\u001b[39;49;00m), \u001b[90m# Corresponds to 'x' dim in batch\u001b[39;49;00m\u001b[90m\u001b[39;49;00m\n",
- " {\u001b[33m'\u001b[39;49;00m\u001b[33mx\u001b[39;49;00m\u001b[33m'\u001b[39;49;00m: \u001b[94m5\u001b[39;49;00m, \u001b[33m'\u001b[39;49;00m\u001b[33my\u001b[39;49;00m\u001b[33m'\u001b[39;49;00m: \u001b[94m5\u001b[39;49;00m},\u001b[90m\u001b[39;49;00m\n",
- " [],\u001b[90m\u001b[39;49;00m\n",
- " [\u001b[33m'\u001b[39;49;00m\u001b[33mx\u001b[39;49;00m\u001b[33m'\u001b[39;49;00m],\u001b[90m\u001b[39;49;00m\n",
- " \u001b[94mlambda\u001b[39;49;00m da: da.isel(x=\u001b[96mslice\u001b[39;49;00m(\u001b[94m0\u001b[39;49;00m, \u001b[94m5\u001b[39;49;00m)) \u001b[90m# Expected: take first 5 elements of original 'x'\u001b[39;49;00m\u001b[90m\u001b[39;49;00m\n",
- " ),\u001b[90m\u001b[39;49;00m\n",
- " \u001b[90m# Case 2: Dimension reduction with a mean model\u001b[39;49;00m\u001b[90m\u001b[39;49;00m\n",
- " (\u001b[90m\u001b[39;49;00m\n",
- " MeanAlongDim(ax=\u001b[94m2\u001b[39;49;00m), \u001b[90m# Corresponds to 'y' dim in batch\u001b[39;49;00m\u001b[90m\u001b[39;49;00m\n",
- " {\u001b[33m'\u001b[39;49;00m\u001b[33mx\u001b[39;49;00m\u001b[33m'\u001b[39;49;00m: \u001b[94m10\u001b[39;49;00m},\u001b[90m\u001b[39;49;00m\n",
- " [],\u001b[90m\u001b[39;49;00m\n",
- " [\u001b[33m'\u001b[39;49;00m\u001b[33mx\u001b[39;49;00m\u001b[33m'\u001b[39;49;00m],\u001b[90m\u001b[39;49;00m\n",
- " \u001b[94mlambda\u001b[39;49;00m da: da.mean(dim=\u001b[33m'\u001b[39;49;00m\u001b[33my\u001b[39;49;00m\u001b[33m'\u001b[39;49;00m) \u001b[90m# Expected: mean along original 'y'\u001b[39;49;00m\u001b[90m\u001b[39;49;00m\n",
- " ),\u001b[90m\u001b[39;49;00m\n",
- " ]\u001b[90m\u001b[39;49;00m\n",
- " )\u001b[90m\u001b[39;49;00m\n",
- " \u001b[94mdef\u001b[39;49;00m\u001b[90m \u001b[39;49;00m\u001b[92mtest_predict_on_array_reassembly\u001b[39;49;00m(\u001b[90m\u001b[39;49;00m\n",
- " map_dataset_fixture,\u001b[90m\u001b[39;49;00m\n",
- " model,\u001b[90m\u001b[39;49;00m\n",
- " output_tensor_dim,\u001b[90m\u001b[39;49;00m\n",
- " new_dim,\u001b[90m\u001b[39;49;00m\n",
- " resample_dim,\u001b[90m\u001b[39;49;00m\n",
- " expected_transform\u001b[90m\u001b[39;49;00m\n",
- " ):\u001b[90m\u001b[39;49;00m\n",
- " \u001b[90m \u001b[39;49;00m\u001b[33m\"\"\"\u001b[39;49;00m\n",
- " \u001b[33m Tests that predict_on_array correctly reassembles batches from different models.\u001b[39;49;00m\n",
- " \u001b[33m \"\"\"\u001b[39;49;00m\u001b[90m\u001b[39;49;00m\n",
- " \u001b[90m# --- Run the function under test ---\u001b[39;49;00m\u001b[90m\u001b[39;49;00m\n",
- " \u001b[90m# Using a small batch_size to ensure multiple iterations\u001b[39;49;00m\u001b[90m\u001b[39;49;00m\n",
- "> predicted_da, predicted_n = predict_on_array(\u001b[90m\u001b[39;49;00m\n",
- " dataset=map_dataset_fixture,\u001b[90m\u001b[39;49;00m\n",
- " model=model,\u001b[90m\u001b[39;49;00m\n",
- " output_tensor_dim=output_tensor_dim,\u001b[90m\u001b[39;49;00m\n",
- " new_dim=new_dim,\u001b[90m\u001b[39;49;00m\n",
- " resample_dim=resample_dim,\u001b[90m\u001b[39;49;00m\n",
- " batch_size=\u001b[94m4\u001b[39;49;00m\u001b[90m\u001b[39;49;00m\n",
- " )\u001b[90m\u001b[39;49;00m\n",
- "\u001b[1m\u001b[31mE ValueError: too many values to unpack (expected 2)\u001b[0m\n",
+ "\u001b[33m=============================== warnings summary ===============================\u001b[0m\n",
+ "test_predict_on_array.py::test_predict_on_array_all_models[model0-output_tensor_dim0-new_dim0-core_dim0-resample_dim0-]\n",
+ "test_predict_on_array.py::test_predict_on_array_all_models[model1-output_tensor_dim1-new_dim1-core_dim1-resample_dim1-]\n",
+ "test_predict_on_array.py::test_predict_on_array_all_models[model2-output_tensor_dim2-new_dim2-core_dim2-resample_dim2-]\n",
+ "test_predict_on_array.py::test_predict_on_array_all_models[model3-output_tensor_dim3-new_dim3-core_dim3-resample_dim3-]\n",
+ "test_predict_on_array.py::test_predict_on_array_all_models[model4-output_tensor_dim4-new_dim4-core_dim4-resample_dim4-]\n",
+ " /Users/nkalauni/Documents/Cline/xbatcher-deep-learning/notebooks/test_predict_on_array.py:108: RuntimeWarning: invalid value encountered in divide\n",
+ " expected_avg_data = expected_sum.data / expected_count.data\n",
"\n",
- "\u001b[1m\u001b[31mtest_predict_on_array.py\u001b[0m:67: ValueError\n",
- "\u001b[36m\u001b[1m=========================== short test summary info ============================\u001b[0m\n",
- "\u001b[31mFAILED\u001b[0m test_predict_on_array.py::\u001b[1mtest_predict_on_array_reassembly[model0-output_tensor_dim0-new_dim0-resample_dim0-]\u001b[0m - KeyError: 'y'\n",
- "\u001b[31mFAILED\u001b[0m test_predict_on_array.py::\u001b[1mtest_predict_on_array_reassembly[model1-output_tensor_dim1-new_dim1-resample_dim1-]\u001b[0m - ValueError: too many values to unpack (expected 2)\n",
- "\u001b[31m============================== \u001b[31m\u001b[1m2 failed\u001b[0m\u001b[31m in 1.98s\u001b[0m\u001b[31m ===============================\u001b[0m\n"
+ "-- Docs: https://docs.pytest.org/en/stable/how-to/capture-warnings.html\n",
+ "\u001b[33m======================== \u001b[32m7 passed\u001b[0m, \u001b[33m\u001b[1m5 warnings\u001b[0m\u001b[33m in 1.25s\u001b[0m\u001b[33m =========================\u001b[0m\n"
]
}
],
diff --git a/notebooks/test_get_array_size.py b/notebooks/test_get_array_size.py
index d54aea67..868eade7 100644
--- a/notebooks/test_get_array_size.py
+++ b/notebooks/test_get_array_size.py
@@ -46,120 +46,11 @@ def bgen_fixture() -> xbatcher.BatchGenerator:
),
(
"Mixed: Resample x, add new channel dimension and keep t as core",
- {'x': 30, 'channel': 12},
+ {'x': 30, 'channel': 12, 't': 10},
['channel'],
['t'],
['x'],
- {'x': 300, 'channel': 12}
- ),
- (
- "Identity resampling (ratio=1)",
- {'x': 10, 'y': 10},
- [],
- [],
- ['x', 'y'],
- {'x': 100, 'y': 100}
- ),
- (
- "Core dims only: 't' is a core dim",
- {'t': 10},
- [],
- ['t'],
- [],
- {'t': 10}
- ),
- ]
-)
-def test_get_output_array_size_scenarios(
- bgen_fixture, # The fixture is passed as an argument
- case_description,
- output_tensor_dim,
- new_dim,
- core_dim,
- resample_dim,
- expected_output
-):
- """
- Tests various valid scenarios for calculating the output array size.
- The `case_description` parameter is not used in the code but helps make
- test results more readable.
- """
- # The `bgen_fixture` argument is the BatchGenerator instance created by our fixture
- result = _get_output_array_size(
- bgen=bgen_fixture,
- output_tensor_dim=output_tensor_dim,
- new_dim=new_dim,
- core_dim=core_dim,
- resample_dim=resample_dim
- )
-
- assert result == expected_output, f"Failed on case: {case_description}"
-
-def test_get_output_array_size_raises_error_on_mismatched_core_dim(bgen_fixture):
- """Tests ValueError when a core_dim size doesn't match the source."""
- with pytest.raises(ValueError, match="does not equal the source data array size"):
- _get_output_array_size(
- bgen_fixture, output_tensor_dim={'t': 99}, new_dim=[], core_dim=['t'], resample_dim=[]
- )
-
-def test_get_output_array_size_raises_error_on_unspecified_dim(bgen_fixture):
- """Tests ValueError when a dimension is not specified in any category."""
- with pytest.raises(ValueError, match="must be specified in one of"):
- _get_output_array_size(
- bgen_fixture, output_tensor_dim={'x': 10}, new_dim=[], core_dim=[], resample_dim=[]
- )
-
-def test_get_resample_factor_raises_error_on_invalid_ratio(bgen_fixture):
- """Tests AssertionError when the resample ratio is not an integer or its inverse."""
- with pytest.raises(AssertionError, match="must be an integer or its inverse"):
- # 15 / 10 = 1.5, which is not a valid ratio
- _get_resample_factor(bgen_fixture, output_tensor_dim={'x': 15}, resample_dim=['x'])
-
-@pytest.fixture
-def bgen_fixture() -> xbatcher.BatchGenerator:
- data = xr.DataArray(
- data=np.random.rand(100, 100, 10),
- dims=("x", "y", "t"),
- coords={
- "x": np.arange(100),
- "y": np.arange(100),
- "t": np.arange(10),
- }
- )
-
- bgen = xbatcher.BatchGenerator(
- data,
- input_dims=dict(x=10, y=10),
- input_overlap=dict(x=5, y=5),
- )
- return bgen
-
-@pytest.mark.parametrize(
- "case_description, output_tensor_dim, new_dim, core_dim, resample_dim, expected_output",
- [
- (
- "Resampling only: Downsample x, Upsample y",
- {'x': 5, 'y': 20},
- [],
- [],
- ['x', 'y'],
- {'x': 50, 'y': 200}
- ),
- (
- "New dimensions only: Add a 'channel' dimension",
- {'channel': 3},
- ['channel'],
- [],
- [],
- {'channel': 3}
- ),
- (
- "Mixed: Resample x, add new channel dimension and keep t as core",
- {'x': 30, 'channel': 12},
- ['channel'],
- ['t'],
- ['x'],
- {'x': 300, 'channel': 12}
+ {'x': 300, 'channel': 12, 't': 10}
),
(
"Identity resampling (ratio=1)",
diff --git a/notebooks/test_predict_on_array.py b/notebooks/test_predict_on_array.py
index 4a41e18f..8f45aa4f 100644
--- a/notebooks/test_predict_on_array.py
+++ b/notebooks/test_predict_on_array.py
@@ -5,27 +5,107 @@
import pytest
from xbatcher.loaders.torch import MapDataset
-from functions import _get_output_array_size, predict_on_array
-from dummy_models import *
+from functions import _get_output_array_size, _resample_coordinate
+from functions import predict_on_array, _get_resample_factor
+from dummy_models import Identity, MeanAlongDim, SubsetAlongAxis, ExpandAlongAxis, AddAxis
@pytest.fixture
def map_dataset_fixture() -> MapDataset:
- """
- Creates a MapDataset with a predictable BatchGenerator for testing.
- - Data is an xarray DataArray with dimensions x=20, y=10
- - Values are a simple np.arange sequence for easy verification.
- - Batches are size x=10, y=5 with overlap x=2, y=2
- """
- # Using a smaller, more manageable dataset for testing
data = xr.DataArray(
- data=np.arange(20 * 10).reshape(20, 10),
+ data=np.arange(20 * 10).reshape(20, 10).astype(np.float32),
dims=("x", "y"),
- coords={"x": np.arange(20), "y": np.arange(10)}
- ).astype(float)
-
- bgen = xbatcher.BatchGenerator(
- data,
- input_dims=dict(x=10, y=5),
- input_overlap=dict(x=2, y=2),
+ coords={"x": np.arange(20, dtype=float), "y": np.arange(10, dtype=float)},
)
+ bgen = xbatcher.BatchGenerator(data, input_dims=dict(x=10, y=5), input_overlap=dict(x=2, y=2))
return MapDataset(bgen)
+
+@pytest.mark.parametrize("factor, mode, expected", [
+ (2.0, "edges", np.arange(0, 10, 0.5)),
+ (0.5, "edges", np.arange(0, 10, 2.0)),
+])
+def test_resample_coordinate(factor, mode, expected):
+ coord = xr.DataArray(np.arange(10, dtype=float), dims="x")
+ resampled = _resample_coordinate(coord, factor, mode)
+ np.testing.assert_allclose(resampled, expected)
+
+@pytest.mark.parametrize(
+ "model, output_tensor_dim, new_dim, core_dim, resample_dim, manual_transform",
+ [
+ # Case 1: Identity - No change
+ (
+ Identity(),
+ {'x': 10, 'y': 5},
+ [], [], ['x', 'y'],
+ lambda da: da.data
+ ),
+ # Case 2: ExpandAlongAxis - Upsampling
+ (
+ ExpandAlongAxis(ax=1, n_repeats=2), # ax=1 is 'x'
+ {'x': 20, 'y': 5},
+ [], [], ['x', 'y'],
+ lambda da: da.data.repeat(2, axis=0) # axis=0 in the 2D numpy array
+ ),
+ # Case 3: SubsetAlongAxis - Coarsening
+ (
+ SubsetAlongAxis(ax=1, n=5), # ax=1 is 'x'
+ {'x': 5, 'y': 5},
+ [], [], ['x', 'y'],
+ lambda da: da.isel(x=slice(0, 5)).data
+ ),
+ # Case 4: MeanAlongDim - Dimension reduction
+ (
+ MeanAlongDim(ax=2), # ax=2 is 'y'
+ {'x': 10},
+ [], [], ['x'],
+ lambda da: da.mean(dim='y').data
+ ),
+ # Case 5: AddAxis - Add a new dimension
+ (
+ AddAxis(ax=1), # Add new dim at axis 1
+ {'channel': 1, 'x': 10, 'y': 5},
+ ['channel'], [], ['x', 'y'],
+ lambda da: np.expand_dims(da.data, axis=0)
+ ),
+ ]
+)
+def test_predict_on_array_all_models(
+ map_dataset_fixture, model, output_tensor_dim, new_dim, core_dim, resample_dim, manual_transform
+):
+ """
+ Tests reassembly, averaging, and coordinate assignment using a variety of models.
+ """
+ dataset = map_dataset_fixture
+ bgen = dataset.X_generator
+ resample_factor = _get_resample_factor(bgen, output_tensor_dim, resample_dim)
+
+ # --- Run the function under test ---
+ result_da = predict_on_array(
+ dataset=dataset, model=model, output_tensor_dim=output_tensor_dim,
+ new_dim=new_dim, core_dim=core_dim, resample_dim=resample_dim, batch_size=4
+ )
+
+ # --- Manually calculate the expected result ---
+ expected_size = _get_output_array_size(bgen, output_tensor_dim, new_dim, core_dim, resample_dim)
+ expected_sum = xr.DataArray(np.zeros(list(expected_size.values())), dims=list(expected_size.keys()))
+ expected_count = xr.full_like(expected_sum, 0, dtype=int)
+
+ for i in range(len(dataset)):
+ batch_da = bgen[i]
+ old_indexer = bgen._batch_selectors.selectors[i][0]
+ new_indexer = {}
+ for key in old_indexer:
+ if key in resample_dim:
+ new_indexer[key] = slice(int(old_indexer[key].start * resample_factor.get(key, 1)), int(old_indexer[key].stop * resample_factor.get(key, 1)))
+ elif key in core_dim:
+ new_indexer[key] = old_indexer[key]
+
+ model_output_on_batch = manual_transform(batch_da)
+ print(f"Batch {i}: {new_indexer} -> {model_output_on_batch.shape}")
+ print(f"Expected sum shape: {expected_sum.loc[new_indexer].shape}")
+ expected_sum.loc[new_indexer] += model_output_on_batch
+ expected_count.loc[new_indexer] += 1
+
+ expected_avg_data = expected_sum.data / expected_count.data
+
+ # --- Assert correctness ---
+ np.testing.assert_allclose(result_da.values, expected_avg_data, equal_nan=True)