Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion pymc/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,7 @@ def __set_compiler_flags():
)
from pymc.model.core import *
from pymc.model.transform.conditioning import do, observe
from pymc.model_graph import model_to_graphviz, model_to_networkx
from pymc.model_graph import model_to_graphviz, model_to_mermaid, model_to_networkx
from pymc.plots import *
from pymc.printing import *
from pymc.pytensorf import *
Expand Down
136 changes: 136 additions & 0 deletions pymc/model_graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -429,6 +429,14 @@
for parent in parents
]

def nodes(self, plates: list[Plate] | None = None) -> list[NodeInfo]:
"""Get all nodes in the model graph."""
plates = plates or self.get_plates()
nodes = []
for plate in plates:
nodes.extend(plate.variables)
return nodes


def make_graph(
name: str,
Expand Down Expand Up @@ -785,3 +793,131 @@
if include_dim_lengths
else create_plate_label_without_dim_length,
)


def _build_mermaid_node(node: NodeInfo) -> list[str]:
var = node.var
node_type = node.node_type
if node_type == NodeType.DATA:
return [

Check warning on line 802 in pymc/model_graph.py

View check run for this annotation

Codecov / codecov/patch

pymc/model_graph.py#L802

Added line #L802 was not covered by tests
f"{var.name}[{var.name} ~ Data]",
f"{var.name}@{{ shape: db }}",
]
elif node_type == NodeType.OBSERVED_RV:
return [

Check warning on line 807 in pymc/model_graph.py

View check run for this annotation

Codecov / codecov/patch

pymc/model_graph.py#L807

Added line #L807 was not covered by tests
f"{var.name}([{var.name} ~ {random_variable_symbol(var)}])",
f"{var.name}@{{ shape: rounded }}",
f"style {var.name} fill:#757575",
]

elif node_type == NodeType.FREE_RV:
return [
f"{var.name}([{var.name} ~ {random_variable_symbol(var)}])",
f"{var.name}@{{ shape: rounded }}",
]
elif node_type == NodeType.DETERMINISTIC:
return [

Check warning on line 819 in pymc/model_graph.py

View check run for this annotation

Codecov / codecov/patch

pymc/model_graph.py#L818-L819

Added lines #L818 - L819 were not covered by tests
f"{var.name}([{var.name} ~ Deterministic])",
f"{var.name}@{{ shape: rect }}",
]
elif node_type == NodeType.POTENTIAL:
return [

Check warning on line 824 in pymc/model_graph.py

View check run for this annotation

Codecov / codecov/patch

pymc/model_graph.py#L823-L824

Added lines #L823 - L824 were not covered by tests
f"{var.name}([{var.name} ~ Potential])",
f"{var.name}@{{ shape: diam }}",
f"style {var.name} fill:#f0f0f0",
]

return []

Check warning on line 830 in pymc/model_graph.py

View check run for this annotation

Codecov / codecov/patch

pymc/model_graph.py#L830

Added line #L830 was not covered by tests


def _build_mermaid_nodes(nodes) -> list[str]:
node_lines = []
for node in nodes:
node_lines.extend(_build_mermaid_node(node))

return node_lines


def _build_mermaid_edges(edges) -> list[str]:
"""Return a list of Mermaid edge definitions."""
edge_lines = []
for child, parent in edges:
child_id = str(child).replace(":", "_")
parent_id = str(parent).replace(":", "_")
edge_lines.append(f"{parent_id} --> {child_id}")
return edge_lines


def _build_mermaid_plates(plates, include_dim_lengths) -> list[str]:
plate_lines = []
for plate in plates:
if not plate.dim_info:
continue

plate_label_func = (

Check warning on line 857 in pymc/model_graph.py

View check run for this annotation

Codecov / codecov/patch

pymc/model_graph.py#L857

Added line #L857 was not covered by tests
create_plate_label_with_dim_length
if include_dim_lengths
else create_plate_label_without_dim_length
)
plate_label = plate_label_func(plate.dim_info)
plate_name = f'subgraph "{plate_label}"'
plate_lines.append(plate_name)
for var in plate.variables:
plate_lines.append(f" {var.var.name}")
plate_lines.append("end")

Check warning on line 867 in pymc/model_graph.py

View check run for this annotation

Codecov / codecov/patch

pymc/model_graph.py#L862-L867

Added lines #L862 - L867 were not covered by tests

return plate_lines


def model_to_mermaid(model=None, *, var_names=None, include_dim_lengths: bool = True) -> str:
"""Produce a Mermaid diagram string from a PyMC model.

Parameters
----------
model : pm.Model
The model to plot. Not required when called from inside a modelcontext.
var_names : iterable of variable names, optional
Subset of variables to be plotted that identify a subgraph with respect to the entire model graph
include_dim_lengths : bool
Include the dim lengths in the plate label. Default is True.

Returns
-------
str
Mermaid diagram string representing the model graph.

Examples
--------
Visualize a simple PyMC model

.. code-block:: python

import pymc as pm

with pm.Model() as model:
mu = pm.Normal("mu", mu=0, sigma=1)
sigma = pm.HalfNormal("sigma", sigma=1)

pm.Normal("obs", mu=mu, sigma=sigma, observed=[1, 2, 3])

print(pm.model_to_mermaid(model))


"""
model = pm.modelcontext(model)
graph = ModelGraph(model)
plates = sorted(graph.get_plates(var_names=var_names), key=lambda plate: hash(plate.dim_info))
edges = sorted(graph.edges(var_names=var_names))
nodes = sorted(graph.nodes(plates=plates), key=lambda node: cast(str, node.var.name))

return "\n".join(
[
"graph TD",
"%% Nodes:",
*_build_mermaid_nodes(nodes),
"\n%% Edges:",
*_build_mermaid_edges(edges),
"\n%% Plates:",
*_build_mermaid_plates(plates, include_dim_lengths=include_dim_lengths),
]
)
23 changes: 23 additions & 0 deletions tests/test_model_graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,8 @@
# limitations under the License.
import warnings

from textwrap import dedent

import numpy as np
import pytensor
import pytensor.tensor as pt
Expand All @@ -31,6 +33,7 @@
NodeType,
Plate,
model_to_graphviz,
model_to_mermaid,
model_to_networkx,
)

Expand Down Expand Up @@ -629,3 +632,23 @@ def test_scalars_dim_info() -> None:
]

assert graph.edges() == []


def test_model_to_mermaid(simple_model):
expected_mermaid_string = dedent("""
graph TD
%% Nodes:
a([a ~ Normal])
a@{ shape: rounded }
b([b ~ Normal])
b@{ shape: rounded }
c([c ~ Normal])
c@{ shape: rounded }

%% Edges:
a --> b
b --> c

%% Plates:
""")
assert model_to_mermaid(simple_model) == expected_mermaid_string.strip()