Skip to content

Commit 78a9083

Browse files
Github action: auto-update.
1 parent 157786d commit 78a9083

File tree

89 files changed

+534
-519
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

89 files changed

+534
-519
lines changed
Binary file not shown.
Binary file not shown.
Binary file not shown.

dev/_downloads/1e09e18baf3b11b9dac15e73967dfd23/plot_SFNO_swe.py

Lines changed: 12 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -48,11 +48,11 @@
4848

4949
train_loader, test_loaders = load_spherical_swe(
5050
n_train=200,
51-
batch_size=8,
51+
batch_size=32,
5252
train_resolution=(32, 64),
5353
test_resolutions=[(32, 64), (64, 128)],
54-
n_tests=[50, 50],
55-
test_batch_sizes=[10, 10],
54+
n_tests=[40, 40],
55+
test_batch_sizes=[40, 40],
5656
)
5757

5858

@@ -71,6 +71,7 @@
7171
out_channels=3,
7272
hidden_channels=64,
7373
domain_padding=[0.05, 0.05],
74+
n_layers=2,
7475
)
7576
model = model.to(device)
7677

@@ -175,7 +176,7 @@
175176
#
176177
# In practice we would train a Neural Operator on one or multiple GPUs
177178

178-
fig = plt.figure(figsize=(7, 7))
179+
fig = plt.figure(figsize=(14, 7))
179180
for index, resolution in enumerate([(32, 64), (64, 128)]):
180181
test_samples = test_loaders[resolution].dataset
181182
data = test_samples[0]
@@ -195,20 +196,24 @@
195196
plt.xticks([], [])
196197
plt.yticks([], [])
197198

199+
# Compute the min and max to use consistent color mapping
200+
vmin = y.min()
201+
vmax = y.max()
202+
198203
# Plot ground-truth fields
199204
ax = fig.add_subplot(2, 3, index * 3 + 2)
200-
ax.imshow(y)
205+
im_gt = ax.imshow(y, vmin=vmin, vmax=vmax)
201206
ax.set_title("Ground-truth y")
202207
plt.xticks([], [])
203208
plt.yticks([], [])
204209

205210
# Plot model prediction
206211
ax = fig.add_subplot(2, 3, index * 3 + 3)
207-
ax.imshow(out)
212+
im_pred = ax.imshow(out, vmin=vmin, vmax=vmax)
208213
ax.set_title("SFNO prediction")
209214
plt.xticks([], [])
210215
plt.yticks([], [])
211216

212-
fig.suptitle("SFNO predictions on spherical shallow water equations", y=0.98)
217+
fig.suptitle("SFNO predictions on spherical shallow water equations", y=0.98, fontsize=24)
213218
plt.tight_layout()
214219
fig.show()
Binary file not shown.
Binary file not shown.
Binary file not shown.

dev/_downloads/513ad9091cd6d9b455b9ddd3c807b97f/plot_SFNO_swe.ipynb

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -40,7 +40,7 @@
4040
},
4141
"outputs": [],
4242
"source": [
43-
"train_loader, test_loaders = load_spherical_swe(\n n_train=200,\n batch_size=8,\n train_resolution=(32, 64),\n test_resolutions=[(32, 64), (64, 128)],\n n_tests=[50, 50],\n test_batch_sizes=[10, 10],\n)"
43+
"train_loader, test_loaders = load_spherical_swe(\n n_train=200,\n batch_size=32,\n train_resolution=(32, 64),\n test_resolutions=[(32, 64), (64, 128)],\n n_tests=[40, 40],\n test_batch_sizes=[40, 40],\n)"
4444
]
4545
},
4646
{
@@ -58,7 +58,7 @@
5858
},
5959
"outputs": [],
6060
"source": [
61-
"model = SFNO(\n n_modes=(16, 32),\n in_channels=3,\n out_channels=3,\n hidden_channels=64,\n domain_padding=[0.05, 0.05],\n)\nmodel = model.to(device)\n\n# Count and display the number of parameters\nn_params = count_model_params(model)\nprint(f\"\\nOur model has {n_params} parameters.\")\nsys.stdout.flush()"
61+
"model = SFNO(\n n_modes=(16, 32),\n in_channels=3,\n out_channels=3,\n hidden_channels=64,\n domain_padding=[0.05, 0.05],\n n_layers=2,\n)\nmodel = model.to(device)\n\n# Count and display the number of parameters\nn_params = count_model_params(model)\nprint(f\"\\nOur model has {n_params} parameters.\")\nsys.stdout.flush()"
6262
]
6363
},
6464
{
@@ -159,7 +159,7 @@
159159
},
160160
"outputs": [],
161161
"source": [
162-
"fig = plt.figure(figsize=(7, 7))\nfor index, resolution in enumerate([(32, 64), (64, 128)]):\n test_samples = test_loaders[resolution].dataset\n data = test_samples[0]\n # Input x\n x = data[\"x\"]\n # Ground-truth\n y = data[\"y\"][0, ...].numpy()\n # Model prediction: SFNO output\n x_in = x.unsqueeze(0).to(device)\n out = model(x_in).squeeze()[0, ...].detach().cpu().numpy()\n x = x[0, ...].detach().numpy()\n\n # Plot input fields\n ax = fig.add_subplot(2, 3, index * 3 + 1)\n ax.imshow(x)\n ax.set_title(f\"Input x {resolution}\")\n plt.xticks([], [])\n plt.yticks([], [])\n\n # Plot ground-truth fields\n ax = fig.add_subplot(2, 3, index * 3 + 2)\n ax.imshow(y)\n ax.set_title(\"Ground-truth y\")\n plt.xticks([], [])\n plt.yticks([], [])\n\n # Plot model prediction\n ax = fig.add_subplot(2, 3, index * 3 + 3)\n ax.imshow(out)\n ax.set_title(\"SFNO prediction\")\n plt.xticks([], [])\n plt.yticks([], [])\n\nfig.suptitle(\"SFNO predictions on spherical shallow water equations\", y=0.98)\nplt.tight_layout()\nfig.show()"
162+
"fig = plt.figure(figsize=(14, 7))\nfor index, resolution in enumerate([(32, 64), (64, 128)]):\n test_samples = test_loaders[resolution].dataset\n data = test_samples[0]\n # Input x\n x = data[\"x\"]\n # Ground-truth\n y = data[\"y\"][0, ...].numpy()\n # Model prediction: SFNO output\n x_in = x.unsqueeze(0).to(device)\n out = model(x_in).squeeze()[0, ...].detach().cpu().numpy()\n x = x[0, ...].detach().numpy()\n\n # Plot input fields\n ax = fig.add_subplot(2, 3, index * 3 + 1)\n ax.imshow(x)\n ax.set_title(f\"Input x {resolution}\")\n plt.xticks([], [])\n plt.yticks([], [])\n\n # Compute the min and max to use consistent color mapping\n vmin = y.min()\n vmax = y.max()\n\n # Plot ground-truth fields\n ax = fig.add_subplot(2, 3, index * 3 + 2)\n im_gt = ax.imshow(y, vmin=vmin, vmax=vmax)\n ax.set_title(\"Ground-truth y\")\n plt.xticks([], [])\n plt.yticks([], [])\n\n # Plot model prediction\n ax = fig.add_subplot(2, 3, index * 3 + 3)\n im_pred = ax.imshow(out, vmin=vmin, vmax=vmax)\n ax.set_title(\"SFNO prediction\")\n plt.xticks([], [])\n plt.yticks([], [])\n\nfig.suptitle(\"SFNO predictions on spherical shallow water equations\", y=0.98, fontsize=24)\nplt.tight_layout()\nfig.show()"
163163
]
164164
}
165165
],
Binary file not shown.
Binary file not shown.

0 commit comments

Comments
 (0)