Skip to content

Commit e717ad1

Browse files
authored
Merge pull request #85 from ErdunGAO/testpnl
Update PNL
2 parents b0a6b6f + e0ba5b1 commit e717ad1

File tree

5 files changed

+676
-2065
lines changed

5 files changed

+676
-2065
lines changed

causallearn/search/FCMBased/PNL/PNL.py

Lines changed: 39 additions & 60 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,20 @@
33
import torch.autograd as autograd
44
import torch.nn as nn
55
from scipy import stats
6+
from torch.utils.data import Dataset, DataLoader
7+
8+
class PairDataset(Dataset):
9+
10+
def __init__(self, data):
11+
super(PairDataset, self).__init__()
12+
self.data = data
13+
self.num_data = data.shape[0]
14+
15+
def __len__(self):
16+
return self.num_data
17+
18+
def __getitem__(self, index):
19+
return self.data[index, :]
620

721

822
class MLP(nn.Module):
@@ -52,7 +66,7 @@ class PNL(object):
5266
independent from the learned disturbance.
5367
"""
5468

55-
def __init__(self, epochs=100000):
69+
def __init__(self, epochs=3000):
5670
'''
5771
Construct the PNL model.
5872
@@ -62,23 +76,6 @@ def __init__(self, epochs=100000):
6276
'''
6377

6478
self.epochs = epochs
65-
66-
def dele_abnormal(self, data_x, data_y):
67-
68-
mean_x = np.mean(data_x)
69-
sigma_x = np.std(data_x)
70-
remove_idx_x = np.where(abs(data_x - mean_x) > 3*sigma_x)[0]
71-
72-
mean_y = np.mean(data_y)
73-
sigma_y = np.std(data_y)
74-
remove_idx_y = np.where(abs(data_y - mean_y) > 3*sigma_y)[0]
75-
76-
remove_idx = np.append(remove_idx_x, remove_idx_y)
77-
78-
data_x = np.delete(data_x, remove_idx)
79-
data_y = np.delete(data_y, remove_idx)
80-
81-
return data_x.reshape(len(data_x), 1), data_y.reshape(len(data_y), 1)
8279

8380
def nica_mnd(self, X, TotalEpoch):
8481
"""
@@ -93,56 +90,42 @@ def nica_mnd(self, X, TotalEpoch):
9390
---------
9491
Y (n*T): the separation result.
9592
"""
96-
trpattern = X.T
97-
98-
# --------------------------------------------------------
99-
x1 = torch.from_numpy(trpattern[0, :]).type(torch.FloatTensor).reshape(-1, 1)
100-
x2 = torch.from_numpy(trpattern[1, :]).type(torch.FloatTensor).reshape(-1, 1)
101-
x1.requires_grad = True
102-
x2.requires_grad = True
93+
X = X.astype(np.float32)
10394

104-
y1 = x1
105-
106-
Final_y2 = x2
107-
Min_loss = float('inf')
95+
train_dataset = PairDataset(X)
96+
train_loader = DataLoader(train_dataset, batch_size=128, drop_last=True)
10897

10998
G1 = MLP(1, 1, n_layers=3, n_units=12)
11099
G2 = MLP(1, 1, n_layers=1, n_units=12)
111100
optimizer = torch.optim.Adam([
112101
{'params': G1.parameters()},
113-
{'params': G2.parameters()}], lr=1e-5, betas=(0.9, 0.99))
114-
115-
loss_all = torch.zeros(TotalEpoch)
116-
loss_pdf_all = torch.zeros(TotalEpoch)
117-
loss_jacob_all = torch.zeros(TotalEpoch)
118-
119-
for epoch in range(TotalEpoch):
120-
G1.zero_grad()
121-
G2.zero_grad()
122-
123-
y2 = G2(x2) - G1(x1)
124-
125-
loss_pdf = 0.5 * torch.sum(y2**2)
102+
{'params': G2.parameters()}], lr=1e-4, betas=(0.9, 0.99))
126103

127-
jacob = autograd.grad(outputs=y2, inputs=x2, grad_outputs=torch.ones(y2.shape), create_graph=True,
128-
retain_graph=True, only_inputs=True)[0]
104+
for _ in range(TotalEpoch):
105+
optimizer.zero_grad()
106+
for x_batch in train_loader:
129107

130-
loss_jacob = - torch.sum(torch.log(torch.abs(jacob) + 1e-16))
108+
x1, x2 = x_batch[:,0].reshape(-1,1), x_batch[:,1].reshape(-1,1)
109+
x1.requires_grad = True
110+
x2.requires_grad = True
111+
112+
e = G2(x2) - G1(x1)
113+
loss_pdf = 0.5 * torch.sum(e**2)
131114

132-
loss = loss_jacob + loss_pdf
115+
jacob = autograd.grad(outputs=e, inputs=x2, grad_outputs=torch.ones(e.shape), create_graph=True,
116+
retain_graph=True, only_inputs=True)[0]
117+
loss_jacob = - torch.sum(torch.log(torch.abs(jacob) + 1e-16))
133118

134-
loss_all[epoch] = loss
135-
loss_pdf_all[epoch] = loss_pdf
136-
loss_jacob_all[epoch] = loss_jacob
119+
loss = loss_jacob + loss_pdf
137120

138-
if loss < Min_loss:
139-
Min_loss = loss
140-
Final_y2 = y2
141-
142-
loss.backward()
143-
optimizer.step()
121+
loss.backward()
122+
optimizer.step()
123+
124+
X1_all = torch.tensor(X[:, 0].reshape(-1,1))
125+
X2_all = torch.tensor(X[:, 1].reshape(-1,1))
126+
e_estimated = G2(X2_all) - G1(X1_all)
144127

145-
return y1, Final_y2
128+
return X1_all, e_estimated
146129

147130
def cause_or_effect(self, data_x, data_y):
148131
'''
@@ -159,10 +142,6 @@ def cause_or_effect(self, data_x, data_y):
159142
pval_backward: p value in the y->x direction
160143
'''
161144
torch.manual_seed(0)
162-
163-
# Delete the abnormal samples
164-
data_x, data_y = self.dele_abnormal(data_x, data_y)
165-
166145
# Now let's see if x1 -> x2 is plausible
167146
data = np.concatenate((data_x, data_y), axis=1)
168147
# print('To see if x1 -> x2...')

0 commit comments

Comments
 (0)