33
44BASE_DIR = os .path .join (os .path .dirname (__file__ ), '..' )
55sys .path .append (BASE_DIR )
6- import math
76
87import numpy as np
98import torch
109import torch .autograd as autograd
1110import torch .nn as nn
12- import torch .nn .functional as F
13-
14- from causallearn .utils .KCI .KCI import KCI_UInd
15- import torch .autograd as autograd
16- import matplotlib .pyplot as plt
11+ from scipy import stats
1712
1813
1914class MLP (nn .Module ):
@@ -38,7 +33,7 @@ def __init__(self, n_inputs, n_outputs, n_layers=1, n_units=100):
3833
3934 # create layers
4035 layers = [nn .Linear (n_inputs , n_units )]
41- for i in range (n_layers ):
36+ for _ in range (n_layers ):
4237 layers .append (nn .ReLU ())
4338 layers .append (nn .Linear (n_units , n_units ))
4439 layers .append (nn .ReLU ())
@@ -49,62 +44,52 @@ def forward(self, x):
4944 x = self .layers (x )
5045 return x
5146
52-
53- class MixGaussianLayer (nn .Module ):
54- def __init__ (self , Mix_K = 3 ):
55- super (MixGaussianLayer , self ).__init__ ()
56- self .Mix_K = Mix_K
57- self .Pi = nn .Parameter (torch .randn (self .Mix_K , 1 ))
58- self .Mu = nn .Parameter (torch .randn (self .Mix_K , 1 ))
59- self .Var = nn .Parameter (torch .randn (self .Mix_K , 1 ))
60-
61- def forward (self , x ):
62- Constraint_Pi = F .softmax (self .Pi , 0 )
63- # -(x-u)**2/(2var**2)
64- Middle1 = - ((x .expand (len (x ), self .Mix_K ) - self .Mu .T .expand (len (x ), self .Mix_K )).pow (2 )).div (
65- 2 * (self .Var .T .expand (len (x ), self .Mix_K )).pow (2 ))
66- # sum Pi*Middle/var
67- Middle2 = torch .exp (Middle1 ).mm (Constraint_Pi .div (torch .sqrt (2 * math .pi * self .Var .pow (2 ))))
68- # log sum
69- out = sum (torch .log (Middle2 ))
70-
71- return out
72-
73-
7447class PNL (object ):
7548 """
7649 Use of constrained nonlinear ICA for distinguishing cause from effect.
7750 Python Version 3.7
7851 PURPOSE:
7952 To find which one of xi (i=1,2) is the cause. In particular, this
8053 function does
81- 1) preprocessing to make xi rather clear to Gaussian,
54+ 1) preprocessing to make xi rather close to Gaussian,
8255 2) learn the corresponding 'disturbance' under each assumed causal
8356 direction, and
8457 3) performs the independence tests to see if the assumed cause if
8558 independent from the learned disturbance.
8659 """
8760
88- def __init__ (self , kernelX = 'Gaussian' , kernelY = 'Gaussian' , mix_K = 3 , epochs = 100000 ):
61+ def __init__ (self , epochs = 100000 ):
8962 '''
90- Construct the ANM model.
63+ Construct the PNL model.
9164
9265 Parameters:
9366 ----------
94- kernelX: kernel function for hypothetical cause
95- kernelY: kernel function for estimated noise
96- mix_K: number of Gaussian mixtures for independent components
9767 epochs: training epochs.
9868 '''
99- self .kernelX = kernelX
100- self .kernelY = kernelY
101- self .mix_K = mix_K
69+
10270 self .epochs = epochs
71+
72+ def dele_abnormal (self , data_x , data_y ):
73+
74+ mean_x = np .mean (data_x )
75+ sigma_x = np .std (data_x )
76+ remove_idx_x = np .where (abs (data_x - mean_x ) > 3 * sigma_x )[0 ]
77+
78+ mean_y = np .mean (data_y )
79+ sigma_y = np .std (data_y )
80+ remove_idx_y = np .where (abs (data_y - mean_y ) > 3 * sigma_y )[0 ]
81+
82+ remove_idx = np .append (remove_idx_x , remove_idx_y )
83+
84+ data_x = np .delete (data_x , remove_idx )
85+ data_y = np .delete (data_y , remove_idx )
10386
104- def nica_mnd (self , X , TotalEpoch , KofMix ):
87+ return data_x .reshape (len (data_x ), 1 ), data_y .reshape (len (data_y ), 1 )
88+
89+ def nica_mnd (self , X , TotalEpoch ):
10590 """
106- Use of "Nonlinear ICA with MND for Matlab " for distinguishing cause from effect
107- PURPOSE: Performing nonlinear ICA with the minimal nonlinear distortion or smoothness regularization .
91+ Use of "Nonlinear ICA" for distinguishing cause from effect
92+ PURPOSE: Performing nonlinear ICA.
10893
10994 Parameters
11095 ----------
@@ -115,56 +100,54 @@ def nica_mnd(self, X, TotalEpoch, KofMix):
115100 Y (n*T): the separation result.
116101 """
117102 trpattern = X .T
118- trpattern = trpattern - np .tile (np .mean (trpattern , axis = 1 ).reshape (2 , 1 ), (1 , len (trpattern [0 ])))
119- trpattern = np .dot (np .diag (1.5 / np .std (trpattern , axis = 1 )), trpattern )
103+
120104 # --------------------------------------------------------
121105 x1 = torch .from_numpy (trpattern [0 , :]).type (torch .FloatTensor ).reshape (- 1 , 1 )
122106 x2 = torch .from_numpy (trpattern [1 , :]).type (torch .FloatTensor ).reshape (- 1 , 1 )
123107 x1 .requires_grad = True
124108 x2 .requires_grad = True
109+
125110 y1 = x1
126111
127112 Final_y2 = x2
128113 Min_loss = float ('inf' )
129114
130- G1 = MLP (1 , 1 , n_layers = 1 , n_units = 20 )
131- G2 = MLP (1 , 1 , n_layers = 1 , n_units = 20 )
132- # MixGaussian = MixGaussianLayer(Mix_K=KofMix)
133- G3 = MLP (1 , 1 , n_layers = 1 , n_units = 20 )
115+ G1 = MLP (1 , 1 , n_layers = 3 , n_units = 12 )
116+ G2 = MLP (1 , 1 , n_layers = 1 , n_units = 12 )
134117 optimizer = torch .optim .Adam ([
135118 {'params' : G1 .parameters ()},
136- {'params' : G2 .parameters ()},
137- {'params' : G3 .parameters ()}], lr = 1e-4 , betas = (0.9 , 0.99 ))
119+ {'params' : G2 .parameters ()}], lr = 1e-5 , betas = (0.9 , 0.99 ))
120+
121+ loss_all = torch .zeros (TotalEpoch )
122+ loss_pdf_all = torch .zeros (TotalEpoch )
123+ loss_jacob_all = torch .zeros (TotalEpoch )
138124
139125 for epoch in range (TotalEpoch ):
126+ G1 .zero_grad ()
127+ G2 .zero_grad ()
140128
141129 y2 = G2 (x2 ) - G1 (x1 )
142- # y2 = x2 - G1(x1)
143130
144- loss_pdf = torch .sum (( y2 ) ** 2 )
131+ loss_pdf = 0.5 * torch .sum (y2 ** 2 )
145132
146- jacob = autograd .grad (outputs = G2 ( x2 ) , inputs = x2 , grad_outputs = torch .ones (y2 .shape ), create_graph = True ,
133+ jacob = autograd .grad (outputs = y2 , inputs = x2 , grad_outputs = torch .ones (y2 .shape ), create_graph = True ,
147134 retain_graph = True , only_inputs = True )[0 ]
135+
148136 loss_jacob = - torch .sum (torch .log (torch .abs (jacob ) + 1e-16 ))
149137
150138 loss = loss_jacob + loss_pdf
151139
140+ loss_all [epoch ] = loss
141+ loss_pdf_all [epoch ] = loss_pdf
142+ loss_jacob_all [epoch ] = loss_jacob
143+
152144 if loss < Min_loss :
153145 Min_loss = loss
154146 Final_y2 = y2
155-
156- if epoch % 100 == 0 :
157- print ('---------------------------{}-th epoch-------------------------------' .format (epoch ))
158- print ('The Total loss is {}' .format (loss ))
159- print ('The jacob loss is {}' .format (loss_jacob ))
160- print ('The pdf loss is {}' .format (loss_pdf ))
161-
162- optimizer .zero_grad ()
163- loss .backward (retain_graph = True )
147+
148+ loss .backward ()
164149 optimizer .step ()
165- plt .plot (x1 .detach ().numpy (), G1 (x1 ).detach ().numpy (), '.' )
166- plt .plot (x2 .detach ().numpy (), G2 (x2 ).detach ().numpy (),'.' )
167- plt .show ()
150+
168151 return y1 , Final_y2
169152
170153 def cause_or_effect (self , data_x , data_y ):
@@ -181,28 +164,28 @@ def cause_or_effect(self, data_x, data_y):
181164 pval_forward: p value in the x->y direction
182165 pval_backward: p value in the y->x direction
183166 '''
167+ torch .manual_seed (0 )
184168
185- raise SyntaxError ('There are some potential issues in the current implementation of PNL. We are working on them and will update as soon as possible.' )
186-
187- kci = KCI_UInd (self .kernelX , self .kernelY )
169+ # Delete the abnormal samples
170+ data_x , data_y = self .dele_abnormal (data_x , data_y )
188171
189172 # Now let's see if x1 -> x2 is plausible
190173 data = np .concatenate ((data_x , data_y ), axis = 1 )
191- y1 , y2 = self . nica_mnd ( data , self . epochs , self . mix_K )
192- print ( 'To see if x1 -> x2...' )
174+ # print('To see if x1 -> x2...' )
175+ y1 , y2 = self . nica_mnd ( data , self . epochs )
193176
194177 y1_np = y1 .detach ().numpy ()
195178 y2_np = y2 .detach ().numpy ()
196179
197- pval_foward , _ = kci . compute_pvalue (y1_np , y2_np )
180+ _ , pval_forward = stats . ttest_ind (y1_np , y2_np )
198181
199182 # Now let's see if x2 -> x1 is plausible
200- y1 , y2 = self . nica_mnd ( data [:, [ 1 , 0 ]], self . epochs , self . mix_K )
201- print ( 'To see if x2 -> x1...' )
202-
183+ # print('To see if x2 -> x1...' )
184+ y1 , y2 = self . nica_mnd ( data [:, [ 1 , 0 ]], self . epochs )
185+
203186 y1_np = y1 .detach ().numpy ()
204187 y2_np = y2 .detach ().numpy ()
205188
206- pval_backward , _ = kci . compute_pvalue (y1_np , y2_np )
207-
208- return pval_foward , pval_backward
189+ _ , pval_backward = stats . ttest_ind (y1_np , y2_np )
190+
191+ return np . round ( pval_forward , 3 ), np . round ( pval_backward , 3 )
0 commit comments