Skip to content

Commit fed64ac

Browse files
committed
Merge branch 'feature/pinns' of github.com:gjbex/Python-for-machine-learning into feature/pinns
2 parents 2a7fe4e + 46be7b2 commit fed64ac

File tree

1 file changed

+370
-0
lines changed

1 file changed

+370
-0
lines changed

source-code/pinns/pinns.ipynb

Lines changed: 370 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,370 @@
1+
{
2+
"cells": [
3+
{
4+
"cell_type": "markdown",
5+
"id": "289d5b78-6162-4bac-af94-b52d5c194158",
6+
"metadata": {},
7+
"source": [
8+
"# Requirments"
9+
]
10+
},
11+
{
12+
"cell_type": "code",
13+
"execution_count": 37,
14+
"id": "420984ea-3335-42a0-858e-0a1d4435e157",
15+
"metadata": {},
16+
"outputs": [],
17+
"source": [
18+
"from collections import OrderedDict\n",
19+
"import torch\n",
20+
"from torch.func import functional_call, grad, vmap\n",
21+
"import torchopt"
22+
]
23+
},
24+
{
25+
"cell_type": "markdown",
26+
"id": "3d72e721-148e-4348-b11e-4fa7898c8c6a",
27+
"metadata": {},
28+
"source": [
29+
"# Differential equation"
30+
]
31+
},
32+
{
33+
"cell_type": "markdown",
34+
"id": "eef29346-dc44-4c53-9fe8-1bf5ce444010",
35+
"metadata": {},
36+
"source": [
37+
"The differential equation to solve is:\n",
38+
"$$\n",
39+
" \\frac{d f}{dt} = R f(t)\\left(1 - f(t)\\right)\n",
40+
"$$\n",
41+
"with initial condition $f(0) = 0.5$.\n",
42+
"\n",
43+
"This equation can be used to model population growth."
44+
]
45+
},
46+
{
47+
"cell_type": "markdown",
48+
"id": "7c67e465-a82d-42b4-a832-bd80abbd5b94",
49+
"metadata": {},
50+
"source": [
51+
"# Loss function"
52+
]
53+
},
54+
{
55+
"cell_type": "markdown",
56+
"id": "f8454b80-e172-4bf8-aaeb-bd02b5957bbd",
57+
"metadata": {},
58+
"source": [
59+
"The loss function to train the neural network will be the sum of two terms, the first evaluates the differential equation in $M$ time points, the second enforces the initial conditions.\n",
60+
"$$\n",
61+
" \\begin{array}{lcl}\n",
62+
" L_{\\mathrm{DE}} & = & \\frac{1}{M} \\sum_{j=1}^{M} \\left( \\frac{df_{\\mathrm{NN}}}{dt}(t_j) - R f_{\\mathrm{NN}}(t_j) \\left( 1 - f_{\\mathrm{NN}}(t_j) \\right) \\right)^2 \\\\\n",
63+
" L_{\\mathrm{BC}} & = & \\left( f_{\\mathrm{NN}}(0) - 0.5 \\right)^2 \\\\\n",
64+
" L & = & L_{\\mathrm{DE}} + L_{\\mathrm{BC}}\n",
65+
" \\end{array}\n",
66+
"$$"
67+
]
68+
},
69+
{
70+
"cell_type": "code",
71+
"execution_count": 35,
72+
"id": "15acb49f-c285-4b57-be15-98d0f97556ac",
73+
"metadata": {},
74+
"outputs": [],
75+
"source": [
76+
"def tuple_to_dict_parameters(model, params):\n",
77+
" keys = list(dict(model.named_parameters()).keys())\n",
78+
" values = list(params)\n",
79+
" return OrderedDict(({k:v for k, v in zip(keys, values)}))"
80+
]
81+
},
82+
{
83+
"cell_type": "code",
84+
"execution_count": 36,
85+
"id": "40ec2dc8-457e-4eff-92f1-8dd5b726f21f",
86+
"metadata": {},
87+
"outputs": [],
88+
"source": [
89+
"def make_forward_fn(model, derivative_order=1):\n",
90+
"\n",
91+
" def f(x: torch.Tensor, params: dict[str, torch.nn.Parameter] | tuple[torch.nn.Parameter, ...]) -> torch.Tensor:\n",
92+
" if isinstance(params, tuple):\n",
93+
" params_dict = tuple_to_dict_parameters(model, params)\n",
94+
" else:\n",
95+
" params_dict = params\n",
96+
" return functional_call(model, params_dict, (x, ))\n",
97+
"\n",
98+
" fns = [f]\n",
99+
" dfunc = f\n",
100+
" for _ in range(derivative_order):\n",
101+
" dfunc = grad(dfunc)\n",
102+
" fns.append(vmap(dfunc, in_dims=(0, None)))\n",
103+
" return fns"
104+
]
105+
},
106+
{
107+
"cell_type": "markdown",
108+
"id": "68efecd1-85c3-4143-b4db-f89f71992ec3",
109+
"metadata": {},
110+
"source": [
111+
"# Neural network"
112+
]
113+
},
114+
{
115+
"cell_type": "code",
116+
"execution_count": 14,
117+
"id": "e234203f-9d86-4462-8670-aa7ba4d6953e",
118+
"metadata": {},
119+
"outputs": [],
120+
"source": [
121+
"class PINN(torch.nn.Module):\n",
122+
"\n",
123+
" def __init__(self, nr_inputs, nr_layers, nr_neurons, activation=torch.nn.Tanh()):\n",
124+
" super().__init__()\n",
125+
" self.num_inputs = nr_inputs\n",
126+
" self.num_layers = nr_layers\n",
127+
" self.num_neurons = nr_neurons\n",
128+
" layers = []\n",
129+
" layers.append(torch.nn.Linear(self.num_inputs, self.num_neurons))\n",
130+
" for _ in range(self.num_layers):\n",
131+
" layers.append(torch.nn.Linear(self.num_neurons, self.num_neurons))\n",
132+
" layers.append(activation)\n",
133+
" layers.append(torch.nn.Linear(self.num_neurons, 1))\n",
134+
" self.network = torch.nn.Sequential(*layers)\n",
135+
"\n",
136+
" def forward(self, x):\n",
137+
" return self.network(x.reshape(-1, 1)).squeeze()"
138+
]
139+
},
140+
{
141+
"cell_type": "code",
142+
"execution_count": 15,
143+
"id": "f879a606-3b85-488b-a2f6-de3f45774f07",
144+
"metadata": {},
145+
"outputs": [],
146+
"source": [
147+
"model = PINN(nr_inputs=1, nr_neurons=20, nr_layers=3)"
148+
]
149+
},
150+
{
151+
"cell_type": "code",
152+
"execution_count": 38,
153+
"id": "d0897090-7f4c-48c5-9e98-1a20ed846d8b",
154+
"metadata": {},
155+
"outputs": [],
156+
"source": [
157+
"f, dfdx = make_forward_fn(model, derivative_order=1)"
158+
]
159+
},
160+
{
161+
"cell_type": "code",
162+
"execution_count": 39,
163+
"id": "08440717-e30d-4f97-bf77-39b03130acb5",
164+
"metadata": {},
165+
"outputs": [],
166+
"source": [
167+
"R, x_boundary, f_boundary = 1.0, 0.0, 0.5"
168+
]
169+
},
170+
{
171+
"cell_type": "code",
172+
"execution_count": 40,
173+
"id": "1cfe8788-5a39-4134-920f-fdc3d7b37d2b",
174+
"metadata": {},
175+
"outputs": [],
176+
"source": [
177+
"def loss_function(params, x):\n",
178+
" f_value = f(x, params)\n",
179+
" interior = dfdx(x, params) - R*f_value*(1.0 - f_value)\n",
180+
" boundaries = f(torch.tensor([x_boundary]), params) - torch.tensor([f_boundary])\n",
181+
" loss = torch.nn.MSELoss()\n",
182+
" return (loss(interior, torch.zeros_like(interior)) +\n",
183+
" loss(boundaries, torch.zeros_like(boundaries)))"
184+
]
185+
},
186+
{
187+
"cell_type": "code",
188+
"execution_count": 41,
189+
"id": "2c661202-c375-4c75-9ae8-1018a90ced63",
190+
"metadata": {},
191+
"outputs": [],
192+
"source": [
193+
"batch_size = 30\n",
194+
"nr_iters = 100\n",
195+
"learning_rate = 1.0e-1\n",
196+
"domain = (-5.0, 5.0)"
197+
]
198+
},
199+
{
200+
"cell_type": "code",
201+
"execution_count": 42,
202+
"id": "1f62a320-d18d-4a7c-9673-cf0db5f4fa9c",
203+
"metadata": {},
204+
"outputs": [],
205+
"source": [
206+
"optimizer = torchopt.FuncOptimizer(torchopt.adam(lr=learning_rate))"
207+
]
208+
},
209+
{
210+
"cell_type": "code",
211+
"execution_count": 44,
212+
"id": "c1846827-0111-4533-91f4-317245f10cf3",
213+
"metadata": {},
214+
"outputs": [],
215+
"source": [
216+
"params = tuple(model.parameters())"
217+
]
218+
},
219+
{
220+
"cell_type": "code",
221+
"execution_count": 46,
222+
"id": "73d5134a-20ad-4247-b6a1-5180b3ee785b",
223+
"metadata": {},
224+
"outputs": [
225+
{
226+
"name": "stdout",
227+
"output_type": "stream",
228+
"text": [
229+
"iteration 1 with loss 0.11463413387537003\n",
230+
"iteration 2 with loss 27.16790771484375\n",
231+
"iteration 3 with loss 0.6420497894287109\n",
232+
"iteration 4 with loss 2.8561654090881348\n",
233+
"iteration 5 with loss 3.5104665756225586\n",
234+
"iteration 6 with loss 2.035282611846924\n",
235+
"iteration 7 with loss 0.14833369851112366\n",
236+
"iteration 8 with loss 0.17722176015377045\n",
237+
"iteration 9 with loss 0.29466158151626587\n",
238+
"iteration 10 with loss 0.13089382648468018\n",
239+
"iteration 11 with loss 0.3749958872795105\n",
240+
"iteration 12 with loss 0.1261375993490219\n",
241+
"iteration 13 with loss 0.17460887134075165\n",
242+
"iteration 14 with loss 0.5559177994728088\n",
243+
"iteration 15 with loss 0.41323214769363403\n",
244+
"iteration 16 with loss 0.40801599621772766\n",
245+
"iteration 17 with loss 0.33771443367004395\n",
246+
"iteration 18 with loss 0.20494796335697174\n",
247+
"iteration 19 with loss 0.10330705344676971\n",
248+
"iteration 20 with loss 0.10414065420627594\n",
249+
"iteration 21 with loss 0.18002170324325562\n",
250+
"iteration 22 with loss 0.19086192548274994\n",
251+
"iteration 23 with loss 0.1662447154521942\n",
252+
"iteration 24 with loss 0.12068456411361694\n",
253+
"iteration 25 with loss 0.20130692422389984\n",
254+
"iteration 26 with loss 0.09801061451435089\n",
255+
"iteration 27 with loss 0.13131101429462433\n",
256+
"iteration 28 with loss 0.09928253293037415\n",
257+
"iteration 29 with loss 0.08115098625421524\n",
258+
"iteration 30 with loss 0.08818767964839935\n",
259+
"iteration 31 with loss 0.08493972569704056\n",
260+
"iteration 32 with loss 0.07110773026943207\n",
261+
"iteration 33 with loss 0.05105520784854889\n",
262+
"iteration 34 with loss 0.05446767061948776\n",
263+
"iteration 35 with loss 0.062104418873786926\n",
264+
"iteration 36 with loss 0.05202716961503029\n",
265+
"iteration 37 with loss 0.055305611342191696\n",
266+
"iteration 38 with loss 0.044861141592264175\n",
267+
"iteration 39 with loss 0.046682748943567276\n",
268+
"iteration 40 with loss 0.042420465499162674\n",
269+
"iteration 41 with loss 0.041048914194107056\n",
270+
"iteration 42 with loss 0.03341205418109894\n",
271+
"iteration 43 with loss 0.0331701822578907\n",
272+
"iteration 44 with loss 0.03715697303414345\n",
273+
"iteration 45 with loss 0.04585421085357666\n",
274+
"iteration 46 with loss 0.03548423945903778\n",
275+
"iteration 47 with loss 0.03360460698604584\n",
276+
"iteration 48 with loss 0.03336722403764725\n",
277+
"iteration 49 with loss 0.035427533090114594\n",
278+
"iteration 50 with loss 0.040354594588279724\n",
279+
"iteration 51 with loss 0.030506618320941925\n",
280+
"iteration 52 with loss 0.022171396762132645\n",
281+
"iteration 53 with loss 0.026050299406051636\n",
282+
"iteration 54 with loss 0.03255322948098183\n",
283+
"iteration 55 with loss 0.026948392391204834\n",
284+
"iteration 56 with loss 0.023532869294285774\n",
285+
"iteration 57 with loss 0.02279188483953476\n",
286+
"iteration 58 with loss 0.024768201634287834\n",
287+
"iteration 59 with loss 0.02384166046977043\n",
288+
"iteration 60 with loss 0.025534283369779587\n",
289+
"iteration 61 with loss 0.023878537118434906\n",
290+
"iteration 62 with loss 0.01744457334280014\n",
291+
"iteration 63 with loss 0.01068549882620573\n",
292+
"iteration 64 with loss 0.018854187801480293\n",
293+
"iteration 65 with loss 0.01397640723735094\n",
294+
"iteration 66 with loss 0.014322957023978233\n",
295+
"iteration 67 with loss 0.00784828420728445\n",
296+
"iteration 68 with loss 0.015125863254070282\n",
297+
"iteration 69 with loss 0.011706520803272724\n",
298+
"iteration 70 with loss 0.010356402024626732\n",
299+
"iteration 71 with loss 0.010507567785680294\n",
300+
"iteration 72 with loss 0.0085351737216115\n",
301+
"iteration 73 with loss 0.008149412460625172\n",
302+
"iteration 74 with loss 0.007394777145236731\n",
303+
"iteration 75 with loss 0.0041252137161791325\n",
304+
"iteration 76 with loss 0.007325059734284878\n",
305+
"iteration 77 with loss 0.004225192591547966\n",
306+
"iteration 78 with loss 0.00552705954760313\n",
307+
"iteration 79 with loss 0.005880812648683786\n",
308+
"iteration 80 with loss 0.006272132974117994\n",
309+
"iteration 81 with loss 0.003222860163077712\n",
310+
"iteration 82 with loss 0.005884816870093346\n",
311+
"iteration 83 with loss 0.0035097438376396894\n",
312+
"iteration 84 with loss 0.005251692607998848\n",
313+
"iteration 85 with loss 0.005707685835659504\n",
314+
"iteration 86 with loss 0.003127788659185171\n",
315+
"iteration 87 with loss 0.0027884359005838633\n",
316+
"iteration 88 with loss 0.004568313714116812\n",
317+
"iteration 89 with loss 0.0025422151666134596\n",
318+
"iteration 90 with loss 0.0016729753697291017\n",
319+
"iteration 91 with loss 0.0035289982333779335\n",
320+
"iteration 92 with loss 0.0016110112192109227\n",
321+
"iteration 93 with loss 0.0029011431615799665\n",
322+
"iteration 94 with loss 0.0015002115396782756\n",
323+
"iteration 95 with loss 0.0027219068724662066\n",
324+
"iteration 96 with loss 0.004007971379905939\n",
325+
"iteration 97 with loss 0.0012513422407209873\n",
326+
"iteration 98 with loss 0.00218919082544744\n",
327+
"iteration 99 with loss 0.002706077415496111\n",
328+
"iteration 100 with loss 0.002446003956720233\n"
329+
]
330+
}
331+
],
332+
"source": [
333+
"for iteration in range(nr_iters):\n",
334+
" x = torch.FloatTensor(batch_size).uniform_(*domain)\n",
335+
" loss = loss_function(params, x)\n",
336+
" params = optimizer.step(loss, params)\n",
337+
" print(f'iteration {iteration + 1} with loss {float(loss)}')"
338+
]
339+
},
340+
{
341+
"cell_type": "code",
342+
"execution_count": null,
343+
"id": "3884aaa0-b27b-4773-97a3-499eb2f7b929",
344+
"metadata": {},
345+
"outputs": [],
346+
"source": []
347+
}
348+
],
349+
"metadata": {
350+
"kernelspec": {
351+
"display_name": "Python 3 (ipykernel)",
352+
"language": "python",
353+
"name": "python3"
354+
},
355+
"language_info": {
356+
"codemirror_mode": {
357+
"name": "ipython",
358+
"version": 3
359+
},
360+
"file_extension": ".py",
361+
"mimetype": "text/x-python",
362+
"name": "python",
363+
"nbconvert_exporter": "python",
364+
"pygments_lexer": "ipython3",
365+
"version": "3.12.1"
366+
}
367+
},
368+
"nbformat": 4,
369+
"nbformat_minor": 5
370+
}

0 commit comments

Comments
 (0)