You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
"\n# Training an FNO with incremental meta-learning\nIn this example, we demonstrate how to use the small Darcy-Flow \nexample we ship with the package to demonstrate the Incremental FNO\nmeta-learning algorithm\n"
"device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")"
55
+
]
56
+
},
57
+
{
58
+
"cell_type": "markdown",
59
+
"metadata": {},
60
+
"source": [
61
+
"Set up the incremental FNO model\nWe start with 2 modes in each dimension\nWe choose to update the modes by the incremental gradient explained algorithm\n\n"
"optimizer = AdamW(model.parameters(), lr=8e-3, weight_decay=1e-4)\nscheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=30)\n\n\n# If one wants to use Incremental Resolution, one should use the IncrementalDataProcessor - When passed to the trainer, the trainer will automatically update the resolution\n# Incremental_resolution : bool, default is False\n# if True, increase the resolution of the input incrementally\n# uses the incremental_res_gap parameter\n# uses the subsampling_rates parameter - a list of resolutions to use\n# uses the dataset_indices parameter - a list of indices of the dataset to slice to regularize the input resolution\n# uses the dataset_resolution parameter - the resolution of the input\n# uses the epoch_gap parameter - the number of epochs to wait before increasing the resolution\n# uses the verbose parameter - if True, print the resolution and the number of modes\ndata_transform = IncrementalDataProcessor(\n in_normalizer=None,\n out_normalizer=None,\n device=device,\n subsampling_rates=[2, 1],\n dataset_resolution=16,\n dataset_indices=[2, 3],\n epoch_gap=10,\n verbose=True,\n)\n\ndata_transform = data_transform.to(device)"
"Set up the IncrementalTrainer\nother options include setting incremental_loss_gap = True\nIf one wants to use incremental resolution set it to True\nIn this example we only update the modes and not the resolution\nWhen using the incremental resolution one should keep in mind that the numnber of modes initially set should be strictly less than the resolution\nAgain these are the various paramaters for the various incremental settings\nincremental_grad : bool, default is False\n if True, use the base incremental algorithm which is based on gradient variance\n uses the incremental_grad_eps parameter - set the threshold for gradient variance\n uses the incremental_buffer paramater - sets the number of buffer modes to calculate the gradient variance\n uses the incremental_max_iter parameter - sets the initial number of iterations\n uses the incremental_grad_max_iter parameter - sets the maximum number of iterations to accumulate the gradients\nincremental_loss_gap : bool, default is False\n if True, use the incremental algorithm based on loss gap\n uses the incremental_loss_eps parameter\n\n"
134
+
]
135
+
},
136
+
{
137
+
"cell_type": "code",
138
+
"execution_count": null,
139
+
"metadata": {
140
+
"collapsed": false
141
+
},
142
+
"outputs": [],
143
+
"source": [
144
+
"# Finally pass all of these to the Trainer\ntrainer = IncrementalFNOTrainer(\n model=model,\n n_epochs=20,\n data_processor=data_transform,\n device=device,\n verbose=True,\n incremental_loss_gap=False,\n incremental_grad=True,\n incremental_grad_eps=0.9999,\n incremental_loss_eps = 0.001,\n incremental_buffer=5,\n incremental_max_iter=1,\n incremental_grad_max_iter=2,\n)"
"Plot the prediction, and compare with the ground-truth\nNote that we trained on a very small resolution for\na very small number of epochs\nIn practice, we would train at larger resolution, on many more samples.\n\nHowever, for practicity, we created a minimal example that\ni) fits in just a few Mb of memory\nii) can be trained quickly on CPU\n\nIn practice we would train a Neural Operator on one or multiple GPUs\n\n"
170
+
]
171
+
},
172
+
{
173
+
"cell_type": "code",
174
+
"execution_count": null,
175
+
"metadata": {
176
+
"collapsed": false
177
+
},
178
+
"outputs": [],
179
+
"source": [
180
+
"test_samples = test_loaders[32].dataset\n\nfig = plt.figure(figsize=(7, 7))\nfor index in range(3):\n data = test_samples[index]\n # Input x\n x = data[\"x\"].to(device)\n # Ground-truth\n y = data[\"y\"].to(device)\n # Model prediction\n out = model(x.unsqueeze(0))\n ax = fig.add_subplot(3, 3, index * 3 + 1)\n x = x.cpu().squeeze().detach().numpy()\n y = y.cpu().squeeze().detach().numpy()\n ax.imshow(x, cmap=\"gray\")\n if index == 0:\n ax.set_title(\"Input x\")\n plt.xticks([], [])\n plt.yticks([], [])\n\n ax = fig.add_subplot(3, 3, index * 3 + 2)\n ax.imshow(y.squeeze())\n if index == 0:\n ax.set_title(\"Ground-truth y\")\n plt.xticks([], [])\n plt.yticks([], [])\n\n ax = fig.add_subplot(3, 3, index * 3 + 3)\n ax.imshow(out.cpu().squeeze().detach().numpy())\n if index == 0:\n ax.set_title(\"Model prediction\")\n plt.xticks([], [])\n plt.yticks([], [])\n\nfig.suptitle(\"Inputs, ground-truth output and prediction.\", y=0.98)\nplt.tight_layout()\nfig.show()"
"for res, test_loader in test_loaders.items():\n print(res)\n # Get first batch\n batch = next(iter(test_loader))\n x = batch['x']\n y = batch['y']\n\n print(f'Testing samples for res {res} have shape {x.shape[1:]}')\n\n\ndata = train_dataset[0]\nx = data['x']\ny = data['y']\n\nprint(f'Training samples have shape {x.shape[1:]}')\n\n\n# Which sample to view\nindex = 0\n\ndata = train_dataset[index]\ndata = data_processor.preprocess(data, batched=False)\n\n# The first step of the default FNO model is a grid-based\n# positional embedding. We will add it manually here to\n# visualize the channels appended by this embedding.\npositional_embedding = GridEmbedding2D(in_channels=1)\n# at train time, data will be collated with a batch dim.\n# we create a batch dim to pass into the embedding, then re-squeeze\nx = positional_embedding(data['x'].unsqueeze(0)).squeeze(0)\ny = data['y']\nfig = plt.figure(figsize=(7, 7))\nax = fig.add_subplot(2, 2, 1)\nax.imshow(x[0], cmap='gray')\nax.set_title('input x')\nax = fig.add_subplot(2, 2, 2)\nax.imshow(y.squeeze())\nax.set_title('input y')\nax = fig.add_subplot(2, 2, 3)\nax.imshow(x[1])\nax.set_title('x: 1st pos embedding')\nax = fig.add_subplot(2, 2, 4)\nax.imshow(x[2])\nax.set_title('x: 2nd pos embedding')\nfig.suptitle('Visualizing one input sample', y=0.98)\nplt.tight_layout()\nfig.show()"
0 commit comments