Skip to content

Commit 27d3a81

Browse files
committed
Update ANM to be compatible with the latest version of sklearn
1 parent ebcb1ac commit 27d3a81

File tree

2 files changed

+6
-20
lines changed

2 files changed

+6
-20
lines changed

causallearn/search/FCMBased/ANM/ANM.py

Lines changed: 1 addition & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,3 @@
1-
import os
2-
import sys
3-
4-
BASE_DIR = os.path.join(os.path.dirname(__file__), '..')
5-
sys.path.append(BASE_DIR)
61
from sklearn.gaussian_process import GaussianProcessRegressor
72
from sklearn.gaussian_process.kernels import RBF
83
from sklearn.gaussian_process.kernels import ConstantKernel as C
@@ -49,7 +44,7 @@ def fit_gp(self, X, y):
4944

5045
# fit Gaussian process, including hyperparameter optimization
5146
gpr.fit(X, y)
52-
pred_y = gpr.predict(X)
47+
pred_y = gpr.predict(X).reshape(-1, 1)
5348
return pred_y
5449

5550
def cause_or_effect(self, data_x, data_y):

tests/TestANM.py

Lines changed: 5 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,5 @@
1-
import os
2-
import sys
3-
4-
BASE_DIR = os.path.join(os.path.dirname(__file__), '..')
5-
sys.path.append(BASE_DIR)
6-
import sys
71
import unittest
8-
from pickle import load
9-
102
import numpy as np
11-
123
from causallearn.search.FCMBased.ANM.ANM import ANM
134

145

@@ -33,7 +24,7 @@ class TestANM(unittest.TestCase):
3324
# Test ANM by some simulated data
3425
def test_anm_simulation_1(self):
3526
# simulated data y = x + 2x^3 + e
36-
simulated_dataset_1 = np.loadtxt('TestData/anm_simulation_1.txt', delimiter=',')
27+
simulated_dataset_1 = np.loadtxt('tests/TestData/anm_simulation_1.txt', delimiter=',')
3728
simulated_dataset_1_p_value_forward, simulated_dataset_1_p_value_backward = 0.99541, 0.0 # round(value, 5) results
3829
x_1 = simulated_dataset_1[:, 0].reshape(-1, 1)
3930
y_1 = simulated_dataset_1[:, 1].reshape(-1, 1)
@@ -46,7 +37,7 @@ def test_anm_simulation_1(self):
4637

4738
def test_anm_simulation_2(self):
4839
# simulated data y = 5 * exp(x) + e
49-
simulated_dataset_2 = np.loadtxt('TestData/anm_simulation_2.txt', delimiter=',')
40+
simulated_dataset_2 = np.loadtxt('tests/TestData/anm_simulation_2.txt', delimiter=',')
5041
simulated_dataset_2_p_value_forward, simulated_dataset_2_p_value_backward = 0.99348, 0.0 # round(value, 5) results
5142
x_2 = simulated_dataset_2[:, 0].reshape(-1, 1)
5243
y_2 = simulated_dataset_2[:, 1].reshape(-1, 1)
@@ -59,7 +50,7 @@ def test_anm_simulation_2(self):
5950

6051
def test_anm_simulation_3(self):
6152
# simulated data y = 3^x + e
62-
simulated_dataset_3 = np.loadtxt('TestData/anm_simulation_3.txt', delimiter=',')
53+
simulated_dataset_3 = np.loadtxt('tests/TestData/anm_simulation_3.txt', delimiter=',')
6354
simulated_dataset_3_p_value_forward, simulated_dataset_3_p_value_backward = 0.65933, 0.0 # round(value, 5) results
6455
x_3 = simulated_dataset_3[:, 0].reshape(-1, 1)
6556
y_3 = simulated_dataset_3[:, 1].reshape(-1, 1)
@@ -73,8 +64,8 @@ def test_anm_simulation_3(self):
7364
# data pair from the Tuebingen cause-effect pair dataset.
7465
def test_anm_pair(self):
7566

76-
dataset = np.loadtxt('TestData/pair0001.txt')
77-
dataset_p_value_forward, dataset_p_value_backward = 0.14736, 0.0 # round(value, 5) results
67+
dataset = np.loadtxt('tests/TestData/pair0001.txt')
68+
dataset_p_value_forward, dataset_p_value_backward = 0.14773, 0.0 # round(value, 5) results
7869
p_value_forward, p_value_backward = self.anm.cause_or_effect(dataset[:, 0].reshape(-1, 1), dataset[:, 1].reshape(-1, 1))
7970
self.assertTrue(round(p_value_forward, 5) == dataset_p_value_forward)
8071
self.assertTrue(round(p_value_backward, 5) == dataset_p_value_backward)

0 commit comments

Comments
 (0)