Skip to content

Commit bacd7e2

Browse files
add mutual solver-weighting link
1 parent 973d0c0 commit bacd7e2

File tree

6 files changed

+61
-75
lines changed

6 files changed

+61
-75
lines changed

pina/loss/ntk_weighting.py

Lines changed: 20 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,6 @@
11
"""Module for Neural Tangent Kernel Class"""
22

33
import torch
4-
from torch.nn import Module
54
from .weighting_interface import WeightingInterface
65
from ..utils import check_consistency
76

@@ -21,43 +20,45 @@ class NeuralTangentKernelWeighting(WeightingInterface):
2120
2221
"""
2322

24-
def __init__(self, model, alpha=0.5):
23+
def __init__(self, alpha=0.5):
2524
"""
2625
Initialization of the :class:`NeuralTangentKernelWeighting` class.
2726
28-
:param torch.nn.Module model: The neural network model.
2927
:param float alpha: The alpha parameter.
30-
3128
:raises ValueError: If ``alpha`` is not between 0 and 1 (inclusive).
3229
"""
33-
3430
super().__init__()
31+
32+
# Check consistency
3533
check_consistency(alpha, float)
36-
check_consistency(model, Module)
3734
if alpha < 0 or alpha > 1:
3835
raise ValueError("alpha should be a value between 0 and 1")
36+
37+
# Initialize parameters
3938
self.alpha = alpha
40-
self.model = model
4139
self.weights = {}
42-
self.default_value_weights = 1
40+
self.default_value_weights = 1.0
4341

4442
def aggregate(self, losses):
4543
"""
46-
Weight the losses according to the Neural Tangent Kernel
47-
algorithm.
44+
Weight the losses according to the Neural Tangent Kernel algorithm.
4845
4946
:param dict(torch.Tensor) input: The dictionary of losses.
50-
:return: The losses aggregation. It should be a scalar Tensor.
47+
:return: The aggregation of the losses. It should be a scalar Tensor.
5148
:rtype: torch.Tensor
5249
"""
50+
# Define a dictionary to store the norms of the gradients
5351
losses_norm = {}
54-
for condition in losses:
55-
losses[condition].backward(retain_graph=True)
56-
grads = []
57-
for param in self.model.parameters():
58-
grads.append(param.grad.view(-1))
59-
grads = torch.cat(grads)
60-
losses_norm[condition] = torch.norm(grads)
52+
53+
# Compute the gradient norms for each loss component
54+
for condition, loss in losses.items():
55+
loss.backward(retain_graph=True)
56+
grads = torch.cat(
57+
[p.grad.flatten() for p in self.solver.model.parameters()]
58+
)
59+
losses_norm[condition] = grads.norm()
60+
61+
# Update the weights
6162
self.weights = {
6263
condition: self.alpha
6364
* self.weights.get(condition, self.default_value_weights)
@@ -66,6 +67,7 @@ def aggregate(self, losses):
6667
/ sum(losses_norm.values())
6768
for condition in losses
6869
}
70+
6971
return sum(
7072
self.weights[condition] * loss for condition, loss in losses.items()
7173
)

pina/loss/scalar_weighting.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -37,12 +37,16 @@ def __init__(self, weights):
3737
:type weights: float | int | dict
3838
"""
3939
super().__init__()
40+
41+
# Check consistency
4042
check_consistency([weights], (float, dict, int))
43+
44+
# Weights initialization
4145
if isinstance(weights, (float, int)):
4246
self.default_value_weights = weights
4347
self.weights = {}
4448
else:
45-
self.default_value_weights = 1
49+
self.default_value_weights = 1.0
4650
self.weights = weights
4751

4852
def aggregate(self, losses):

pina/loss/weighting_interface.py

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@ def __init__(self):
1313
"""
1414
Initialization of the :class:`WeightingInterface` class.
1515
"""
16-
self.condition_names = None
16+
self._solver = None
1717

1818
@abstractmethod
1919
def aggregate(self, losses):
@@ -22,3 +22,13 @@ def aggregate(self, losses):
2222
2323
:param dict losses: The dictionary of losses.
2424
"""
25+
26+
@property
27+
def solver(self):
28+
"""
29+
The solver employing this weighting schema.
30+
31+
:return: The solver.
32+
:rtype: :class:`~pina.solver.SolverInterface`
33+
"""
34+
return self._solver

pina/solver/solver.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -44,7 +44,7 @@ def __init__(self, problem, weighting, use_lt):
4444
weighting = _NoWeighting()
4545
check_consistency(weighting, WeightingInterface)
4646
self._pina_weighting = weighting
47-
weighting.condition_names = list(self._pina_problem.conditions.keys())
47+
weighting._solver = self
4848

4949
# check consistency use_lt
5050
check_consistency(use_lt, bool)

tests/test_weighting/test_ntk_weighting.py

Lines changed: 15 additions & 47 deletions
Original file line numberDiff line numberDiff line change
@@ -2,64 +2,32 @@
22
from pina import Trainer
33
from pina.solver import PINN
44
from pina.model import FeedForward
5-
from pina.problem.zoo import Poisson2DSquareProblem
65
from pina.loss import NeuralTangentKernelWeighting
6+
from pina.problem.zoo import Poisson2DSquareProblem
77

8-
problem = Poisson2DSquareProblem()
9-
condition_names = problem.conditions.keys()
108

9+
# Initialize problem and model
10+
problem = Poisson2DSquareProblem()
11+
problem.discretise_domain(10)
12+
model = FeedForward(len(problem.input_variables), len(problem.output_variables))
1113

12-
@pytest.mark.parametrize(
13-
"model,alpha",
14-
[
15-
(
16-
FeedForward(
17-
len(problem.input_variables), len(problem.output_variables)
18-
),
19-
0.5,
20-
)
21-
],
22-
)
23-
def test_constructor(model, alpha):
24-
NeuralTangentKernelWeighting(model=model, alpha=alpha)
2514

15+
@pytest.mark.parametrize("alpha", [0.0, 0.5, 1.0])
16+
def test_constructor(alpha):
17+
NeuralTangentKernelWeighting(alpha=alpha)
2618

27-
@pytest.mark.parametrize("model", [0.5])
28-
def test_wrong_constructor1(model):
19+
# Should fail if alpha is not >= 0
2920
with pytest.raises(ValueError):
30-
NeuralTangentKernelWeighting(model)
31-
21+
NeuralTangentKernelWeighting(alpha=-0.1)
3222

33-
@pytest.mark.parametrize(
34-
"model,alpha",
35-
[
36-
(
37-
FeedForward(
38-
len(problem.input_variables), len(problem.output_variables)
39-
),
40-
1.2,
41-
)
42-
],
43-
)
44-
def test_wrong_constructor2(model, alpha):
23+
# Should fail if alpha is not <= 1
4524
with pytest.raises(ValueError):
46-
NeuralTangentKernelWeighting(model, alpha)
25+
NeuralTangentKernelWeighting(alpha=1.1)
4726

4827

49-
@pytest.mark.parametrize(
50-
"model,alpha",
51-
[
52-
(
53-
FeedForward(
54-
len(problem.input_variables), len(problem.output_variables)
55-
),
56-
0.5,
57-
)
58-
],
59-
)
60-
def test_train_aggregation(model, alpha):
61-
weighting = NeuralTangentKernelWeighting(model=model, alpha=alpha)
62-
problem.discretise_domain(50)
28+
@pytest.mark.parametrize("alpha", [0.0, 0.5, 1.0])
29+
def test_train_aggregation(alpha):
30+
weighting = NeuralTangentKernelWeighting(alpha=alpha)
6331
solver = PINN(problem=problem, model=model, weighting=weighting)
6432
trainer = Trainer(solver=solver, max_epochs=5, accelerator="cpu")
6533
trainer.train()

tests/test_weighting/test_standard_weighting.py renamed to tests/test_weighting/test_scalar_weighting.py

Lines changed: 9 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,16 +1,17 @@
11
import pytest
22
import torch
3-
43
from pina import Trainer
54
from pina.solver import PINN
65
from pina.model import FeedForward
7-
from pina.problem.zoo import Poisson2DSquareProblem
86
from pina.loss import ScalarWeighting
7+
from pina.problem.zoo import Poisson2DSquareProblem
98

9+
10+
# Initialize problem and model
1011
problem = Poisson2DSquareProblem()
12+
problem.discretise_domain(50)
1113
model = FeedForward(len(problem.input_variables), len(problem.output_variables))
1214
condition_names = problem.conditions.keys()
13-
print(problem.conditions.keys())
1415

1516

1617
@pytest.mark.parametrize(
@@ -19,11 +20,13 @@
1920
def test_constructor(weights):
2021
ScalarWeighting(weights=weights)
2122

23+
# Should fail if weights are not a scalar
24+
with pytest.raises(ValueError):
25+
ScalarWeighting(weights="invalid")
2226

23-
@pytest.mark.parametrize("weights", ["a", [1, 2, 3]])
24-
def test_wrong_constructor(weights):
27+
# Should fail if weights are not a dictionary
2528
with pytest.raises(ValueError):
26-
ScalarWeighting(weights=weights)
29+
ScalarWeighting(weights=[1, 2, 3])
2730

2831

2932
@pytest.mark.parametrize(
@@ -45,7 +48,6 @@ def test_aggregate(weights):
4548
)
4649
def test_train_aggregation(weights):
4750
weighting = ScalarWeighting(weights=weights)
48-
problem.discretise_domain(50)
4951
solver = PINN(problem=problem, model=model, weighting=weighting)
5052
trainer = Trainer(solver=solver, max_epochs=5, accelerator="cpu")
5153
trainer.train()

0 commit comments

Comments
 (0)