diff --git a/pymc/__init__.py b/pymc/__init__.py index 684feac11..69a29c97e 100644 --- a/pymc/__init__.py +++ b/pymc/__init__.py @@ -66,7 +66,7 @@ def __set_compiler_flags(): ) from pymc.model.core import * from pymc.model.transform.conditioning import do, observe -from pymc.model_graph import model_to_graphviz, model_to_networkx +from pymc.model_graph import model_to_graphviz, model_to_mermaid, model_to_networkx from pymc.plots import * from pymc.printing import * from pymc.pytensorf import * diff --git a/pymc/model_graph.py b/pymc/model_graph.py index 3a241948a..9ec8a502d 100644 --- a/pymc/model_graph.py +++ b/pymc/model_graph.py @@ -429,6 +429,14 @@ def edges( for parent in parents ] + def nodes(self, plates: list[Plate] | None = None) -> list[NodeInfo]: + """Get all nodes in the model graph.""" + plates = plates or self.get_plates() + nodes = [] + for plate in plates: + nodes.extend(plate.variables) + return nodes + def make_graph( name: str, @@ -785,3 +793,131 @@ def model_to_graphviz( if include_dim_lengths else create_plate_label_without_dim_length, ) + + +def _build_mermaid_node(node: NodeInfo) -> list[str]: + var = node.var + node_type = node.node_type + if node_type == NodeType.DATA: + return [ + f"{var.name}[{var.name} ~ Data]", + f"{var.name}@{{ shape: db }}", + ] + elif node_type == NodeType.OBSERVED_RV: + return [ + f"{var.name}([{var.name} ~ {random_variable_symbol(var)}])", + f"{var.name}@{{ shape: rounded }}", + f"style {var.name} fill:#757575", + ] + + elif node_type == NodeType.FREE_RV: + return [ + f"{var.name}([{var.name} ~ {random_variable_symbol(var)}])", + f"{var.name}@{{ shape: rounded }}", + ] + elif node_type == NodeType.DETERMINISTIC: + return [ + f"{var.name}([{var.name} ~ Deterministic])", + f"{var.name}@{{ shape: rect }}", + ] + elif node_type == NodeType.POTENTIAL: + return [ + f"{var.name}([{var.name} ~ Potential])", + f"{var.name}@{{ shape: diam }}", + f"style {var.name} fill:#f0f0f0", + ] + + return [] + + +def _build_mermaid_nodes(nodes) -> list[str]: + node_lines = [] + for node in nodes: + node_lines.extend(_build_mermaid_node(node)) + + return node_lines + + +def _build_mermaid_edges(edges) -> list[str]: + """Return a list of Mermaid edge definitions.""" + edge_lines = [] + for child, parent in edges: + child_id = str(child).replace(":", "_") + parent_id = str(parent).replace(":", "_") + edge_lines.append(f"{parent_id} --> {child_id}") + return edge_lines + + +def _build_mermaid_plates(plates, include_dim_lengths) -> list[str]: + plate_lines = [] + for plate in plates: + if not plate.dim_info: + continue + + plate_label_func = ( + create_plate_label_with_dim_length + if include_dim_lengths + else create_plate_label_without_dim_length + ) + plate_label = plate_label_func(plate.dim_info) + plate_name = f'subgraph "{plate_label}"' + plate_lines.append(plate_name) + for var in plate.variables: + plate_lines.append(f" {var.var.name}") + plate_lines.append("end") + + return plate_lines + + +def model_to_mermaid(model=None, *, var_names=None, include_dim_lengths: bool = True) -> str: + """Produce a Mermaid diagram string from a PyMC model. + + Parameters + ---------- + model : pm.Model + The model to plot. Not required when called from inside a modelcontext. + var_names : iterable of variable names, optional + Subset of variables to be plotted that identify a subgraph with respect to the entire model graph + include_dim_lengths : bool + Include the dim lengths in the plate label. Default is True. + + Returns + ------- + str + Mermaid diagram string representing the model graph. + + Examples + -------- + Visualize a simple PyMC model + + .. code-block:: python + + import pymc as pm + + with pm.Model() as model: + mu = pm.Normal("mu", mu=0, sigma=1) + sigma = pm.HalfNormal("sigma", sigma=1) + + pm.Normal("obs", mu=mu, sigma=sigma, observed=[1, 2, 3]) + + print(pm.model_to_mermaid(model)) + + + """ + model = pm.modelcontext(model) + graph = ModelGraph(model) + plates = sorted(graph.get_plates(var_names=var_names), key=lambda plate: hash(plate.dim_info)) + edges = sorted(graph.edges(var_names=var_names)) + nodes = sorted(graph.nodes(plates=plates), key=lambda node: cast(str, node.var.name)) + + return "\n".join( + [ + "graph TD", + "%% Nodes:", + *_build_mermaid_nodes(nodes), + "\n%% Edges:", + *_build_mermaid_edges(edges), + "\n%% Plates:", + *_build_mermaid_plates(plates, include_dim_lengths=include_dim_lengths), + ] + ) diff --git a/tests/test_model_graph.py b/tests/test_model_graph.py index 70c10d410..35a202215 100644 --- a/tests/test_model_graph.py +++ b/tests/test_model_graph.py @@ -13,6 +13,8 @@ # limitations under the License. import warnings +from textwrap import dedent + import numpy as np import pytensor import pytensor.tensor as pt @@ -31,6 +33,7 @@ NodeType, Plate, model_to_graphviz, + model_to_mermaid, model_to_networkx, ) @@ -629,3 +632,23 @@ def test_scalars_dim_info() -> None: ] assert graph.edges() == [] + + +def test_model_to_mermaid(simple_model): + expected_mermaid_string = dedent(""" + graph TD + %% Nodes: + a([a ~ Normal]) + a@{ shape: rounded } + b([b ~ Normal]) + b@{ shape: rounded } + c([c ~ Normal]) + c@{ shape: rounded } + + %% Edges: + a --> b + b --> c + + %% Plates: + """) + assert model_to_mermaid(simple_model) == expected_mermaid_string.strip()