Skip to content

Commit ef35424

Browse files
add linear weighting
1 parent 96402ba commit ef35424

File tree

7 files changed

+176
-9
lines changed

7 files changed

+176
-9
lines changed

docs/source/_rst/_code.rst

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -253,7 +253,6 @@ Callbacks
253253
Optimizer callback <callback/optimizer_callback.rst>
254254
R3 Refinment callback <callback/refinement/r3_refinement.rst>
255255
Refinment Interface callback <callback/refinement/refinement_interface.rst>
256-
Weighting callback <callback/linear_weight_update_callback.rst>
257256

258257
Losses and Weightings
259258
---------------------
@@ -267,4 +266,5 @@ Losses and Weightings
267266
WeightingInterface <loss/weighting_interface.rst>
268267
ScalarWeighting <loss/scalar_weighting.rst>
269268
NeuralTangentKernelWeighting <loss/ntk_weighting.rst>
270-
SelfAdaptiveWeighting <loss/self_adaptive_weighting.rst>
269+
SelfAdaptiveWeighting <loss/self_adaptive_weighting.rst>
270+
LinearWeighting <loss/linear_weighting.rst>
Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,9 @@
1+
LinearWeighting
2+
=============================
3+
.. currentmodule:: pina.loss.linear_weighting
4+
5+
.. automodule:: pina.loss.linear_weighting
6+
7+
.. autoclass:: LinearWeighting
8+
:members:
9+
:show-inheritance:

pina/loss/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
"ScalarWeighting",
99
"NeuralTangentKernelWeighting",
1010
"SelfAdaptiveWeighting",
11+
"LinearWeighting",
1112
]
1213

1314
from .loss_interface import LossInterface
@@ -17,3 +18,4 @@
1718
from .scalar_weighting import ScalarWeighting
1819
from .ntk_weighting import NeuralTangentKernelWeighting
1920
from .self_adaptive_weighting import SelfAdaptiveWeighting
21+
from .linear_weighting import LinearWeighting

pina/loss/linear_weighting.py

Lines changed: 64 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,64 @@
1+
"""Module for the LinearWeighting class."""
2+
3+
from ..loss import WeightingInterface
4+
from ..utils import check_consistency, check_positive_integer
5+
6+
7+
class LinearWeighting(WeightingInterface):
8+
"""
9+
A weighting scheme that linearly scales weights from initial values to final
10+
values over a specified number of epochs.
11+
"""
12+
13+
def __init__(self, initial_weights, final_weights, target_epoch):
14+
"""
15+
:param dict initial_weights: The weights to be assigned to each loss
16+
term at the beginning of training. The keys are the conditions and
17+
the values are the corresponding weights. If a condition is not
18+
present in the dictionary, the default value (1) is used.
19+
:param dict final_weights: The weights to be assigned to each loss term
20+
once the target epoch is reached. The keys are the conditions and
21+
the values are the corresponding weights. If a condition is not
22+
present in the dictionary, the default value (1) is used.
23+
:param int target_epoch: The epoch at which the weights reach their
24+
final values.
25+
:raises ValueError: If the keys of the two dictionaries are not
26+
consistent.
27+
"""
28+
super().__init__(update_every_n_epochs=1, aggregator="sum")
29+
30+
# Check consistency
31+
check_consistency([initial_weights, final_weights], dict)
32+
check_positive_integer(value=target_epoch, strict=True)
33+
34+
# Check that the keys of the two dictionaries are the same
35+
if initial_weights.keys() != final_weights.keys():
36+
raise ValueError(
37+
"The keys of the initial_weights and final_weights "
38+
"dictionaries must be the same."
39+
)
40+
41+
# Initialization
42+
self.initial_weights = initial_weights
43+
self.final_weights = final_weights
44+
self.target_epoch = target_epoch
45+
46+
def weights_update(self, losses):
47+
"""
48+
Update the weighting scheme based on the given losses.
49+
50+
:param dict losses: The dictionary of losses.
51+
:return: The updated weights.
52+
:rtype: dict
53+
"""
54+
return {
55+
condition: self.last_saved_weights().get(
56+
condition, self.initial_weights.get(condition, 1)
57+
)
58+
+ (
59+
self.final_weights.get(condition, 1)
60+
- self.initial_weights.get(condition, 1)
61+
)
62+
/ (self.target_epoch)
63+
for condition in losses.keys()
64+
}

pina/loss/ntk_weighting.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -61,11 +61,10 @@ def weights_update(self, losses):
6161
losses_norm[condition] = grads.norm()
6262

6363
# Update the weights
64-
self.weights = {
65-
condition: self.alpha * self.weights.get(condition, 1)
64+
return {
65+
condition: self.alpha * self.last_saved_weights().get(condition, 1)
6666
+ (1 - self.alpha)
6767
* losses_norm[condition]
6868
/ sum(losses_norm.values())
6969
for condition in losses
7070
}
71-
return self.weights

pina/loss/scalar_weighting.py

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@ def __init__(self, weights):
1717
If a single scalar value is provided, it is assigned to all loss
1818
terms. If a dictionary is provided, the keys are the conditions and
1919
the values are the weights. If a condition is not present in the
20-
dictionary, the default value is used.
20+
dictionary, the default value (1) is used.
2121
:type weights: float | int | dict
2222
"""
2323
super().__init__(update_every_n_epochs=1, aggregator="sum")
@@ -29,11 +29,9 @@ def __init__(self, weights):
2929
if isinstance(weights, dict):
3030
self.values = weights
3131
self.default_value_weights = 1
32-
elif isinstance(weights, (float, int)):
32+
else:
3333
self.values = {}
3434
self.default_value_weights = weights
35-
else:
36-
raise ValueError
3735

3836
def weights_update(self, losses):
3937
"""
Lines changed: 95 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,95 @@
1+
import math
2+
import pytest
3+
from pina import Trainer
4+
from pina.solver import PINN
5+
from pina.model import FeedForward
6+
from pina.loss import LinearWeighting
7+
from pina.problem.zoo import Poisson2DSquareProblem
8+
9+
10+
# Initialize problem and model
11+
problem = Poisson2DSquareProblem()
12+
problem.discretise_domain(10)
13+
model = FeedForward(len(problem.input_variables), len(problem.output_variables))
14+
15+
# Weights for testing
16+
init_weight_1 = {cond: 3 for cond in problem.conditions.keys()}
17+
init_weight_2 = {cond: 4 for cond in problem.conditions.keys()}
18+
final_weight_1 = {cond: 1 for cond in problem.conditions.keys()}
19+
final_weight_2 = {cond: 5 for cond in problem.conditions.keys()}
20+
21+
22+
@pytest.mark.parametrize("initial_weights", [init_weight_1, init_weight_2])
23+
@pytest.mark.parametrize("final_weights", [final_weight_1, final_weight_2])
24+
@pytest.mark.parametrize("target_epoch", [5, 10])
25+
def test_constructor(initial_weights, final_weights, target_epoch):
26+
LinearWeighting(
27+
initial_weights=initial_weights,
28+
final_weights=final_weights,
29+
target_epoch=target_epoch,
30+
)
31+
32+
# Should fail if initial_weights is not a dictionary
33+
with pytest.raises(ValueError):
34+
LinearWeighting(
35+
initial_weights=[1, 1, 1],
36+
final_weights=final_weights,
37+
target_epoch=target_epoch,
38+
)
39+
40+
# Should fail if final_weights is not a dictionary
41+
with pytest.raises(ValueError):
42+
LinearWeighting(
43+
initial_weights=initial_weights,
44+
final_weights=[1, 1, 1],
45+
target_epoch=target_epoch,
46+
)
47+
48+
# Should fail if target_epoch is not an integer
49+
with pytest.raises(AssertionError):
50+
LinearWeighting(
51+
initial_weights=initial_weights,
52+
final_weights=final_weights,
53+
target_epoch=1.5,
54+
)
55+
56+
# Should fail if target_epoch is not positive
57+
with pytest.raises(AssertionError):
58+
LinearWeighting(
59+
initial_weights=initial_weights,
60+
final_weights=final_weights,
61+
target_epoch=0,
62+
)
63+
64+
# Should fail if dictionary keys do not match
65+
with pytest.raises(ValueError):
66+
LinearWeighting(
67+
initial_weights={list(initial_weights.keys())[0]: 1},
68+
final_weights=final_weights,
69+
target_epoch=target_epoch,
70+
)
71+
72+
73+
@pytest.mark.parametrize("initial_weights", [init_weight_1, init_weight_2])
74+
@pytest.mark.parametrize("final_weights", [final_weight_1, final_weight_2])
75+
@pytest.mark.parametrize("target_epoch", [5, 10])
76+
def test_train_aggregation(initial_weights, final_weights, target_epoch):
77+
weighting = LinearWeighting(
78+
initial_weights=initial_weights,
79+
final_weights=final_weights,
80+
target_epoch=target_epoch,
81+
)
82+
solver = PINN(problem=problem, model=model, weighting=weighting)
83+
trainer = Trainer(solver=solver, max_epochs=target_epoch, accelerator="cpu")
84+
trainer.train()
85+
86+
# Check that weights are updated correctly
87+
assert all(
88+
math.isclose(
89+
weighting.last_saved_weights()[cond],
90+
final_weights[cond],
91+
rel_tol=1e-5,
92+
abs_tol=1e-8,
93+
)
94+
for cond in final_weights.keys()
95+
)

0 commit comments

Comments
 (0)