16
16
17
17
from typing import TYPE_CHECKING
18
18
from typing import ClassVar
19
- from typing import List
19
+ from typing import Dict
20
20
21
21
import paddle
22
22
27
27
28
28
29
29
class NTK (base .LossAggregator ):
30
+ r"""Weighted Neural Tangent Kernel.
31
+
32
+ reference: [https://github.com/PredictiveIntelligenceLab/jaxpi/blob/main/jaxpi/models.py#L148-L158](https://github.com/PredictiveIntelligenceLab/jaxpi/blob/main/jaxpi/models.py#L148-L158)
33
+
34
+ Attributes:
35
+ should_persist(bool): Whether to persist the loss aggregator when saving.
36
+ Those loss aggregators with parameters and/or buffers should be persisted.
37
+
38
+ Args:
39
+ model (nn.Layer): Training model.
40
+ num_losses (int, optional): Number of losses. Defaults to 1.
41
+ update_freq (int, optional): Weight updating frequency. Defaults to 1000.
42
+
43
+ Examples:
44
+ >>> import paddle
45
+ >>> from ppsci.loss import mtl
46
+ >>> model = paddle.nn.Linear(3, 4)
47
+ >>> loss_aggregator = mtl.NTK(model, num_losses=2)
48
+ >>> for i in range(5):
49
+ ... x1 = paddle.randn([8, 3])
50
+ ... x2 = paddle.randn([8, 3])
51
+ ... y1 = model(x1)
52
+ ... y2 = model(x2)
53
+ ... loss1 = paddle.sum(y1)
54
+ ... loss2 = paddle.sum((y2 - 2) ** 2)
55
+ ... loss_aggregator({'loss1': loss1, 'loss2': loss2}).backward()
56
+ """
30
57
should_persist : ClassVar [bool ] = True
58
+ weight : paddle .Tensor
31
59
32
60
def __init__ (
33
61
self ,
@@ -41,18 +69,20 @@ def __init__(
41
69
self .update_freq = update_freq
42
70
self .register_buffer ("weight" , paddle .ones ([num_losses ]))
43
71
44
- def _compute_weight (self , losses ):
72
+ def _compute_weight (self , losses : Dict [ str , paddle . Tensor ] ):
45
73
ntk_sum = 0
46
74
ntk_value = []
47
- for loss in losses :
48
- loss .backward (retain_graph = True ) # NOTE: Keep graph for loss backward
75
+ for loss in losses .values ():
76
+ grads = paddle .grad (
77
+ loss ,
78
+ self .model .parameters (),
79
+ create_graph = False ,
80
+ retain_graph = True ,
81
+ allow_unused = True ,
82
+ )
49
83
with paddle .no_grad ():
50
84
grad = paddle .concat (
51
- [
52
- p .grad .reshape ([- 1 ])
53
- for p in self .model .parameters ()
54
- if p .grad is not None
55
- ]
85
+ [grad .reshape ([- 1 ]) for grad in grads if grad is not None ]
56
86
)
57
87
ntk_value .append (
58
88
paddle .sqrt (
@@ -65,17 +95,19 @@ def _compute_weight(self, losses):
65
95
66
96
return ntk_weight
67
97
68
- def __call__ (self , losses : List ["paddle.Tensor" ], step : int = 0 ) -> "paddle.Tensor" :
98
+ def __call__ (
99
+ self , losses : Dict [str , "paddle.Tensor" ], step : int = 0
100
+ ) -> "paddle.Tensor" :
69
101
assert len (losses ) == self .num_losses , (
70
102
f"Length of given losses({ len (losses )} ) should be equal to "
71
103
f"num_losses({ self .num_losses } )."
72
104
)
73
105
self .step = step
74
106
75
107
# compute current loss with moving weights
76
- loss = self . weight [ 0 ] * losses [ 0 ]
77
- for i in range ( 1 , len (losses )):
78
- loss += self .weight [i ] * losses [ i ]
108
+ loss = 0
109
+ for i , ( k , v ) in enumerate (losses . items ( )):
110
+ loss = loss + self .weight [i ] * v
79
111
80
112
# update moving weights every 'update_freq' steps
81
113
if self .step % self .update_freq == 0 :
0 commit comments