Skip to content

Commit f458c73

Browse files
authored
Update ntk_weighting.py
1 parent 1fc3ad8 commit f458c73

File tree

1 file changed

+17
-11
lines changed

1 file changed

+17
-11
lines changed

pina/loss/ntk_weighting.py

Lines changed: 17 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -6,17 +6,27 @@
66
from ..utils import check_consistency
77

88

9-
class _NoWeighting(WeightingInterface):
10-
def aggregate(self, losses):
11-
return sum(losses.values())
9+
class NeuralTangentKernelWeighting(WeightingInterface):
10+
"""
11+
A neural tangent kernel scheme for weighting different losses to
12+
boost the convergence.
1213
14+
.. seealso::
1315
14-
class NeuralTangetKernelWeighting(WeightingInterface):
15-
"""
16-
TODO
16+
**Original reference**: Jacot, Gabriel, Hongler, *ANeural Tangent
17+
Kernel: Convergence and Generalization in Neural Networks*.
18+
arXiv preprint arXiv:1806.07572 (2018).
19+
DOI: `arXiv:1806.07572 <https://arxiv.org/abs/1806.07572>`_.
20+
1721
"""
1822

1923
def __init__(self, model, alpha=0.5):
24+
"""
25+
Initialization of the :class:`NeuralTangentKernelWeighting` class.
26+
27+
:param float alpha: The alpha parameter.
28+
"""
29+
2030
super().__init__()
2131
check_consistency(alpha, float)
2232
check_consistency(model, Module)
@@ -30,15 +40,11 @@ def __init__(self, model, alpha=0.5):
3040
def aggregate(self, losses):
3141
"""
3242
Weights the losses according to the Neural Tangent Kernel
43+
algorithm.
3344
3445
:param dict(torch.Tensor) input: The dictionary of losses.
3546
:return: The losses aggregation. It should be a scalar Tensor.
3647
:rtype: torch.Tensor
37-
38-
Reference:
39-
Wang, S., Sankaran, S., Wang, H., & Perdikaris, P. (2023).
40-
An expert's guide to training physics-informed neural networks.
41-
arXiv preprint arXiv:2308.08468.
4248
"""
4349
losses_norm = {}
4450
for condition in losses:

0 commit comments

Comments
 (0)