Skip to content

Commit d7ebfcf

Browse files
committed
visualization for GIN
1 parent 9436373 commit d7ebfcf

File tree

3 files changed

+54
-1
lines changed
  • docs/source/search_methods_index
    • Hidden causal representation learning
    • Score-based causal discovery methods
  • tests

3 files changed

+54
-1
lines changed

docs/source/search_methods_index/Hidden causal representation learning/gin.rst

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,22 @@ Usage
1515
from causallearn.search.FCMBased.GIN.GIN import GIN
1616
G, K = GIN(data)
1717
18+
# Visualization using pydot
19+
from causallearn.utils.GraphUtils import GraphUtils
20+
import matplotlib.image as mpimg
21+
import matplotlib.pyplot as plt
22+
import io
23+
24+
pyd = GraphUtils.to_pydot(G)
25+
tmp_png = pyd.create_png(f="png")
26+
fp = io.BytesIO(tmp_png)
27+
img = mpimg.imread(fp, format='png')
28+
plt.axis('off')
29+
plt.imshow(img)
30+
plt.show()
31+
32+
Visualization using pydot is recommended (`usage example <https://github.com/cmu-phil/causal-learn/blob/main/tests/TestGIN.py>`_). If specific label names are needed, please refer to this `usage example <https://github.com/cmu-phil/causal-learn/blob/e4e73f8b58510a3cd5a9125ba50c0ac62a425ef3/tests/TestGraphVisualization.py#L106>`_ (e.g., GraphUtils.to_pydot(G, labels=["A", "B", "C"]).
33+
1834
Parameters
1935
-----------------------------------------------------------
2036
**data**: numpy.ndarray, shape (n_samples, n_features). Data, where n_samples is the number of samples

docs/source/search_methods_index/Score-based causal discovery methods/GES.rst

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,10 @@ Usage
1818
1919
# Visualization using pydot
2020
from causallearn.utils.GraphUtils import GraphUtils
21+
import matplotlib.image as mpimg
22+
import matplotlib.pyplot as plt
23+
import io
24+
2125
pyd = GraphUtils.to_pydot(Record['G'])
2226
tmp_png = pyd.create_png(f="png")
2327
fp = io.BytesIO(tmp_png)

tests/TestGIN.py

Lines changed: 34 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,13 @@
11
import random
22
import sys
3+
import io
34

45
sys.path.append("")
56
import unittest
67

78
import numpy as np
9+
import matplotlib.image as mpimg
10+
import matplotlib.pyplot as plt
811

912
from causallearn.search.HiddenCausal.GIN.GIN import GIN
1013

@@ -25,6 +28,16 @@ def test_case1(self):
2528
g, k = GIN(data)
2629
print(g, k)
2730

31+
# Visualization using pydot
32+
from causallearn.utils.GraphUtils import GraphUtils
33+
pyd = GraphUtils.to_pydot(g)
34+
tmp_png = pyd.create_png(f="png")
35+
fp = io.BytesIO(tmp_png)
36+
img = mpimg.imread(fp, format='png')
37+
plt.axis('off')
38+
plt.imshow(img)
39+
plt.show()
40+
2841
def test_case2(self):
2942
sample_size = 1000
3043
np.random.seed(0)
@@ -47,6 +60,16 @@ def test_case2(self):
4760
g, k = GIN(data)
4861
print(g, k)
4962

63+
# Visualization using pydot
64+
from causallearn.utils.GraphUtils import GraphUtils
65+
pyd = GraphUtils.to_pydot(g)
66+
tmp_png = pyd.create_png(f="png")
67+
fp = io.BytesIO(tmp_png)
68+
img = mpimg.imread(fp, format='png')
69+
plt.axis('off')
70+
plt.imshow(img)
71+
plt.show()
72+
5073
def test_case3(self):
5174
sample_size = 1000
5275
random.seed(42)
@@ -68,4 +91,14 @@ def test_case3(self):
6891
data = np.array([X1, X2, X3, X4, X5, X6, X7, X8]).T
6992
data = (data - np.mean(data, axis=0)) / np.std(data, axis=0)
7093
g, k = GIN(data)
71-
print(g, k)
94+
print(g, k)
95+
96+
# Visualization using pydot
97+
from causallearn.utils.GraphUtils import GraphUtils
98+
pyd = GraphUtils.to_pydot(g)
99+
tmp_png = pyd.create_png(f="png")
100+
fp = io.BytesIO(tmp_png)
101+
img = mpimg.imread(fp, format='png')
102+
plt.axis('off')
103+
plt.imshow(img)
104+
plt.show()

0 commit comments

Comments
 (0)