Skip to content

Commit 21ccd97

Browse files
Add labels parameter for to_pydot function
1 parent f57f711 commit 21ccd97

File tree

2 files changed

+41
-3
lines changed

2 files changed

+41
-3
lines changed

causallearn/utils/GraphUtils.py

Lines changed: 23 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -509,16 +509,36 @@ def get_g_arrow_type(endpoint):
509509
return graphviz_g
510510

511511
@staticmethod
512-
def to_pydot(G: Graph, edges: List[Edge] | None = None, title: str = "", dpi: int = 200):
512+
def to_pydot(G: Graph, edges: List[Edge] | None = None, labels: List[str] | None = None, title: str = "", dpi: float = 200):
513+
'''
514+
Convert a graph object to a DOT object.
515+
516+
Parameters
517+
----------
518+
G : Graph
519+
A graph object of causal-learn
520+
edges : list, optional (default=None)
521+
Edges list of graph G
522+
labels : list, optional (default=None)
523+
Nodes labels of graph G
524+
title : str, optional (default="")
525+
The name of graph G
526+
dpi : float, optional (default=200)
527+
The dots per inch of dot object
528+
Returns
529+
-------
530+
pydot_g : Dot
531+
'''
513532
pydot_g = pydot.Dot(title, graph_type="digraph", fontsize=18)
514533
pydot_g.obj_dict["attributes"]["dpi"] = dpi
515534
nodes = G.get_nodes()
516535
for i, node in enumerate(nodes):
536+
node_name = labels[i] if labels is not None and len(labels) > i else node.get_name()
517537
pydot_g.add_node(pydot.Node(i, label=node.get_name()))
518538
if node.get_node_type() == NodeType.LATENT:
519-
pydot_g.add_node(pydot.Node(i, label=node.get_name(), shape='square'))
539+
pydot_g.add_node(pydot.Node(i, label=node_name, shape='square'))
520540
else:
521-
pydot_g.add_node(pydot.Node(i, label=node.get_name()))
541+
pydot_g.add_node(pydot.Node(i, label=node_name))
522542

523543
def get_g_arrow_type(endpoint):
524544
if endpoint == Endpoint.TAIL:

tests/TestGraphVisualization.py

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -102,3 +102,21 @@ def test_color(self):
102102
edges = [edge]
103103
pyd = GraphUtils.to_pydot(pag, edges)
104104
pyd.write_png('green.png')
105+
106+
def test_plot_with_labels(self):
107+
nodes = []
108+
for i in range(3):
109+
nodes.append(GraphNode(f"X{i + 1}"))
110+
dag1 = Dag(nodes)
111+
112+
dag1.add_directed_edge(nodes[0], nodes[1])
113+
dag1.add_directed_edge(nodes[0], nodes[2])
114+
dag1.add_directed_edge(nodes[1], nodes[2])
115+
116+
pyd = GraphUtils.to_pydot(dag1, labels=["A", "B", "C"])
117+
tmp_png = pyd.create_png(f="png")
118+
fp = io.BytesIO(tmp_png)
119+
img = mpimg.imread(fp, format='png')
120+
plt.axis('off')
121+
plt.imshow(img)
122+
plt.show()

0 commit comments

Comments
 (0)