Skip to content

Commit c42bdd5

Browse files
add self-adaptive weighting
1 parent bacd7e2 commit c42bdd5

File tree

5 files changed

+129
-0
lines changed

5 files changed

+129
-0
lines changed

docs/source/_rst/_code.rst

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -267,3 +267,4 @@ Losses and Weightings
267267
WeightingInterface <loss/weighting_interface.rst>
268268
ScalarWeighting <loss/scalar_weighting.rst>
269269
NeuralTangentKernelWeighting <loss/ntk_weighting.rst>
270+
SelfAdaptiveWeighting <loss/self_adaptive_weighting.rst>
Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,9 @@
1+
SelfAdaptiveWeighting
2+
=============================
3+
.. currentmodule:: pina.loss.self_adaptive_weighting
4+
5+
.. automodule:: pina.loss.self_adaptive_weighting
6+
7+
.. autoclass:: SelfAdaptiveWeighting
8+
:members:
9+
:show-inheritance:

pina/loss/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
"WeightingInterface",
88
"ScalarWeighting",
99
"NeuralTangentKernelWeighting",
10+
"SelfAdaptiveWeighting",
1011
]
1112

1213
from .loss_interface import LossInterface
@@ -15,3 +16,4 @@
1516
from .weighting_interface import WeightingInterface
1617
from .scalar_weighting import ScalarWeighting
1718
from .ntk_weighting import NeuralTangentKernelWeighting
19+
from .self_adaptive_weighting import SelfAdaptiveWeighting
Lines changed: 80 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,80 @@
1+
"""Module for Self-Adaptive Weighting class."""
2+
3+
import torch
4+
from .weighting_interface import WeightingInterface
5+
from ..utils import check_positive_integer
6+
7+
8+
class SelfAdaptiveWeighting(WeightingInterface):
9+
"""
10+
A self-adaptive weighting scheme to tackle the imbalance among the loss
11+
components. This formulation equalizes the gradient norms of the losses,
12+
preventing bias toward any particular term during training.
13+
14+
.. seealso::
15+
16+
**Original reference**:
17+
Wang, S., Sankaran, S., Stinis., P., Perdikaris, P. (2025).
18+
*Simulating Three-dimensional Turbulence with Physics-informed Neural
19+
Networks*.
20+
DOI: `arXiv preprint arXiv:2507.08972.
21+
<https://arxiv.org/abs/2507.08972>`_
22+
23+
"""
24+
25+
def __init__(self, k=100):
26+
"""
27+
Initialization of the :class:`SelfAdaptiveWeighting` class.
28+
29+
:param int k: The number of epochs after which the weights are updated.
30+
Default is 100.
31+
32+
:raises ValueError: If ``k`` is not a positive integer.
33+
"""
34+
super().__init__()
35+
36+
# Check consistency
37+
check_positive_integer(value=k, strict=True)
38+
39+
# Initialize parameters
40+
self.k = k
41+
self.weights = {}
42+
self.default_value_weights = 1.0
43+
44+
def aggregate(self, losses):
45+
"""
46+
Weight the losses according to the self-adaptive algorithm.
47+
48+
:param dict(torch.Tensor) losses: The dictionary of losses.
49+
:return: The aggregation of the losses. It should be a scalar Tensor.
50+
:rtype: torch.Tensor
51+
"""
52+
# If weights have not been initialized, set them to 1
53+
if not self.weights:
54+
self.weights = {
55+
condition: self.default_value_weights for condition in losses
56+
}
57+
58+
# Update every k epochs
59+
if self.solver.trainer.current_epoch % self.k == 0:
60+
61+
# Define a dictionary to store the norms of the gradients
62+
losses_norm = {}
63+
64+
# Compute the gradient norms for each loss component
65+
for condition, loss in losses.items():
66+
loss.backward(retain_graph=True)
67+
grads = torch.cat(
68+
[p.grad.flatten() for p in self.solver.model.parameters()]
69+
)
70+
losses_norm[condition] = grads.norm()
71+
72+
# Update the weights
73+
self.weights = {
74+
condition: sum(losses_norm.values()) / losses_norm[condition]
75+
for condition in losses
76+
}
77+
78+
return sum(
79+
self.weights[condition] * loss for condition, loss in losses.items()
80+
)
Lines changed: 37 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,37 @@
1+
import pytest
2+
from pina import Trainer
3+
from pina.solver import PINN
4+
from pina.model import FeedForward
5+
from pina.loss import SelfAdaptiveWeighting
6+
from pina.problem.zoo import Poisson2DSquareProblem
7+
8+
9+
# Initialize problem and model
10+
problem = Poisson2DSquareProblem()
11+
problem.discretise_domain(10)
12+
model = FeedForward(len(problem.input_variables), len(problem.output_variables))
13+
14+
15+
@pytest.mark.parametrize("k", [10, 100, 1000])
16+
def test_constructor(k):
17+
SelfAdaptiveWeighting(k=k)
18+
19+
# Should fail if k is not an integer
20+
with pytest.raises(AssertionError):
21+
SelfAdaptiveWeighting(k=1.5)
22+
23+
# Should fail if k is not > 0
24+
with pytest.raises(AssertionError):
25+
SelfAdaptiveWeighting(k=0)
26+
27+
# Should fail if k is not > 0
28+
with pytest.raises(AssertionError):
29+
SelfAdaptiveWeighting(k=-3)
30+
31+
32+
@pytest.mark.parametrize("k", [2, 3])
33+
def test_train_aggregation(k):
34+
weighting = SelfAdaptiveWeighting(k=k)
35+
solver = PINN(problem=problem, model=model, weighting=weighting)
36+
trainer = Trainer(solver=solver, max_epochs=5, accelerator="cpu")
37+
trainer.train()

0 commit comments

Comments
 (0)