33import torch .autograd as autograd
44import torch .nn as nn
55from scipy import stats
6+ from torch .utils .data import Dataset , DataLoader
7+
8+ class PairDataset (Dataset ):
9+
10+ def __init__ (self , data ):
11+ super (PairDataset , self ).__init__ ()
12+ self .data = data
13+ self .num_data = data .shape [0 ]
14+
15+ def __len__ (self ):
16+ return self .num_data
17+
18+ def __getitem__ (self , index ):
19+ return self .data [index , :]
620
721
822class MLP (nn .Module ):
@@ -52,7 +66,7 @@ class PNL(object):
5266 independent from the learned disturbance.
5367 """
5468
55- def __init__ (self , epochs = 100000 ):
69+ def __init__ (self , epochs = 3000 ):
5670 '''
5771 Construct the PNL model.
5872
@@ -62,23 +76,6 @@ def __init__(self, epochs=100000):
6276 '''
6377
6478 self .epochs = epochs
65-
66- def dele_abnormal (self , data_x , data_y ):
67-
68- mean_x = np .mean (data_x )
69- sigma_x = np .std (data_x )
70- remove_idx_x = np .where (abs (data_x - mean_x ) > 3 * sigma_x )[0 ]
71-
72- mean_y = np .mean (data_y )
73- sigma_y = np .std (data_y )
74- remove_idx_y = np .where (abs (data_y - mean_y ) > 3 * sigma_y )[0 ]
75-
76- remove_idx = np .append (remove_idx_x , remove_idx_y )
77-
78- data_x = np .delete (data_x , remove_idx )
79- data_y = np .delete (data_y , remove_idx )
80-
81- return data_x .reshape (len (data_x ), 1 ), data_y .reshape (len (data_y ), 1 )
8279
8380 def nica_mnd (self , X , TotalEpoch ):
8481 """
@@ -93,56 +90,41 @@ def nica_mnd(self, X, TotalEpoch):
9390 ---------
9491 Y (n*T): the separation result.
9592 """
96- trpattern = X .T
93+ X = X .astype ( np . float32 )
9794
98- # --------------------------------------------------------
99- x1 = torch .from_numpy (trpattern [0 , :]).type (torch .FloatTensor ).reshape (- 1 , 1 )
100- x2 = torch .from_numpy (trpattern [1 , :]).type (torch .FloatTensor ).reshape (- 1 , 1 )
101- x1 .requires_grad = True
102- x2 .requires_grad = True
103-
104- y1 = x1
105-
106- Final_y2 = x2
107- Min_loss = float ('inf' )
95+ train_dataset = PairDataset (X )
96+ train_loader = DataLoader (train_dataset , batch_size = 128 , drop_last = True )
10897
10998 G1 = MLP (1 , 1 , n_layers = 3 , n_units = 12 )
11099 G2 = MLP (1 , 1 , n_layers = 1 , n_units = 12 )
111100 optimizer = torch .optim .Adam ([
112101 {'params' : G1 .parameters ()},
113- {'params' : G2 .parameters ()}], lr = 1e-5 , betas = (0.9 , 0.99 ))
114-
115- loss_all = torch .zeros (TotalEpoch )
116- loss_pdf_all = torch .zeros (TotalEpoch )
117- loss_jacob_all = torch .zeros (TotalEpoch )
118-
119- for epoch in range (TotalEpoch ):
120- G1 .zero_grad ()
121- G2 .zero_grad ()
122-
123- y2 = G2 (x2 ) - G1 (x1 )
102+ {'params' : G2 .parameters ()}], lr = 1e-4 , betas = (0.9 , 0.99 ))
124103
125- loss_pdf = 0.5 * torch .sum (y2 ** 2 )
104+ for _ in range (TotalEpoch ):
105+ optimizer .zero_grad ()
106+ for x_batch in train_loader :
126107
127- jacob = autograd .grad (outputs = y2 , inputs = x2 , grad_outputs = torch .ones (y2 .shape ), create_graph = True ,
128- retain_graph = True , only_inputs = True )[0 ]
108+ x1 , x2 = x_batch [:,0 ].reshape (- 1 ,1 ), x_batch [:,1 ].reshape (- 1 ,1 )
109+ x1 .requires_grad = True
110+ x2 .requires_grad = True
111+ y2 = G2 (x2 ) - G1 (x1 )
112+ loss_pdf = 0.5 * torch .sum (y2 ** 2 )
129113
130- loss_jacob = - torch .sum (torch .log (torch .abs (jacob ) + 1e-16 ))
114+ jacob = autograd .grad (outputs = y2 , inputs = x2 , grad_outputs = torch .ones (y2 .shape ), create_graph = True ,
115+ retain_graph = True , only_inputs = True )[0 ]
116+ loss_jacob = - torch .sum (torch .log (torch .abs (jacob ) + 1e-16 ))
131117
132- loss = loss_jacob + loss_pdf
118+ loss = loss_jacob + loss_pdf
133119
134- loss_all [epoch ] = loss
135- loss_pdf_all [epoch ] = loss_pdf
136- loss_jacob_all [epoch ] = loss_jacob
137-
138- if loss < Min_loss :
139- Min_loss = loss
140- Final_y2 = y2
141-
142- loss .backward ()
143- optimizer .step ()
120+ loss .backward ()
121+ optimizer .step ()
122+
123+ X1_all = torch .tensor (X [:, 0 ].reshape (- 1 ,1 ))
124+ X2_all = torch .tensor (X [:, 1 ].reshape (- 1 ,1 ))
125+ Final_y2 = G2 (X2_all ) - G1 (X1_all )
144126
145- return y1 , Final_y2
127+ return X1_all , Final_y2
146128
147129 def cause_or_effect (self , data_x , data_y ):
148130 '''
@@ -159,10 +141,6 @@ def cause_or_effect(self, data_x, data_y):
159141 pval_backward: p value in the y->x direction
160142 '''
161143 torch .manual_seed (0 )
162-
163- # Delete the abnormal samples
164- data_x , data_y = self .dele_abnormal (data_x , data_y )
165-
166144 # Now let's see if x1 -> x2 is plausible
167145 data = np .concatenate ((data_x , data_y ), axis = 1 )
168146 # print('To see if x1 -> x2...')
0 commit comments