|
| 1 | +import unittest |
| 2 | + |
| 3 | +import numpy as np |
| 4 | + |
| 5 | +import causallearn.utils.cit as cit |
| 6 | + |
| 7 | + |
| 8 | +class TestCIT_RCIT(unittest.TestCase): |
| 9 | + def test_Gaussian_dist(self): |
| 10 | + np.random.seed(10) |
| 11 | + X = np.random.randn(300, 1) |
| 12 | + X_prime = np.random.randn(300, 1) |
| 13 | + Y = X + 0.5 * np.random.randn(300, 1) |
| 14 | + Z = Y + 0.5 * np.random.randn(300, 1) |
| 15 | + data = np.hstack((X, X_prime, Y, Z)) |
| 16 | + |
| 17 | + pvalue01 = [] |
| 18 | + pvalue03 = [] |
| 19 | + pvalue032 = [] |
| 20 | + for approx in ["lpd4", "hbe", "gamma", "chi2", "perm"]: |
| 21 | + for num_f in [50, 100]: |
| 22 | + for num_f2 in [5, 10]: |
| 23 | + for rcit in [True, False]: |
| 24 | + cit_CIT = cit.CIT(data, 'rcit', approx=approx, num_f=num_f, |
| 25 | + num_f2=num_f2, rcit=rcit) |
| 26 | + pvalue01.append(round(cit_CIT(0, 1), 4)) |
| 27 | + pvalue03.append(round(cit_CIT(0, 3), 4)) |
| 28 | + pvalue032.append(round(cit_CIT(0, 3, {2}), 4)) |
| 29 | + |
| 30 | + self.assertTrue(np.all((0.0 <= pvalue01) & (pvalue01 <= 1.0)), |
| 31 | + "pvalue01 contains invalid values") |
| 32 | + self.assertTrue(np.all((0.0 <= pvalue03) & (pvalue03 <= 1.0)), |
| 33 | + "pvalue03 contains invalid values") |
| 34 | + self.assertTrue(np.all((0.0 <= pvalue032) & (pvalue032 <= 1.0)), |
| 35 | + "pvalue032 contains invalid values") |
0 commit comments