Skip to content

Commit fba553f

Browse files
committed
linting
Signed-off-by: Oliver Schacht <[email protected]>
1 parent e83e4f0 commit fba553f

File tree

1 file changed

+55
-55
lines changed

1 file changed

+55
-55
lines changed

causallearn/utils/FastKCI/FastKCI.py

Lines changed: 55 additions & 55 deletions
Original file line numberDiff line numberDiff line change
@@ -9,19 +9,21 @@
99
from sklearn.gaussian_process.kernels import ConstantKernel, WhiteKernel, RBF
1010
import warnings
1111

12+
1213
class FastKCI_CInd(object):
1314
"""
14-
Python implementation of as speed-up version of the Kernel-based Conditional Independence (KCI) test. Unconditional version.
15+
Python implementation of as speed-up version of the Kernel-based Conditional Independence (KCI) test.
16+
Unconditional version.
1517
1618
References
1719
----------
1820
[1] K. Zhang, J. Peters, D. Janzing, and B. Schölkopf,
19-
"A kernel-based conditional independence test and application in causal discovery," In UAI 2011.
21+
"A kernel-based conditional independence test and application in causal discovery", In UAI 2011.
2022
[2] M. Zhang, and S. Williamson,
21-
"Embarrassingly Parallel Inference for Gaussian Processes" In JMLR 20 (2019)
23+
"Embarrassingly Parallel Inference for Gaussian Processes", In JMLR 20 (2019)
2224
23-
"""
24-
def __init__(self, K=10, J=8, alpha=500, epsilon=1e-3, eig_thresh = 1e-6, trimming_thresh = 1e-3, use_gp=False):
25+
"""
26+
def __init__(self, K=10, J=8, alpha=500, epsilon=1e-3, eig_thresh=1e-6, trimming_thresh=1e-3, use_gp=False):
2527
"""
2628
Initialize the FastKCI_CInd object.
2729
@@ -33,7 +35,7 @@ def __init__(self, K=10, J=8, alpha=500, epsilon=1e-3, eig_thresh = 1e-6, trimmi
3335
epsilon: Penalty for the matrix ridge regression.
3436
eig_threshold: Threshold for Eigenvalues.
3537
use_gp: Whether to use Gaussian Process Regression to determine the kernel widths
36-
"""
38+
"""
3739
self.K = K
3840
self.J = J
3941
self.alpha = alpha
@@ -42,7 +44,6 @@ def __init__(self, K=10, J=8, alpha=500, epsilon=1e-3, eig_thresh = 1e-6, trimmi
4244
self.trimming_thresh = trimming_thresh
4345
self.use_gp = use_gp
4446
self.nullss = 5000
45-
# TODO: Adjust to causal-learn API
4647

4748
def compute_pvalue(self, data_x=None, data_y=None, data_z=None):
4849
"""
@@ -67,21 +68,21 @@ def compute_pvalue(self, data_x=None, data_y=None, data_z=None):
6768
self.Z_proposal, self.prob_Z = zip(*Z_proposal)
6869
block_res = Parallel(n_jobs=-1)(delayed(self.pvalue_onblocks)(self.Z_proposal[i]) for i in range(self.J))
6970
test_stat, null_samples, log_likelihood = zip(*block_res)
70-
71+
7172
log_likelihood = np.array(log_likelihood)
7273
self.prob_Z += log_likelihood
73-
self.prob_Z = np.around(np.exp(self.prob_Z-logsumexp(self.prob_Z)), 6) # experimental, not used
74+
self.prob_Z = np.around(np.exp(self.prob_Z-logsumexp(self.prob_Z)), 6) # experimental, not used
7475
self.all_null_samples = np.vstack(null_samples)
7576
self.all_p = np.array([np.sum(self.all_null_samples[i,] > test_stat[i]) / float(self.nullss) for i in range(self.J)])
7677
self.prob_weights = np.around(np.exp(log_likelihood-logsumexp(log_likelihood)), 6)
7778
self.all_test_stats = np.array(test_stat)
78-
self.test_stat = np.average(np.array(test_stat), weights = self.prob_weights)
79-
self.null_samples = np.average(null_samples, axis = 0, weights = self.prob_weights)
79+
self.test_stat = np.average(np.array(test_stat), weights=self.prob_weights)
80+
self.null_samples = np.average(null_samples, axis=0, weights=self.prob_weights)
8081
# experimental, not used
81-
self.pvalue_alt = np.sum(np.average(null_samples, axis = 0, weights = self.prob_Z) > np.average(np.array(test_stat), weights = self.prob_Z)) / float(self.nullss)
82+
self.pvalue_alt = np.sum(np.average(null_samples, axis=0, weights=self.prob_Z) > np.average(np.array(test_stat), weights=self.prob_Z)) / float(self.nullss)
8283
self.pvalue = np.sum(self.null_samples > self.test_stat) / float(self.nullss)
8384

84-
return self.pvalue, self.test_stat
85+
return self.pvalue, self.test_stat
8586

8687
def partition_data(self):
8788
"""
@@ -94,26 +95,26 @@ def partition_data(self):
9495
"""
9596
Z_mean = self.data_z.mean(axis=0)
9697
Z_sd = self.data_z.std(axis=0)
97-
mu_k = np.random.normal(Z_mean, Z_sd, size = (self.K,self.data_z.shape[1]))
98+
mu_k = np.random.normal(Z_mean, Z_sd, size=(self.K, self.data_z.shape[1]))
9899
sigma_k = np.eye(self.data_z.shape[1])
99100
pi_j = np.random.dirichlet([self.alpha]*self.K)
100-
ll = np.tile(np.log(pi_j),(self.n,1))
101+
ll = np.tile(np.log(pi_j), (self.n, 1))
101102
for k in range(self.K):
102-
ll[:,k] += stats.multivariate_normal.logpdf(self.data_z, mu_k[k,:], cov=sigma_k, allow_singular=True)
103-
Z = np.array([ np.random.multinomial(1,np.exp(ll[n,:]-logsumexp(ll[n,:]))).argmax() for n in range(self.n)])
103+
ll[:, k] += stats.multivariate_normal.logpdf(self.data_z, mu_k[k, :], cov=sigma_k, allow_singular=True)
104+
Z = np.array([np.random.multinomial(1, np.exp(ll[n, :]-logsumexp(ll[n, :]))).argmax() for n in range(self.n)])
104105
le = LabelEncoder()
105106
Z = le.fit_transform(Z)
106-
prob_Z = np.take_along_axis(ll, Z[:, None], axis=1).sum() # experimental, might be removed
107+
prob_Z = np.take_along_axis(ll, Z[:, None], axis=1).sum() # experimental, might be removed
107108
return Z, prob_Z
108-
109+
109110
def pvalue_onblocks(self, Z_proposal):
110111
"""
111112
Calculate p value on given partitions of the data.
112113
113114
Parameters
114115
----------
115116
Z_proposal: partition of the data into K clusters (nxd1 array)
116-
117+
117118
Returns
118119
_________
119120
test_stat: test statistic (scalar)
@@ -123,13 +124,13 @@ def pvalue_onblocks(self, Z_proposal):
123124
unique_Z_j = np.unique(Z_proposal)
124125
test_stat = 0
125126
log_likelihood = 0
126-
null_samples = np.zeros((1,self.nullss))
127+
null_samples = np.zeros((1, self.nullss))
127128
for k in unique_Z_j:
128-
K_mask = (Z_proposal==k)
129+
K_mask = (Z_proposal == k)
129130
X_k = np.copy(self.data[0][K_mask])
130131
Y_k = np.copy(self.data[1][K_mask])
131132
Z_k = np.copy(self.data_z[K_mask])
132-
if (Z_k.shape[0]<6): # small blocks cause problems in GP, experimental
133+
if (Z_k.shape[0] < 6): # small blocks cause problems in GP, experimental
133134
continue
134135
Kx, Ky, Kzx, Kzy, epsilon_x, epsilon_y, likelihood_x, likelihood_y = self.kernel_matrix(X_k, Y_k, Z_k)
135136
KxR, Rzx = Kernel.center_kernel_matrix_regression(Kx, Kzx, epsilon_x)
@@ -143,7 +144,6 @@ def pvalue_onblocks(self, Z_proposal):
143144
log_likelihood += likelihood_x + likelihood_y
144145
return test_stat, null_samples, log_likelihood
145146

146-
147147
def kernel_matrix(self, data_x, data_y, data_z):
148148
"""
149149
Calculates the Gaussian Kernel for given data inputs as well as the shared kernel.
@@ -157,10 +157,10 @@ def kernel_matrix(self, data_x, data_y, data_z):
157157

158158
data_x = stats.zscore(data_x, ddof=1, axis=0)
159159
data_x[np.isnan(data_x)] = 0.
160-
160+
161161
data_y = stats.zscore(data_y, ddof=1, axis=0)
162162
data_y[np.isnan(data_y)] = 0.
163-
163+
164164
data_z = stats.zscore(data_z, ddof=1, axis=0)
165165
data_z[np.isnan(data_z)] = 0.
166166

@@ -189,23 +189,23 @@ def kernel_matrix(self, data_x, data_y, data_z):
189189
with warnings.catch_warnings():
190190
warnings.filterwarnings("ignore", category=Warning)
191191
# P(X|Z)
192-
gpx.fit(X = data_z, y = data_x)
192+
gpx.fit(X=data_z, y=data_x)
193193
likelihood_x = gpx.log_marginal_likelihood_value_
194194
gpy = GaussianProcessRegressor()
195195
with warnings.catch_warnings():
196196
warnings.filterwarnings("ignore", category=Warning)
197197
# P(Y|X,Z)
198-
gpy.fit(X = np.c_[data_x,data_z], y=data_y)
198+
gpy.fit(X=np.c_[data_x, data_z], y=data_y)
199199
likelihood_y = gpy.log_marginal_likelihood_value_
200200

201-
else:
201+
else:
202202
n, Dz = data_z.shape
203-
203+
204204
widthz = np.sqrt(1.0 / (kernelX.width * data_x.shape[1]))
205205

206206
# Instantiate a Gaussian Process model for x
207207
wx, vx = eigh(Kx)
208-
topkx = int(np.max([np.min([400, np.floor(n / 4)]), np.min([n+1,8])]))
208+
topkx = int(np.max([np.min([400, np.floor(n / 4)]), np.min([n+1, 8])]))
209209
idx = np.argsort(-wx)
210210
wx = wx[idx]
211211
vx = vx[:, idx]
@@ -228,7 +228,7 @@ def kernel_matrix(self, data_x, data_y, data_z):
228228

229229
# Instantiate a Gaussian Process model for y
230230
wy, vy = eigh(Ky)
231-
topky = int(np.max([np.min([400, np.floor(n / 4)]), np.min([n+1,8])]))
231+
topky = int(np.max([np.min([400, np.floor(n / 4)]), np.min([n+1, 8])]))
232232
idy = np.argsort(-wy)
233233
wy = wy[idy]
234234
vy = vy[:, idy]
@@ -297,7 +297,7 @@ def get_uuprod(self, Kx, Ky):
297297
uu_prod = uu.T.dot(uu)
298298

299299
return uu_prod, size_u
300-
300+
301301
def get_kappa(self, mean_appr, var_appr):
302302
"""
303303
Get parameters for the approximated gamma distribution
@@ -333,11 +333,12 @@ def null_sample_spectral(self, uu_prod, size_u, T):
333333
eig_uu = -np.sort(-eig_uu)
334334
eig_uu = eig_uu[0:np.min((T, size_u))]
335335
eig_uu = eig_uu[eig_uu > np.max(eig_uu) * self.eig_thresh]
336-
336+
337337
f_rand = np.random.chisquare(1, (eig_uu.shape[0], self.nullss))
338338
null_dstr = eig_uu.T.dot(f_rand)
339339
return null_dstr
340340

341+
341342
class FastKCI_UInd(object):
342343
"""
343344
Python implementation of as speed-up version of the Kernel-based Conditional Independence (KCI) test. Unconditional version.
@@ -348,8 +349,8 @@ class FastKCI_UInd(object):
348349
"A kernel-based conditional independence test and application in causal discovery," In UAI 2011.
349350
[2] M. Zhang, and S. Williamson,
350351
"Embarrassingly Parallel Inference for Gaussian Processes" In JMLR 20 (2019)
351-
"""
352-
def __init__(self, K=10, J=8, alpha=500, trimming_thresh = 1e-3):
352+
"""
353+
def __init__(self, K=10, J=8, alpha=500, trimming_thresh=1e-3):
353354
"""
354355
Construct the FastKCI_UInd model.
355356
@@ -359,14 +360,13 @@ def __init__(self, K=10, J=8, alpha=500, trimming_thresh = 1e-3):
359360
J: Number of independent repittitions.
360361
alpha: Parameter for the Dirichlet distribution.
361362
trimming_thresh: Threshold for trimming the propensity weights.
362-
"""
363+
"""
363364
self.K = K
364365
self.J = J
365366
self.alpha = alpha
366367
self.trimming_thresh = trimming_thresh
367368
self.nullss = 5000
368369
self.eig_thresh = 1e-5
369-
# TODO: Adjust to causal-learn API
370370

371371
def compute_pvalue(self, data_x=None, data_y=None):
372372
"""
@@ -385,17 +385,17 @@ def compute_pvalue(self, data_x=None, data_y=None):
385385
self.data_x = data_x
386386
self.data_y = data_y
387387
self.n = data_x.shape[0]
388-
388+
389389
Z_proposal = Parallel(n_jobs=-1)(delayed(self.partition_data)() for i in range(self.J))
390390
self.Z_proposal, self.prob_Y = zip(*Z_proposal)
391391
block_res = Parallel(n_jobs=-1)(delayed(self.pvalue_onblocks)(self.Z_proposal[i]) for i in range(self.J))
392392
test_stat, null_samples, log_likelihood = zip(*block_res)
393393
self.prob_weights = np.around(np.exp(log_likelihood-logsumexp(log_likelihood)), 6)
394-
self.test_stat = np.average(np.array(test_stat), weights = self.prob_weights)
395-
self.null_samples = np.average(null_samples, axis = 0, weights = self.prob_weights)
394+
self.test_stat = np.average(np.array(test_stat), weights=self.prob_weights)
395+
self.null_samples = np.average(null_samples, axis=0, weights=self.prob_weights)
396396
self.pvalue = np.sum(self.null_samples > self.test_stat) / float(self.nullss)
397397

398-
return self.pvalue, self.test_stat
398+
return self.pvalue, self.test_stat
399399

400400
def partition_data(self):
401401
"""
@@ -408,25 +408,25 @@ def partition_data(self):
408408
"""
409409
Y_mean = self.data_y.mean(axis=0)
410410
Y_sd = self.data_y.std(axis=0)
411-
mu_k = np.random.normal(Y_mean, Y_sd, size = (self.K,self.data_y.shape[1]))
411+
mu_k = np.random.normal(Y_mean, Y_sd, size=(self.K, self.data_y.shape[1]))
412412
sigma_k = np.eye(self.data_y.shape[1])
413413
pi_j = np.random.dirichlet([self.alpha]*self.K)
414-
ll = np.tile(np.log(pi_j),(self.n,1))
414+
ll = np.tile(np.log(pi_j), (self.n, 1))
415415
for k in range(self.K):
416-
ll[:,k] += stats.multivariate_normal.logpdf(self.data_y, mu_k[k,:], cov=sigma_k, allow_singular=True)
417-
Z = np.array([ np.random.multinomial(1,np.exp(ll[n,:]-logsumexp(ll[n,:]))).argmax() for n in range(self.n)])
416+
ll[:, k] += stats.multivariate_normal.logpdf(self.data_y, mu_k[k, :], cov=sigma_k, allow_singular=True)
417+
Z = np.array([np.random.multinomial(1, np.exp(ll[n, :]-logsumexp(ll[n, :]))).argmax() for n in range(self.n)])
418418
prop_Y = np.take_along_axis(ll, Z[:, None], axis=1).sum()
419419
le = LabelEncoder()
420420
Z = le.fit_transform(Z)
421421
return (Z, prop_Y)
422-
422+
423423
def pvalue_onblocks(self, Z_proposal):
424424
unique_Z_j = np.unique(Z_proposal)
425425
test_stat = 0
426426
log_likelihood = 0
427-
null_samples = np.zeros((1,self.nullss))
427+
null_samples = np.zeros((1, self.nullss))
428428
for k in unique_Z_j:
429-
K_mask = (Z_proposal==k)
429+
K_mask = (Z_proposal == k)
430430
X_k = np.copy(self.data_x[K_mask])
431431
Y_k = np.copy(self.data_y[K_mask])
432432

@@ -442,12 +442,12 @@ def pvalue_onblocks(self, Z_proposal):
442442
with warnings.catch_warnings():
443443
warnings.filterwarnings("ignore", category=Warning)
444444
# P(X|Y)
445-
gpx.fit(X = Y_k, y = X_k)
445+
gpx.fit(X=Y_k, y=X_k)
446446
likelihood = gpx.log_marginal_likelihood_value_
447447
log_likelihood += likelihood
448-
448+
449449
return test_stat, null_samples, log_likelihood
450-
450+
451451
def kernel_matrix(self, data):
452452
"""
453453
Calculates the Gaussian Kernel for given data inputs.
@@ -460,12 +460,12 @@ def kernel_matrix(self, data):
460460
kernel_obj.set_width_empirical_hsic(data)
461461

462462
data = stats.zscore(data, ddof=1, axis=0)
463-
data[np.isnan(data)] = 0.
463+
data[np.isnan(data)] = 0.
464464

465465
K = kernel_obj.kernel(data)
466466

467467
return K
468-
468+
469469
def get_kappa(self, mean_appr, var_appr):
470470
"""
471471
Get parameters for the approximated gamma distribution
@@ -513,7 +513,7 @@ def null_sample_spectral(self, Kxc, Kyc):
513513
f_rand = np.random.chisquare(1, (lambda_prod.shape[0], self.nullss))
514514
null_dstr = lambda_prod.T.dot(f_rand) / T
515515
return null_dstr
516-
516+
517517
def HSIC_V_statistic(self, Kx, Ky):
518518
"""
519519
Compute V test statistic from kernel matrices Kx and Ky

0 commit comments

Comments
 (0)