Skip to content

Commit 716d43f

Browse files
AleDinvedario-coscia
authored andcommitted
Neural Tangent Kernel integration + typo fix (#505)
* NTK weighting + typo fixing * black code formatter + .rst docs --------- Co-authored-by: Dario Coscia <[email protected]>
1 parent 2b09cb9 commit 716d43f

File tree

4 files changed

+147
-0
lines changed

4 files changed

+147
-0
lines changed
Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,9 @@
1+
NeuralTangentKernelWeighting
2+
=============================
3+
.. currentmodule:: pina.loss.ntk_weighting
4+
5+
.. automodule:: pina.loss.ntk_weighting
6+
7+
.. autoclass:: NeuralTangentKernelWeighting
8+
:members:
9+
:show-inheritance:

pina/loss/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,10 +6,12 @@
66
"PowerLoss",
77
"WeightingInterface",
88
"ScalarWeighting",
9+
"NeuralTangentKernelWeighting",
910
]
1011

1112
from .loss_interface import LossInterface
1213
from .power_loss import PowerLoss
1314
from .lp_loss import LpLoss
1415
from .weighting_interface import WeightingInterface
1516
from .scalar_weighting import ScalarWeighting
17+
from .ntk_weighting import NeuralTangentKernelWeighting

pina/loss/ntk_weighting.py

Lines changed: 71 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,71 @@
1+
"""Module for Neural Tangent Kernel Class"""
2+
3+
import torch
4+
from torch.nn import Module
5+
from .weighting_interface import WeightingInterface
6+
from ..utils import check_consistency
7+
8+
9+
class NeuralTangentKernelWeighting(WeightingInterface):
10+
"""
11+
A neural tangent kernel scheme for weighting different losses to
12+
boost the convergence.
13+
14+
.. seealso::
15+
16+
**Original reference**: Wang, Sifan, Xinling Yu, and
17+
Paris Perdikaris. *When and why PINNs fail to train:
18+
A neural tangent kernel perspective*. Journal of
19+
Computational Physics 449 (2022): 110768.
20+
DOI: `10.1016/j.jcp.2021.110768 <https://doi.org/10.1016/j.jcp.2021.110768>`_.
21+
22+
23+
24+
"""
25+
26+
def __init__(self, model, alpha=0.5):
27+
"""
28+
Initialization of the :class:`NeuralTangentKernelWeighting` class.
29+
30+
:param torch.nn.Module model: The neural network model.
31+
:param float alpha: The alpha parameter.
32+
"""
33+
34+
super().__init__()
35+
check_consistency(alpha, float)
36+
check_consistency(model, Module)
37+
if alpha < 0 or alpha > 1:
38+
raise ValueError("alpha should be a value between 0 and 1")
39+
self.alpha = alpha
40+
self.model = model
41+
self.weights = {}
42+
self.default_value_weights = 1
43+
44+
def aggregate(self, losses):
45+
"""
46+
Weights the losses according to the Neural Tangent Kernel
47+
algorithm.
48+
49+
:param dict(torch.Tensor) input: The dictionary of losses.
50+
:return: The losses aggregation. It should be a scalar Tensor.
51+
:rtype: torch.Tensor
52+
"""
53+
losses_norm = {}
54+
for condition in losses:
55+
losses[condition].backward(retain_graph=True)
56+
grads = []
57+
for param in self.model.parameters():
58+
grads.append(param.grad.view(-1))
59+
grads = torch.cat(grads)
60+
losses_norm[condition] = torch.norm(grads)
61+
self.weights = {
62+
condition: self.alpha
63+
* self.weights.get(condition, self.default_value_weights)
64+
+ (1 - self.alpha)
65+
* losses_norm[condition]
66+
/ sum(losses_norm.values())
67+
for condition in losses
68+
}
69+
return sum(
70+
self.weights[condition] * loss for condition, loss in losses.items()
71+
)
Lines changed: 65 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,65 @@
1+
import pytest
2+
from pina import Trainer
3+
from pina.solver import PINN
4+
from pina.model import FeedForward
5+
from pina.problem.zoo import Poisson2DSquareProblem
6+
from pina.loss import NeuralTangentKernelWeighting
7+
8+
problem = Poisson2DSquareProblem()
9+
condition_names = problem.conditions.keys()
10+
11+
12+
@pytest.mark.parametrize(
13+
"model,alpha",
14+
[
15+
(
16+
FeedForward(
17+
len(problem.input_variables), len(problem.output_variables)
18+
),
19+
0.5,
20+
)
21+
],
22+
)
23+
def test_constructor(model, alpha):
24+
NeuralTangentKernelWeighting(model=model, alpha=alpha)
25+
26+
27+
@pytest.mark.parametrize("model", [0.5])
28+
def test_wrong_constructor1(model):
29+
with pytest.raises(ValueError):
30+
NeuralTangentKernelWeighting(model)
31+
32+
33+
@pytest.mark.parametrize(
34+
"model,alpha",
35+
[
36+
(
37+
FeedForward(
38+
len(problem.input_variables), len(problem.output_variables)
39+
),
40+
1.2,
41+
)
42+
],
43+
)
44+
def test_wrong_constructor2(model, alpha):
45+
with pytest.raises(ValueError):
46+
NeuralTangentKernelWeighting(model, alpha)
47+
48+
49+
@pytest.mark.parametrize(
50+
"model,alpha",
51+
[
52+
(
53+
FeedForward(
54+
len(problem.input_variables), len(problem.output_variables)
55+
),
56+
0.5,
57+
)
58+
],
59+
)
60+
def test_train_aggregation(model, alpha):
61+
weighting = NeuralTangentKernelWeighting(model=model, alpha=alpha)
62+
problem.discretise_domain(50)
63+
solver = PINN(problem=problem, model=model, weighting=weighting)
64+
trainer = Trainer(solver=solver, max_epochs=5, accelerator="cpu")
65+
trainer.train()

0 commit comments

Comments
 (0)