@@ -27,7 +27,7 @@ def __init__(self, update_every_n_epochs=1, alpha=0.5):
2727 :param int update_every_n_epochs: The number of training epochs between
2828 weight updates. If set to 1, the weights are updated at every epoch.
2929 Default is 1.
30- :param float alpha: The alpha parameter.
30+ :param float alpha: The alpha parameter. Default is 0.5.
3131 :raises ValueError: If ``alpha`` is not between 0 and 1 (inclusive).
3232 """
3333 super ().__init__ (update_every_n_epochs = update_every_n_epochs )
@@ -49,22 +49,28 @@ def weights_update(self, losses):
4949 :return: The updated weights.
5050 :rtype: dict
5151 """
52- # Define a dictionary to store the norms of the gradients
53- losses_norm = {}
52+ # Get model parameters and define a dictionary to store the norms
53+ params = [p for p in self .solver .model .parameters () if p .requires_grad ]
54+ norms = {}
5455
55- # Compute the gradient norms for each loss component
56+ # Iterate over conditions
5657 for condition , loss in losses .items ():
57- loss .backward (retain_graph = True )
58- grads = torch .cat (
59- [p .grad .flatten () for p in self .solver .model .parameters ()]
58+
59+ # Compute gradients
60+ grads = torch .autograd .grad (
61+ loss ,
62+ params ,
63+ retain_graph = True ,
64+ allow_unused = True ,
6065 )
61- losses_norm [condition ] = grads .norm ()
6266
63- # Update the weights
67+ # Compute norms
68+ norms [condition ] = torch .cat (
69+ [g .flatten () for g in grads if g is not None ]
70+ ).norm ()
71+
6472 return {
6573 condition : self .alpha * self .last_saved_weights ().get (condition , 1 )
66- + (1 - self .alpha )
67- * losses_norm [condition ]
68- / sum (losses_norm .values ())
74+ + (1 - self .alpha ) * norms [condition ] / sum (norms .values ())
6975 for condition in losses
7076 }
0 commit comments