1+ import random
2+ import sys
3+ import time
4+ import warnings
5+ from typing import Any , Dict , List , Optional
6+
7+ import numpy as np
8+ from causallearn .graph .GeneralGraph import GeneralGraph
9+ from causallearn .graph .GraphNode import GraphNode
10+ from causallearn .score .LocalScoreFunction import (
11+ local_score_BDeu ,
12+ local_score_BIC ,
13+ local_score_BIC_from_cov ,
14+ local_score_cv_general ,
15+ local_score_cv_multi ,
16+ local_score_marginal_general ,
17+ local_score_marginal_multi ,
18+ )
19+ from causallearn .search .PermutationBased .gst import GST ;
20+ from causallearn .score .LocalScoreFunctionClass import LocalScoreClass
21+ from causallearn .utils .DAG2CPDAG import dag2cpdag
22+
23+
24+ def boss (
25+ X : np .ndarray ,
26+ score_func : str = "local_score_BIC_from_cov" ,
27+ parameters : Optional [Dict [str , Any ]] = None ,
28+ verbose : Optional [bool ] = True ,
29+ node_names : Optional [List [str ]] = None ,
30+ ) -> GeneralGraph :
31+ """
32+ Perform a best order score search (BOSS) algorithm
33+
34+ Parameters
35+ ----------
36+ X : data set (numpy ndarray), shape (n_samples, n_features). The input data, where n_samples is the number of samples and n_features is the number of features.
37+ score_func : the string name of score function. (str(one of 'local_score_CV_general', 'local_score_marginal_general',
38+ 'local_score_CV_multi', 'local_score_marginal_multi', 'local_score_BIC', 'local_score_BIC_from_cov', 'local_score_BDeu')).
39+ parameters : when using CV likelihood,
40+ parameters['kfold']: k-fold cross validation
41+ parameters['lambda']: regularization parameter
42+ parameters['dlabel']: for variables with multi-dimensions,
43+ indicate which dimensions belong to the i-th variable.
44+ verbose : whether to print the time cost and verbose output of the algorithm.
45+
46+ Returns
47+ -------
48+ G : learned causal graph, where G.graph[j,i] = 1 and G.graph[i,j] = -1 indicates i --> j, G.graph[i,j] = G.graph[j,i] = -1 indicates i --- j.
49+ """
50+
51+ X = X .copy ()
52+ n , p = X .shape
53+ if n < p :
54+ warnings .warn ("The number of features is much larger than the sample size!" )
55+
56+ if score_func == "local_score_CV_general" :
57+ # % k-fold negative cross validated likelihood based on regression in RKHS
58+ if parameters is None :
59+ parameters = {
60+ "kfold" : 10 , # 10 fold cross validation
61+ "lambda" : 0.01 ,
62+ } # regularization parameter
63+ localScoreClass = LocalScoreClass (
64+ data = X , local_score_fun = local_score_cv_general , parameters = parameters
65+ )
66+ elif score_func == "local_score_marginal_general" :
67+ # negative marginal likelihood based on regression in RKHS
68+ parameters = {}
69+ localScoreClass = LocalScoreClass (
70+ data = X , local_score_fun = local_score_marginal_general , parameters = parameters
71+ )
72+ elif score_func == "local_score_CV_multi" :
73+ # k-fold negative cross validated likelihood based on regression in RKHS
74+ # for data with multi-variate dimensions
75+ if parameters is None :
76+ parameters = {
77+ "kfold" : 10 ,
78+ "lambda" : 0.01 ,
79+ "dlabel" : {},
80+ } # regularization parameter
81+ for i in range (X .shape [1 ]):
82+ parameters ["dlabel" ]["{}" .format (i )] = i
83+ localScoreClass = LocalScoreClass (
84+ data = X , local_score_fun = local_score_cv_multi , parameters = parameters
85+ )
86+ elif score_func == "local_score_marginal_multi" :
87+ # negative marginal likelihood based on regression in RKHS
88+ # for data with multi-variate dimensions
89+ if parameters is None :
90+ parameters = {"dlabel" : {}}
91+ for i in range (X .shape [1 ]):
92+ parameters ["dlabel" ]["{}" .format (i )] = i
93+ localScoreClass = LocalScoreClass (
94+ data = X , local_score_fun = local_score_marginal_multi , parameters = parameters
95+ )
96+ elif score_func == "local_score_BIC" :
97+ # SEM BIC score
98+ warnings .warn ("Please use 'local_score_BIC_from_cov' instead" )
99+ if parameters is None :
100+ parameters = {"lambda_value" : 2 }
101+ localScoreClass = LocalScoreClass (
102+ data = X , local_score_fun = local_score_BIC , parameters = parameters
103+ )
104+ elif score_func == "local_score_BIC_from_cov" :
105+ # SEM BIC score
106+ if parameters is None :
107+ parameters = {"lambda_value" : 2 }
108+ localScoreClass = LocalScoreClass (
109+ data = X , local_score_fun = local_score_BIC_from_cov , parameters = parameters
110+ )
111+ elif score_func == "local_score_BDeu" :
112+ # BDeu score
113+ localScoreClass = LocalScoreClass (
114+ data = X , local_score_fun = local_score_BDeu , parameters = None
115+ )
116+ else :
117+ raise Exception ("Unknown function!" )
118+
119+ score = localScoreClass
120+ gsts = [GST (i , score ) for i in range (p )]
121+
122+ node_names = [("X%d" % (i + 1 )) for i in range (p )] if node_names is None else node_names
123+ nodes = []
124+
125+ for name in node_names :
126+ node = GraphNode (name )
127+ nodes .append (node )
128+
129+ G = GeneralGraph (nodes )
130+
131+ runtime = time .perf_counter ()
132+
133+ order = [v for v in range (p )]
134+
135+ gsts = [GST (v , score ) for v in order ]
136+ parents = {v : [] for v in order }
137+
138+ variables = [v for v in order ]
139+ while True :
140+ improved = False
141+ random .shuffle (variables )
142+ if verbose :
143+ for i , v in enumerate (order ):
144+ parents [v ].clear ()
145+ gsts [v ].trace (order [:i ], parents [v ])
146+ sys .stdout .write ("\r BOSS edge count: %i " % np .sum ([len (parents [v ]) for v in range (p )]))
147+ sys .stdout .flush ()
148+
149+ for v in variables :
150+ improved |= better_mutation (v , order , gsts )
151+ if not improved : break
152+
153+ for i , v in enumerate (order ):
154+ parents [v ].clear ()
155+ gsts [v ].trace (order [:i ], parents [v ])
156+
157+ runtime = time .perf_counter () - runtime
158+
159+ if verbose :
160+ sys .stdout .write ("\n BOSS completed in: %.2fs \n " % runtime )
161+ sys .stdout .flush ()
162+
163+ for y in range (p ):
164+ for x in parents [y ]:
165+ G .add_directed_edge (nodes [x ], nodes [y ])
166+
167+ G = dag2cpdag (G )
168+
169+ return G
170+
171+
172+ def reversed_enumerate (iter , j ):
173+ for w in reversed (iter ):
174+ yield j , w
175+ j -= 1
176+
177+
178+ def better_mutation (v , order , gsts ):
179+ i = order .index (v )
180+ p = len (order )
181+ scores = np .zeros (p + 1 )
182+
183+ prefix = []
184+ score = 0
185+ for j , w in enumerate (order ):
186+ scores [j ] = gsts [v ].trace (prefix ) + score
187+ if v != w :
188+ score += gsts [w ].trace (prefix )
189+ prefix .append (w )
190+
191+ scores [p ] = gsts [v ].trace (prefix ) + score
192+ best = p
193+
194+ prefix .append (v )
195+ score = 0
196+ for j , w in reversed_enumerate (order , p - 1 ):
197+ if v != w :
198+ prefix .remove (w )
199+ score += gsts [w ].trace (prefix )
200+ scores [j ] += score
201+ if scores [j ] > scores [best ]: best = j
202+
203+ if scores [i ] + 1e-6 > scores [best ]: return False
204+ order .remove (v )
205+ order .insert (best - int (best > i ), v )
206+
207+ return True
0 commit comments