Skip to content

Commit eaeab58

Browse files
committed
black code formatter + .rst docs
1 parent 662f297 commit eaeab58

File tree

4 files changed

+45
-35
lines changed

4 files changed

+45
-35
lines changed
Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,9 @@
1+
NeuralTangentKernelWeighting
2+
===============
3+
.. currentmodule:: pina.loss
4+
5+
.. automodule:: pina.loss
6+
7+
.. autoclass:: WeightingInterface
8+
:members:
9+
:show-inheritance:

pina/loss/__init__.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -8,12 +8,12 @@
88
"PowerLoss",
99
"WeightingInterface",
1010
"ScalarWeighting",
11-
"Neural_Tangent_Kernel"
11+
"NeuralTangetKernelWeighting",
1212
]
1313

1414
from .loss_interface import LossInterface
1515
from .power_loss import PowerLoss
1616
from .lp_loss import LpLoss
1717
from .weighting_interface import WeightingInterface
1818
from .scalar_weighting import ScalarWeighting
19-
from .ntk_weighting import Neural_Tangent_Kernel
19+
from .ntk_weighting import NeuralTangetKernelWeighting

pina/loss/ntk_weighting.py

Lines changed: 23 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -1,61 +1,59 @@
1-
""" Module for Neural Tangent Kernel Class """
1+
"""Module for Neural Tangent Kernel Class"""
22

3+
import torch
34
from torch.nn import Module
4-
from torch import norm,cat
55
from .weighting_interface import WeightingInterface
66
from ..operator import grad
77
from ..utils import check_consistency
88

9-
class Neural_Tangent_Kernel(WeightingInterface):
9+
10+
class NeuralTangetKernelWeighting(WeightingInterface):
1011
"""
1112
TODO
1213
"""
14+
1315
def __init__(self, model, alpha=0.5):
1416
super().__init__()
1517
check_consistency(alpha, float)
16-
17-
if alpha<0 or alpha > 1:
18+
check_consistency(model, Module)
19+
if alpha < 0 or alpha > 1:
1820
raise ValueError("alpha should be a value between 0 and 1")
19-
if not isinstance(model, Module):
20-
raise TypeError("Please pass a valid torch.nn.Module model")
2121
self.alpha = alpha
2222
self.model = model
2323
self.weights = {}
2424
self.default_value_weights = 1
2525

2626
def aggregate(self, losses):
2727
"""
28-
Weights the losses according to the Neural Tangent Kernel
28+
Weights the losses according to the Neural Tangent Kernel
2929
3030
:param dict(torch.Tensor) input: The dictionary of losses.
3131
:param alpha(float) input: The parameter alpha that regulates the moving average
32-
between old and new weights.
32+
between old and new weights.
3333
:return: The losses aggregation. It should be a scalar Tensor.
3434
:rtype: torch.Tensor
3535
36-
Reference: TODO
37-
Wang, S., Sankaran, S., Wang, H., & Perdikaris, P. (2023).
38-
An expert's guide to training physics-informed neural networks.
36+
Reference:
37+
Wang, S., Sankaran, S., Wang, H., & Perdikaris, P. (2023).
38+
An expert's guide to training physics-informed neural networks.
3939
arXiv preprint arXiv:2308.08468.
40-
"""
40+
"""
4141
losses_norm = {}
4242
for condition in losses:
4343
losses[condition].backward(retain_graph=True)
44-
grads=[]
44+
grads = []
4545
for param in self.model.parameters():
4646
grads.append(param.grad.view(-1))
47-
grads = cat(grads)
48-
49-
print(grads)
50-
losses_norm[condition]=norm(grads)
47+
grads = torch.cat(grads)
48+
losses_norm[condition] = torch.norm(grads)
5149
self.weights = {
52-
condition:
53-
self.alpha * self.weights.get(
54-
condition, self.default_value_weights) +
55-
(1- self.alpha)*losses_norm[condition]/sum(losses_norm.values())
50+
condition: self.alpha
51+
* self.weights.get(condition, self.default_value_weights)
52+
+ (1 - self.alpha)
53+
* losses_norm[condition]
54+
/ torch.sum(losses_norm.values())
5655
for condition in losses
5756
}
5857
return sum(
59-
self.weights.get(condition, self.default_value_weights) * loss for
60-
condition, loss in losses.items()
61-
)
58+
self.weights[condition] * loss for condition, loss in losses.items()
59+
)

tests/test_weighting/test_ntk_weighting.py

Lines changed: 11 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -3,28 +3,31 @@
33
from pina.solver import PINN
44
from pina.model import FeedForward
55
from pina.problem.zoo import Poisson2DSquareProblem
6-
from pina.loss import Neural_Tangent_Kernel
6+
from pina.loss import NeuralTangetKernelWeighting
77

88
problem = Poisson2DSquareProblem()
99
model = FeedForward(len(problem.input_variables), len(problem.output_variables))
1010
condition_names = problem.conditions.keys()
11-
#print(problem.conditions.keys())
11+
# print(problem.conditions.keys())
12+
13+
14+
def test_constructor(model, alpha):
15+
NeuralTangetKernelWeighting(model=model, alpha=alpha)
1216

13-
def test_constructor(model,alpha):
14-
Neural_Tangent_Kernel(model=model,alpha=alpha)
1517

1618
def test_wrong_constructor1(alpha):
1719
with pytest.raises(ValueError):
18-
Neural_Tangent_Kernel(model,alpha)
20+
NeuralTangetKernelWeighting(model, alpha)
21+
1922

2023
def test_wrong_constructor2(model):
2124
with pytest.raises(TypeError):
22-
Neural_Tangent_Kernel(model)
25+
NeuralTangetKernelWeighting(model)
26+
2327

2428
def test_train_aggregation(model):
25-
weighting = Neural_Tangent_Kernel(model=model,alpha = 0.5)
29+
weighting = NeuralTangetKernelWeighting(model=model, alpha=0.5)
2630
problem.discretise_domain(50)
2731
solver = PINN(problem=problem, model=model, weighting=weighting)
2832
trainer = Trainer(solver=solver, max_epochs=5, accelerator="cpu")
2933
trainer.train()
30-

0 commit comments

Comments
 (0)