|
| 1 | +from pytensor.d3viz.formatting import PyDotFormatter |
| 2 | + |
| 3 | + |
| 4 | +def function_to_mermaid(fn): |
| 5 | + formatter = PyDotFormatter() |
| 6 | + dot = formatter(fn) |
| 7 | + |
| 8 | + nodes = dot.get_nodes() |
| 9 | + edges = dot.get_edges() |
| 10 | + |
| 11 | + mermaid_lines = ["graph TD"] |
| 12 | + mermaid_lines.append("%% Nodes:") |
| 13 | + for node in nodes: |
| 14 | + name = node.get_name() |
| 15 | + label = node.get_label() |
| 16 | + shape = node.get_shape() |
| 17 | + |
| 18 | + if label.endswith("."): |
| 19 | + label = f"{label}0" |
| 20 | + |
| 21 | + if shape == "box": |
| 22 | + shape = "rect" |
| 23 | + else: |
| 24 | + shape = "rounded" |
| 25 | + |
| 26 | + mermaid_lines.extend( |
| 27 | + [ |
| 28 | + f'{name}["{label}"]', |
| 29 | + f"{name}@{{ shape: {shape} }}", |
| 30 | + ] |
| 31 | + ) |
| 32 | + |
| 33 | + fillcolor = node.get_fillcolor() |
| 34 | + if fillcolor is not None and not fillcolor.startswith("#"): |
| 35 | + fillcolor = _color_to_hex(fillcolor) |
| 36 | + mermaid_lines.append(f"style {name} fill:{fillcolor}") |
| 37 | + |
| 38 | + mermaid_lines.append("%% Edges:") |
| 39 | + for edge in edges: |
| 40 | + source = edge.get_source() |
| 41 | + target = edge.get_destination() |
| 42 | + |
| 43 | + mermaid_lines.append(f"{source} --> {target}") |
| 44 | + |
| 45 | + return "\n".join(mermaid_lines) |
| 46 | + |
| 47 | + |
| 48 | +def _color_to_hex(color_name): |
| 49 | + """Based on the colors in d3viz module.""" |
| 50 | + return { |
| 51 | + "limegreen": "#32CD32", |
| 52 | + "SpringGreen": "#00FF7F", |
| 53 | + "YellowGreen": "#9ACD32", |
| 54 | + "dodgerblue": "#1E90FF", |
| 55 | + "lightgrey": "#D3D3D3", |
| 56 | + "yellow": "#FFFF00", |
| 57 | + "cyan": "#00FFFF", |
| 58 | + "magenta": "#FF00FF", |
| 59 | + "red": "#FF0000", |
| 60 | + "blue": "#0000FF", |
| 61 | + "green": "#008000", |
| 62 | + "grey": "#808080", |
| 63 | + }.get(color_name) |
0 commit comments