1- """ Module for Neural Tangent Kernel Class """
1+ """Module for Neural Tangent Kernel Class"""
22
3+ import torch
34from torch .nn import Module
4- from torch import norm ,cat
55from .weighting_interface import WeightingInterface
66from ..operator import grad
77from ..utils import check_consistency
88
9- class Neural_Tangent_Kernel (WeightingInterface ):
9+
10+ class NeuralTangetKernelWeighting (WeightingInterface ):
1011 """
1112 TODO
1213 """
14+
1315 def __init__ (self , model , alpha = 0.5 ):
1416 super ().__init__ ()
1517 check_consistency (alpha , float )
16-
17- if alpha < 0 or alpha > 1 :
18+ check_consistency ( model , Module )
19+ if alpha < 0 or alpha > 1 :
1820 raise ValueError ("alpha should be a value between 0 and 1" )
19- if not isinstance (model , Module ):
20- raise TypeError ("Please pass a valid torch.nn.Module model" )
2121 self .alpha = alpha
2222 self .model = model
2323 self .weights = {}
2424 self .default_value_weights = 1
2525
2626 def aggregate (self , losses ):
2727 """
28- Weights the losses according to the Neural Tangent Kernel
28+ Weights the losses according to the Neural Tangent Kernel
2929
3030 :param dict(torch.Tensor) input: The dictionary of losses.
3131 :param alpha(float) input: The parameter alpha that regulates the moving average
32- between old and new weights.
32+ between old and new weights.
3333 :return: The losses aggregation. It should be a scalar Tensor.
3434 :rtype: torch.Tensor
3535
36- Reference: TODO
37- Wang, S., Sankaran, S., Wang, H., & Perdikaris, P. (2023).
38- An expert's guide to training physics-informed neural networks.
36+ Reference:
37+ Wang, S., Sankaran, S., Wang, H., & Perdikaris, P. (2023).
38+ An expert's guide to training physics-informed neural networks.
3939 arXiv preprint arXiv:2308.08468.
40- """
40+ """
4141 losses_norm = {}
4242 for condition in losses :
4343 losses [condition ].backward (retain_graph = True )
44- grads = []
44+ grads = []
4545 for param in self .model .parameters ():
4646 grads .append (param .grad .view (- 1 ))
47- grads = cat (grads )
48-
49- print (grads )
50- losses_norm [condition ]= norm (grads )
47+ grads = torch .cat (grads )
48+ losses_norm [condition ] = torch .norm (grads )
5149 self .weights = {
52- condition :
53- self .alpha * self .weights .get (
54- condition , self .default_value_weights ) +
55- (1 - self .alpha )* losses_norm [condition ]/ sum (losses_norm .values ())
50+ condition : self .alpha
51+ * self .weights .get (condition , self .default_value_weights )
52+ + (1 - self .alpha )
53+ * losses_norm [condition ]
54+ / torch .sum (losses_norm .values ())
5655 for condition in losses
5756 }
5857 return sum (
59- self .weights .get (condition , self .default_value_weights ) * loss for
60- condition , loss in losses .items ()
61- )
58+ self .weights [condition ] * loss for condition , loss in losses .items ()
59+ )
0 commit comments