Skip to content

Commit 42dc8c4

Browse files
Update TestGraphVisualization.py
1 parent 19cd058 commit 42dc8c4

File tree

1 file changed

+27
-6
lines changed

1 file changed

+27
-6
lines changed

tests/TestGraphVisualization.py

Lines changed: 27 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -3,14 +3,35 @@
33
sys.path.append("")
44

55
import unittest
6+
import numpy as np
7+
import io
8+
69
from causallearn.graph.Dag import Dag
710
from causallearn.graph.GraphNode import GraphNode
811
from causallearn.utils.GraphUtils import GraphUtils
912
from causallearn.utils.DAG2PAG import dag2pag
13+
from causallearn.search.ConstraintBased.PC import pc
14+
from causallearn.utils.cit import fisherz
15+
import matplotlib.image as mpimg
16+
import matplotlib.pyplot as plt
1017

1118
class testGraphVisualization(unittest.TestCase):
1219

13-
def test_draw_dag(self):
20+
def test_draw_CPDAG(self):
21+
data_path = "data_linear_10.txt"
22+
data = np.loadtxt(data_path, skiprows=1) # Import the file at data_path as data
23+
cg = pc(data, 0.05, fisherz, True, 0,
24+
-1) # Run PC and obtain the estimated graph (CausalGraph object)
25+
pyd = GraphUtils.to_pydot(cg.G)
26+
tmp_png = pyd.create_png(f="png")
27+
fp = io.BytesIO(tmp_png)
28+
img = mpimg.imread(fp, format='png')
29+
plt.axis('off')
30+
plt.imshow(img)
31+
plt.show()
32+
print('finish')
33+
34+
def test_draw_DAG(self):
1435
nodes = []
1536
for i in range(3):
1637
nodes.append(GraphNode(f"X{i + 1}"))
@@ -20,10 +41,10 @@ def test_draw_dag(self):
2041
dag1.add_directed_edge(nodes[0], nodes[2])
2142
dag1.add_directed_edge(nodes[1], nodes[2])
2243

23-
pgv_g = GraphUtils.to_pgv(dag1)
24-
pgv_g.draw('dag.png', prog='dot', format='png')
44+
pyd = GraphUtils.to_pydot(dag1)
45+
pyd.write_png('dag.png')
2546

26-
def test_draw_pag(self):
47+
def test_draw_PAG(self):
2748
nodes = []
2849
for i in range(5):
2950
nodes.append(GraphNode(str(i)))
@@ -35,6 +56,6 @@ def test_draw_pag(self):
3556
dag.add_directed_edge(nodes[3], nodes[4])
3657
pag = dag2pag(dag, [nodes[0], nodes[2]])
3758

38-
pgv_g = GraphUtils.to_pgv(pag)
39-
pgv_g.draw('pag.png', prog='dot', format='png')
59+
pyd = GraphUtils.to_pydot(pag)
60+
pyd.write_png('pag.png')
4061

0 commit comments

Comments
 (0)