Skip to content

Commit 9b08a1c

Browse files
add linear weight update callback (#474)
1 parent a251459 commit 9b08a1c

File tree

3 files changed

+251
-0
lines changed

3 files changed

+251
-0
lines changed

pina/callback/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,8 +3,10 @@
33
"R3Refinement",
44
"MetricTracker",
55
"PINAProgressBar",
6+
"LinearWeightUpdate",
67
]
78

89
from .optimizer_callback import SwitchOptimizer
910
from .adaptive_refinement_callback import R3Refinement
1011
from .processing_callback import MetricTracker, PINAProgressBar
12+
from .linear_weight_update_callback import LinearWeightUpdate
Lines changed: 85 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,85 @@
1+
"""PINA Callbacks Implementations"""
2+
3+
import warnings
4+
from lightning.pytorch.callbacks import Callback
5+
from ..utils import check_consistency
6+
from ..loss import ScalarWeighting
7+
8+
9+
class LinearWeightUpdate(Callback):
10+
"""
11+
Callback to linearly adjust the weight of a condition from an
12+
initial value to a target value over a specified number of epochs.
13+
"""
14+
15+
def __init__(
16+
self, target_epoch, condition_name, initial_value, target_value
17+
):
18+
"""
19+
Callback initialization.
20+
21+
:param int target_epoch: The epoch at which the weight of the condition
22+
should reach the target value.
23+
:param str condition_name: The name of the condition whose weight
24+
should be adjusted.
25+
:param float initial_value: The initial value of the weight.
26+
:param float target_value: The target value of the weight.
27+
"""
28+
super().__init__()
29+
self.target_epoch = target_epoch
30+
self.condition_name = condition_name
31+
self.initial_value = initial_value
32+
self.target_value = target_value
33+
34+
# Check consistency
35+
check_consistency(self.target_epoch, int, subclass=False)
36+
check_consistency(self.condition_name, str, subclass=False)
37+
check_consistency(self.initial_value, (float, int), subclass=False)
38+
check_consistency(self.target_value, (float, int), subclass=False)
39+
40+
def on_train_start(self, trainer, solver):
41+
"""
42+
Initialize the weight of the condition to the specified `initial_value`.
43+
44+
:param Trainer trainer: a pina:class:`Trainer` instance.
45+
:param SolverInterface solver: a pina:class:`SolverInterface` instance.
46+
"""
47+
# Check that the target epoch is valid
48+
if not 0 < self.target_epoch <= trainer.max_epochs:
49+
raise ValueError(
50+
"`target_epoch` must be greater than 0"
51+
" and less than or equal to `max_epochs`."
52+
)
53+
54+
# Check that the condition is a problem condition
55+
if self.condition_name not in solver.problem.conditions:
56+
raise ValueError(
57+
f"`{self.condition_name}` must be a problem condition."
58+
)
59+
60+
# Check that the initial value is not equal to the target value
61+
if self.initial_value == self.target_value:
62+
warnings.warn(
63+
"`initial_value` is equal to `target_value`. "
64+
"No effective adjustment will be performed.",
65+
UserWarning,
66+
)
67+
68+
# Check that the weighting schema is ScalarWeighting
69+
if not isinstance(solver.weighting, ScalarWeighting):
70+
raise ValueError("The weighting schema must be ScalarWeighting.")
71+
72+
# Initialize the weight of the condition
73+
solver.weighting.weights[self.condition_name] = self.initial_value
74+
75+
def on_train_epoch_start(self, trainer, solver):
76+
"""
77+
Adjust at each epoch the weight of the condition.
78+
79+
:param Trainer trainer: a pina:class:`Trainer` instance.
80+
:param SolverInterface solver: a pina:class:`SolverInterface` instance.
81+
"""
82+
if 0 < trainer.current_epoch <= self.target_epoch:
83+
solver.weighting.weights[self.condition_name] += (
84+
self.target_value - self.initial_value
85+
) / (self.target_epoch - 1)
Lines changed: 164 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,164 @@
1+
import pytest
2+
import math
3+
from pina.solver import PINN
4+
from pina.loss import ScalarWeighting
5+
from pina.trainer import Trainer
6+
from pina.model import FeedForward
7+
from pina.problem.zoo import Poisson2DSquareProblem as Poisson
8+
from pina.callback import LinearWeightUpdate
9+
10+
11+
# Define the problem
12+
poisson_problem = Poisson()
13+
poisson_problem.discretise_domain(50, "grid")
14+
cond_name = list(poisson_problem.conditions.keys())[0]
15+
16+
# Define the model
17+
model = FeedForward(
18+
input_dimensions=len(poisson_problem.input_variables),
19+
output_dimensions=len(poisson_problem.output_variables),
20+
layers=[32, 32],
21+
)
22+
23+
# Define the weighting schema
24+
weights_dict = {key: 1 for key in poisson_problem.conditions.keys()}
25+
weighting = ScalarWeighting(weights=weights_dict)
26+
27+
# Define the solver
28+
solver = PINN(problem=poisson_problem, model=model, weighting=weighting)
29+
30+
# Value used for testing
31+
epochs = 10
32+
33+
34+
@pytest.mark.parametrize("initial_value", [1, 5.5])
35+
@pytest.mark.parametrize("target_value", [10, 25.5])
36+
def test_constructor(initial_value, target_value):
37+
LinearWeightUpdate(
38+
target_epoch=epochs,
39+
condition_name=cond_name,
40+
initial_value=initial_value,
41+
target_value=target_value,
42+
)
43+
44+
# Target_epoch must be int
45+
with pytest.raises(ValueError):
46+
LinearWeightUpdate(
47+
target_epoch=10.0,
48+
condition_name=cond_name,
49+
initial_value=0,
50+
target_value=1,
51+
)
52+
53+
# Condition_name must be str
54+
with pytest.raises(ValueError):
55+
LinearWeightUpdate(
56+
target_epoch=epochs,
57+
condition_name=100,
58+
initial_value=0,
59+
target_value=1,
60+
)
61+
62+
# Initial_value must be float or int
63+
with pytest.raises(ValueError):
64+
LinearWeightUpdate(
65+
target_epoch=epochs,
66+
condition_name=cond_name,
67+
initial_value="0",
68+
target_value=1,
69+
)
70+
71+
# Target_value must be float or int
72+
with pytest.raises(ValueError):
73+
LinearWeightUpdate(
74+
target_epoch=epochs,
75+
condition_name=cond_name,
76+
initial_value=0,
77+
target_value="1",
78+
)
79+
80+
81+
@pytest.mark.parametrize("initial_value, target_value", [(1, 10), (10, 1)])
82+
def test_training(initial_value, target_value):
83+
callback = LinearWeightUpdate(
84+
target_epoch=epochs,
85+
condition_name=cond_name,
86+
initial_value=initial_value,
87+
target_value=target_value,
88+
)
89+
trainer = Trainer(
90+
solver=solver,
91+
callbacks=[callback],
92+
accelerator="cpu",
93+
max_epochs=epochs,
94+
)
95+
trainer.train()
96+
97+
# Check that the final weight value matches the target value
98+
final_value = solver.weighting.weights[cond_name]
99+
assert math.isclose(final_value, target_value)
100+
101+
# Target_epoch must be greater than 0
102+
with pytest.raises(ValueError):
103+
callback = LinearWeightUpdate(
104+
target_epoch=0,
105+
condition_name=cond_name,
106+
initial_value=0,
107+
target_value=1,
108+
)
109+
trainer = Trainer(
110+
solver=solver,
111+
callbacks=[callback],
112+
accelerator="cpu",
113+
max_epochs=5,
114+
)
115+
trainer.train()
116+
117+
# Target_epoch must be less than or equal to max_epochs
118+
with pytest.raises(ValueError):
119+
callback = LinearWeightUpdate(
120+
target_epoch=epochs,
121+
condition_name=cond_name,
122+
initial_value=0,
123+
target_value=1,
124+
)
125+
trainer = Trainer(
126+
solver=solver,
127+
callbacks=[callback],
128+
accelerator="cpu",
129+
max_epochs=epochs - 1,
130+
)
131+
trainer.train()
132+
133+
# Condition_name must be a problem condition
134+
with pytest.raises(ValueError):
135+
callback = LinearWeightUpdate(
136+
target_epoch=epochs,
137+
condition_name="not_a_condition",
138+
initial_value=0,
139+
target_value=1,
140+
)
141+
trainer = Trainer(
142+
solver=solver,
143+
callbacks=[callback],
144+
accelerator="cpu",
145+
max_epochs=epochs,
146+
)
147+
trainer.train()
148+
149+
# Weighting schema must be ScalarWeighting
150+
with pytest.raises(ValueError):
151+
callback = LinearWeightUpdate(
152+
target_epoch=epochs,
153+
condition_name=cond_name,
154+
initial_value=0,
155+
target_value=1,
156+
)
157+
unweighted_solver = PINN(problem=poisson_problem, model=model)
158+
trainer = Trainer(
159+
solver=unweighted_solver,
160+
callbacks=[callback],
161+
accelerator="cpu",
162+
max_epochs=epochs,
163+
)
164+
trainer.train()

0 commit comments

Comments
 (0)