|
11 | 11 | { |
12 | 12 | "cell_type": "code", |
13 | 13 | "execution_count": null, |
14 | | - "id": "405f62a3-f187-436b-a94c-1658014c9aae", |
| 14 | + "id": "0d976e10-f285-4682-a532-ecfe9106a9f8", |
15 | 15 | "metadata": {}, |
16 | 16 | "outputs": [], |
17 | 17 | "source": [ |
|
36 | 36 | " else:\n", |
37 | 37 | " # This is a resampled axis, determine the new size\n", |
38 | 38 | " # by the ratio of the batchgen window to the tensor size.\n", |
39 | | - " window_size = ds.X_generator.input_dims[key]\n", |
| 39 | + " window_size = bgen.input_dims[key]\n", |
40 | 40 | " tensor_size = output_tensor_dim[key]\n", |
41 | 41 | " resample_ratio = tensor_size / window_size\n", |
42 | 42 | " \n", |
43 | | - " temp_output_size = ds.X_generator.ds.sizes[key] * resample_ratio\n", |
| 43 | + " temp_output_size = bgen.ds.sizes[key] * resample_ratio\n", |
44 | 44 | " assert temp_output_size.is_integer()\n", |
45 | 45 | " output_size[key] = int(temp_output_size)\n", |
46 | | - " return output_size\n", |
47 | | - " \n", |
48 | | - "\n", |
| 46 | + " return output_size" |
| 47 | + ] |
| 48 | + }, |
| 49 | + { |
| 50 | + "cell_type": "code", |
| 51 | + "execution_count": null, |
| 52 | + "id": "eb717cb1-1715-4441-8b67-5593d1dbd6b8", |
| 53 | + "metadata": {}, |
| 54 | + "outputs": [], |
| 55 | + "source": [ |
49 | 56 | "def predict_on_array(\n", |
50 | 57 | " dataset: MapDataset,\n", |
51 | 58 | " model: torch.nn.Module,\n", |
|
62 | 69 | " dims=tuple(output_size.keys())\n", |
63 | 70 | " )\n", |
64 | 71 | " output_n = xr.full_like(output_da, 0)\n", |
65 | | - "\n", |
66 | | - " '''\n", |
| 72 | + " \n", |
67 | 73 | " # Prepare data laoder\n", |
68 | 74 | " loader = torch.utils.data.DataLoader(dataset, batch_size=batch_size)\n", |
| 75 | + "\n", |
| 76 | + " '''\n", |
69 | 77 | " for i, batch in enumerate(loader):\n", |
70 | 78 | " out_batch = model(batch).detach().numpy()\n", |
71 | | - " # TODO write each example to the output array\n", |
| 79 | + "\n", |
| 80 | + " # Iterate over each example in the batch\n", |
72 | 81 | " for ib in range(out_batch.shape[0]):\n", |
73 | 82 | " # Get the slice object associated with this example\n", |
74 | 83 | " old_indexer = dataset.X_generator._batch_selectors.selectors[(i*batch_size)+ib][0]\n", |
75 | 84 | " # Only index into axes that are resampled, rescaling the bounds\n", |
76 | | - " new_indexer = dict(\n", |
77 | | - " dim: slice()\n", |
78 | | - " for dim in old_indexer if dim in resample_dim\n", |
79 | | - " )\n", |
80 | | - " output_da.loc[indexer] += model_output\n", |
81 | | - " output_n.loc[indexer] += 1\n", |
| 85 | + " new_indexer = dict()\n", |
| 86 | + " for key in old_indexer:\n", |
| 87 | + " if key in resample_dim:\n", |
| 88 | + " resample_ratio = output_tensor_dim[key] / dataset.X_generator.input_dims[key]\n", |
| 89 | + "\n", |
| 90 | + " \n", |
| 91 | + " output_da.loc[new_indexer] += out_batch[ib, ...]\n", |
| 92 | + " output_n.loc[new_indexer] += 1\n", |
82 | 93 | " '''\n", |
| 94 | + " \n", |
83 | 95 | "\n", |
84 | 96 | " # TODO aggregate output\n", |
85 | 97 | " return output_da" |
|
0 commit comments