Skip to content

Commit 90aa9b0

Browse files
[Fix] Align type of losses to Dict[str, Tensor] for mtl.NTK (#1088)
* align type of losses to Dict[str, Tensor] for mtl.NTK * update code
1 parent b02a2b1 commit 90aa9b0

File tree

1 file changed

+45
-13
lines changed

1 file changed

+45
-13
lines changed

ppsci/loss/mtl/ntk.py

Lines changed: 45 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@
1616

1717
from typing import TYPE_CHECKING
1818
from typing import ClassVar
19-
from typing import List
19+
from typing import Dict
2020

2121
import paddle
2222

@@ -27,7 +27,35 @@
2727

2828

2929
class NTK(base.LossAggregator):
30+
r"""Weighted Neural Tangent Kernel.
31+
32+
reference: [https://github.com/PredictiveIntelligenceLab/jaxpi/blob/main/jaxpi/models.py#L148-L158](https://github.com/PredictiveIntelligenceLab/jaxpi/blob/main/jaxpi/models.py#L148-L158)
33+
34+
Attributes:
35+
should_persist(bool): Whether to persist the loss aggregator when saving.
36+
Those loss aggregators with parameters and/or buffers should be persisted.
37+
38+
Args:
39+
model (nn.Layer): Training model.
40+
num_losses (int, optional): Number of losses. Defaults to 1.
41+
update_freq (int, optional): Weight updating frequency. Defaults to 1000.
42+
43+
Examples:
44+
>>> import paddle
45+
>>> from ppsci.loss import mtl
46+
>>> model = paddle.nn.Linear(3, 4)
47+
>>> loss_aggregator = mtl.NTK(model, num_losses=2)
48+
>>> for i in range(5):
49+
... x1 = paddle.randn([8, 3])
50+
... x2 = paddle.randn([8, 3])
51+
... y1 = model(x1)
52+
... y2 = model(x2)
53+
... loss1 = paddle.sum(y1)
54+
... loss2 = paddle.sum((y2 - 2) ** 2)
55+
... loss_aggregator({'loss1': loss1, 'loss2': loss2}).backward()
56+
"""
3057
should_persist: ClassVar[bool] = True
58+
weight: paddle.Tensor
3159

3260
def __init__(
3361
self,
@@ -41,18 +69,20 @@ def __init__(
4169
self.update_freq = update_freq
4270
self.register_buffer("weight", paddle.ones([num_losses]))
4371

44-
def _compute_weight(self, losses):
72+
def _compute_weight(self, losses: Dict[str, paddle.Tensor]):
4573
ntk_sum = 0
4674
ntk_value = []
47-
for loss in losses:
48-
loss.backward(retain_graph=True) # NOTE: Keep graph for loss backward
75+
for loss in losses.values():
76+
grads = paddle.grad(
77+
loss,
78+
self.model.parameters(),
79+
create_graph=False,
80+
retain_graph=True,
81+
allow_unused=True,
82+
)
4983
with paddle.no_grad():
5084
grad = paddle.concat(
51-
[
52-
p.grad.reshape([-1])
53-
for p in self.model.parameters()
54-
if p.grad is not None
55-
]
85+
[grad.reshape([-1]) for grad in grads if grad is not None]
5686
)
5787
ntk_value.append(
5888
paddle.sqrt(
@@ -65,17 +95,19 @@ def _compute_weight(self, losses):
6595

6696
return ntk_weight
6797

68-
def __call__(self, losses: List["paddle.Tensor"], step: int = 0) -> "paddle.Tensor":
98+
def __call__(
99+
self, losses: Dict[str, "paddle.Tensor"], step: int = 0
100+
) -> "paddle.Tensor":
69101
assert len(losses) == self.num_losses, (
70102
f"Length of given losses({len(losses)}) should be equal to "
71103
f"num_losses({self.num_losses})."
72104
)
73105
self.step = step
74106

75107
# compute current loss with moving weights
76-
loss = self.weight[0] * losses[0]
77-
for i in range(1, len(losses)):
78-
loss += self.weight[i] * losses[i]
108+
loss = 0
109+
for i, (k, v) in enumerate(losses.items()):
110+
loss = loss + self.weight[i] * v
79111

80112
# update moving weights every 'update_freq' steps
81113
if self.step % self.update_freq == 0:

0 commit comments

Comments
 (0)