11from typing import Optional
2-
2+ from causallearn . score . LocalScoreFunctionClass import LocalScoreClass
33from causallearn .graph .GeneralGraph import GeneralGraph
44from causallearn .graph .GraphNode import GraphNode
55from causallearn .utils .DAG2CPDAG import dag2cpdag
@@ -46,12 +46,14 @@ def ges(X: ndarray, score_func: str = 'local_score_BIC', maxP: Optional[float] =
4646 if maxP is None :
4747 maxP = X .shape [1 ] / 2 # maximum number of parents
4848 N = X .shape [1 ] # number of variables
49+ localScoreClass = LocalScoreClass (data = X , local_score_fun = local_score_cv_general , parameters = parameters )
4950
5051 elif score_func == 'local_score_marginal_general' : # negative marginal likelihood based on regression in RKHS
5152 parameters = {}
5253 if maxP is None :
5354 maxP = X .shape [1 ] / 2 # maximum number of parents
5455 N = X .shape [1 ] # number of variables
56+ localScoreClass = LocalScoreClass (data = X , local_score_fun = local_score_marginal_general , parameters = parameters )
5557
5658 elif score_func == 'local_score_CV_multi' : # k-fold negative cross validated likelihood based on regression in RKHS
5759 # for data with multi-variate dimensions
@@ -62,6 +64,7 @@ def ges(X: ndarray, score_func: str = 'local_score_BIC', maxP: Optional[float] =
6264 if maxP is None :
6365 maxP = len (parameters ['dlabel' ]) / 2
6466 N = len (parameters ['dlabel' ])
67+ localScoreClass = LocalScoreClass (data = X , local_score_fun = local_score_cv_multi , parameters = parameters )
6568
6669 elif score_func == 'local_score_marginal_multi' : # negative marginal likelihood based on regression in RKHS
6770 # for data with multi-variate dimensions
@@ -72,19 +75,23 @@ def ges(X: ndarray, score_func: str = 'local_score_BIC', maxP: Optional[float] =
7275 if maxP is None :
7376 maxP = len (parameters ['dlabel' ]) / 2
7477 N = len (parameters ['dlabel' ])
78+ localScoreClass = LocalScoreClass (data = X , local_score_fun = local_score_marginal_multi , parameters = parameters )
7579
7680 elif score_func == 'local_score_BIC' : # Greedy equivalence search with BIC score
7781 if maxP is None :
7882 maxP = X .shape [1 ] / 2
7983 N = X .shape [1 ] # number of variables
84+ localScoreClass = LocalScoreClass (data = X , local_score_fun = local_score_BIC , parameters = None )
8085
8186 elif score_func == 'local_score_BDeu' : # Greedy equivalence search with BDeu score
8287 if maxP is None :
8388 maxP = X .shape [1 ] / 2
8489 N = X .shape [1 ] # number of variables
90+ localScoreClass = LocalScoreClass (data = X , local_score_fun = local_score_BDeu , parameters = None )
8591
8692 else :
8793 raise Exception ('Unknown function!' )
94+ score_func = localScoreClass
8895
8996 node_names = [("x%d" % i ) for i in range (N )]
9097 nodes = []
0 commit comments