Skip to content

Commit e83e4f0

Browse files
committed
add docstrings
Signed-off-by: Oliver Schacht <[email protected]>
1 parent 173b913 commit e83e4f0

File tree

1 file changed

+154
-2
lines changed

1 file changed

+154
-2
lines changed

causallearn/utils/RCIT/RCIT.py

Lines changed: 154 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -7,13 +7,56 @@
77

88

99
class RCIT(object):
10+
"""
11+
Python implementation of Randomized Conditional Independence Test (RCIT) test.
12+
The original R implementation can be found at https://github.com/ericstrobl/RCIT/tree/master
13+
14+
References
15+
----------
16+
[1] Strobl, E. V., Zhang, K., and Visweswaran, S. (2019). "Approximate kernel-based conditional
17+
independence tests for fast non-parametric causal discovery." Journal of Causal Inference, 7(1), 20180017.
18+
"""
1019
def __init__(self, approx="lpd4", num_f=100, num_f2=5, rcit=True):
20+
"""
21+
Initialize the RCIT object.
22+
23+
Parameters
24+
----------
25+
approx : str
26+
Method for approximating the null distribution.
27+
- "lpd4" for the Lindsay-Pilla-Basak method
28+
- "hbe" for the Hall-Buckley-Eagleson method
29+
- "gamma" for the Satterthwaite-Welch method
30+
- "chi2" for a normalized chi-squared statistic
31+
- "perm" for permutation testing
32+
Default is "lpd4".
33+
num_f : int
34+
Number of features for conditioning set. Default is 25.
35+
num_f2 : int
36+
Number of features for non-conditioning sets. Default is 5.
37+
rcit : bool
38+
Whether to use RCIT or RCoT. Default is True.
39+
"""
1140
self.approx = approx
1241
self.num_f = num_f
1342
self.num_f2 = num_f2
1443
self.rcit = rcit
1544

1645
def compute_pvalue(self, data_x, data_y, data_z):
46+
"""
47+
Compute the p value and return it together with the test statistic.
48+
49+
Parameters
50+
----------
51+
data_x: input data for x (nxd1 array)
52+
data_y: input data for y (nxd2 array)
53+
data_z: input data for z (nxd3 array)
54+
55+
Returns
56+
-------
57+
p: p value
58+
sta: test statistic
59+
"""
1760
d = data_z.shape[1]
1861
r = data_x.shape[0]
1962
r1 = 500 if (r > 500) else r
@@ -114,7 +157,27 @@ def compute_pvalue(self, data_x, data_y, data_z):
114157
return p, sta
115158

116159
def random_fourier_features(self, x, w=None, b=None, num_f=None, sigma=None):
117-
160+
"""
161+
Generate random Fourier features.
162+
163+
Parameters
164+
----------
165+
x : np.ndarray
166+
Random variable x.
167+
w : np.ndarray
168+
RRandom coefficients.
169+
b : np.ndarray
170+
Random offsets.
171+
num_f : int
172+
Number of random Fourier features.
173+
sigma : float
174+
Smooth parameter of RBF kernel.
175+
176+
Returns
177+
-------
178+
feat : np.ndarray
179+
Random Fourier features.
180+
"""
118181
if num_f is None:
119182
num_f = 25
120183

@@ -133,6 +196,22 @@ def random_fourier_features(self, x, w=None, b=None, num_f=None, sigma=None):
133196
return feat
134197

135198
def matrix_cov(self, mat_a, mat_b):
199+
"""
200+
Compute the covariance matrix between two matrices.
201+
Equivalent to ``cov()`` between two matrices in R.
202+
203+
Parameters
204+
----------
205+
mat_a : np.ndarray
206+
First data matrix.
207+
mat_b : np.ndarray
208+
Second data matrix.
209+
210+
Returns
211+
-------
212+
mat_cov : np.ndarray
213+
Covariance matrix.
214+
"""
136215
n_obs = mat_a.shape[0]
137216

138217
assert mat_a.shape == mat_b.shape
@@ -145,10 +224,47 @@ def matrix_cov(self, mat_a, mat_b):
145224

146225

147226
class RIT(object):
227+
"""
228+
Python implementation of Randomized Independence Test (RIT) test.
229+
The original R implementation can be found at https://github.com/ericstrobl/RCIT/tree/master
230+
231+
References
232+
----------
233+
[1] Strobl, E. V., Zhang, K., and Visweswaran, S. (2019). "Approximate kernel-based conditional
234+
independence tests for fast non-parametric causal discovery." Journal of Causal Inference, 7(1), 20180017.
235+
"""
148236
def __init__(self, approx="lpd4"):
237+
"""
238+
Initialize the RIT object.
239+
240+
Parameters
241+
----------
242+
approx : str
243+
Method for approximating the null distribution.
244+
- "lpd4" for the Lindsay-Pilla-Basak method
245+
- "hbe" for the Hall-Buckley-Eagleson method
246+
- "gamma" for the Satterthwaite-Welch method
247+
- "chi2" for a normalized chi-squared statistic
248+
- "perm" for permutation testing
249+
Default is "lpd4".
250+
"""
149251
self.approx = approx
150252

151253
def compute_pvalue(self, data_x, data_y):
254+
"""
255+
Compute the p value and return it together with the test statistic.
256+
257+
Parameters
258+
----------
259+
data_x: input data for x (nxd1 array)
260+
data_y: input data for y (nxd2 array)
261+
data_z: input data for z (nxd3 array)
262+
263+
Returns
264+
-------
265+
p: p value
266+
sta: test statistic
267+
"""
152268
r = data_x.shape[0]
153269
r1 = 500 if (r > 500) else r
154270

@@ -221,7 +337,27 @@ def compute_pvalue(self, data_x, data_y):
221337
return p, sta
222338

223339
def random_fourier_features(self, x, w=None, b=None, num_f=None, sigma=None):
224-
340+
"""
341+
Generate random Fourier features.
342+
343+
Parameters
344+
----------
345+
x : np.ndarray
346+
Random variable x.
347+
w : np.ndarray
348+
RRandom coefficients.
349+
b : np.ndarray
350+
Random offsets.
351+
num_f : int
352+
Number of random Fourier features.
353+
sigma : float
354+
Smooth parameter of RBF kernel.
355+
356+
Returns
357+
-------
358+
feat : np.ndarray
359+
Random Fourier features.
360+
"""
225361
if num_f is None:
226362
num_f = 25
227363

@@ -240,6 +376,22 @@ def random_fourier_features(self, x, w=None, b=None, num_f=None, sigma=None):
240376
return feat
241377

242378
def matrix_cov(self, mat_a, mat_b):
379+
"""
380+
Compute the covariance matrix between two matrices.
381+
Equivalent to ``cov()`` between two matrices in R.
382+
383+
Parameters
384+
----------
385+
mat_a : np.ndarray
386+
First data matrix.
387+
mat_b : np.ndarray
388+
Second data matrix.
389+
390+
Returns
391+
-------
392+
mat_cov : np.ndarray
393+
Covariance matrix.
394+
"""
243395
n_obs = mat_a.shape[0]
244396

245397
assert mat_a.shape == mat_b.shape

0 commit comments

Comments
 (0)