Skip to content

Commit 145762d

Browse files
committed
fix bad variable names in _get_output_array_size
1 parent 1773398 commit 145762d

File tree

1 file changed

+27
-15
lines changed

1 file changed

+27
-15
lines changed

notebooks/functions.ipynb

Lines changed: 27 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@
1111
{
1212
"cell_type": "code",
1313
"execution_count": null,
14-
"id": "405f62a3-f187-436b-a94c-1658014c9aae",
14+
"id": "0d976e10-f285-4682-a532-ecfe9106a9f8",
1515
"metadata": {},
1616
"outputs": [],
1717
"source": [
@@ -36,16 +36,23 @@
3636
" else:\n",
3737
" # This is a resampled axis, determine the new size\n",
3838
" # by the ratio of the batchgen window to the tensor size.\n",
39-
" window_size = ds.X_generator.input_dims[key]\n",
39+
" window_size = bgen.input_dims[key]\n",
4040
" tensor_size = output_tensor_dim[key]\n",
4141
" resample_ratio = tensor_size / window_size\n",
4242
" \n",
43-
" temp_output_size = ds.X_generator.ds.sizes[key] * resample_ratio\n",
43+
" temp_output_size = bgen.ds.sizes[key] * resample_ratio\n",
4444
" assert temp_output_size.is_integer()\n",
4545
" output_size[key] = int(temp_output_size)\n",
46-
" return output_size\n",
47-
" \n",
48-
"\n",
46+
" return output_size"
47+
]
48+
},
49+
{
50+
"cell_type": "code",
51+
"execution_count": null,
52+
"id": "eb717cb1-1715-4441-8b67-5593d1dbd6b8",
53+
"metadata": {},
54+
"outputs": [],
55+
"source": [
4956
"def predict_on_array(\n",
5057
" dataset: MapDataset,\n",
5158
" model: torch.nn.Module,\n",
@@ -62,24 +69,29 @@
6269
" dims=tuple(output_size.keys())\n",
6370
" )\n",
6471
" output_n = xr.full_like(output_da, 0)\n",
65-
"\n",
66-
" '''\n",
72+
" \n",
6773
" # Prepare data laoder\n",
6874
" loader = torch.utils.data.DataLoader(dataset, batch_size=batch_size)\n",
75+
"\n",
76+
" '''\n",
6977
" for i, batch in enumerate(loader):\n",
7078
" out_batch = model(batch).detach().numpy()\n",
71-
" # TODO write each example to the output array\n",
79+
"\n",
80+
" # Iterate over each example in the batch\n",
7281
" for ib in range(out_batch.shape[0]):\n",
7382
" # Get the slice object associated with this example\n",
7483
" old_indexer = dataset.X_generator._batch_selectors.selectors[(i*batch_size)+ib][0]\n",
7584
" # Only index into axes that are resampled, rescaling the bounds\n",
76-
" new_indexer = dict(\n",
77-
" dim: slice()\n",
78-
" for dim in old_indexer if dim in resample_dim\n",
79-
" )\n",
80-
" output_da.loc[indexer] += model_output\n",
81-
" output_n.loc[indexer] += 1\n",
85+
" new_indexer = dict()\n",
86+
" for key in old_indexer:\n",
87+
" if key in resample_dim:\n",
88+
" resample_ratio = output_tensor_dim[key] / dataset.X_generator.input_dims[key]\n",
89+
"\n",
90+
" \n",
91+
" output_da.loc[new_indexer] += out_batch[ib, ...]\n",
92+
" output_n.loc[new_indexer] += 1\n",
8293
" '''\n",
94+
" \n",
8395
"\n",
8496
" # TODO aggregate output\n",
8597
" return output_da"

0 commit comments

Comments
 (0)