Skip to content

Commit 2a4546b

Browse files
committed
2 parents b5dd1d3 + b3fab0c commit 2a4546b

File tree

2 files changed

+47
-21
lines changed

2 files changed

+47
-21
lines changed

causallearn/search/FCMBased/lingam/hsic.py

Lines changed: 45 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,9 @@
11
"""
22
Python implementation of the LiNGAM algorithms.
3+
* some slight modification for speedup, 04/26/2022
34
The LiNGAM Project: https://sites.google.com/site/sshimizu06/lingam
45
"""
6+
import time
57

68
import numpy as np
79
from scipy.stats import gamma
@@ -34,16 +36,12 @@ def get_kernel_width(X):
3436
X_med = X
3537

3638
G = np.sum(X_med * X_med, 1).reshape(n_samples, 1)
37-
Q = np.tile(G, (1, n_samples))
38-
R = np.tile(G.T, (n_samples, 1))
39-
40-
dists = Q + R - 2 * np.dot(X_med, X_med.T)
39+
dists = G + G.T - 2 * np.dot(X_med, X_med.T)
4140
dists = dists - np.tril(dists)
4241
dists = dists.reshape(n_samples ** 2, 1)
4342

4443
return np.sqrt(0.5 * np.median(dists[dists > 0]))
4544

46-
4745
def _rbf_dot(X, Y, width):
4846
"""Compute the inner product of radial basis functions."""
4947
n_samples_X = X.shape[0]
@@ -57,6 +55,11 @@ def _rbf_dot(X, Y, width):
5755

5856
return np.exp(-H / 2 / (width ** 2))
5957

58+
def _rbf_dot_XX(X, width):
59+
"""rbf dot, in special case with X dot X"""
60+
G = np.sum(X * X, axis=1)
61+
H = G[None, :] + G[:, None] - 2 * np.dot(X, X.T)
62+
return np.exp(-H / 2 / (width ** 2))
6063

6164
def get_gram_matrix(X, width):
6265
"""Get the centered gram matrices.
@@ -76,11 +79,13 @@ def get_gram_matrix(X, width):
7679
the centered gram matrices.
7780
"""
7881
n = X.shape[0]
79-
H = np.eye(n) - 1 / n * np.ones((n, n))
80-
81-
K = _rbf_dot(X, X, width)
82-
Kc = np.dot(np.dot(H, K), H)
8382

83+
K = _rbf_dot_XX(X, width)
84+
K_colsums = K.sum(axis=0)
85+
K_rowsums = K.sum(axis=1)
86+
K_allsum = K_rowsums.sum()
87+
Kc = K - (K_colsums[None, :] + K_rowsums[:, None]) / n + np.ones((n, n)) * (K_allsum / n ** 2)
88+
# equivalent to H @ K @ H, where H = np.eye(n) - 1 / n * np.ones((n, n)).
8489
return K, Kc
8590

8691

@@ -101,7 +106,7 @@ def hsic_teststat(Kc, Lc, n):
101106
the HSIC statistic.
102107
"""
103108
# test statistic m*HSICb under H1
104-
return 1 / n * np.sum(np.sum(Kc.T * Lc))
109+
return 1 / n * np.sum(Kc.T * Lc)
105110

106111

107112
def hsic_test_gamma(X, Y, bw_method='mdbs'):
@@ -148,25 +153,47 @@ def hsic_test_gamma(X, Y, bw_method='mdbs'):
148153

149154
# test statistic m*HSICb under H1
150155
n = X.shape[0]
151-
bone = np.ones((n, 1))
152156
test_stat = hsic_teststat(Kc, Lc, n)
153157

154158
var = (1 / 6 * Kc * Lc) ** 2
155159
# second subtracted term is bias correction
156-
var = 1 / n / (n - 1) * (np.sum(np.sum(var)) - np.sum(np.diag(var)))
160+
var = 1 / n / (n - 1) * (np.sum(var) - np.trace(var))
157161
# variance under H0
158162
var = 72 * (n - 4) * (n - 5) / n / (n - 1) / (n - 2) / (n - 3) * var
159163

160-
K = K - np.diag(np.diag(K))
161-
L = L - np.diag(np.diag(L))
162-
mu_X = 1 / n / (n - 1) * np.dot(bone.T, np.dot(K, bone))
163-
mu_Y = 1 / n / (n - 1) * np.dot(bone.T, np.dot(L, bone))
164+
K[np.diag_indices(n)] = 0
165+
L[np.diag_indices(n)] = 0
166+
mu_X = 1 / n / (n - 1) * K.sum()
167+
mu_Y = 1 / n / (n - 1) * L.sum()
164168
# mean under H0
165169
mean = 1 / n * (1 + mu_X * mu_Y - mu_X - mu_Y)
166170

167171
alpha = mean ** 2 / var
168172
# threshold for hsicArr*m
169-
beta = np.dot(var, n) / mean
170-
p = 1 - gamma.cdf(test_stat, alpha, scale=beta)[0][0]
173+
beta = var * n / mean
174+
p = 1 - gamma.cdf(test_stat, alpha, scale=beta)
171175

172176
return test_stat, p
177+
178+
179+
if __name__ == '__main__':
180+
X = np.random.uniform(0, 1, (15000,))
181+
Y = X ** 2 + np.random.uniform(0, 1, (15000,))
182+
tic = time.time()
183+
test_stat, p = hsic_test_gamma(X, Y)
184+
print(f'now used: {time.time() - tic: .5f}s')
185+
186+
from causallearn.search.FCMBased.lingam.hsic import hsic_test_gamma as hsic_test_gamma_old
187+
tic = time.time()
188+
test_stat_old, p_old = hsic_test_gamma_old(X, Y)
189+
print(f'originally used: {time.time() - tic: .5f}s')
190+
191+
assert np.isclose(test_stat, test_stat_old)
192+
assert np.isclose(p, p_old)
193+
print('equivalent test passed.')
194+
195+
'''
196+
now used: 6.78904s
197+
originally used: 65.28648s
198+
equivalent test passed.
199+
'''

causallearn/search/HiddenCausal/GIN/GIN.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -23,9 +23,8 @@
2323

2424
def fisher_test(pvals):
2525
pvals = [pval if pval >= 1e-5 else 1e-5 for pval in pvals]
26-
return min(pvals)
27-
# fisher_stat = -2.0 * np.sum(np.log(pvals))
28-
# return 1 - chi2.cdf(fisher_stat, 2 * len(pvals))
26+
fisher_stat = -2.0 * np.sum(np.log(pvals))
27+
return 1 - chi2.cdf(fisher_stat, 2 * len(pvals))
2928

3029

3130
def GIN(data, indep_test=kci, alpha=0.05):

0 commit comments

Comments
 (0)