Skip to content

Commit 0a1d81c

Browse files
committed
implement mermaid logic
1 parent cd2e1a3 commit 0a1d81c

File tree

1 file changed

+136
-0
lines changed

1 file changed

+136
-0
lines changed

pymc/model_graph.py

Lines changed: 136 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -429,6 +429,14 @@ def edges(
429429
for parent in parents
430430
]
431431

432+
def nodes(self, plates: list[Plate] | None = None) -> list[NodeInfo]:
433+
"""Get all nodes in the model graph."""
434+
plates = plates or self.get_plates()
435+
nodes = []
436+
for plate in plates:
437+
nodes.extend(plate.variables)
438+
return nodes
439+
432440

433441
def make_graph(
434442
name: str,
@@ -785,3 +793,131 @@ def model_to_graphviz(
785793
if include_dim_lengths
786794
else create_plate_label_without_dim_length,
787795
)
796+
797+
798+
def _build_mermaid_node(node: NodeInfo) -> list[str]:
799+
var = node.var
800+
node_type = node.node_type
801+
if node_type == NodeType.DATA:
802+
return [
803+
f"{var.name}[{var.name} ~ Data]",
804+
f"{var.name}@{{ shape: db }}",
805+
]
806+
elif node_type == NodeType.OBSERVED_RV:
807+
return [
808+
f"{var.name}([{var.name} ~ {random_variable_symbol(var)}])",
809+
f"{var.name}@{{ shape: rounded }}",
810+
f"style {var.name} fill:#757575",
811+
]
812+
813+
elif node_type == NodeType.FREE_RV:
814+
return [
815+
f"{var.name}([{var.name} ~ {random_variable_symbol(var)}])",
816+
f"{var.name}@{{ shape: rounded }}",
817+
]
818+
elif node_type == NodeType.DETERMINISTIC:
819+
return [
820+
f"{var.name}([{var.name} ~ Deterministic])",
821+
f"{var.name}@{{ shape: rect }}",
822+
]
823+
elif node_type == NodeType.POTENTIAL:
824+
return [
825+
f"{var.name}([{var.name} ~ Potential])",
826+
f"{var.name}@{{ shape: diam }}",
827+
f"style {var.name} fill:#f0f0f0",
828+
]
829+
830+
return []
831+
832+
833+
def _build_mermaid_nodes(nodes) -> list[str]:
834+
node_lines = []
835+
for node in nodes:
836+
node_lines.extend(_build_mermaid_node(node))
837+
838+
return node_lines
839+
840+
841+
def _build_mermaid_edges(edges) -> list[str]:
842+
"""Return a list of Mermaid edge definitions."""
843+
edge_lines = []
844+
for child, parent in edges:
845+
child_id = str(child).replace(":", "_")
846+
parent_id = str(parent).replace(":", "_")
847+
edge_lines.append(f"{parent_id} --> {child_id}")
848+
return edge_lines
849+
850+
851+
def _build_mermaid_plates(plates, include_dim_lengths) -> list[str]:
852+
plate_lines = []
853+
for plate in plates:
854+
if not plate.dim_info:
855+
continue
856+
857+
plate_label_func = (
858+
create_plate_label_with_dim_length
859+
if include_dim_lengths
860+
else create_plate_label_without_dim_length
861+
)
862+
plate_label = plate_label_func(plate.dim_info)
863+
plate_name = f'subgraph "{plate_label}"'
864+
plate_lines.append(plate_name)
865+
for var in plate.variables:
866+
plate_lines.append(f" {var.var.name}")
867+
plate_lines.append("end")
868+
869+
return plate_lines
870+
871+
872+
def model_to_mermaid(model=None, *, var_names=None, include_dim_lengths: bool = True) -> str:
873+
"""Produce a Mermaid diagram string from a PyMC model.
874+
875+
Parameters
876+
----------
877+
model : pm.Model
878+
The model to plot. Not required when called from inside a modelcontext.
879+
var_names : iterable of variable names, optional
880+
Subset of variables to be plotted that identify a subgraph with respect to the entire model graph
881+
include_dim_lengths : bool
882+
Include the dim lengths in the plate label. Default is True.
883+
884+
Returns
885+
-------
886+
str
887+
Mermaid diagram string representing the model graph.
888+
889+
Examples
890+
--------
891+
Visualize a simple PyMC model
892+
893+
.. code-block:: python
894+
895+
import pymc as pm
896+
897+
with pm.Model() as model:
898+
mu = pm.Normal("mu", mu=0, sigma=1)
899+
sigma = pm.HalfNormal("sigma", sigma=1)
900+
901+
pm.Normal("obs", mu=mu, sigma=sigma, observed=[1, 2, 3])
902+
903+
print(pm.model_to_mermaid(model))
904+
905+
906+
"""
907+
model = pm.modelcontext(model)
908+
graph = ModelGraph(model)
909+
plates = graph.get_plates(var_names=var_names)
910+
edges = graph.edges(var_names=var_names)
911+
nodes = graph.nodes(plates=plates)
912+
913+
return "\n".join(
914+
[
915+
"graph TD",
916+
"%% Nodes:",
917+
*_build_mermaid_nodes(nodes),
918+
"\n%% Edges:",
919+
*_build_mermaid_edges(edges),
920+
"\n%% Plates:",
921+
*_build_mermaid_plates(plates, include_dim_lengths=include_dim_lengths),
922+
]
923+
)

0 commit comments

Comments
 (0)