Skip to content

Commit 96402ba

Browse files
weighting refactory
Co-authored-by: Dario Coscia <[email protected]>
1 parent c42bdd5 commit 96402ba

File tree

12 files changed

+215
-389
lines changed

12 files changed

+215
-389
lines changed

docs/source/_rst/callback/linear_weight_update_callback.rst

Lines changed: 0 additions & 7 deletions
This file was deleted.

pina/callback/__init__.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4,11 +4,9 @@
44
"SwitchOptimizer",
55
"MetricTracker",
66
"PINAProgressBar",
7-
"LinearWeightUpdate",
87
"R3Refinement",
98
]
109

1110
from .optimizer_callback import SwitchOptimizer
1211
from .processing_callback import MetricTracker, PINAProgressBar
13-
from .linear_weight_update_callback import LinearWeightUpdate
1412
from .refinement import R3Refinement

pina/callback/linear_weight_update_callback.py

Lines changed: 0 additions & 87 deletions
This file was deleted.

pina/loss/ntk_weighting.py

Lines changed: 15 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22

33
import torch
44
from .weighting_interface import WeightingInterface
5-
from ..utils import check_consistency
5+
from ..utils import check_consistency, in_range
66

77

88
class NeuralTangentKernelWeighting(WeightingInterface):
@@ -20,32 +20,34 @@ class NeuralTangentKernelWeighting(WeightingInterface):
2020
2121
"""
2222

23-
def __init__(self, alpha=0.5):
23+
def __init__(self, update_every_n_epochs=1, alpha=0.5):
2424
"""
2525
Initialization of the :class:`NeuralTangentKernelWeighting` class.
2626
27+
:param int update_every_n_epochs: The number of training epochs between
28+
weight updates. If set to 1, the weights are updated at every epoch.
29+
Default is 1.
2730
:param float alpha: The alpha parameter.
2831
:raises ValueError: If ``alpha`` is not between 0 and 1 (inclusive).
2932
"""
30-
super().__init__()
33+
super().__init__(update_every_n_epochs=update_every_n_epochs)
3134

3235
# Check consistency
3336
check_consistency(alpha, float)
34-
if alpha < 0 or alpha > 1:
35-
raise ValueError("alpha should be a value between 0 and 1")
37+
if not in_range(alpha, [0, 1], strict=False):
38+
raise ValueError("alpha must be in range (0, 1).")
3639

3740
# Initialize parameters
3841
self.alpha = alpha
3942
self.weights = {}
40-
self.default_value_weights = 1.0
4143

42-
def aggregate(self, losses):
44+
def weights_update(self, losses):
4345
"""
44-
Weight the losses according to the Neural Tangent Kernel algorithm.
46+
Update the weighting scheme based on the given losses.
4547
46-
:param dict(torch.Tensor) input: The dictionary of losses.
47-
:return: The aggregation of the losses. It should be a scalar Tensor.
48-
:rtype: torch.Tensor
48+
:param dict losses: The dictionary of losses.
49+
:return: The updated weights.
50+
:rtype: dict
4951
"""
5052
# Define a dictionary to store the norms of the gradients
5153
losses_norm = {}
@@ -60,14 +62,10 @@ def aggregate(self, losses):
6062

6163
# Update the weights
6264
self.weights = {
63-
condition: self.alpha
64-
* self.weights.get(condition, self.default_value_weights)
65+
condition: self.alpha * self.weights.get(condition, 1)
6566
+ (1 - self.alpha)
6667
* losses_norm[condition]
6768
/ sum(losses_norm.values())
6869
for condition in losses
6970
}
70-
71-
return sum(
72-
self.weights[condition] * loss for condition, loss in losses.items()
73-
)
71+
return self.weights

pina/loss/scalar_weighting.py

Lines changed: 28 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -4,22 +4,6 @@
44
from ..utils import check_consistency
55

66

7-
class _NoWeighting(WeightingInterface):
8-
"""
9-
Weighting scheme that does not apply any weighting to the losses.
10-
"""
11-
12-
def aggregate(self, losses):
13-
"""
14-
Aggregate the losses.
15-
16-
:param dict losses: The dictionary of losses.
17-
:return: The aggregated losses.
18-
:rtype: torch.Tensor
19-
"""
20-
return sum(losses.values())
21-
22-
237
class ScalarWeighting(WeightingInterface):
248
"""
259
Weighting scheme that assigns a scalar weight to each loss term.
@@ -36,28 +20,42 @@ def __init__(self, weights):
3620
dictionary, the default value is used.
3721
:type weights: float | int | dict
3822
"""
39-
super().__init__()
23+
super().__init__(update_every_n_epochs=1, aggregator="sum")
4024

4125
# Check consistency
4226
check_consistency([weights], (float, dict, int))
4327

44-
# Weights initialization
45-
if isinstance(weights, (float, int)):
28+
# Initialization
29+
if isinstance(weights, dict):
30+
self.values = weights
31+
self.default_value_weights = 1
32+
elif isinstance(weights, (float, int)):
33+
self.values = {}
4634
self.default_value_weights = weights
47-
self.weights = {}
4835
else:
49-
self.default_value_weights = 1.0
50-
self.weights = weights
36+
raise ValueError
5137

52-
def aggregate(self, losses):
38+
def weights_update(self, losses):
5339
"""
54-
Aggregate the losses.
40+
Update the weighting scheme based on the given losses.
5541
5642
:param dict losses: The dictionary of losses.
57-
:return: The aggregated losses.
58-
:rtype: torch.Tensor
43+
:return: The updated weights.
44+
:rtype: dict
45+
"""
46+
return {
47+
condition: self.values.get(condition, self.default_value_weights)
48+
for condition in losses.keys()
49+
}
50+
51+
52+
class _NoWeighting(ScalarWeighting):
53+
"""
54+
Weighting scheme that does not apply any weighting to the losses.
55+
"""
56+
57+
def __init__(self):
58+
"""
59+
Initialization of the :class:`_NoWeighting` class.
5960
"""
60-
return sum(
61-
self.weights.get(condition, self.default_value_weights) * loss
62-
for condition, loss in losses.items()
63-
)
61+
super().__init__(weights=1)

pina/loss/self_adaptive_weighting.py

Lines changed: 26 additions & 49 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,6 @@
22

33
import torch
44
from .weighting_interface import WeightingInterface
5-
from ..utils import check_positive_integer
65

76

87
class SelfAdaptiveWeighting(WeightingInterface):
@@ -22,59 +21,37 @@ class SelfAdaptiveWeighting(WeightingInterface):
2221
2322
"""
2423

25-
def __init__(self, k=100):
24+
def __init__(self, update_every_n_epochs=1):
2625
"""
2726
Initialization of the :class:`SelfAdaptiveWeighting` class.
2827
29-
:param int k: The number of epochs after which the weights are updated.
30-
Default is 100.
31-
32-
:raises ValueError: If ``k`` is not a positive integer.
28+
:param int update_every_n_epochs: The number of training epochs between
29+
weight updates. If set to 1, the weights are updated at every epoch.
30+
Default is 1.
3331
"""
34-
super().__init__()
35-
36-
# Check consistency
37-
check_positive_integer(value=k, strict=True)
32+
super().__init__(update_every_n_epochs=update_every_n_epochs)
3833

39-
# Initialize parameters
40-
self.k = k
41-
self.weights = {}
42-
self.default_value_weights = 1.0
43-
44-
def aggregate(self, losses):
34+
def weights_update(self, losses):
4535
"""
46-
Weight the losses according to the self-adaptive algorithm.
36+
Update the weighting scheme based on the given losses.
4737
48-
:param dict(torch.Tensor) losses: The dictionary of losses.
49-
:return: The aggregation of the losses. It should be a scalar Tensor.
50-
:rtype: torch.Tensor
38+
:param dict losses: The dictionary of losses.
39+
:return: The updated weights.
40+
:rtype: dict
5141
"""
52-
# If weights have not been initialized, set them to 1
53-
if not self.weights:
54-
self.weights = {
55-
condition: self.default_value_weights for condition in losses
56-
}
57-
58-
# Update every k epochs
59-
if self.solver.trainer.current_epoch % self.k == 0:
60-
61-
# Define a dictionary to store the norms of the gradients
62-
losses_norm = {}
63-
64-
# Compute the gradient norms for each loss component
65-
for condition, loss in losses.items():
66-
loss.backward(retain_graph=True)
67-
grads = torch.cat(
68-
[p.grad.flatten() for p in self.solver.model.parameters()]
69-
)
70-
losses_norm[condition] = grads.norm()
71-
72-
# Update the weights
73-
self.weights = {
74-
condition: sum(losses_norm.values()) / losses_norm[condition]
75-
for condition in losses
76-
}
77-
78-
return sum(
79-
self.weights[condition] * loss for condition, loss in losses.items()
80-
)
42+
# Define a dictionary to store the norms of the gradients
43+
losses_norm = {}
44+
45+
# Compute the gradient norms for each loss component
46+
for condition, loss in losses.items():
47+
loss.backward(retain_graph=True)
48+
grads = torch.cat(
49+
[p.grad.flatten() for p in self.solver.model.parameters()]
50+
)
51+
losses_norm[condition] = grads.norm()
52+
53+
# Update the weights
54+
return {
55+
condition: sum(losses_norm.values()) / losses_norm[condition]
56+
for condition in losses
57+
}

0 commit comments

Comments
 (0)