Skip to content

Commit 7c3a2ab

Browse files
committed
standardize mermaid node names
1 parent 4271195 commit 7c3a2ab

File tree

1 file changed

+20
-14
lines changed

1 file changed

+20
-14
lines changed

pymc/model_graph.py

Lines changed: 20 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -795,36 +795,42 @@ def model_to_graphviz(
795795
)
796796

797797

798+
def _create_mermaid_node_name(name: str) -> str:
799+
return name.replace(":", "_").replace(" ", "_")
800+
801+
798802
def _build_mermaid_node(node: NodeInfo) -> list[str]:
799803
var = node.var
800804
node_type = node.node_type
805+
name = var.name
806+
node_name = _create_mermaid_node_name(name)
801807
if node_type == NodeType.DATA:
802808
return [
803-
f"{var.name}[{var.name} ~ Data]",
804-
f"{var.name}@{{ shape: db }}",
809+
f"{node_name}[{var.name} ~ Data]",
810+
f"{node_name}@{{ shape: db }}",
805811
]
806812
elif node_type == NodeType.OBSERVED_RV:
807813
return [
808-
f"{var.name}([{var.name} ~ {random_variable_symbol(var)}])",
809-
f"{var.name}@{{ shape: rounded }}",
810-
f"style {var.name} fill:#757575",
814+
f"{node_name}([{name} ~ {random_variable_symbol(var)}])",
815+
f"{node_name}@{{ shape: rounded }}",
816+
f"style {node_name} fill:#757575",
811817
]
812818

813819
elif node_type == NodeType.FREE_RV:
814820
return [
815-
f"{var.name}([{var.name} ~ {random_variable_symbol(var)}])",
816-
f"{var.name}@{{ shape: rounded }}",
821+
f"{node_name}([{name} ~ {random_variable_symbol(var)}])",
822+
f"{node_name}@{{ shape: rounded }}",
817823
]
818824
elif node_type == NodeType.DETERMINISTIC:
819825
return [
820-
f"{var.name}([{var.name} ~ Deterministic])",
821-
f"{var.name}@{{ shape: rect }}",
826+
f"{node_name}([{name} ~ Deterministic])",
827+
f"{node_name}@{{ shape: rect }}",
822828
]
823829
elif node_type == NodeType.POTENTIAL:
824830
return [
825-
f"{var.name}([{var.name} ~ Potential])",
826-
f"{var.name}@{{ shape: diam }}",
827-
f"style {var.name} fill:#f0f0f0",
831+
f"{node_name}([{name} ~ Potential])",
832+
f"{node_name}@{{ shape: diam }}",
833+
f"style {node_name} fill:#f0f0f0",
828834
]
829835

830836
return []
@@ -842,8 +848,8 @@ def _build_mermaid_edges(edges) -> list[str]:
842848
"""Return a list of Mermaid edge definitions."""
843849
edge_lines = []
844850
for child, parent in edges:
845-
child_id = str(child).replace(":", "_")
846-
parent_id = str(parent).replace(":", "_")
851+
child_id = _create_mermaid_node_name(child)
852+
parent_id = _create_mermaid_node_name(parent)
847853
edge_lines.append(f"{parent_id} --> {child_id}")
848854
return edge_lines
849855

0 commit comments

Comments
 (0)