Skip to content

Commit a749d30

Browse files
authored
Merge pull request #6 from s-kganz/main
Create notebooks for function definitions and testing
2 parents f42cf64 + 76d7e5e commit a749d30

File tree

3 files changed

+907
-361
lines changed

3 files changed

+907
-361
lines changed

notebooks/functions.ipynb

Lines changed: 81 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,81 @@
1+
{
2+
"cells": [
3+
{
4+
"cell_type": "markdown",
5+
"id": "a4a784f3-7130-4c1d-aa11-eb02d614ee76",
6+
"metadata": {},
7+
"source": [
8+
"# Functions for other notebooks to use"
9+
]
10+
},
11+
{
12+
"cell_type": "code",
13+
"execution_count": null,
14+
"id": "405f62a3-f187-436b-a94c-1658014c9aae",
15+
"metadata": {},
16+
"outputs": [],
17+
"source": [
18+
"def predict_on_array(\n",
19+
" dataset: MapDataset,\n",
20+
" model: torch.nn.Module,\n",
21+
" output_tensor_dim: dict[str, int],\n",
22+
" new_dim: list[str],\n",
23+
" resample_dim: list[str],\n",
24+
" batch_size: int=16\n",
25+
"):\n",
26+
" # TODO set up output array\n",
27+
" output_size = {}\n",
28+
" for key, size in output_tensor_dim.items():\n",
29+
" if key in new_dim:\n",
30+
" # This is a new axis, size is determined\n",
31+
" # by the tensor size.\n",
32+
" output_size[key] = output_tensor_dim[key]\n",
33+
" else:\n",
34+
" # This is a resampled axis, determine the new size\n",
35+
" # by the ratio of the batchgen window to the tensor size.\n",
36+
" window_size = ds.X_generator.input_dims[key]\n",
37+
" tensor_size = output_tensor_dim[key]\n",
38+
" resample_ratio = tensor_size / window_size\n",
39+
"\n",
40+
" temp_output_size = ds.X_generator.ds.sizes[key] * resample_ratio\n",
41+
" assert temp_output_size.is_integer()\n",
42+
" output_size[key] = int(temp_output_size)\n",
43+
" \n",
44+
" output_array = np.zeros(tuple(output_size.values()))\n",
45+
" output_n = np.zeros(output_array.shape)\n",
46+
"\n",
47+
" '''\n",
48+
" # Prepare data laoder\n",
49+
" loader = torch.utils.data.DataLoader(dataset, batch_size=batch_size)\n",
50+
" for batch in loader:\n",
51+
" out_batch = model(batch).detach().numpy()\n",
52+
" # TODO write each example to the output array\n",
53+
" '''\n",
54+
"\n",
55+
" # TODO aggregate output\n",
56+
" return output_array"
57+
]
58+
}
59+
],
60+
"metadata": {
61+
"kernelspec": {
62+
"display_name": "Python 3 (ipykernel)",
63+
"language": "python",
64+
"name": "python3"
65+
},
66+
"language_info": {
67+
"codemirror_mode": {
68+
"name": "ipython",
69+
"version": 3
70+
},
71+
"file_extension": ".py",
72+
"mimetype": "text/x-python",
73+
"name": "python",
74+
"nbconvert_exporter": "python",
75+
"pygments_lexer": "ipython3",
76+
"version": "3.10.16"
77+
}
78+
},
79+
"nbformat": 4,
80+
"nbformat_minor": 5
81+
}

0 commit comments

Comments
 (0)