Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions pina/callback/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,10 @@
"R3Refinement",
"MetricTracker",
"PINAProgressBar",
"LinearWeightUpdate",
]

from .optimizer_callback import SwitchOptimizer
from .adaptive_refinement_callback import R3Refinement
from .processing_callback import MetricTracker, PINAProgressBar
from .linear_weight_update_callback import LinearWeightUpdate
85 changes: 85 additions & 0 deletions pina/callback/linear_weight_update_callback.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,85 @@
"""PINA Callbacks Implementations"""

import warnings
from lightning.pytorch.callbacks import Callback
from ..utils import check_consistency
from ..loss import ScalarWeighting


class LinearWeightUpdate(Callback):
"""
Callback to linearly adjust the weight of a condition from an
initial value to a target value over a specified number of epochs.
"""

def __init__(
self, target_epoch, condition_name, initial_value, target_value
):
"""
Callback initialization.

:param int target_epoch: The epoch at which the weight of the condition
should reach the target value.
:param str condition_name: The name of the condition whose weight
should be adjusted.
:param float initial_value: The initial value of the weight.
:param float target_value: The target value of the weight.
"""
super().__init__()
self.target_epoch = target_epoch
self.condition_name = condition_name
self.initial_value = initial_value
self.target_value = target_value

# Check consistency
check_consistency(self.target_epoch, int, subclass=False)
check_consistency(self.condition_name, str, subclass=False)
check_consistency(self.initial_value, (float, int), subclass=False)
check_consistency(self.target_value, (float, int), subclass=False)

def on_train_start(self, trainer, solver):
"""
Initialize the weight of the condition to the specified `initial_value`.

:param Trainer trainer: a pina:class:`Trainer` instance.
:param SolverInterface solver: a pina:class:`SolverInterface` instance.
"""
# Check that the target epoch is valid
if not 0 < self.target_epoch <= trainer.max_epochs:
raise ValueError(
"`target_epoch` must be greater than 0"
" and less than or equal to `max_epochs`."
)

# Check that the condition is a problem condition
if self.condition_name not in solver.problem.conditions:
raise ValueError(
f"`{self.condition_name}` must be a problem condition."
)

# Check that the initial value is not equal to the target value
if self.initial_value == self.target_value:
warnings.warn(
"`initial_value` is equal to `target_value`. "
"No effective adjustment will be performed.",
UserWarning,
)

# Check that the weighting schema is ScalarWeighting
if not isinstance(solver.weighting, ScalarWeighting):
raise ValueError("The weighting schema must be ScalarWeighting.")

# Initialize the weight of the condition
solver.weighting.weights[self.condition_name] = self.initial_value

def on_train_epoch_start(self, trainer, solver):
"""
Adjust at each epoch the weight of the condition.

:param Trainer trainer: a pina:class:`Trainer` instance.
:param SolverInterface solver: a pina:class:`SolverInterface` instance.
"""
if 0 < trainer.current_epoch <= self.target_epoch:
solver.weighting.weights[self.condition_name] += (
self.target_value - self.initial_value
) / (self.target_epoch - 1)
164 changes: 164 additions & 0 deletions tests/test_callback/test_linear_weight_update_callback.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,164 @@
import pytest
import math
from pina.solver import PINN
from pina.loss import ScalarWeighting
from pina.trainer import Trainer
from pina.model import FeedForward
from pina.problem.zoo import Poisson2DSquareProblem as Poisson
from pina.callback import LinearWeightUpdate


# Define the problem
poisson_problem = Poisson()
poisson_problem.discretise_domain(50, "grid")
cond_name = list(poisson_problem.conditions.keys())[0]

# Define the model
model = FeedForward(
input_dimensions=len(poisson_problem.input_variables),
output_dimensions=len(poisson_problem.output_variables),
layers=[32, 32],
)

# Define the weighting schema
weights_dict = {key: 1 for key in poisson_problem.conditions.keys()}
weighting = ScalarWeighting(weights=weights_dict)

# Define the solver
solver = PINN(problem=poisson_problem, model=model, weighting=weighting)

# Value used for testing
epochs = 10


@pytest.mark.parametrize("initial_value", [1, 5.5])
@pytest.mark.parametrize("target_value", [10, 25.5])
def test_constructor(initial_value, target_value):
LinearWeightUpdate(
target_epoch=epochs,
condition_name=cond_name,
initial_value=initial_value,
target_value=target_value,
)

# Target_epoch must be int
with pytest.raises(ValueError):
LinearWeightUpdate(
target_epoch=10.0,
condition_name=cond_name,
initial_value=0,
target_value=1,
)

# Condition_name must be str
with pytest.raises(ValueError):
LinearWeightUpdate(
target_epoch=epochs,
condition_name=100,
initial_value=0,
target_value=1,
)

# Initial_value must be float or int
with pytest.raises(ValueError):
LinearWeightUpdate(
target_epoch=epochs,
condition_name=cond_name,
initial_value="0",
target_value=1,
)

# Target_value must be float or int
with pytest.raises(ValueError):
LinearWeightUpdate(
target_epoch=epochs,
condition_name=cond_name,
initial_value=0,
target_value="1",
)


@pytest.mark.parametrize("initial_value, target_value", [(1, 10), (10, 1)])
def test_training(initial_value, target_value):
callback = LinearWeightUpdate(
target_epoch=epochs,
condition_name=cond_name,
initial_value=initial_value,
target_value=target_value,
)
trainer = Trainer(
solver=solver,
callbacks=[callback],
accelerator="cpu",
max_epochs=epochs,
)
trainer.train()

# Check that the final weight value matches the target value
final_value = solver.weighting.weights[cond_name]
assert math.isclose(final_value, target_value)

# Target_epoch must be greater than 0
with pytest.raises(ValueError):
callback = LinearWeightUpdate(
target_epoch=0,
condition_name=cond_name,
initial_value=0,
target_value=1,
)
trainer = Trainer(
solver=solver,
callbacks=[callback],
accelerator="cpu",
max_epochs=5,
)
trainer.train()

# Target_epoch must be less than or equal to max_epochs
with pytest.raises(ValueError):
callback = LinearWeightUpdate(
target_epoch=epochs,
condition_name=cond_name,
initial_value=0,
target_value=1,
)
trainer = Trainer(
solver=solver,
callbacks=[callback],
accelerator="cpu",
max_epochs=epochs - 1,
)
trainer.train()

# Condition_name must be a problem condition
with pytest.raises(ValueError):
callback = LinearWeightUpdate(
target_epoch=epochs,
condition_name="not_a_condition",
initial_value=0,
target_value=1,
)
trainer = Trainer(
solver=solver,
callbacks=[callback],
accelerator="cpu",
max_epochs=epochs,
)
trainer.train()

# Weighting schema must be ScalarWeighting
with pytest.raises(ValueError):
callback = LinearWeightUpdate(
target_epoch=epochs,
condition_name=cond_name,
initial_value=0,
target_value=1,
)
unweighted_solver = PINN(problem=poisson_problem, model=model)
trainer = Trainer(
solver=unweighted_solver,
callbacks=[callback],
accelerator="cpu",
max_epochs=epochs,
)
trainer.train()
Loading