Skip to content

Commit 282f20d

Browse files
authored
Merge pull request #212 from py-why/add_boss
Add boss doc by Bryan
2 parents b07c316 + 2933d13 commit 282f20d

File tree

4 files changed

+86
-20
lines changed

4 files changed

+86
-20
lines changed

causallearn/search/PermutationBased/BOSS.py

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -53,7 +53,7 @@ def boss(
5353
if n < p:
5454
warnings.warn("The number of features is much larger than the sample size!")
5555

56-
if score_func == "local_score_CV_general":
56+
if score_func == "local_score_CV_general":
5757
# % k-fold negative cross validated likelihood based on regression in RKHS
5858
if parameters is None:
5959
parameters = {
@@ -63,13 +63,13 @@ def boss(
6363
localScoreClass = LocalScoreClass(
6464
data=X, local_score_fun=local_score_cv_general, parameters=parameters
6565
)
66-
elif score_func == "local_score_marginal_general":
66+
elif score_func == "local_score_marginal_general":
6767
# negative marginal likelihood based on regression in RKHS
6868
parameters = {}
6969
localScoreClass = LocalScoreClass(
7070
data=X, local_score_fun=local_score_marginal_general, parameters=parameters
7171
)
72-
elif score_func == "local_score_CV_multi":
72+
elif score_func == "local_score_CV_multi":
7373
# k-fold negative cross validated likelihood based on regression in RKHS
7474
# for data with multi-variate dimensions
7575
if parameters is None:
@@ -83,7 +83,7 @@ def boss(
8383
localScoreClass = LocalScoreClass(
8484
data=X, local_score_fun=local_score_cv_multi, parameters=parameters
8585
)
86-
elif score_func == "local_score_marginal_multi":
86+
elif score_func == "local_score_marginal_multi":
8787
# negative marginal likelihood based on regression in RKHS
8888
# for data with multi-variate dimensions
8989
if parameters is None:
@@ -93,22 +93,22 @@ def boss(
9393
localScoreClass = LocalScoreClass(
9494
data=X, local_score_fun=local_score_marginal_multi, parameters=parameters
9595
)
96-
elif score_func == "local_score_BIC":
96+
elif score_func == "local_score_BIC":
9797
# SEM BIC score
9898
warnings.warn("Please use 'local_score_BIC_from_cov' instead")
9999
if parameters is None:
100100
parameters = {"lambda_value": 2}
101101
localScoreClass = LocalScoreClass(
102102
data=X, local_score_fun=local_score_BIC, parameters=parameters
103103
)
104-
elif score_func == "local_score_BIC_from_cov":
104+
elif score_func == "local_score_BIC_from_cov":
105105
# SEM BIC score
106106
if parameters is None:
107107
parameters = {"lambda_value": 2}
108108
localScoreClass = LocalScoreClass(
109109
data=X, local_score_fun=local_score_BIC_from_cov, parameters=parameters
110110
)
111-
elif score_func == "local_score_BDeu":
111+
elif score_func == "local_score_BDeu":
112112
# BDeu score
113113
localScoreClass = LocalScoreClass(
114114
data=X, local_score_fun=local_score_BDeu, parameters=None
@@ -204,4 +204,4 @@ def better_mutation(v, order, gsts):
204204
order.remove(v)
205205
order.insert(best - int(best > i), v)
206206

207-
return True
207+
return True

causallearn/search/PermutationBased/GRaSP.py

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@
1616
local_score_marginal_general,
1717
local_score_marginal_multi,
1818
)
19-
from causallearn.search.PermutationBased.gst import GST;
19+
from causallearn.search.PermutationBased.gst import GST
2020
from causallearn.score.LocalScoreFunctionClass import LocalScoreClass
2121
from causallearn.utils.DAG2CPDAG import dag2cpdag
2222

@@ -111,7 +111,7 @@ def grasp(
111111
if n < p:
112112
warnings.warn("The number of features is much larger than the sample size!")
113113

114-
if score_func == "local_score_CV_general":
114+
if score_func == "local_score_CV_general":
115115
# k-fold negative cross validated likelihood based on regression in RKHS
116116
if parameters is None:
117117
parameters = {
@@ -127,7 +127,7 @@ def grasp(
127127
localScoreClass = LocalScoreClass(
128128
data=X, local_score_fun=local_score_marginal_general, parameters=parameters
129129
)
130-
elif score_func == "local_score_CV_multi":
130+
elif score_func == "local_score_CV_multi":
131131
# k-fold negative cross validated likelihood based on regression in RKHS
132132
# for data with multi-variate dimensions
133133
if parameters is None:
@@ -141,7 +141,7 @@ def grasp(
141141
localScoreClass = LocalScoreClass(
142142
data=X, local_score_fun=local_score_cv_multi, parameters=parameters
143143
)
144-
elif score_func == "local_score_marginal_multi":
144+
elif score_func == "local_score_marginal_multi":
145145
# negative marginal likelihood based on regression in RKHS
146146
# for data with multi-variate dimensions
147147
if parameters is None:
@@ -151,22 +151,22 @@ def grasp(
151151
localScoreClass = LocalScoreClass(
152152
data=X, local_score_fun=local_score_marginal_multi, parameters=parameters
153153
)
154-
elif score_func == "local_score_BIC":
154+
elif score_func == "local_score_BIC":
155155
# SEM BIC score
156156
warnings.warn("Please use 'local_score_BIC_from_cov' instead")
157157
if parameters is None:
158158
parameters = {"lambda_value": 2}
159159
localScoreClass = LocalScoreClass(
160160
data=X, local_score_fun=local_score_BIC, parameters=parameters
161161
)
162-
elif score_func == "local_score_BIC_from_cov":
162+
elif score_func == "local_score_BIC_from_cov":
163163
# SEM BIC score
164164
if parameters is None:
165165
parameters = {"lambda_value": 2}
166166
localScoreClass = LocalScoreClass(
167167
data=X, local_score_fun=local_score_BIC_from_cov, parameters=parameters
168168
)
169-
elif score_func == "local_score_BDeu":
169+
elif score_func == "local_score_BDeu":
170170
# BDeu score
171171
localScoreClass = LocalScoreClass(
172172
data=X, local_score_fun=local_score_BDeu, parameters=None
@@ -204,7 +204,7 @@ def grasp(
204204
sys.stdout.flush()
205205

206206
runtime = time.perf_counter() - runtime
207-
207+
208208
if verbose:
209209
sys.stdout.write("\nGRaSP completed in: %.2fs \n" % runtime)
210210
sys.stdout.flush()

docs/source/search_methods_index/Permutation-based causal discovery methods/GRaSP.rst

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@ GRaSP
66
Algorithm Introduction
77
--------------------------------------
88

9-
Greedy relaxation of the sparsest permutation (GRaSP) algorithm [1]_.
9+
Greedy relaxations of the sparsest permutation (GRaSP) algorithm [1]_.
1010

1111

1212
Usage
@@ -19,7 +19,7 @@ Usage
1919
G = grasp(X)
2020
2121
# or customized parameters
22-
G = grasp(X, score_func, depth, maxP, parameters)
22+
G = grasp(X, score_func, depth, parameters)
2323
2424
# Visualization using pydot
2525
from causallearn.utils.GraphUtils import GraphUtils
@@ -50,8 +50,6 @@ and n_features is the number of features.
5050
- ":ref:`local_score_CV_multi <Generalized score with cross validation>`": Generalized score with cross validation for data with multi-dimensional variables [2]_.
5151
- ":ref:`local_score_marginal_multi <Generalized score with marginal likelihood>`": Generalized score with marginal likelihood for data with multi-dimensional variables [2]_.
5252

53-
**maxP**: Allowed maximum number of parents when searching the graph. Default: None.
54-
5553
**parameters**: Needed when using CV likelihood. Default: None.
5654
- parameters['kfold']: k-fold cross validation.
5755
- parameters['lambda']: regularization parameter.
Lines changed: 68 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,68 @@
1+
.. _BOSS:
2+
3+
BOSS
4+
==============================================
5+
6+
Algorithm Introduction
7+
--------------------------------------
8+
9+
Best order score search (BOSS) algorithm [1]_.
10+
11+
12+
Usage
13+
----------------------------
14+
.. code-block:: python
15+
16+
from causallearn.search.PermutationBased.BOSS import boss
17+
18+
# default parameters
19+
G = boss(X)
20+
21+
# or customized parameters
22+
G = boss(X, score_func, parameters)
23+
24+
# Visualization using pydot
25+
from causallearn.utils.GraphUtils import GraphUtils
26+
import matplotlib.image as mpimg
27+
import matplotlib.pyplot as plt
28+
import io
29+
30+
pyd = GraphUtils.to_pydot(G)
31+
tmp_png = pyd.create_png(f="png")
32+
fp = io.BytesIO(tmp_png)
33+
img = mpimg.imread(fp, format='png')
34+
plt.axis('off')
35+
plt.imshow(img)
36+
plt.show()
37+
38+
Visualization using pydot is recommended (`usage example <https://github.com/cmu-phil/causal-learn/blob/main/tests/TestBOSS.py>`_). If specific label names are needed, please refer to this `usage example <https://github.com/cmu-phil/causal-learn/blob/e4e73f8b58510a3cd5a9125ba50c0ac62a425ef3/tests/TestGraphVisualization.py#L106>`_ (e.g., GraphUtils.to_pydot(G, labels=["A", "B", "C"]).
39+
40+
Parameters
41+
-------------------
42+
**X**: numpy.ndarray, shape (n_samples, n_features). Data, where n_samples is the number of samples
43+
and n_features is the number of features.
44+
45+
**score_func**: The score function you would like to use, including (see :ref:`score_functions`.). Default: 'local_score_BIC'.
46+
- ":ref:`local_score_BIC <BIC score>`": BIC score [3]_.
47+
- ":ref:`local_score_BDeu <BDeu score>`": BDeu score [4]_.
48+
- ":ref:`local_score_CV_general <Generalized score with cross validation>`": Generalized score with cross validation for data with single-dimensional variables [2]_.
49+
- ":ref:`local_score_marginal_general <Generalized score with marginal likelihood>`": Generalized score with marginal likelihood for data with single-dimensional variables [2]_.
50+
- ":ref:`local_score_CV_multi <Generalized score with cross validation>`": Generalized score with cross validation for data with multi-dimensional variables [2]_.
51+
- ":ref:`local_score_marginal_multi <Generalized score with marginal likelihood>`": Generalized score with marginal likelihood for data with multi-dimensional variables [2]_.
52+
53+
**parameters**: Needed when using CV likelihood. Default: None.
54+
- parameters['kfold']: k-fold cross validation.
55+
- parameters['lambda']: regularization parameter.
56+
- parameters['dlabel']: for variables with multi-dimensions, indicate which dimensions belong to the i-th variable.
57+
58+
59+
60+
Returns
61+
-------------------
62+
- **G**: learned general graph, where G.graph[j,i]=1 and G.graph[i,j]=-1 indicate i --> j; G.graph[i,j] = G.graph[j,i] = -1 indicates i --- j.
63+
64+
65+
.. [1] Andrews, B., Ramsey, J., Sanchez Romero, R., Camchong, J., & Kummerfeld, E. (2023). Fast scalable and accurate discovery of dags using the best order score search and grow shrink trees. Advances in Neural Information Processing Systems, 36, 63945-63956.
66+
.. [2] Huang, B., Zhang, K., Lin, Y., Schölkopf, B., & Glymour, C. (2018, July). Generalized score functions for causal discovery. In Proceedings of the 24th ACM SIGKDD International Conference on Knowledge Discovery & Data Mining (pp. 1551-1560).
67+
.. [3] Schwarz, G. (1978). Estimating the dimension of a model. The annals of statistics, 461-464.
68+
.. [4] Buntine, W. (1991). Theory refinement on Bayesian networks. In Uncertainty proceedings 1991 (pp. 52-60). Morgan Kaufmann.

0 commit comments

Comments
 (0)