11"""
22Python implementation of the LiNGAM algorithms.
3+ * some slight modification for speedup, 04/26/2022
34The LiNGAM Project: https://sites.google.com/site/sshimizu06/lingam
45"""
6+ import time
57
68import numpy as np
79from scipy .stats import gamma
@@ -34,16 +36,12 @@ def get_kernel_width(X):
3436 X_med = X
3537
3638 G = np .sum (X_med * X_med , 1 ).reshape (n_samples , 1 )
37- Q = np .tile (G , (1 , n_samples ))
38- R = np .tile (G .T , (n_samples , 1 ))
39-
40- dists = Q + R - 2 * np .dot (X_med , X_med .T )
39+ dists = G + G .T - 2 * np .dot (X_med , X_med .T )
4140 dists = dists - np .tril (dists )
4241 dists = dists .reshape (n_samples ** 2 , 1 )
4342
4443 return np .sqrt (0.5 * np .median (dists [dists > 0 ]))
4544
46-
4745def _rbf_dot (X , Y , width ):
4846 """Compute the inner product of radial basis functions."""
4947 n_samples_X = X .shape [0 ]
@@ -57,6 +55,11 @@ def _rbf_dot(X, Y, width):
5755
5856 return np .exp (- H / 2 / (width ** 2 ))
5957
58+ def _rbf_dot_XX (X , width ):
59+ """rbf dot, in special case with X dot X"""
60+ G = np .sum (X * X , axis = 1 )
61+ H = G [None , :] + G [:, None ] - 2 * np .dot (X , X .T )
62+ return np .exp (- H / 2 / (width ** 2 ))
6063
6164def get_gram_matrix (X , width ):
6265 """Get the centered gram matrices.
@@ -76,11 +79,13 @@ def get_gram_matrix(X, width):
7679 the centered gram matrices.
7780 """
7881 n = X .shape [0 ]
79- H = np .eye (n ) - 1 / n * np .ones ((n , n ))
80-
81- K = _rbf_dot (X , X , width )
82- Kc = np .dot (np .dot (H , K ), H )
8382
83+ K = _rbf_dot_XX (X , width )
84+ K_colsums = K .sum (axis = 0 )
85+ K_rowsums = K .sum (axis = 1 )
86+ K_allsum = K_rowsums .sum ()
87+ Kc = K - (K_colsums [None , :] + K_rowsums [:, None ]) / n + np .ones ((n , n )) * (K_allsum / n ** 2 )
88+ # equivalent to H @ K @ H, where H = np.eye(n) - 1 / n * np.ones((n, n)).
8489 return K , Kc
8590
8691
@@ -101,7 +106,7 @@ def hsic_teststat(Kc, Lc, n):
101106 the HSIC statistic.
102107 """
103108 # test statistic m*HSICb under H1
104- return 1 / n * np .sum (np . sum ( Kc .T * Lc ) )
109+ return 1 / n * np .sum (Kc .T * Lc )
105110
106111
107112def hsic_test_gamma (X , Y , bw_method = 'mdbs' ):
@@ -148,25 +153,47 @@ def hsic_test_gamma(X, Y, bw_method='mdbs'):
148153
149154 # test statistic m*HSICb under H1
150155 n = X .shape [0 ]
151- bone = np .ones ((n , 1 ))
152156 test_stat = hsic_teststat (Kc , Lc , n )
153157
154158 var = (1 / 6 * Kc * Lc ) ** 2
155159 # second subtracted term is bias correction
156- var = 1 / n / (n - 1 ) * (np .sum (np . sum ( var )) - np .sum ( np . diag ( var ) ))
160+ var = 1 / n / (n - 1 ) * (np .sum (var ) - np .trace ( var ))
157161 # variance under H0
158162 var = 72 * (n - 4 ) * (n - 5 ) / n / (n - 1 ) / (n - 2 ) / (n - 3 ) * var
159163
160- K = K - np .diag ( np . diag ( K ))
161- L = L - np .diag ( np . diag ( L ))
162- mu_X = 1 / n / (n - 1 ) * np . dot ( bone . T , np . dot ( K , bone ) )
163- mu_Y = 1 / n / (n - 1 ) * np . dot ( bone . T , np . dot ( L , bone ) )
164+ K [ np .diag_indices ( n )] = 0
165+ L [ np .diag_indices ( n )] = 0
166+ mu_X = 1 / n / (n - 1 ) * K . sum ( )
167+ mu_Y = 1 / n / (n - 1 ) * L . sum ( )
164168 # mean under H0
165169 mean = 1 / n * (1 + mu_X * mu_Y - mu_X - mu_Y )
166170
167171 alpha = mean ** 2 / var
168172 # threshold for hsicArr*m
169- beta = np . dot ( var , n ) / mean
170- p = 1 - gamma .cdf (test_stat , alpha , scale = beta )[ 0 ][ 0 ]
173+ beta = var * n / mean
174+ p = 1 - gamma .cdf (test_stat , alpha , scale = beta )
171175
172176 return test_stat , p
177+
178+
179+ if __name__ == '__main__' :
180+ X = np .random .uniform (0 , 1 , (15000 ,))
181+ Y = X ** 2 + np .random .uniform (0 , 1 , (15000 ,))
182+ tic = time .time ()
183+ test_stat , p = hsic_test_gamma (X , Y )
184+ print (f'now used: { time .time () - tic : .5f} s' )
185+
186+ from causallearn .search .FCMBased .lingam .hsic import hsic_test_gamma as hsic_test_gamma_old
187+ tic = time .time ()
188+ test_stat_old , p_old = hsic_test_gamma_old (X , Y )
189+ print (f'originally used: { time .time () - tic : .5f} s' )
190+
191+ assert np .isclose (test_stat , test_stat_old )
192+ assert np .isclose (p , p_old )
193+ print ('equivalent test passed.' )
194+
195+ '''
196+ now used: 6.78904s
197+ originally used: 65.28648s
198+ equivalent test passed.
199+ '''
0 commit comments