Skip to content

Commit f19c866

Browse files
Erdun GaoErdun Gao
authored andcommitted
Adding the test file of PNL and updating the PNL function
1 parent 1ebf232 commit f19c866

File tree

4 files changed

+2105
-122
lines changed

4 files changed

+2105
-122
lines changed

causallearn/search/FCMBased/PNL/PNL.py

Lines changed: 59 additions & 76 deletions
Original file line numberDiff line numberDiff line change
@@ -3,17 +3,12 @@
33

44
BASE_DIR = os.path.join(os.path.dirname(__file__), '..')
55
sys.path.append(BASE_DIR)
6-
import math
76

87
import numpy as np
98
import torch
109
import torch.autograd as autograd
1110
import torch.nn as nn
12-
import torch.nn.functional as F
13-
14-
from causallearn.utils.KCI.KCI import KCI_UInd
15-
import torch.autograd as autograd
16-
import matplotlib.pyplot as plt
11+
from scipy import stats
1712

1813

1914
class MLP(nn.Module):
@@ -38,7 +33,7 @@ def __init__(self, n_inputs, n_outputs, n_layers=1, n_units=100):
3833

3934
# create layers
4035
layers = [nn.Linear(n_inputs, n_units)]
41-
for i in range(n_layers):
36+
for _ in range(n_layers):
4237
layers.append(nn.ReLU())
4338
layers.append(nn.Linear(n_units, n_units))
4439
layers.append(nn.ReLU())
@@ -49,62 +44,52 @@ def forward(self, x):
4944
x = self.layers(x)
5045
return x
5146

52-
53-
class MixGaussianLayer(nn.Module):
54-
def __init__(self, Mix_K=3):
55-
super(MixGaussianLayer, self).__init__()
56-
self.Mix_K = Mix_K
57-
self.Pi = nn.Parameter(torch.randn(self.Mix_K, 1))
58-
self.Mu = nn.Parameter(torch.randn(self.Mix_K, 1))
59-
self.Var = nn.Parameter(torch.randn(self.Mix_K, 1))
60-
61-
def forward(self, x):
62-
Constraint_Pi = F.softmax(self.Pi, 0)
63-
# -(x-u)**2/(2var**2)
64-
Middle1 = -((x.expand(len(x), self.Mix_K) - self.Mu.T.expand(len(x), self.Mix_K)).pow(2)).div(
65-
2 * (self.Var.T.expand(len(x), self.Mix_K)).pow(2))
66-
# sum Pi*Middle/var
67-
Middle2 = torch.exp(Middle1).mm(Constraint_Pi.div(torch.sqrt(2 * math.pi * self.Var.pow(2))))
68-
# log sum
69-
out = sum(torch.log(Middle2))
70-
71-
return out
72-
73-
7447
class PNL(object):
7548
"""
7649
Use of constrained nonlinear ICA for distinguishing cause from effect.
7750
Python Version 3.7
7851
PURPOSE:
7952
To find which one of xi (i=1,2) is the cause. In particular, this
8053
function does
81-
1) preprocessing to make xi rather clear to Gaussian,
54+
1) preprocessing to make xi rather close to Gaussian,
8255
2) learn the corresponding 'disturbance' under each assumed causal
8356
direction, and
8457
3) performs the independence tests to see if the assumed cause if
8558
independent from the learned disturbance.
8659
"""
8760

88-
def __init__(self, kernelX='Gaussian', kernelY='Gaussian', mix_K=3, epochs=100000):
61+
def __init__(self, epochs=100000):
8962
'''
90-
Construct the ANM model.
63+
Construct the PNL model.
9164
9265
Parameters:
9366
----------
94-
kernelX: kernel function for hypothetical cause
95-
kernelY: kernel function for estimated noise
96-
mix_K: number of Gaussian mixtures for independent components
9767
epochs: training epochs.
9868
'''
99-
self.kernelX = kernelX
100-
self.kernelY = kernelY
101-
self.mix_K = mix_K
69+
10270
self.epochs = epochs
71+
72+
def dele_abnormal(self, data_x, data_y):
73+
74+
mean_x = np.mean(data_x)
75+
sigma_x = np.std(data_x)
76+
remove_idx_x = np.where(abs(data_x - mean_x) > 3*sigma_x)[0]
77+
78+
mean_y = np.mean(data_y)
79+
sigma_y = np.std(data_y)
80+
remove_idx_y = np.where(abs(data_y - mean_y) > 3*sigma_y)[0]
81+
82+
remove_idx = np.append(remove_idx_x, remove_idx_y)
83+
84+
data_x = np.delete(data_x, remove_idx)
85+
data_y = np.delete(data_y, remove_idx)
10386

104-
def nica_mnd(self, X, TotalEpoch, KofMix):
87+
return data_x.reshape(len(data_x), 1), data_y.reshape(len(data_y), 1)
88+
89+
def nica_mnd(self, X, TotalEpoch):
10590
"""
106-
Use of "Nonlinear ICA with MND for Matlab" for distinguishing cause from effect
107-
PURPOSE: Performing nonlinear ICA with the minimal nonlinear distortion or smoothness regularization.
91+
Use of "Nonlinear ICA" for distinguishing cause from effect
92+
PURPOSE: Performing nonlinear ICA.
10893
10994
Parameters
11095
----------
@@ -115,56 +100,54 @@ def nica_mnd(self, X, TotalEpoch, KofMix):
115100
Y (n*T): the separation result.
116101
"""
117102
trpattern = X.T
118-
trpattern = trpattern - np.tile(np.mean(trpattern, axis=1).reshape(2, 1), (1, len(trpattern[0])))
119-
trpattern = np.dot(np.diag(1.5 / np.std(trpattern, axis=1)), trpattern)
103+
120104
# --------------------------------------------------------
121105
x1 = torch.from_numpy(trpattern[0, :]).type(torch.FloatTensor).reshape(-1, 1)
122106
x2 = torch.from_numpy(trpattern[1, :]).type(torch.FloatTensor).reshape(-1, 1)
123107
x1.requires_grad = True
124108
x2.requires_grad = True
109+
125110
y1 = x1
126111

127112
Final_y2 = x2
128113
Min_loss = float('inf')
129114

130-
G1 = MLP(1, 1, n_layers=1, n_units=20)
131-
G2 = MLP(1, 1, n_layers=1, n_units=20)
132-
# MixGaussian = MixGaussianLayer(Mix_K=KofMix)
133-
G3 = MLP(1, 1, n_layers=1, n_units=20)
115+
G1 = MLP(1, 1, n_layers=3, n_units=12)
116+
G2 = MLP(1, 1, n_layers=1, n_units=12)
134117
optimizer = torch.optim.Adam([
135118
{'params': G1.parameters()},
136-
{'params': G2.parameters()},
137-
{'params': G3.parameters()}], lr=1e-4, betas=(0.9, 0.99))
119+
{'params': G2.parameters()}], lr=1e-5, betas=(0.9, 0.99))
120+
121+
loss_all = torch.zeros(TotalEpoch)
122+
loss_pdf_all = torch.zeros(TotalEpoch)
123+
loss_jacob_all = torch.zeros(TotalEpoch)
138124

139125
for epoch in range(TotalEpoch):
126+
G1.zero_grad()
127+
G2.zero_grad()
140128

141129
y2 = G2(x2) - G1(x1)
142-
# y2 = x2 - G1(x1)
143130

144-
loss_pdf = torch.sum((y2)**2)
131+
loss_pdf = 0.5 * torch.sum(y2**2)
145132

146-
jacob = autograd.grad(outputs=G2(x2), inputs=x2, grad_outputs=torch.ones(y2.shape), create_graph=True,
133+
jacob = autograd.grad(outputs=y2, inputs=x2, grad_outputs=torch.ones(y2.shape), create_graph=True,
147134
retain_graph=True, only_inputs=True)[0]
135+
148136
loss_jacob = - torch.sum(torch.log(torch.abs(jacob) + 1e-16))
149137

150138
loss = loss_jacob + loss_pdf
151139

140+
loss_all[epoch] = loss
141+
loss_pdf_all[epoch] = loss_pdf
142+
loss_jacob_all[epoch] = loss_jacob
143+
152144
if loss < Min_loss:
153145
Min_loss = loss
154146
Final_y2 = y2
155-
156-
if epoch % 100 == 0:
157-
print('---------------------------{}-th epoch-------------------------------'.format(epoch))
158-
print('The Total loss is {}'.format(loss))
159-
print('The jacob loss is {}'.format(loss_jacob))
160-
print('The pdf loss is {}'.format(loss_pdf))
161-
162-
optimizer.zero_grad()
163-
loss.backward(retain_graph=True)
147+
148+
loss.backward()
164149
optimizer.step()
165-
plt.plot(x1.detach().numpy(), G1(x1).detach().numpy(), '.')
166-
plt.plot(x2.detach().numpy(), G2(x2).detach().numpy(),'.')
167-
plt.show()
150+
168151
return y1, Final_y2
169152

170153
def cause_or_effect(self, data_x, data_y):
@@ -181,28 +164,28 @@ def cause_or_effect(self, data_x, data_y):
181164
pval_forward: p value in the x->y direction
182165
pval_backward: p value in the y->x direction
183166
'''
167+
torch.manual_seed(0)
184168

185-
raise SyntaxError('There are some potential issues in the current implementation of PNL. We are working on them and will update as soon as possible.')
186-
187-
kci = KCI_UInd(self.kernelX, self.kernelY)
169+
# Delete the abnormal samples
170+
data_x, data_y = self.dele_abnormal(data_x, data_y)
188171

189172
# Now let's see if x1 -> x2 is plausible
190173
data = np.concatenate((data_x, data_y), axis=1)
191-
y1, y2 = self.nica_mnd(data, self.epochs, self.mix_K)
192-
print('To see if x1 -> x2...')
174+
# print('To see if x1 -> x2...')
175+
y1, y2 = self.nica_mnd(data, self.epochs)
193176

194177
y1_np = y1.detach().numpy()
195178
y2_np = y2.detach().numpy()
196179

197-
pval_foward, _ = kci.compute_pvalue(y1_np, y2_np)
180+
_, pval_forward = stats.ttest_ind(y1_np, y2_np)
198181

199182
# Now let's see if x2 -> x1 is plausible
200-
y1, y2 = self.nica_mnd(data[:, [1, 0]], self.epochs, self.mix_K)
201-
print('To see if x2 -> x1...')
202-
183+
# print('To see if x2 -> x1...')
184+
y1, y2 = self.nica_mnd(data[:, [1, 0]], self.epochs)
185+
203186
y1_np = y1.detach().numpy()
204187
y2_np = y2.detach().numpy()
205188

206-
pval_backward, _ = kci.compute_pvalue(y1_np, y2_np)
207-
208-
return pval_foward, pval_backward
189+
_, pval_backward = stats.ttest_ind(y1_np, y2_np)
190+
191+
return np.round(pval_forward, 3), np.round(pval_backward, 3)

0 commit comments

Comments
 (0)