Skip to content

Commit e0ba5b1

Browse files
committed
update pnl
1 parent 58e83d8 commit e0ba5b1

File tree

1 file changed

+6
-5
lines changed
  • causallearn/search/FCMBased/PNL

1 file changed

+6
-5
lines changed

causallearn/search/FCMBased/PNL/PNL.py

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -108,10 +108,11 @@ def nica_mnd(self, X, TotalEpoch):
108108
x1, x2 = x_batch[:,0].reshape(-1,1), x_batch[:,1].reshape(-1,1)
109109
x1.requires_grad = True
110110
x2.requires_grad = True
111-
y2 = G2(x2) - G1(x1)
112-
loss_pdf = 0.5 * torch.sum(y2**2)
111+
112+
e = G2(x2) - G1(x1)
113+
loss_pdf = 0.5 * torch.sum(e**2)
113114

114-
jacob = autograd.grad(outputs=y2, inputs=x2, grad_outputs=torch.ones(y2.shape), create_graph=True,
115+
jacob = autograd.grad(outputs=e, inputs=x2, grad_outputs=torch.ones(e.shape), create_graph=True,
115116
retain_graph=True, only_inputs=True)[0]
116117
loss_jacob = - torch.sum(torch.log(torch.abs(jacob) + 1e-16))
117118

@@ -122,9 +123,9 @@ def nica_mnd(self, X, TotalEpoch):
122123

123124
X1_all = torch.tensor(X[:, 0].reshape(-1,1))
124125
X2_all = torch.tensor(X[:, 1].reshape(-1,1))
125-
Final_y2 = G2(X2_all) - G1(X1_all)
126+
e_estimated = G2(X2_all) - G1(X1_all)
126127

127-
return X1_all, Final_y2
128+
return X1_all, e_estimated
128129

129130
def cause_or_effect(self, data_x, data_y):
130131
'''

0 commit comments

Comments
 (0)