Skip to content

Commit 31acc6f

Browse files
committed
add function_to_mermaid
1 parent d3bbc20 commit 31acc6f

File tree

1 file changed

+63
-0
lines changed

1 file changed

+63
-0
lines changed

pytensor/mermaid.py

Lines changed: 63 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,63 @@
1+
from pytensor.d3viz.formatting import PyDotFormatter
2+
3+
4+
def function_to_mermaid(fn):
5+
formatter = PyDotFormatter()
6+
dot = formatter(fn)
7+
8+
nodes = dot.get_nodes()
9+
edges = dot.get_edges()
10+
11+
mermaid_lines = ["graph TD"]
12+
mermaid_lines.append("%% Nodes:")
13+
for node in nodes:
14+
name = node.get_name()
15+
label = node.get_label()
16+
shape = node.get_shape()
17+
18+
if label.endswith("."):
19+
label = f"{label}0"
20+
21+
if shape == "box":
22+
shape = "rect"
23+
else:
24+
shape = "rounded"
25+
26+
mermaid_lines.extend(
27+
[
28+
f'{name}["{label}"]',
29+
f"{name}@{{ shape: {shape} }}",
30+
]
31+
)
32+
33+
fillcolor = node.get_fillcolor()
34+
if fillcolor is not None and not fillcolor.startswith("#"):
35+
fillcolor = _color_to_hex(fillcolor)
36+
mermaid_lines.append(f"style {name} fill:{fillcolor}")
37+
38+
mermaid_lines.append("%% Edges:")
39+
for edge in edges:
40+
source = edge.get_source()
41+
target = edge.get_destination()
42+
43+
mermaid_lines.append(f"{source} --> {target}")
44+
45+
return "\n".join(mermaid_lines)
46+
47+
48+
def _color_to_hex(color_name):
49+
"""Based on the colors in d3viz module."""
50+
return {
51+
"limegreen": "#32CD32",
52+
"SpringGreen": "#00FF7F",
53+
"YellowGreen": "#9ACD32",
54+
"dodgerblue": "#1E90FF",
55+
"lightgrey": "#D3D3D3",
56+
"yellow": "#FFFF00",
57+
"cyan": "#00FFFF",
58+
"magenta": "#FF00FF",
59+
"red": "#FF0000",
60+
"blue": "#0000FF",
61+
"green": "#008000",
62+
"grey": "#808080",
63+
}.get(color_name)

0 commit comments

Comments
 (0)