66from ..utils import check_consistency
77
88
9- class _NoWeighting (WeightingInterface ):
10- def aggregate (self , losses ):
11- return sum (losses .values ())
9+ class NeuralTangentKernelWeighting (WeightingInterface ):
10+ """
11+ A neural tangent kernel scheme for weighting different losses to
12+ boost the convergence.
1213
14+ .. seealso::
1315
14- class NeuralTangetKernelWeighting (WeightingInterface ):
15- """
16- TODO
16+ **Original reference**: Jacot, Gabriel, Hongler, *ANeural Tangent
17+ Kernel: Convergence and Generalization in Neural Networks*.
18+ arXiv preprint arXiv:1806.07572 (2018).
19+ DOI: `arXiv:1806.07572 <https://arxiv.org/abs/1806.07572>`_.
20+
1721 """
1822
1923 def __init__ (self , model , alpha = 0.5 ):
24+ """
25+ Initialization of the :class:`NeuralTangentKernelWeighting` class.
26+
27+ :param float alpha: The alpha parameter.
28+ """
29+
2030 super ().__init__ ()
2131 check_consistency (alpha , float )
2232 check_consistency (model , Module )
@@ -30,15 +40,11 @@ def __init__(self, model, alpha=0.5):
3040 def aggregate (self , losses ):
3141 """
3242 Weights the losses according to the Neural Tangent Kernel
43+ algorithm.
3344
3445 :param dict(torch.Tensor) input: The dictionary of losses.
3546 :return: The losses aggregation. It should be a scalar Tensor.
3647 :rtype: torch.Tensor
37-
38- Reference:
39- Wang, S., Sankaran, S., Wang, H., & Perdikaris, P. (2023).
40- An expert's guide to training physics-informed neural networks.
41- arXiv preprint arXiv:2308.08468.
4248 """
4349 losses_norm = {}
4450 for condition in losses :
0 commit comments