From 7c3a2abb4fc53da777b7ff6b899e48e99978f444 Mon Sep 17 00:00:00 2001 From: Will Dean Date: Tue, 24 Jun 2025 17:43:46 -0400 Subject: [PATCH 1/4] standardize mermaid node names --- pymc/model_graph.py | 34 ++++++++++++++++++++-------------- 1 file changed, 20 insertions(+), 14 deletions(-) diff --git a/pymc/model_graph.py b/pymc/model_graph.py index 9ec8a502d..291e14582 100644 --- a/pymc/model_graph.py +++ b/pymc/model_graph.py @@ -795,36 +795,42 @@ def model_to_graphviz( ) +def _create_mermaid_node_name(name: str) -> str: + return name.replace(":", "_").replace(" ", "_") + + def _build_mermaid_node(node: NodeInfo) -> list[str]: var = node.var node_type = node.node_type + name = var.name + node_name = _create_mermaid_node_name(name) if node_type == NodeType.DATA: return [ - f"{var.name}[{var.name} ~ Data]", - f"{var.name}@{{ shape: db }}", + f"{node_name}[{var.name} ~ Data]", + f"{node_name}@{{ shape: db }}", ] elif node_type == NodeType.OBSERVED_RV: return [ - f"{var.name}([{var.name} ~ {random_variable_symbol(var)}])", - f"{var.name}@{{ shape: rounded }}", - f"style {var.name} fill:#757575", + f"{node_name}([{name} ~ {random_variable_symbol(var)}])", + f"{node_name}@{{ shape: rounded }}", + f"style {node_name} fill:#757575", ] elif node_type == NodeType.FREE_RV: return [ - f"{var.name}([{var.name} ~ {random_variable_symbol(var)}])", - f"{var.name}@{{ shape: rounded }}", + f"{node_name}([{name} ~ {random_variable_symbol(var)}])", + f"{node_name}@{{ shape: rounded }}", ] elif node_type == NodeType.DETERMINISTIC: return [ - f"{var.name}([{var.name} ~ Deterministic])", - f"{var.name}@{{ shape: rect }}", + f"{node_name}([{name} ~ Deterministic])", + f"{node_name}@{{ shape: rect }}", ] elif node_type == NodeType.POTENTIAL: return [ - f"{var.name}([{var.name} ~ Potential])", - f"{var.name}@{{ shape: diam }}", - f"style {var.name} fill:#f0f0f0", + f"{node_name}([{name} ~ Potential])", + f"{node_name}@{{ shape: diam }}", + f"style {node_name} fill:#f0f0f0", ] return [] @@ -842,8 +848,8 @@ def _build_mermaid_edges(edges) -> list[str]: """Return a list of Mermaid edge definitions.""" edge_lines = [] for child, parent in edges: - child_id = str(child).replace(":", "_") - parent_id = str(parent).replace(":", "_") + child_id = _create_mermaid_node_name(child) + parent_id = _create_mermaid_node_name(parent) edge_lines.append(f"{parent_id} --> {child_id}") return edge_lines From 9090e1f4771b663c188e2d0a79fe196daadc1038 Mon Sep 17 00:00:00 2001 From: Will Dean Date: Tue, 24 Jun 2025 17:44:11 -0400 Subject: [PATCH 2/4] add test with space in variable name --- tests/test_model_graph.py | 22 ++++++++++++++++++++++ 1 file changed, 22 insertions(+) diff --git a/tests/test_model_graph.py b/tests/test_model_graph.py index 35a202215..6ffbce5b3 100644 --- a/tests/test_model_graph.py +++ b/tests/test_model_graph.py @@ -536,6 +536,14 @@ def simple_model() -> pm.Model: return model +@pytest.fixture(scope="module") +def variable_with_space(): + with pm.Model() as model: + pm.Normal("plant growth") + + return model + + def test_unknown_node_type(simple_model): with pytest.raises(ValueError, match="Node formatters must be of type NodeType."): model_to_graphviz(simple_model, node_formatters={"Unknown Node Type": "dummy"}) @@ -652,3 +660,17 @@ def test_model_to_mermaid(simple_model): %% Plates: """) assert model_to_mermaid(simple_model) == expected_mermaid_string.strip() + + +def test_model_to_mermaid_with_variable_with_space(variable_with_space): + expected_mermaid_string = dedent(""" + graph TD + %% Nodes: + plant_growth([plant growth ~ Normal]) + plant_growth@{ shape: rounded } + + %% Edges: + + %% Plates: + """) + assert model_to_mermaid(variable_with_space) == expected_mermaid_string.strip() From d51d4d007e07ae96f6584c9ccc9f707ba30a4afd Mon Sep 17 00:00:00 2001 From: Will Dean Date: Wed, 25 Jun 2025 06:55:00 -0400 Subject: [PATCH 3/4] remove use of fixture --- tests/test_model_graph.py | 13 ++++--------- 1 file changed, 4 insertions(+), 9 deletions(-) diff --git a/tests/test_model_graph.py b/tests/test_model_graph.py index 6ffbce5b3..1ea1b694b 100644 --- a/tests/test_model_graph.py +++ b/tests/test_model_graph.py @@ -536,14 +536,6 @@ def simple_model() -> pm.Model: return model -@pytest.fixture(scope="module") -def variable_with_space(): - with pm.Model() as model: - pm.Normal("plant growth") - - return model - - def test_unknown_node_type(simple_model): with pytest.raises(ValueError, match="Node formatters must be of type NodeType."): model_to_graphviz(simple_model, node_formatters={"Unknown Node Type": "dummy"}) @@ -662,7 +654,10 @@ def test_model_to_mermaid(simple_model): assert model_to_mermaid(simple_model) == expected_mermaid_string.strip() -def test_model_to_mermaid_with_variable_with_space(variable_with_space): +def test_model_to_mermaid_with_variable_with_space(): + with pm.Model() as variable_with_space: + pm.Normal("plant growth") + expected_mermaid_string = dedent(""" graph TD %% Nodes: From a7bbe34586bd957209cba4f01e717f54affb6cec Mon Sep 17 00:00:00 2001 From: Will Dean Date: Wed, 25 Jun 2025 06:57:02 -0400 Subject: [PATCH 4/4] use the "when in doubt, cast" strategy --- pymc/model_graph.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pymc/model_graph.py b/pymc/model_graph.py index 291e14582..e62e5e244 100644 --- a/pymc/model_graph.py +++ b/pymc/model_graph.py @@ -802,7 +802,7 @@ def _create_mermaid_node_name(name: str) -> str: def _build_mermaid_node(node: NodeInfo) -> list[str]: var = node.var node_type = node.node_type - name = var.name + name = cast(str, var.name) node_name = _create_mermaid_node_name(name) if node_type == NodeType.DATA: return [