Skip to content

Commit fff828d

Browse files
authored
Merge pull request #92 from wean2016/main
Fix Typo for GES Algorithm
2 parents 5eaf0f6 + 6aa83ca commit fff828d

File tree

2 files changed

+91
-2
lines changed

2 files changed

+91
-2
lines changed

causallearn/search/ScoreBased/GES.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -60,7 +60,7 @@ def ges(X: ndarray, score_func: str = 'local_score_BIC', maxP: Optional[float] =
6060
if parameters is None:
6161
parameters = {'kfold': 10, 'lambda': 0.01, 'dlabel': {}} # regularization parameter
6262
for i in range(X.shape[1]):
63-
parameters['dlabel']['{}'.format(i)] = i
63+
parameters['dlabel'][i] = i
6464
if maxP is None:
6565
maxP = len(parameters['dlabel']) / 2
6666
N = len(parameters['dlabel'])
@@ -71,7 +71,7 @@ def ges(X: ndarray, score_func: str = 'local_score_BIC', maxP: Optional[float] =
7171
if parameters is None:
7272
parameters = {'dlabel': {}}
7373
for i in range(X.shape[1]):
74-
parameters['dlabel']['{}'.format(i)] = i
74+
parameters['dlabel'][i] = i
7575
if maxP is None:
7676
maxP = len(parameters['dlabel']) / 2
7777
N = len(parameters['dlabel'])

tests/TestLocalScore.py

Lines changed: 89 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,89 @@
1+
import unittest
2+
from causallearn.utils.GESUtils import *
3+
from causallearn.score.LocalScoreFunctionClass import LocalScoreClass
4+
5+
class TestLocalScore(unittest.TestCase):
6+
7+
np.random.seed(10)
8+
X = np.random.randn(300, 1)
9+
X_prime = np.random.randn(300, 1)
10+
Y = X + 0.5 * np.random.randn(300, 1)
11+
Z = Y + 0.5 * np.random.randn(300, 1)
12+
data = np.mat(np.hstack((X, X_prime, Y, Z)))
13+
14+
np.random.seed(10)
15+
X_slash = np.random.randint(0, 10, (300, 1))
16+
X_prime_slash = np.random.randint(0, 10, (300, 1))
17+
Y_slash = X_slash + np.random.randint(0, 10, (300, 1))
18+
Z_slash = Y_slash + np.random.randint(0, 10, (300, 1))
19+
data_slash = np.mat(np.hstack((X_slash, X_prime_slash, Y_slash, Z_slash)))
20+
21+
def test_local_score_marginal_multi(self):
22+
parameters = {'dlabel': {}}
23+
for i in range(self.data.shape[1]):
24+
parameters['dlabel'][i] = i
25+
26+
localScoreClass = LocalScoreClass(data=self.data, local_score_fun=local_score_marginal_multi, parameters=parameters)
27+
28+
q = localScoreClass.score(0, [0])
29+
p = localScoreClass.score(0, [1])
30+
v = localScoreClass.score(0, [2])
31+
32+
assert q < v < p
33+
34+
def test_local_score_CV_multi(self):
35+
parameters = {'kfold': 10, 'lambda': 0.01, 'dlabel': {}} # regularization parameter
36+
for i in range(self.data.shape[1]):
37+
parameters['dlabel'][i] = i
38+
39+
localScoreClass = LocalScoreClass(data=self.data, local_score_fun=local_score_cv_multi, parameters=parameters)
40+
41+
q = localScoreClass.score(0, [0])
42+
p = localScoreClass.score(0, [1])
43+
v = localScoreClass.score(0, [2])
44+
45+
assert q < v < p
46+
47+
def test_local_score_BIC(self):
48+
parameters = {}
49+
parameters["lambda_value"] = 2
50+
51+
localScoreClass = LocalScoreClass(data=self.data, local_score_fun=local_score_BIC, parameters=parameters)
52+
53+
q = localScoreClass.score(0, [0])
54+
p = localScoreClass.score(0, [1])
55+
v = localScoreClass.score(0, [2])
56+
57+
assert q < v < p
58+
59+
def test_local_score_CV_general(self):
60+
parameters = {'kfold': 10, # 10 fold cross validation
61+
'lambda': 0.01} # regularization parameter
62+
63+
localScoreClass = LocalScoreClass(data=self.data, local_score_fun=local_score_cv_general, parameters=parameters)
64+
65+
q = localScoreClass.score(0, [0])
66+
p = localScoreClass.score(0, [1])
67+
v = localScoreClass.score(0, [2])
68+
69+
assert q < v < p
70+
71+
def test_local_score_marginal_general(self):
72+
parameters = {}
73+
74+
localScoreClass = LocalScoreClass(data=self.data, local_score_fun=local_score_marginal_general, parameters=parameters)
75+
76+
q = localScoreClass.score(0, [0])
77+
p = localScoreClass.score(0, [1])
78+
v = localScoreClass.score(0, [2])
79+
80+
assert q < v < p
81+
82+
def test_local_score_BDeu(self):
83+
localScoreClass = LocalScoreClass(data=self.data_slash, local_score_fun=local_score_BDeu, parameters=None)
84+
85+
q = localScoreClass.score(0, [0])
86+
p = localScoreClass.score(0, [1])
87+
v = localScoreClass.score(0, [2])
88+
89+
assert q < v < p

0 commit comments

Comments
 (0)