Skip to content

Commit c91e48b

Browse files
Fix gradients computation
1 parent 23d35f9 commit c91e48b

File tree

2 files changed

+35
-20
lines changed

2 files changed

+35
-20
lines changed

pina/loss/ntk_weighting.py

Lines changed: 18 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,7 @@ def __init__(self, update_every_n_epochs=1, alpha=0.5):
2727
:param int update_every_n_epochs: The number of training epochs between
2828
weight updates. If set to 1, the weights are updated at every epoch.
2929
Default is 1.
30-
:param float alpha: The alpha parameter.
30+
:param float alpha: The alpha parameter. Default is 0.5.
3131
:raises ValueError: If ``alpha`` is not between 0 and 1 (inclusive).
3232
"""
3333
super().__init__(update_every_n_epochs=update_every_n_epochs)
@@ -49,22 +49,28 @@ def weights_update(self, losses):
4949
:return: The updated weights.
5050
:rtype: dict
5151
"""
52-
# Define a dictionary to store the norms of the gradients
53-
losses_norm = {}
52+
# Get model parameters and define a dictionary to store the norms
53+
params = [p for p in self.solver.model.parameters() if p.requires_grad]
54+
norms = {}
5455

55-
# Compute the gradient norms for each loss component
56+
# Iterate over conditions
5657
for condition, loss in losses.items():
57-
loss.backward(retain_graph=True)
58-
grads = torch.cat(
59-
[p.grad.flatten() for p in self.solver.model.parameters()]
58+
59+
# Compute gradients
60+
grads = torch.autograd.grad(
61+
loss,
62+
params,
63+
retain_graph=True,
64+
allow_unused=True,
6065
)
61-
losses_norm[condition] = grads.norm()
6266

63-
# Update the weights
67+
# Compute norms
68+
norms[condition] = torch.cat(
69+
[g.flatten() for g in grads if g is not None]
70+
).norm()
71+
6472
return {
6573
condition: self.alpha * self.last_saved_weights().get(condition, 1)
66-
+ (1 - self.alpha)
67-
* losses_norm[condition]
68-
/ sum(losses_norm.values())
74+
+ (1 - self.alpha) * norms[condition] / sum(norms.values())
6975
for condition in losses
7076
}

pina/loss/self_adaptive_weighting.py

Lines changed: 17 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -39,19 +39,28 @@ def weights_update(self, losses):
3939
:return: The updated weights.
4040
:rtype: dict
4141
"""
42-
# Define a dictionary to store the norms of the gradients
43-
losses_norm = {}
42+
# Get model parameters and define a dictionary to store the norms
43+
params = [p for p in self.solver.model.parameters() if p.requires_grad]
44+
norms = {}
4445

45-
# Compute the gradient norms for each loss component
46+
# Iterate over conditions
4647
for condition, loss in losses.items():
47-
loss.backward(retain_graph=True)
48-
grads = torch.cat(
49-
[p.grad.flatten() for p in self.solver.model.parameters()]
48+
49+
# Compute gradients
50+
grads = torch.autograd.grad(
51+
loss,
52+
params,
53+
retain_graph=True,
54+
allow_unused=True,
5055
)
51-
losses_norm[condition] = grads.norm()
56+
57+
# Compute norms
58+
norms[condition] = torch.cat(
59+
[g.flatten() for g in grads if g is not None]
60+
).norm()
5261

5362
# Update the weights
5463
return {
55-
condition: sum(losses_norm.values()) / losses_norm[condition]
64+
condition: sum(norms.values()) / norms[condition]
5665
for condition in losses
5766
}

0 commit comments

Comments
 (0)