Skip to content

Commit 2642628

Browse files
committed
add unit tests
Signed-off-by: Oliver Schacht <[email protected]>
1 parent cf9d425 commit 2642628

File tree

2 files changed

+68
-0
lines changed

2 files changed

+68
-0
lines changed

tests/TestCIT_FastKCI.py

Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,33 @@
1+
import unittest
2+
3+
import numpy as np
4+
5+
import causallearn.utils.cit as cit
6+
7+
8+
class TestCIT_FastKCI(unittest.TestCase):
9+
def test_Gaussian_dist(self):
10+
np.random.seed(10)
11+
X = np.random.randn(1200, 1)
12+
X_prime = np.random.randn(1200, 1)
13+
Y = X + 0.5 * np.random.randn(1200, 1)
14+
Z = Y + 0.5 * np.random.randn(1200, 1)
15+
data = np.hstack((X, X_prime, Y, Z))
16+
17+
pvalue01 = []
18+
pvalue03 = []
19+
pvalue032 = []
20+
for K in [3, 10]:
21+
for J in [8, 16]:
22+
for use_gp in [True, False]:
23+
cit_CIT = cit.CIT(data, 'fastkci', K=K, J=J, use_gp=use_gp)
24+
pvalue01.append(round(cit_CIT(0, 1), 4))
25+
pvalue03.append(round(cit_CIT(0, 3), 4))
26+
pvalue032.append(round(cit_CIT(0, 3, {2}), 4))
27+
28+
self.assertTrue(np.all((0.0 <= pvalue01) & (pvalue01 <= 1.0)),
29+
"pvalue01 contains invalid values")
30+
self.assertTrue(np.all((0.0 <= pvalue03) & (pvalue03 <= 1.0)),
31+
"pvalue03 contains invalid values")
32+
self.assertTrue(np.all((0.0 <= pvalue032) & (pvalue032 <= 1.0)),
33+
"pvalue032 contains invalid values")

tests/TestCIT_RCIT.py

Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,35 @@
1+
import unittest
2+
3+
import numpy as np
4+
5+
import causallearn.utils.cit as cit
6+
7+
8+
class TestCIT_RCIT(unittest.TestCase):
9+
def test_Gaussian_dist(self):
10+
np.random.seed(10)
11+
X = np.random.randn(300, 1)
12+
X_prime = np.random.randn(300, 1)
13+
Y = X + 0.5 * np.random.randn(300, 1)
14+
Z = Y + 0.5 * np.random.randn(300, 1)
15+
data = np.hstack((X, X_prime, Y, Z))
16+
17+
pvalue01 = []
18+
pvalue03 = []
19+
pvalue032 = []
20+
for approx in ["lpd4", "hbe", "gamma", "chi2", "perm"]:
21+
for num_f in [50, 100]:
22+
for num_f2 in [5, 10]:
23+
for rcit in [True, False]:
24+
cit_CIT = cit.CIT(data, 'rcit', approx=approx, num_f=num_f,
25+
num_f2=num_f2, rcit=rcit)
26+
pvalue01.append(round(cit_CIT(0, 1), 4))
27+
pvalue03.append(round(cit_CIT(0, 3), 4))
28+
pvalue032.append(round(cit_CIT(0, 3, {2}), 4))
29+
30+
self.assertTrue(np.all((0.0 <= pvalue01) & (pvalue01 <= 1.0)),
31+
"pvalue01 contains invalid values")
32+
self.assertTrue(np.all((0.0 <= pvalue03) & (pvalue03 <= 1.0)),
33+
"pvalue03 contains invalid values")
34+
self.assertTrue(np.all((0.0 <= pvalue032) & (pvalue032 <= 1.0)),
35+
"pvalue032 contains invalid values")

0 commit comments

Comments
 (0)