@@ -108,10 +108,11 @@ def nica_mnd(self, X, TotalEpoch):
108108 x1 , x2 = x_batch [:,0 ].reshape (- 1 ,1 ), x_batch [:,1 ].reshape (- 1 ,1 )
109109 x1 .requires_grad = True
110110 x2 .requires_grad = True
111- y2 = G2 (x2 ) - G1 (x1 )
112- loss_pdf = 0.5 * torch .sum (y2 ** 2 )
111+
112+ e = G2 (x2 ) - G1 (x1 )
113+ loss_pdf = 0.5 * torch .sum (e ** 2 )
113114
114- jacob = autograd .grad (outputs = y2 , inputs = x2 , grad_outputs = torch .ones (y2 .shape ), create_graph = True ,
115+ jacob = autograd .grad (outputs = e , inputs = x2 , grad_outputs = torch .ones (e .shape ), create_graph = True ,
115116 retain_graph = True , only_inputs = True )[0 ]
116117 loss_jacob = - torch .sum (torch .log (torch .abs (jacob ) + 1e-16 ))
117118
@@ -122,9 +123,9 @@ def nica_mnd(self, X, TotalEpoch):
122123
123124 X1_all = torch .tensor (X [:, 0 ].reshape (- 1 ,1 ))
124125 X2_all = torch .tensor (X [:, 1 ].reshape (- 1 ,1 ))
125- Final_y2 = G2 (X2_all ) - G1 (X1_all )
126+ e_estimated = G2 (X2_all ) - G1 (X1_all )
126127
127- return X1_all , Final_y2
128+ return X1_all , e_estimated
128129
129130 def cause_or_effect (self , data_x , data_y ):
130131 '''
0 commit comments