Skip to content

Commit f4bdc6c

Browse files
authored
Convert model to mermaid diagram
1 parent 3b8500b commit f4bdc6c

File tree

3 files changed

+160
-1
lines changed

3 files changed

+160
-1
lines changed

pymc/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -66,7 +66,7 @@ def __set_compiler_flags():
6666
)
6767
from pymc.model.core import *
6868
from pymc.model.transform.conditioning import do, observe
69-
from pymc.model_graph import model_to_graphviz, model_to_networkx
69+
from pymc.model_graph import model_to_graphviz, model_to_mermaid, model_to_networkx
7070
from pymc.plots import *
7171
from pymc.printing import *
7272
from pymc.pytensorf import *

pymc/model_graph.py

Lines changed: 136 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -429,6 +429,14 @@ def edges(
429429
for parent in parents
430430
]
431431

432+
def nodes(self, plates: list[Plate] | None = None) -> list[NodeInfo]:
433+
"""Get all nodes in the model graph."""
434+
plates = plates or self.get_plates()
435+
nodes = []
436+
for plate in plates:
437+
nodes.extend(plate.variables)
438+
return nodes
439+
432440

433441
def make_graph(
434442
name: str,
@@ -785,3 +793,131 @@ def model_to_graphviz(
785793
if include_dim_lengths
786794
else create_plate_label_without_dim_length,
787795
)
796+
797+
798+
def _build_mermaid_node(node: NodeInfo) -> list[str]:
799+
var = node.var
800+
node_type = node.node_type
801+
if node_type == NodeType.DATA:
802+
return [
803+
f"{var.name}[{var.name} ~ Data]",
804+
f"{var.name}@{{ shape: db }}",
805+
]
806+
elif node_type == NodeType.OBSERVED_RV:
807+
return [
808+
f"{var.name}([{var.name} ~ {random_variable_symbol(var)}])",
809+
f"{var.name}@{{ shape: rounded }}",
810+
f"style {var.name} fill:#757575",
811+
]
812+
813+
elif node_type == NodeType.FREE_RV:
814+
return [
815+
f"{var.name}([{var.name} ~ {random_variable_symbol(var)}])",
816+
f"{var.name}@{{ shape: rounded }}",
817+
]
818+
elif node_type == NodeType.DETERMINISTIC:
819+
return [
820+
f"{var.name}([{var.name} ~ Deterministic])",
821+
f"{var.name}@{{ shape: rect }}",
822+
]
823+
elif node_type == NodeType.POTENTIAL:
824+
return [
825+
f"{var.name}([{var.name} ~ Potential])",
826+
f"{var.name}@{{ shape: diam }}",
827+
f"style {var.name} fill:#f0f0f0",
828+
]
829+
830+
return []
831+
832+
833+
def _build_mermaid_nodes(nodes) -> list[str]:
834+
node_lines = []
835+
for node in nodes:
836+
node_lines.extend(_build_mermaid_node(node))
837+
838+
return node_lines
839+
840+
841+
def _build_mermaid_edges(edges) -> list[str]:
842+
"""Return a list of Mermaid edge definitions."""
843+
edge_lines = []
844+
for child, parent in edges:
845+
child_id = str(child).replace(":", "_")
846+
parent_id = str(parent).replace(":", "_")
847+
edge_lines.append(f"{parent_id} --> {child_id}")
848+
return edge_lines
849+
850+
851+
def _build_mermaid_plates(plates, include_dim_lengths) -> list[str]:
852+
plate_lines = []
853+
for plate in plates:
854+
if not plate.dim_info:
855+
continue
856+
857+
plate_label_func = (
858+
create_plate_label_with_dim_length
859+
if include_dim_lengths
860+
else create_plate_label_without_dim_length
861+
)
862+
plate_label = plate_label_func(plate.dim_info)
863+
plate_name = f'subgraph "{plate_label}"'
864+
plate_lines.append(plate_name)
865+
for var in plate.variables:
866+
plate_lines.append(f" {var.var.name}")
867+
plate_lines.append("end")
868+
869+
return plate_lines
870+
871+
872+
def model_to_mermaid(model=None, *, var_names=None, include_dim_lengths: bool = True) -> str:
873+
"""Produce a Mermaid diagram string from a PyMC model.
874+
875+
Parameters
876+
----------
877+
model : pm.Model
878+
The model to plot. Not required when called from inside a modelcontext.
879+
var_names : iterable of variable names, optional
880+
Subset of variables to be plotted that identify a subgraph with respect to the entire model graph
881+
include_dim_lengths : bool
882+
Include the dim lengths in the plate label. Default is True.
883+
884+
Returns
885+
-------
886+
str
887+
Mermaid diagram string representing the model graph.
888+
889+
Examples
890+
--------
891+
Visualize a simple PyMC model
892+
893+
.. code-block:: python
894+
895+
import pymc as pm
896+
897+
with pm.Model() as model:
898+
mu = pm.Normal("mu", mu=0, sigma=1)
899+
sigma = pm.HalfNormal("sigma", sigma=1)
900+
901+
pm.Normal("obs", mu=mu, sigma=sigma, observed=[1, 2, 3])
902+
903+
print(pm.model_to_mermaid(model))
904+
905+
906+
"""
907+
model = pm.modelcontext(model)
908+
graph = ModelGraph(model)
909+
plates = sorted(graph.get_plates(var_names=var_names), key=lambda plate: hash(plate.dim_info))
910+
edges = sorted(graph.edges(var_names=var_names))
911+
nodes = sorted(graph.nodes(plates=plates), key=lambda node: cast(str, node.var.name))
912+
913+
return "\n".join(
914+
[
915+
"graph TD",
916+
"%% Nodes:",
917+
*_build_mermaid_nodes(nodes),
918+
"\n%% Edges:",
919+
*_build_mermaid_edges(edges),
920+
"\n%% Plates:",
921+
*_build_mermaid_plates(plates, include_dim_lengths=include_dim_lengths),
922+
]
923+
)

tests/test_model_graph.py

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,8 @@
1313
# limitations under the License.
1414
import warnings
1515

16+
from textwrap import dedent
17+
1618
import numpy as np
1719
import pytensor
1820
import pytensor.tensor as pt
@@ -31,6 +33,7 @@
3133
NodeType,
3234
Plate,
3335
model_to_graphviz,
36+
model_to_mermaid,
3437
model_to_networkx,
3538
)
3639

@@ -629,3 +632,23 @@ def test_scalars_dim_info() -> None:
629632
]
630633

631634
assert graph.edges() == []
635+
636+
637+
def test_model_to_mermaid(simple_model):
638+
expected_mermaid_string = dedent("""
639+
graph TD
640+
%% Nodes:
641+
a([a ~ Normal])
642+
a@{ shape: rounded }
643+
b([b ~ Normal])
644+
b@{ shape: rounded }
645+
c([c ~ Normal])
646+
c@{ shape: rounded }
647+
648+
%% Edges:
649+
a --> b
650+
b --> c
651+
652+
%% Plates:
653+
""")
654+
assert model_to_mermaid(simple_model) == expected_mermaid_string.strip()

0 commit comments

Comments
 (0)