Skip to content

Commit 7260db0

Browse files
committed
add cg.draw_pydot_graph
1 parent 42dc8c4 commit 7260db0

File tree

3 files changed

+118
-27
lines changed

3 files changed

+118
-27
lines changed

causallearn/graph/GraphClass.py

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,8 @@
11
import warnings
22
from itertools import permutations
33

4+
import io
5+
import matplotlib.image as mpimg
46
import matplotlib.pyplot as plt
57
import networkx as nx
68
import numpy as np
@@ -10,6 +12,7 @@
1012
from causallearn.graph.Endpoint import Endpoint
1113
from causallearn.graph.GeneralGraph import GeneralGraph
1214
from causallearn.graph.GraphNode import GraphNode
15+
from causallearn.utils.GraphUtils import GraphUtils
1316
from causallearn.utils.PCUtils.Helper import list_union, powerset
1417

1518

@@ -184,3 +187,16 @@ def draw_nx_graph(self, skel=False):
184187
nx.draw(g_to_be_drawn, pos=pos, with_labels=True, labels=self.labels, edge_color=colors)
185188
plt.draw()
186189
plt.show()
190+
191+
def draw_pydot_graph(self):
192+
"""Draw nx_graph if skel = False and draw nx_skel otherwise"""
193+
warnings.filterwarnings("ignore", category=UserWarning)
194+
pyd = GraphUtils.to_pydot(self.G)
195+
tmp_png = pyd.create_png(f="png")
196+
fp = io.BytesIO(tmp_png)
197+
img = mpimg.imread(fp, format='png')
198+
plt.axis('off')
199+
plt.imshow(img)
200+
plt.show()
201+
202+

docs/source/search_methods_index/Constrained-based causal discovery methods/PC.rst

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -18,8 +18,13 @@ Usage
1818
1919
from causallearn.search.ConstraintBased.PC import pc
2020
G = pc(data, alpha, indep_test, stable, uc_rule, uc_priority, mvpc, correction_name, background_knowledge)
21-
G.to_nx_graph()
22-
G.draw_nx_graph(skel=False)
21+
22+
# visualization using pydot
23+
cg.draw_pydot_graph()
24+
25+
# visualization using networkx
26+
# cg.to_nx_graph()
27+
# cg.draw_nx_graph(skel=False)
2328
2429
Parameters
2530
-------------------
@@ -59,7 +64,7 @@ For detailed usage, please kindly refer to its `usage example <https://github.co
5964

6065
Returns
6166
-------------------
62-
**cg** : a CausalGraph object. Nodes in the graph correspond to the column indices in the data. For visualization by networkx, green edges are undirected, blue edges are directed and red edges are bi-directed.
67+
**cg** : a CausalGraph object. Nodes in the graph correspond to the column indices in the data.
6368

6469
.. [1] Spirtes, P., Glymour, C. N., Scheines, R., & Heckerman, D. (2000). Causation, prediction, and search. MIT press.
6570
.. [2] Tu, R., Zhang, C., Ackermann, P., Mohan, K., Kjellström, H., & Zhang, K. (2019, April). Causal discovery in the presence of missing data. In The 22nd International Conference on Artificial Intelligence and Statistics (pp. 1762-1770). PMLR.

tests/TestPC.py

Lines changed: 94 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -15,8 +15,13 @@ def test_pc_with_fisher_z(self):
1515
data = np.loadtxt(data_path, skiprows=1) # Import the file at data_path as data
1616
cg = pc(data, 0.05, fisherz, True, 0,
1717
-1) # Run PC and obtain the estimated graph (CausalGraph object)
18-
cg.to_nx_graph()
19-
cg.draw_nx_graph(skel=False)
18+
19+
# visualization using pydot
20+
cg.draw_pydot_graph()
21+
22+
# visualization using networkx
23+
# cg.to_nx_graph()
24+
# cg.draw_nx_graph(skel=False)
2025

2126
print('finish')
2227

@@ -26,8 +31,14 @@ def test_pc_with_g_sq(self):
2631
data = np.loadtxt(data_path, skiprows=1) # Import the file at data_path as data
2732
cg = pc(data, 0.05, gsq, True, 0,
2833
-1) # Run PC and obtain the estimated graph (CausalGraph object)
29-
cg.to_nx_graph()
30-
cg.draw_nx_graph(skel=False)
34+
35+
# visualization using pydot
36+
cg.draw_pydot_graph()
37+
38+
# visualization using networkx
39+
# cg.to_nx_graph()
40+
# cg.draw_nx_graph(skel=False)
41+
3142
print('finish')
3243

3344
# example3
@@ -36,8 +47,14 @@ def test_pc_with_chi_sq(self):
3647
data = np.loadtxt(data_path, skiprows=1) # Import the file at data_path as data
3748
cg = pc(data, 0.05, chisq, True, 0,
3849
-1) # Run PC and obtain the estimated graph (CausalGraph object)
39-
cg.to_nx_graph()
40-
cg.draw_nx_graph(skel=False)
50+
51+
# visualization using pydot
52+
cg.draw_pydot_graph()
53+
54+
# visualization using networkx
55+
# cg.to_nx_graph()
56+
# cg.draw_nx_graph(skel=False)
57+
4158
print('finish')
4259

4360
# example4
@@ -46,8 +63,14 @@ def test_pc_with_fisher_z_maxp(self):
4663
data = np.loadtxt(data_path, skiprows=1) # Import the file at data_path as data
4764
cg = pc(data, 0.05, fisherz, True, 1,
4865
-1) # Run PC and obtain the estimated graph (CausalGraph object)
49-
cg.to_nx_graph()
50-
cg.draw_nx_graph(skel=False)
66+
67+
# visualization using pydot
68+
cg.draw_pydot_graph()
69+
70+
# visualization using networkx
71+
# cg.to_nx_graph()
72+
# cg.draw_nx_graph(skel=False)
73+
5174
print('finish')
5275

5376
# example5
@@ -56,8 +79,13 @@ def test_pc_with_fisher_z_definite_maxp(self):
5679
data = np.loadtxt(data_path, skiprows=1) # Import the file at data_path as data
5780
cg = pc(data, 0.05, fisherz, True, 2,
5881
-1) # Run PC and obtain the estimated graph (CausalGraph object)
59-
cg.to_nx_graph()
60-
cg.draw_nx_graph(skel=False)
82+
83+
# visualization using pydot
84+
cg.draw_pydot_graph()
85+
86+
# visualization using networkx
87+
# cg.to_nx_graph()
88+
# cg.draw_nx_graph(skel=False)
6189

6290
print('finish')
6391

@@ -67,8 +95,14 @@ def test_pc_with_fisher_z_with_uc_priority0(self):
6795
data = np.loadtxt(data_path, skiprows=1) # Import the file at data_path as data
6896
cg = pc(data, 0.05, fisherz, True, 0,
6997
0) # Run PC and obtain the estimated graph (CausalGraph object)
70-
cg.to_nx_graph()
71-
cg.draw_nx_graph(skel=False)
98+
99+
# visualization using pydot
100+
cg.draw_pydot_graph()
101+
102+
# visualization using networkx
103+
# cg.to_nx_graph()
104+
# cg.draw_nx_graph(skel=False)
105+
72106
print('finish')
73107

74108
# example7
@@ -77,8 +111,14 @@ def test_pc_with_fisher_z_with_uc_priority1(self):
77111
data = np.loadtxt(data_path, skiprows=1) # Import the file at data_path as data
78112
cg = pc(data, 0.05, fisherz, True, 0,
79113
1) # Run PC and obtain the estimated graph (CausalGraph object)
80-
cg.to_nx_graph()
81-
cg.draw_nx_graph(skel=False)
114+
115+
# visualization using pydot
116+
cg.draw_pydot_graph()
117+
118+
# visualization using networkx
119+
# cg.to_nx_graph()
120+
# cg.draw_nx_graph(skel=False)
121+
82122
print('finish')
83123

84124
# example8
@@ -87,8 +127,14 @@ def test_pc_with_fisher_z_with_uc_priority2(self):
87127
data = np.loadtxt(data_path, skiprows=1) # Import the file at data_path as data
88128
cg = pc(data, 0.05, fisherz, True, 0,
89129
2) # Run PC and obtain the estimated graph (CausalGraph object)
90-
cg.to_nx_graph()
91-
cg.draw_nx_graph(skel=False)
130+
131+
# visualization using pydot
132+
cg.draw_pydot_graph()
133+
134+
# visualization using networkx
135+
# cg.to_nx_graph()
136+
# cg.draw_nx_graph(skel=False)
137+
92138
print('finish')
93139

94140
# example9
@@ -97,8 +143,14 @@ def test_pc_with_fisher_z_with_uc_priority3(self):
97143
data = np.loadtxt(data_path, skiprows=1) # Import the file at data_path as data
98144
cg = pc(data, 0.05, fisherz, True, 0,
99145
3) # Run PC and obtain the estimated graph (CausalGraph object)
100-
cg.to_nx_graph()
101-
cg.draw_nx_graph(skel=False)
146+
147+
# visualization using pydot
148+
cg.draw_pydot_graph()
149+
150+
# visualization using networkx
151+
# cg.to_nx_graph()
152+
# cg.draw_nx_graph(skel=False)
153+
102154
print('finish')
103155

104156
# example10
@@ -107,8 +159,14 @@ def test_pc_with_fisher_z_with_uc_priority4(self):
107159
data = np.loadtxt(data_path, skiprows=1) # Import the file at data_path as data
108160
cg = pc(data, 0.05, fisherz, True, 0,
109161
4) # Run PC and obtain the estimated graph (CausalGraph object)
110-
cg.to_nx_graph()
111-
cg.draw_nx_graph(skel=False)
162+
163+
# visualization using pydot
164+
cg.draw_pydot_graph()
165+
166+
# visualization using networkx
167+
# cg.to_nx_graph()
168+
# cg.draw_nx_graph(skel=False)
169+
112170
print('finish')
113171

114172
# example11
@@ -118,8 +176,14 @@ def test_pc_with_mv_fisher_z_with_uc_priority4(self):
118176

119177
cg = pc(data, 0.05, mv_fisherz, True, 0,
120178
4) # Run PC and obtain the estimated graph (CausalGraph object)
121-
cg.to_nx_graph()
122-
cg.draw_nx_graph(skel=False)
179+
180+
# visualization using pydot
181+
cg.draw_pydot_graph()
182+
183+
# visualization using networkx
184+
# cg.to_nx_graph()
185+
# cg.draw_nx_graph(skel=False)
186+
123187
print('finish')
124188

125189
# example12
@@ -128,6 +192,12 @@ def test_pc_with_kci(self):
128192
data = np.loadtxt(data_path, skiprows=1)[:50, :] # Import the file at data_path as data
129193
cg = pc(data, 0.05, kci, True, 0,
130194
-1) # Run PC and obtain the estimated graph (CausalGraph object)
131-
cg.to_nx_graph()
132-
cg.draw_nx_graph(skel=False)
195+
196+
# visualization using pydot
197+
cg.draw_pydot_graph()
198+
199+
# visualization using networkx
200+
# cg.to_nx_graph()
201+
# cg.draw_nx_graph(skel=False)
202+
133203
print('finish')

0 commit comments

Comments
 (0)