Skip to content

Commit d785a01

Browse files
PyTorch tutorial (#417)
* Added tutorial for PyTorch tensor datatype * Added dill to environment * Removed dill dependency * Added dill again * Removed unnecessary modules in PyTorch environment
1 parent 4357815 commit d785a01

File tree

10 files changed

+165
-17
lines changed

10 files changed

+165
-17
lines changed

.github/workflows/ci_pipeline.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -43,7 +43,7 @@ jobs:
4343
strategy:
4444
fail-fast: false
4545
matrix:
46-
env: ['base', 'fenics', 'mpi4py', 'petsc']
46+
env: ['base', 'fenics', 'mpi4py', 'petsc', 'pytorch']
4747
python: ['3.8', '3.9', '3.10', '3.11', '3.12']
4848

4949
defaults:
Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,7 @@
1+
Full code: `pySDC/tutorial/step_7/D_pySDC_with_PyTorch.py <https://github.com/Parallel-in-Time/pySDC/blob/master/pySDC/tutorial/step_7/D_pySDC_with_PyTorch.py>`_
2+
3+
.. literalinclude:: ../../../pySDC/tutorial/step_7/D_pySDC_with_PyTorch.py
4+
5+
Results:
6+
7+
.. literalinclude:: ../../../data/step_7_D_out.txt

etc/environment-pytorch.yml

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,11 @@
1+
name: pySDC
2+
channels:
3+
- conda-forge
4+
- defaults
5+
dependencies:
6+
- numpy
7+
- scipy>=0.17.1
8+
- sympy>=1.0
9+
- pytorch
10+
- matplotlib>=3.0
11+
- dill

pySDC/playgrounds/ML_initial_guess/ml_heat.py

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -79,6 +79,7 @@ class HeatEquationModel(nn.Module):
7979
def __init__(self, problem, hidden_size=64):
8080
self.input_size = problem.nvars * 3
8181
self.output_size = problem.nvars
82+
self.problem = problem
8283

8384
super().__init__()
8485

@@ -93,8 +94,8 @@ def __init__(self, problem, hidden_size=64):
9394
def forward(self, x, t, dt):
9495
# prepare individual tensors
9596
x = x.float()
96-
_t = torch.ones_like(x) * t
97-
_dt = torch.ones_like(x) * dt
97+
_t = torch.ones(x.shape) * dt
98+
_dt = torch.ones(x.shape) * dt
9899

99100
# Concatenate t and dt with the input x
100101
_x = torch.cat((x, _t, _dt), dim=0)
@@ -104,6 +105,11 @@ def forward(self, x, t, dt):
104105
_x = self.fc2(_x)
105106
return _x
106107

108+
def __call__(self, *args, **kwargs):
109+
me = self.problem.u_init
110+
me[:] = super().__call__(*args, **kwargs)
111+
return me
112+
107113

108114
def train_at_collocation_nodes():
109115
"""

pySDC/playgrounds/ML_initial_guess/tensor.py

Lines changed: 26 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -4,9 +4,6 @@
44
from pySDC.core.Errors import DataError
55

66
try:
7-
# TODO : mpi4py cannot be imported before dolfin when using fenics mesh
8-
# see https://github.com/Parallel-in-Time/pySDC/pull/285#discussion_r1145850590
9-
# This should be dealt with at some point
107
from mpi4py import MPI
118
except ImportError:
129
MPI = None
@@ -26,7 +23,7 @@ class Tensor(torch.Tensor):
2623
@staticmethod
2724
def __new__(cls, init, val=0.0, *args, **kwargs):
2825
"""
29-
Instantiates new datatype. This ensures that even when manipulating data, the result is still a mesh.
26+
Instantiates new datatype. This ensures that even when manipulating data, the result is still a tensor.
3027
3128
Args:
3229
init: either another mesh or a tuple containing the dimensions, the communicator and the dtype
@@ -52,21 +49,38 @@ def __new__(cls, init, val=0.0, *args, **kwargs):
5249
raise NotImplementedError(type(init))
5350
return obj
5451

52+
def __add__(self, *args, **kwargs):
53+
res = super().__add__(*args, **kwargs)
54+
res._comm = self.comm
55+
return res
56+
57+
def __sub__(self, *args, **kwargs):
58+
res = super().__sub__(*args, **kwargs)
59+
res._comm = self.comm
60+
return res
61+
62+
def __lmul__(self, *args, **kwargs):
63+
res = super().__lmul__(*args, **kwargs)
64+
res._comm = self.comm
65+
return res
66+
67+
def __rmul__(self, *args, **kwargs):
68+
res = super().__rmul__(*args, **kwargs)
69+
res._comm = self.comm
70+
return res
71+
72+
def __mul__(self, *args, **kwargs):
73+
res = super().__mul__(*args, **kwargs)
74+
res._comm = self.comm
75+
return res
76+
5577
@property
5678
def comm(self):
5779
"""
5880
Getter for the communicator
5981
"""
6082
return self._comm
6183

62-
def __array_finalize__(self, obj):
63-
"""
64-
Finalizing the datatype. Without this, new datatypes do not 'inherit' the communicator.
65-
"""
66-
if obj is None:
67-
return
68-
self._comm = getattr(obj, '_comm', None)
69-
7084
def __abs__(self):
7185
"""
7286
Overloading the abs operator

pySDC/projects/TOMS/tests/test_AllenCahn_contracting_circle.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,4 @@
11
import pytest
2-
import dill
3-
import os
42

53
results = {}
64

@@ -21,6 +19,8 @@ def test_AllenCahn_contracting_circle(variant, inexact):
2119
@pytest.mark.base
2220
@pytest.mark.order(2)
2321
def test_show_results():
22+
import dill
23+
import os
2424
from pySDC.projects.TOMS.AllenCahn_contracting_circle import show_results
2525

2626
# dump result

pySDC/tests/test_tutorials/test_step_7.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -120,3 +120,10 @@ def test_C_2x2():
120120
for line in p.stderr:
121121
print(line)
122122
assert p.returncode == 0, 'ERROR: did not get return code 0, got %s with %2i processes' % (p.returncode, num_procs)
123+
124+
125+
@pytest.mark.pytorch
126+
def test_D():
127+
from pySDC.tutorial.step_7.D_pySDC_with_PyTorch import train_at_collocation_nodes
128+
129+
train_at_collocation_nodes()
Lines changed: 89 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,89 @@
1+
import numpy as np
2+
import torch
3+
import torch.nn as nn
4+
import torch.optim as optim
5+
from pySDC.playgrounds.ML_initial_guess.ml_heat import HeatEquationModel, Train_pySDC
6+
from pySDC.playgrounds.ML_initial_guess.heat import Heat1DFDTensor
7+
8+
9+
def train_at_collocation_nodes():
10+
"""
11+
For the first proof of concept, we want to train the model specifically to the collocation nodes we use in SDC.
12+
If successful, the initial guess would already be the exact solution and we would need no SDC iterations.
13+
14+
What we find is that we can train the network to predict the solution to one very specific problem rather well.
15+
See the error during training for what happens when we ask the network to solve for exactly what it just trained.
16+
However, if we train for something else, i.e. solving to a different step size in this case, we can only use the
17+
model to predict the solution of what it's been trained for last and it loses the ability to solve for previously
18+
learned things. This is solely because we chose an overly simple model that is unsuitable to the task at hand and
19+
is likely easily solved with a bit of patience. This is just a demonstration of the interface between pySDC and
20+
PyTorch. If you want to do a project with this, feel free to take this as a starting point and do things that
21+
actually do something!
22+
23+
The output shows the training loss during training and, after each of three training sessions is complete, the error
24+
of the prediction with the current state of the network. To demonstrate the forgetfulness, we finally print the
25+
error of all learned predictions after training is complete.
26+
"""
27+
out = ''
28+
errors_mid_training = []
29+
errors_post_training = []
30+
31+
# instantiate the pySDC problem and a model for PyTorch
32+
problem = Heat1DFDTensor()
33+
model = HeatEquationModel(problem)
34+
35+
# setup neural network
36+
lr = 0.001
37+
num_epochs = 250
38+
criterion = nn.MSELoss()
39+
optimizer = optim.Adam(model.parameters(), lr=lr)
40+
41+
# setup initial conditions
42+
t = 0
43+
initial_condition = problem.u_exact(t)
44+
45+
# train the model to predict the solution at certain collocation nodes
46+
collocation_nodes = np.array([0.15505102572168285, 0.6449489742783183, 1]) * 1e-2
47+
for dt in collocation_nodes:
48+
49+
# get target condition from implicit Euler step
50+
target_condition = problem.solve_system(initial_condition, dt, initial_condition, t)
51+
52+
# do the training
53+
for epoch in range(num_epochs):
54+
predicted_state = model(initial_condition, t, dt)
55+
loss = criterion(predicted_state.float(), target_condition.float())
56+
57+
optimizer.zero_grad()
58+
loss.backward()
59+
optimizer.step()
60+
61+
if (epoch + 1) % 50 == 0:
62+
out += f'Training for {dt=:.2e}: Epoch [{epoch+1:4d}/{num_epochs:4d}], Loss: {loss.item():.4e}\n'
63+
64+
# evaluate model to compute error
65+
model_prediction = model(initial_condition, t, dt)
66+
errors_mid_training += [abs(target_condition - model_prediction)]
67+
out += f'Error of prediction at {dt:.2e} during training: {abs(target_condition-model_prediction):.2e}\n'
68+
69+
# compare model and problem
70+
for dt in collocation_nodes:
71+
target_condition = problem.solve_system(initial_condition, dt, initial_condition, t)
72+
model_prediction = model(initial_condition, t, dt)
73+
errors_post_training += [abs(target_condition - model_prediction)]
74+
out += f'Error of prediction at {dt:.2e} after training: {abs(target_condition-model_prediction):.2e}\n'
75+
76+
print(out)
77+
with open('data/step_7_D_out.txt', 'w') as file:
78+
file.write(out)
79+
80+
# test that the training went as expected
81+
assert np.greater([1e-2, 1e-4, 1e-5], errors_mid_training).all(), 'Errors during training are larger than expected'
82+
assert np.greater([1e0, 1e0, 1e-5], errors_post_training).all(), 'Errors after training are larger than expected'
83+
84+
# save the model to use it throughout pySDC
85+
torch.save(model.state_dict(), 'data/heat_equation_model.pth')
86+
87+
88+
if __name__ == '__main__':
89+
train_at_collocation_nodes()

pySDC/tutorial/step_7/README.rst

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -51,3 +51,16 @@ Important things to note:
5151
- Below, we run the code 3 times: with 1 and 2 processors in space as well as 4 processors (2 in time and 2 in space). Do not expect scaling due to the CI environment.
5252

5353
.. include:: doc_step_7_C.rst
54+
55+
56+
Part D: pySDC and PyTorch
57+
-------------------------
58+
59+
PyTorch is a library for machine learning. The data structure is called tensor and allows to run on CPUs as well as GPUs in addition to access to various machine learning methods.
60+
Since the potential for use in pySDC is very large, we have started on a datatype that allows to use PyTorch tensors throughout pySDC.
61+
62+
This example trains a network to predict the results of implicit Euler solves for the heat equation. It is too simple to do anything useful, but demonstrates how to use tensors in pySDC and then apply the enormous PyTorch infrastructure.
63+
This is work in progress in very early stages! The tensor datatype is the simplest possible implementation, rather than an efficient one.
64+
If you want to work on this, your input is appreciated!
65+
66+
.. include:: doc_step_7_D.rst

pyproject.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -59,6 +59,7 @@ markers = [
5959
'cupy: tests for cupy on GPUs',
6060
'libpressio: tests using the libpressio library',
6161
'monodomain: tests the monodomain project, which requires previous compilation of c++ code',
62+
'pytorch: tests for PyTorch related things in pySDC'
6263
]
6364
timeout = 300
6465

0 commit comments

Comments
 (0)