diff --git a/pymc/model_graph.py b/pymc/model_graph.py index 9ec8a502d..e62e5e244 100644 --- a/pymc/model_graph.py +++ b/pymc/model_graph.py @@ -795,36 +795,42 @@ def model_to_graphviz( ) +def _create_mermaid_node_name(name: str) -> str: + return name.replace(":", "_").replace(" ", "_") + + def _build_mermaid_node(node: NodeInfo) -> list[str]: var = node.var node_type = node.node_type + name = cast(str, var.name) + node_name = _create_mermaid_node_name(name) if node_type == NodeType.DATA: return [ - f"{var.name}[{var.name} ~ Data]", - f"{var.name}@{{ shape: db }}", + f"{node_name}[{var.name} ~ Data]", + f"{node_name}@{{ shape: db }}", ] elif node_type == NodeType.OBSERVED_RV: return [ - f"{var.name}([{var.name} ~ {random_variable_symbol(var)}])", - f"{var.name}@{{ shape: rounded }}", - f"style {var.name} fill:#757575", + f"{node_name}([{name} ~ {random_variable_symbol(var)}])", + f"{node_name}@{{ shape: rounded }}", + f"style {node_name} fill:#757575", ] elif node_type == NodeType.FREE_RV: return [ - f"{var.name}([{var.name} ~ {random_variable_symbol(var)}])", - f"{var.name}@{{ shape: rounded }}", + f"{node_name}([{name} ~ {random_variable_symbol(var)}])", + f"{node_name}@{{ shape: rounded }}", ] elif node_type == NodeType.DETERMINISTIC: return [ - f"{var.name}([{var.name} ~ Deterministic])", - f"{var.name}@{{ shape: rect }}", + f"{node_name}([{name} ~ Deterministic])", + f"{node_name}@{{ shape: rect }}", ] elif node_type == NodeType.POTENTIAL: return [ - f"{var.name}([{var.name} ~ Potential])", - f"{var.name}@{{ shape: diam }}", - f"style {var.name} fill:#f0f0f0", + f"{node_name}([{name} ~ Potential])", + f"{node_name}@{{ shape: diam }}", + f"style {node_name} fill:#f0f0f0", ] return [] @@ -842,8 +848,8 @@ def _build_mermaid_edges(edges) -> list[str]: """Return a list of Mermaid edge definitions.""" edge_lines = [] for child, parent in edges: - child_id = str(child).replace(":", "_") - parent_id = str(parent).replace(":", "_") + child_id = _create_mermaid_node_name(child) + parent_id = _create_mermaid_node_name(parent) edge_lines.append(f"{parent_id} --> {child_id}") return edge_lines diff --git a/tests/test_model_graph.py b/tests/test_model_graph.py index 35a202215..1ea1b694b 100644 --- a/tests/test_model_graph.py +++ b/tests/test_model_graph.py @@ -652,3 +652,20 @@ def test_model_to_mermaid(simple_model): %% Plates: """) assert model_to_mermaid(simple_model) == expected_mermaid_string.strip() + + +def test_model_to_mermaid_with_variable_with_space(): + with pm.Model() as variable_with_space: + pm.Normal("plant growth") + + expected_mermaid_string = dedent(""" + graph TD + %% Nodes: + plant_growth([plant growth ~ Normal]) + plant_growth@{ shape: rounded } + + %% Edges: + + %% Plates: + """) + assert model_to_mermaid(variable_with_space) == expected_mermaid_string.strip()