diff --git a/pina/loss/ntk_weighting.py b/pina/loss/ntk_weighting.py index fe671157a..fe1c4fc6a 100644 --- a/pina/loss/ntk_weighting.py +++ b/pina/loss/ntk_weighting.py @@ -27,7 +27,7 @@ def __init__(self, update_every_n_epochs=1, alpha=0.5): :param int update_every_n_epochs: The number of training epochs between weight updates. If set to 1, the weights are updated at every epoch. Default is 1. - :param float alpha: The alpha parameter. + :param float alpha: The alpha parameter. Default is 0.5. :raises ValueError: If ``alpha`` is not between 0 and 1 (inclusive). """ super().__init__(update_every_n_epochs=update_every_n_epochs) @@ -49,22 +49,28 @@ def weights_update(self, losses): :return: The updated weights. :rtype: dict """ - # Define a dictionary to store the norms of the gradients - losses_norm = {} + # Get model parameters and define a dictionary to store the norms + params = [p for p in self.solver.model.parameters() if p.requires_grad] + norms = {} - # Compute the gradient norms for each loss component + # Iterate over conditions for condition, loss in losses.items(): - loss.backward(retain_graph=True) - grads = torch.cat( - [p.grad.flatten() for p in self.solver.model.parameters()] + + # Compute gradients + grads = torch.autograd.grad( + loss, + params, + retain_graph=True, + allow_unused=True, ) - losses_norm[condition] = grads.norm() - # Update the weights + # Compute norms + norms[condition] = torch.cat( + [g.flatten() for g in grads if g is not None] + ).norm() + return { condition: self.alpha * self.last_saved_weights().get(condition, 1) - + (1 - self.alpha) - * losses_norm[condition] - / sum(losses_norm.values()) + + (1 - self.alpha) * norms[condition] / sum(norms.values()) for condition in losses } diff --git a/pina/loss/self_adaptive_weighting.py b/pina/loss/self_adaptive_weighting.py index 62196c529..c796d359f 100644 --- a/pina/loss/self_adaptive_weighting.py +++ b/pina/loss/self_adaptive_weighting.py @@ -39,19 +39,28 @@ def weights_update(self, losses): :return: The updated weights. :rtype: dict """ - # Define a dictionary to store the norms of the gradients - losses_norm = {} + # Get model parameters and define a dictionary to store the norms + params = [p for p in self.solver.model.parameters() if p.requires_grad] + norms = {} - # Compute the gradient norms for each loss component + # Iterate over conditions for condition, loss in losses.items(): - loss.backward(retain_graph=True) - grads = torch.cat( - [p.grad.flatten() for p in self.solver.model.parameters()] + + # Compute gradients + grads = torch.autograd.grad( + loss, + params, + retain_graph=True, + allow_unused=True, ) - losses_norm[condition] = grads.norm() + + # Compute norms + norms[condition] = torch.cat( + [g.flatten() for g in grads if g is not None] + ).norm() # Update the weights return { - condition: sum(losses_norm.values()) / losses_norm[condition] + condition: sum(norms.values()) / norms[condition] for condition in losses }