|
| 1 | +{ |
| 2 | + "cells": [ |
| 3 | + { |
| 4 | + "cell_type": "markdown", |
| 5 | + "id": "289d5b78-6162-4bac-af94-b52d5c194158", |
| 6 | + "metadata": {}, |
| 7 | + "source": [ |
| 8 | + "# Requirments" |
| 9 | + ] |
| 10 | + }, |
| 11 | + { |
| 12 | + "cell_type": "code", |
| 13 | + "execution_count": 37, |
| 14 | + "id": "420984ea-3335-42a0-858e-0a1d4435e157", |
| 15 | + "metadata": {}, |
| 16 | + "outputs": [], |
| 17 | + "source": [ |
| 18 | + "from collections import OrderedDict\n", |
| 19 | + "import torch\n", |
| 20 | + "from torch.func import functional_call, grad, vmap\n", |
| 21 | + "import torchopt" |
| 22 | + ] |
| 23 | + }, |
| 24 | + { |
| 25 | + "cell_type": "markdown", |
| 26 | + "id": "3d72e721-148e-4348-b11e-4fa7898c8c6a", |
| 27 | + "metadata": {}, |
| 28 | + "source": [ |
| 29 | + "# Differential equation" |
| 30 | + ] |
| 31 | + }, |
| 32 | + { |
| 33 | + "cell_type": "markdown", |
| 34 | + "id": "eef29346-dc44-4c53-9fe8-1bf5ce444010", |
| 35 | + "metadata": {}, |
| 36 | + "source": [ |
| 37 | + "The differential equation to solve is:\n", |
| 38 | + "$$\n", |
| 39 | + " \\frac{d f}{dt} = R f(t)\\left(1 - f(t)\\right)\n", |
| 40 | + "$$\n", |
| 41 | + "with initial condition $f(0) = 0.5$.\n", |
| 42 | + "\n", |
| 43 | + "This equation can be used to model population growth." |
| 44 | + ] |
| 45 | + }, |
| 46 | + { |
| 47 | + "cell_type": "markdown", |
| 48 | + "id": "7c67e465-a82d-42b4-a832-bd80abbd5b94", |
| 49 | + "metadata": {}, |
| 50 | + "source": [ |
| 51 | + "# Loss function" |
| 52 | + ] |
| 53 | + }, |
| 54 | + { |
| 55 | + "cell_type": "markdown", |
| 56 | + "id": "f8454b80-e172-4bf8-aaeb-bd02b5957bbd", |
| 57 | + "metadata": {}, |
| 58 | + "source": [ |
| 59 | + "The loss function to train the neural network will be the sum of two terms, the first evaluates the differential equation in $M$ time points, the second enforces the initial conditions.\n", |
| 60 | + "$$\n", |
| 61 | + " \\begin{array}{lcl}\n", |
| 62 | + " L_{\\mathrm{DE}} & = & \\frac{1}{M} \\sum_{j=1}^{M} \\left( \\frac{df_{\\mathrm{NN}}}{dt}(t_j) - R f_{\\mathrm{NN}}(t_j) \\left( 1 - f_{\\mathrm{NN}}(t_j) \\right) \\right)^2 \\\\\n", |
| 63 | + " L_{\\mathrm{BC}} & = & \\left( f_{\\mathrm{NN}}(0) - 0.5 \\right)^2 \\\\\n", |
| 64 | + " L & = & L_{\\mathrm{DE}} + L_{\\mathrm{BC}}\n", |
| 65 | + " \\end{array}\n", |
| 66 | + "$$" |
| 67 | + ] |
| 68 | + }, |
| 69 | + { |
| 70 | + "cell_type": "code", |
| 71 | + "execution_count": 35, |
| 72 | + "id": "15acb49f-c285-4b57-be15-98d0f97556ac", |
| 73 | + "metadata": {}, |
| 74 | + "outputs": [], |
| 75 | + "source": [ |
| 76 | + "def tuple_to_dict_parameters(model, params):\n", |
| 77 | + " keys = list(dict(model.named_parameters()).keys())\n", |
| 78 | + " values = list(params)\n", |
| 79 | + " return OrderedDict(({k:v for k, v in zip(keys, values)}))" |
| 80 | + ] |
| 81 | + }, |
| 82 | + { |
| 83 | + "cell_type": "code", |
| 84 | + "execution_count": 36, |
| 85 | + "id": "40ec2dc8-457e-4eff-92f1-8dd5b726f21f", |
| 86 | + "metadata": {}, |
| 87 | + "outputs": [], |
| 88 | + "source": [ |
| 89 | + "def make_forward_fn(model, derivative_order=1):\n", |
| 90 | + "\n", |
| 91 | + " def f(x: torch.Tensor, params: dict[str, torch.nn.Parameter] | tuple[torch.nn.Parameter, ...]) -> torch.Tensor:\n", |
| 92 | + " if isinstance(params, tuple):\n", |
| 93 | + " params_dict = tuple_to_dict_parameters(model, params)\n", |
| 94 | + " else:\n", |
| 95 | + " params_dict = params\n", |
| 96 | + " return functional_call(model, params_dict, (x, ))\n", |
| 97 | + "\n", |
| 98 | + " fns = [f]\n", |
| 99 | + " dfunc = f\n", |
| 100 | + " for _ in range(derivative_order):\n", |
| 101 | + " dfunc = grad(dfunc)\n", |
| 102 | + " fns.append(vmap(dfunc, in_dims=(0, None)))\n", |
| 103 | + " return fns" |
| 104 | + ] |
| 105 | + }, |
| 106 | + { |
| 107 | + "cell_type": "markdown", |
| 108 | + "id": "68efecd1-85c3-4143-b4db-f89f71992ec3", |
| 109 | + "metadata": {}, |
| 110 | + "source": [ |
| 111 | + "# Neural network" |
| 112 | + ] |
| 113 | + }, |
| 114 | + { |
| 115 | + "cell_type": "code", |
| 116 | + "execution_count": 14, |
| 117 | + "id": "e234203f-9d86-4462-8670-aa7ba4d6953e", |
| 118 | + "metadata": {}, |
| 119 | + "outputs": [], |
| 120 | + "source": [ |
| 121 | + "class PINN(torch.nn.Module):\n", |
| 122 | + "\n", |
| 123 | + " def __init__(self, nr_inputs, nr_layers, nr_neurons, activation=torch.nn.Tanh()):\n", |
| 124 | + " super().__init__()\n", |
| 125 | + " self.num_inputs = nr_inputs\n", |
| 126 | + " self.num_layers = nr_layers\n", |
| 127 | + " self.num_neurons = nr_neurons\n", |
| 128 | + " layers = []\n", |
| 129 | + " layers.append(torch.nn.Linear(self.num_inputs, self.num_neurons))\n", |
| 130 | + " for _ in range(self.num_layers):\n", |
| 131 | + " layers.append(torch.nn.Linear(self.num_neurons, self.num_neurons))\n", |
| 132 | + " layers.append(activation)\n", |
| 133 | + " layers.append(torch.nn.Linear(self.num_neurons, 1))\n", |
| 134 | + " self.network = torch.nn.Sequential(*layers)\n", |
| 135 | + "\n", |
| 136 | + " def forward(self, x):\n", |
| 137 | + " return self.network(x.reshape(-1, 1)).squeeze()" |
| 138 | + ] |
| 139 | + }, |
| 140 | + { |
| 141 | + "cell_type": "code", |
| 142 | + "execution_count": 15, |
| 143 | + "id": "f879a606-3b85-488b-a2f6-de3f45774f07", |
| 144 | + "metadata": {}, |
| 145 | + "outputs": [], |
| 146 | + "source": [ |
| 147 | + "model = PINN(nr_inputs=1, nr_neurons=20, nr_layers=3)" |
| 148 | + ] |
| 149 | + }, |
| 150 | + { |
| 151 | + "cell_type": "code", |
| 152 | + "execution_count": 38, |
| 153 | + "id": "d0897090-7f4c-48c5-9e98-1a20ed846d8b", |
| 154 | + "metadata": {}, |
| 155 | + "outputs": [], |
| 156 | + "source": [ |
| 157 | + "f, dfdx = make_forward_fn(model, derivative_order=1)" |
| 158 | + ] |
| 159 | + }, |
| 160 | + { |
| 161 | + "cell_type": "code", |
| 162 | + "execution_count": 39, |
| 163 | + "id": "08440717-e30d-4f97-bf77-39b03130acb5", |
| 164 | + "metadata": {}, |
| 165 | + "outputs": [], |
| 166 | + "source": [ |
| 167 | + "R, x_boundary, f_boundary = 1.0, 0.0, 0.5" |
| 168 | + ] |
| 169 | + }, |
| 170 | + { |
| 171 | + "cell_type": "code", |
| 172 | + "execution_count": 40, |
| 173 | + "id": "1cfe8788-5a39-4134-920f-fdc3d7b37d2b", |
| 174 | + "metadata": {}, |
| 175 | + "outputs": [], |
| 176 | + "source": [ |
| 177 | + "def loss_function(params, x):\n", |
| 178 | + " f_value = f(x, params)\n", |
| 179 | + " interior = dfdx(x, params) - R*f_value*(1.0 - f_value)\n", |
| 180 | + " boundaries = f(torch.tensor([x_boundary]), params) - torch.tensor([f_boundary])\n", |
| 181 | + " loss = torch.nn.MSELoss()\n", |
| 182 | + " return (loss(interior, torch.zeros_like(interior)) +\n", |
| 183 | + " loss(boundaries, torch.zeros_like(boundaries)))" |
| 184 | + ] |
| 185 | + }, |
| 186 | + { |
| 187 | + "cell_type": "code", |
| 188 | + "execution_count": 41, |
| 189 | + "id": "2c661202-c375-4c75-9ae8-1018a90ced63", |
| 190 | + "metadata": {}, |
| 191 | + "outputs": [], |
| 192 | + "source": [ |
| 193 | + "batch_size = 30\n", |
| 194 | + "nr_iters = 100\n", |
| 195 | + "learning_rate = 1.0e-1\n", |
| 196 | + "domain = (-5.0, 5.0)" |
| 197 | + ] |
| 198 | + }, |
| 199 | + { |
| 200 | + "cell_type": "code", |
| 201 | + "execution_count": 42, |
| 202 | + "id": "1f62a320-d18d-4a7c-9673-cf0db5f4fa9c", |
| 203 | + "metadata": {}, |
| 204 | + "outputs": [], |
| 205 | + "source": [ |
| 206 | + "optimizer = torchopt.FuncOptimizer(torchopt.adam(lr=learning_rate))" |
| 207 | + ] |
| 208 | + }, |
| 209 | + { |
| 210 | + "cell_type": "code", |
| 211 | + "execution_count": 44, |
| 212 | + "id": "c1846827-0111-4533-91f4-317245f10cf3", |
| 213 | + "metadata": {}, |
| 214 | + "outputs": [], |
| 215 | + "source": [ |
| 216 | + "params = tuple(model.parameters())" |
| 217 | + ] |
| 218 | + }, |
| 219 | + { |
| 220 | + "cell_type": "code", |
| 221 | + "execution_count": 46, |
| 222 | + "id": "73d5134a-20ad-4247-b6a1-5180b3ee785b", |
| 223 | + "metadata": {}, |
| 224 | + "outputs": [ |
| 225 | + { |
| 226 | + "name": "stdout", |
| 227 | + "output_type": "stream", |
| 228 | + "text": [ |
| 229 | + "iteration 1 with loss 0.11463413387537003\n", |
| 230 | + "iteration 2 with loss 27.16790771484375\n", |
| 231 | + "iteration 3 with loss 0.6420497894287109\n", |
| 232 | + "iteration 4 with loss 2.8561654090881348\n", |
| 233 | + "iteration 5 with loss 3.5104665756225586\n", |
| 234 | + "iteration 6 with loss 2.035282611846924\n", |
| 235 | + "iteration 7 with loss 0.14833369851112366\n", |
| 236 | + "iteration 8 with loss 0.17722176015377045\n", |
| 237 | + "iteration 9 with loss 0.29466158151626587\n", |
| 238 | + "iteration 10 with loss 0.13089382648468018\n", |
| 239 | + "iteration 11 with loss 0.3749958872795105\n", |
| 240 | + "iteration 12 with loss 0.1261375993490219\n", |
| 241 | + "iteration 13 with loss 0.17460887134075165\n", |
| 242 | + "iteration 14 with loss 0.5559177994728088\n", |
| 243 | + "iteration 15 with loss 0.41323214769363403\n", |
| 244 | + "iteration 16 with loss 0.40801599621772766\n", |
| 245 | + "iteration 17 with loss 0.33771443367004395\n", |
| 246 | + "iteration 18 with loss 0.20494796335697174\n", |
| 247 | + "iteration 19 with loss 0.10330705344676971\n", |
| 248 | + "iteration 20 with loss 0.10414065420627594\n", |
| 249 | + "iteration 21 with loss 0.18002170324325562\n", |
| 250 | + "iteration 22 with loss 0.19086192548274994\n", |
| 251 | + "iteration 23 with loss 0.1662447154521942\n", |
| 252 | + "iteration 24 with loss 0.12068456411361694\n", |
| 253 | + "iteration 25 with loss 0.20130692422389984\n", |
| 254 | + "iteration 26 with loss 0.09801061451435089\n", |
| 255 | + "iteration 27 with loss 0.13131101429462433\n", |
| 256 | + "iteration 28 with loss 0.09928253293037415\n", |
| 257 | + "iteration 29 with loss 0.08115098625421524\n", |
| 258 | + "iteration 30 with loss 0.08818767964839935\n", |
| 259 | + "iteration 31 with loss 0.08493972569704056\n", |
| 260 | + "iteration 32 with loss 0.07110773026943207\n", |
| 261 | + "iteration 33 with loss 0.05105520784854889\n", |
| 262 | + "iteration 34 with loss 0.05446767061948776\n", |
| 263 | + "iteration 35 with loss 0.062104418873786926\n", |
| 264 | + "iteration 36 with loss 0.05202716961503029\n", |
| 265 | + "iteration 37 with loss 0.055305611342191696\n", |
| 266 | + "iteration 38 with loss 0.044861141592264175\n", |
| 267 | + "iteration 39 with loss 0.046682748943567276\n", |
| 268 | + "iteration 40 with loss 0.042420465499162674\n", |
| 269 | + "iteration 41 with loss 0.041048914194107056\n", |
| 270 | + "iteration 42 with loss 0.03341205418109894\n", |
| 271 | + "iteration 43 with loss 0.0331701822578907\n", |
| 272 | + "iteration 44 with loss 0.03715697303414345\n", |
| 273 | + "iteration 45 with loss 0.04585421085357666\n", |
| 274 | + "iteration 46 with loss 0.03548423945903778\n", |
| 275 | + "iteration 47 with loss 0.03360460698604584\n", |
| 276 | + "iteration 48 with loss 0.03336722403764725\n", |
| 277 | + "iteration 49 with loss 0.035427533090114594\n", |
| 278 | + "iteration 50 with loss 0.040354594588279724\n", |
| 279 | + "iteration 51 with loss 0.030506618320941925\n", |
| 280 | + "iteration 52 with loss 0.022171396762132645\n", |
| 281 | + "iteration 53 with loss 0.026050299406051636\n", |
| 282 | + "iteration 54 with loss 0.03255322948098183\n", |
| 283 | + "iteration 55 with loss 0.026948392391204834\n", |
| 284 | + "iteration 56 with loss 0.023532869294285774\n", |
| 285 | + "iteration 57 with loss 0.02279188483953476\n", |
| 286 | + "iteration 58 with loss 0.024768201634287834\n", |
| 287 | + "iteration 59 with loss 0.02384166046977043\n", |
| 288 | + "iteration 60 with loss 0.025534283369779587\n", |
| 289 | + "iteration 61 with loss 0.023878537118434906\n", |
| 290 | + "iteration 62 with loss 0.01744457334280014\n", |
| 291 | + "iteration 63 with loss 0.01068549882620573\n", |
| 292 | + "iteration 64 with loss 0.018854187801480293\n", |
| 293 | + "iteration 65 with loss 0.01397640723735094\n", |
| 294 | + "iteration 66 with loss 0.014322957023978233\n", |
| 295 | + "iteration 67 with loss 0.00784828420728445\n", |
| 296 | + "iteration 68 with loss 0.015125863254070282\n", |
| 297 | + "iteration 69 with loss 0.011706520803272724\n", |
| 298 | + "iteration 70 with loss 0.010356402024626732\n", |
| 299 | + "iteration 71 with loss 0.010507567785680294\n", |
| 300 | + "iteration 72 with loss 0.0085351737216115\n", |
| 301 | + "iteration 73 with loss 0.008149412460625172\n", |
| 302 | + "iteration 74 with loss 0.007394777145236731\n", |
| 303 | + "iteration 75 with loss 0.0041252137161791325\n", |
| 304 | + "iteration 76 with loss 0.007325059734284878\n", |
| 305 | + "iteration 77 with loss 0.004225192591547966\n", |
| 306 | + "iteration 78 with loss 0.00552705954760313\n", |
| 307 | + "iteration 79 with loss 0.005880812648683786\n", |
| 308 | + "iteration 80 with loss 0.006272132974117994\n", |
| 309 | + "iteration 81 with loss 0.003222860163077712\n", |
| 310 | + "iteration 82 with loss 0.005884816870093346\n", |
| 311 | + "iteration 83 with loss 0.0035097438376396894\n", |
| 312 | + "iteration 84 with loss 0.005251692607998848\n", |
| 313 | + "iteration 85 with loss 0.005707685835659504\n", |
| 314 | + "iteration 86 with loss 0.003127788659185171\n", |
| 315 | + "iteration 87 with loss 0.0027884359005838633\n", |
| 316 | + "iteration 88 with loss 0.004568313714116812\n", |
| 317 | + "iteration 89 with loss 0.0025422151666134596\n", |
| 318 | + "iteration 90 with loss 0.0016729753697291017\n", |
| 319 | + "iteration 91 with loss 0.0035289982333779335\n", |
| 320 | + "iteration 92 with loss 0.0016110112192109227\n", |
| 321 | + "iteration 93 with loss 0.0029011431615799665\n", |
| 322 | + "iteration 94 with loss 0.0015002115396782756\n", |
| 323 | + "iteration 95 with loss 0.0027219068724662066\n", |
| 324 | + "iteration 96 with loss 0.004007971379905939\n", |
| 325 | + "iteration 97 with loss 0.0012513422407209873\n", |
| 326 | + "iteration 98 with loss 0.00218919082544744\n", |
| 327 | + "iteration 99 with loss 0.002706077415496111\n", |
| 328 | + "iteration 100 with loss 0.002446003956720233\n" |
| 329 | + ] |
| 330 | + } |
| 331 | + ], |
| 332 | + "source": [ |
| 333 | + "for iteration in range(nr_iters):\n", |
| 334 | + " x = torch.FloatTensor(batch_size).uniform_(*domain)\n", |
| 335 | + " loss = loss_function(params, x)\n", |
| 336 | + " params = optimizer.step(loss, params)\n", |
| 337 | + " print(f'iteration {iteration + 1} with loss {float(loss)}')" |
| 338 | + ] |
| 339 | + }, |
| 340 | + { |
| 341 | + "cell_type": "code", |
| 342 | + "execution_count": null, |
| 343 | + "id": "3884aaa0-b27b-4773-97a3-499eb2f7b929", |
| 344 | + "metadata": {}, |
| 345 | + "outputs": [], |
| 346 | + "source": [] |
| 347 | + } |
| 348 | + ], |
| 349 | + "metadata": { |
| 350 | + "kernelspec": { |
| 351 | + "display_name": "Python 3 (ipykernel)", |
| 352 | + "language": "python", |
| 353 | + "name": "python3" |
| 354 | + }, |
| 355 | + "language_info": { |
| 356 | + "codemirror_mode": { |
| 357 | + "name": "ipython", |
| 358 | + "version": 3 |
| 359 | + }, |
| 360 | + "file_extension": ".py", |
| 361 | + "mimetype": "text/x-python", |
| 362 | + "name": "python", |
| 363 | + "nbconvert_exporter": "python", |
| 364 | + "pygments_lexer": "ipython3", |
| 365 | + "version": "3.12.1" |
| 366 | + } |
| 367 | + }, |
| 368 | + "nbformat": 4, |
| 369 | + "nbformat_minor": 5 |
| 370 | +} |
0 commit comments