@@ -429,6 +429,14 @@ def edges(
429429 for parent in parents
430430 ]
431431
432+ def nodes (self , plates : list [Plate ] | None = None ) -> list [NodeInfo ]:
433+ """Get all nodes in the model graph."""
434+ plates = plates or self .get_plates ()
435+ nodes = []
436+ for plate in plates :
437+ nodes .extend (plate .variables )
438+ return nodes
439+
432440
433441def make_graph (
434442 name : str ,
@@ -785,3 +793,131 @@ def model_to_graphviz(
785793 if include_dim_lengths
786794 else create_plate_label_without_dim_length ,
787795 )
796+
797+
798+ def _build_mermaid_node (node : NodeInfo ) -> list [str ]:
799+ var = node .var
800+ node_type = node .node_type
801+ if node_type == NodeType .DATA :
802+ return [
803+ f"{ var .name } [{ var .name } ~ Data]" ,
804+ f"{ var .name } @{{ shape: db }}" ,
805+ ]
806+ elif node_type == NodeType .OBSERVED_RV :
807+ return [
808+ f"{ var .name } ([{ var .name } ~ { random_variable_symbol (var )} ])" ,
809+ f"{ var .name } @{{ shape: rounded }}" ,
810+ f"style { var .name } fill:#757575" ,
811+ ]
812+
813+ elif node_type == NodeType .FREE_RV :
814+ return [
815+ f"{ var .name } ([{ var .name } ~ { random_variable_symbol (var )} ])" ,
816+ f"{ var .name } @{{ shape: rounded }}" ,
817+ ]
818+ elif node_type == NodeType .DETERMINISTIC :
819+ return [
820+ f"{ var .name } ([{ var .name } ~ Deterministic])" ,
821+ f"{ var .name } @{{ shape: rect }}" ,
822+ ]
823+ elif node_type == NodeType .POTENTIAL :
824+ return [
825+ f"{ var .name } ([{ var .name } ~ Potential])" ,
826+ f"{ var .name } @{{ shape: diam }}" ,
827+ f"style { var .name } fill:#f0f0f0" ,
828+ ]
829+
830+ return []
831+
832+
833+ def _build_mermaid_nodes (nodes ) -> list [str ]:
834+ node_lines = []
835+ for node in nodes :
836+ node_lines .extend (_build_mermaid_node (node ))
837+
838+ return node_lines
839+
840+
841+ def _build_mermaid_edges (edges ) -> list [str ]:
842+ """Return a list of Mermaid edge definitions."""
843+ edge_lines = []
844+ for child , parent in edges :
845+ child_id = str (child ).replace (":" , "_" )
846+ parent_id = str (parent ).replace (":" , "_" )
847+ edge_lines .append (f"{ parent_id } --> { child_id } " )
848+ return edge_lines
849+
850+
851+ def _build_mermaid_plates (plates , include_dim_lengths ) -> list [str ]:
852+ plate_lines = []
853+ for plate in plates :
854+ if not plate .dim_info :
855+ continue
856+
857+ plate_label_func = (
858+ create_plate_label_with_dim_length
859+ if include_dim_lengths
860+ else create_plate_label_without_dim_length
861+ )
862+ plate_label = plate_label_func (plate .dim_info )
863+ plate_name = f'subgraph "{ plate_label } "'
864+ plate_lines .append (plate_name )
865+ for var in plate .variables :
866+ plate_lines .append (f" { var .var .name } " )
867+ plate_lines .append ("end" )
868+
869+ return plate_lines
870+
871+
872+ def model_to_mermaid (model = None , * , var_names = None , include_dim_lengths : bool = True ) -> str :
873+ """Produce a Mermaid diagram string from a PyMC model.
874+
875+ Parameters
876+ ----------
877+ model : pm.Model
878+ The model to plot. Not required when called from inside a modelcontext.
879+ var_names : iterable of variable names, optional
880+ Subset of variables to be plotted that identify a subgraph with respect to the entire model graph
881+ include_dim_lengths : bool
882+ Include the dim lengths in the plate label. Default is True.
883+
884+ Returns
885+ -------
886+ str
887+ Mermaid diagram string representing the model graph.
888+
889+ Examples
890+ --------
891+ Visualize a simple PyMC model
892+
893+ .. code-block:: python
894+
895+ import pymc as pm
896+
897+ with pm.Model() as model:
898+ mu = pm.Normal("mu", mu=0, sigma=1)
899+ sigma = pm.HalfNormal("sigma", sigma=1)
900+
901+ pm.Normal("obs", mu=mu, sigma=sigma, observed=[1, 2, 3])
902+
903+ print(pm.model_to_mermaid(model))
904+
905+
906+ """
907+ model = pm .modelcontext (model )
908+ graph = ModelGraph (model )
909+ plates = sorted (graph .get_plates (var_names = var_names ), key = lambda plate : hash (plate .dim_info ))
910+ edges = sorted (graph .edges (var_names = var_names ))
911+ nodes = sorted (graph .nodes (plates = plates ), key = lambda node : cast (str , node .var .name ))
912+
913+ return "\n " .join (
914+ [
915+ "graph TD" ,
916+ "%% Nodes:" ,
917+ * _build_mermaid_nodes (nodes ),
918+ "\n %% Edges:" ,
919+ * _build_mermaid_edges (edges ),
920+ "\n %% Plates:" ,
921+ * _build_mermaid_plates (plates , include_dim_lengths = include_dim_lengths ),
922+ ]
923+ )
0 commit comments