-
Notifications
You must be signed in to change notification settings - Fork 92
Linear weight update callback #474
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Merged
Merged
Changes from all commits
Commits
File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,85 @@ | ||
| """PINA Callbacks Implementations""" | ||
|
|
||
| import warnings | ||
| from lightning.pytorch.callbacks import Callback | ||
| from ..utils import check_consistency | ||
| from ..loss import ScalarWeighting | ||
|
|
||
|
|
||
| class LinearWeightUpdate(Callback): | ||
| """ | ||
| Callback to linearly adjust the weight of a condition from an | ||
| initial value to a target value over a specified number of epochs. | ||
| """ | ||
|
|
||
| def __init__( | ||
| self, target_epoch, condition_name, initial_value, target_value | ||
| ): | ||
| """ | ||
| Callback initialization. | ||
|
|
||
| :param int target_epoch: The epoch at which the weight of the condition | ||
| should reach the target value. | ||
| :param str condition_name: The name of the condition whose weight | ||
| should be adjusted. | ||
| :param float initial_value: The initial value of the weight. | ||
| :param float target_value: The target value of the weight. | ||
| """ | ||
| super().__init__() | ||
| self.target_epoch = target_epoch | ||
| self.condition_name = condition_name | ||
| self.initial_value = initial_value | ||
| self.target_value = target_value | ||
|
|
||
| # Check consistency | ||
| check_consistency(self.target_epoch, int, subclass=False) | ||
| check_consistency(self.condition_name, str, subclass=False) | ||
| check_consistency(self.initial_value, (float, int), subclass=False) | ||
| check_consistency(self.target_value, (float, int), subclass=False) | ||
|
|
||
| def on_train_start(self, trainer, solver): | ||
| """ | ||
| Initialize the weight of the condition to the specified `initial_value`. | ||
|
|
||
| :param Trainer trainer: a pina:class:`Trainer` instance. | ||
| :param SolverInterface solver: a pina:class:`SolverInterface` instance. | ||
| """ | ||
| # Check that the target epoch is valid | ||
| if not 0 < self.target_epoch <= trainer.max_epochs: | ||
| raise ValueError( | ||
| "`target_epoch` must be greater than 0" | ||
| " and less than or equal to `max_epochs`." | ||
| ) | ||
|
|
||
| # Check that the condition is a problem condition | ||
| if self.condition_name not in solver.problem.conditions: | ||
| raise ValueError( | ||
| f"`{self.condition_name}` must be a problem condition." | ||
| ) | ||
|
|
||
| # Check that the initial value is not equal to the target value | ||
| if self.initial_value == self.target_value: | ||
| warnings.warn( | ||
| "`initial_value` is equal to `target_value`. " | ||
| "No effective adjustment will be performed.", | ||
| UserWarning, | ||
| ) | ||
|
|
||
| # Check that the weighting schema is ScalarWeighting | ||
| if not isinstance(solver.weighting, ScalarWeighting): | ||
| raise ValueError("The weighting schema must be ScalarWeighting.") | ||
|
|
||
| # Initialize the weight of the condition | ||
| solver.weighting.weights[self.condition_name] = self.initial_value | ||
|
|
||
| def on_train_epoch_start(self, trainer, solver): | ||
| """ | ||
| Adjust at each epoch the weight of the condition. | ||
|
|
||
| :param Trainer trainer: a pina:class:`Trainer` instance. | ||
| :param SolverInterface solver: a pina:class:`SolverInterface` instance. | ||
| """ | ||
| if 0 < trainer.current_epoch <= self.target_epoch: | ||
| solver.weighting.weights[self.condition_name] += ( | ||
| self.target_value - self.initial_value | ||
| ) / (self.target_epoch - 1) | ||
164 changes: 164 additions & 0 deletions
164
tests/test_callback/test_linear_weight_update_callback.py
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,164 @@ | ||
| import pytest | ||
| import math | ||
| from pina.solver import PINN | ||
| from pina.loss import ScalarWeighting | ||
| from pina.trainer import Trainer | ||
| from pina.model import FeedForward | ||
| from pina.problem.zoo import Poisson2DSquareProblem as Poisson | ||
| from pina.callback import LinearWeightUpdate | ||
|
|
||
|
|
||
| # Define the problem | ||
| poisson_problem = Poisson() | ||
| poisson_problem.discretise_domain(50, "grid") | ||
| cond_name = list(poisson_problem.conditions.keys())[0] | ||
|
|
||
| # Define the model | ||
| model = FeedForward( | ||
| input_dimensions=len(poisson_problem.input_variables), | ||
| output_dimensions=len(poisson_problem.output_variables), | ||
| layers=[32, 32], | ||
| ) | ||
|
|
||
| # Define the weighting schema | ||
| weights_dict = {key: 1 for key in poisson_problem.conditions.keys()} | ||
| weighting = ScalarWeighting(weights=weights_dict) | ||
|
|
||
| # Define the solver | ||
| solver = PINN(problem=poisson_problem, model=model, weighting=weighting) | ||
|
|
||
| # Value used for testing | ||
| epochs = 10 | ||
|
|
||
|
|
||
| @pytest.mark.parametrize("initial_value", [1, 5.5]) | ||
| @pytest.mark.parametrize("target_value", [10, 25.5]) | ||
| def test_constructor(initial_value, target_value): | ||
| LinearWeightUpdate( | ||
| target_epoch=epochs, | ||
| condition_name=cond_name, | ||
| initial_value=initial_value, | ||
| target_value=target_value, | ||
| ) | ||
|
|
||
| # Target_epoch must be int | ||
| with pytest.raises(ValueError): | ||
| LinearWeightUpdate( | ||
| target_epoch=10.0, | ||
| condition_name=cond_name, | ||
| initial_value=0, | ||
| target_value=1, | ||
| ) | ||
|
|
||
| # Condition_name must be str | ||
| with pytest.raises(ValueError): | ||
| LinearWeightUpdate( | ||
| target_epoch=epochs, | ||
| condition_name=100, | ||
| initial_value=0, | ||
| target_value=1, | ||
| ) | ||
|
|
||
| # Initial_value must be float or int | ||
| with pytest.raises(ValueError): | ||
| LinearWeightUpdate( | ||
| target_epoch=epochs, | ||
| condition_name=cond_name, | ||
| initial_value="0", | ||
| target_value=1, | ||
| ) | ||
|
|
||
| # Target_value must be float or int | ||
| with pytest.raises(ValueError): | ||
| LinearWeightUpdate( | ||
| target_epoch=epochs, | ||
| condition_name=cond_name, | ||
| initial_value=0, | ||
| target_value="1", | ||
| ) | ||
|
|
||
|
|
||
| @pytest.mark.parametrize("initial_value, target_value", [(1, 10), (10, 1)]) | ||
| def test_training(initial_value, target_value): | ||
| callback = LinearWeightUpdate( | ||
| target_epoch=epochs, | ||
| condition_name=cond_name, | ||
| initial_value=initial_value, | ||
| target_value=target_value, | ||
| ) | ||
| trainer = Trainer( | ||
| solver=solver, | ||
| callbacks=[callback], | ||
| accelerator="cpu", | ||
| max_epochs=epochs, | ||
| ) | ||
| trainer.train() | ||
|
|
||
| # Check that the final weight value matches the target value | ||
| final_value = solver.weighting.weights[cond_name] | ||
| assert math.isclose(final_value, target_value) | ||
|
|
||
| # Target_epoch must be greater than 0 | ||
| with pytest.raises(ValueError): | ||
| callback = LinearWeightUpdate( | ||
| target_epoch=0, | ||
| condition_name=cond_name, | ||
| initial_value=0, | ||
| target_value=1, | ||
| ) | ||
| trainer = Trainer( | ||
| solver=solver, | ||
| callbacks=[callback], | ||
| accelerator="cpu", | ||
| max_epochs=5, | ||
| ) | ||
| trainer.train() | ||
|
|
||
| # Target_epoch must be less than or equal to max_epochs | ||
| with pytest.raises(ValueError): | ||
| callback = LinearWeightUpdate( | ||
| target_epoch=epochs, | ||
| condition_name=cond_name, | ||
| initial_value=0, | ||
| target_value=1, | ||
| ) | ||
| trainer = Trainer( | ||
| solver=solver, | ||
| callbacks=[callback], | ||
| accelerator="cpu", | ||
| max_epochs=epochs - 1, | ||
| ) | ||
| trainer.train() | ||
|
|
||
| # Condition_name must be a problem condition | ||
| with pytest.raises(ValueError): | ||
| callback = LinearWeightUpdate( | ||
| target_epoch=epochs, | ||
| condition_name="not_a_condition", | ||
| initial_value=0, | ||
| target_value=1, | ||
| ) | ||
| trainer = Trainer( | ||
| solver=solver, | ||
| callbacks=[callback], | ||
| accelerator="cpu", | ||
| max_epochs=epochs, | ||
| ) | ||
| trainer.train() | ||
|
|
||
| # Weighting schema must be ScalarWeighting | ||
| with pytest.raises(ValueError): | ||
| callback = LinearWeightUpdate( | ||
| target_epoch=epochs, | ||
| condition_name=cond_name, | ||
| initial_value=0, | ||
| target_value=1, | ||
| ) | ||
| unweighted_solver = PINN(problem=poisson_problem, model=model) | ||
| trainer = Trainer( | ||
| solver=unweighted_solver, | ||
| callbacks=[callback], | ||
| accelerator="cpu", | ||
| max_epochs=epochs, | ||
| ) | ||
| trainer.train() |
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Uh oh!
There was an error while loading. Please reload this page.