From 78fdd17222eccd1504e51ee2a2de673fb9ea82e7 Mon Sep 17 00:00:00 2001 From: Rob Zinkov Date: Sun, 30 Mar 2025 15:16:49 +0200 Subject: [PATCH 1/3] Add a way to pass in graph level attributes to graphviz --- pymc/model_graph.py | 13 +++++++++++++ 1 file changed, 13 insertions(+) diff --git a/pymc/model_graph.py b/pymc/model_graph.py index f31a33770..7cdcfd274 100644 --- a/pymc/model_graph.py +++ b/pymc/model_graph.py @@ -430,6 +430,9 @@ def edges( ] +GraphAttrMapping = dict[Any, Any] + + def make_graph( name: str, plates: list[Plate], @@ -439,6 +442,7 @@ def make_graph( figsize=None, dpi=300, node_formatters: NodeTypeFormatterMapping | None = None, + graph_attrs: GraphAttrMapping | None = None, create_plate_label: PlateLabelFunc = create_plate_label_with_dim_length, ): """Make graphviz Digraph of PyMC model. @@ -460,6 +464,9 @@ def make_graph( node_formatters = update_node_formatters(node_formatters) graph = graphviz.Digraph(name) + if graph_attrs is not None: + graph.attr(**graph_attrs) + for plate in plates: if plate.dim_info: # must be preceded by 'cluster' to get a box around it @@ -676,6 +683,7 @@ def model_to_graphviz( figsize: tuple[int, int] | None = None, dpi: int = 300, node_formatters: NodeTypeFormatterMapping | None = None, + graph_attrs: GraphAttrMapping | None = None, include_dim_lengths: bool = True, ): """Produce a graphviz Digraph from a PyMC model. @@ -704,6 +712,10 @@ def model_to_graphviz( the size of the saved figure. dpi : int, optional Dots per inch. It only affects the resolution of the saved figure. The default is 300. + graph_attrs : dict, optional + A dictionary of top-level layout attributes for graphviz + Check out graphviz documentation for more information on available attributes + https://graphviz.org/doc/info/attrs.html node_formatters : dict, optional A dictionary mapping node types to functions that return a dictionary of node attributes. Check out graphviz documentation for more information on available @@ -773,6 +785,7 @@ def model_to_graphviz( save=save, figsize=figsize, dpi=dpi, + graph_attrs=graph_attrs, node_formatters=node_formatters, create_plate_label=create_plate_label_with_dim_length if include_dim_lengths From 574a48608ae1421ced539530210b7e1b94617939 Mon Sep 17 00:00:00 2001 From: Rob Zinkov Date: Sun, 30 Mar 2025 16:22:45 +0200 Subject: [PATCH 2/3] Update pymc/model_graph.py Co-authored-by: Jesse Grabowski <48652735+jessegrabowski@users.noreply.github.com> --- pymc/model_graph.py | 5 +---- 1 file changed, 1 insertion(+), 4 deletions(-) diff --git a/pymc/model_graph.py b/pymc/model_graph.py index 7cdcfd274..80c797c1f 100644 --- a/pymc/model_graph.py +++ b/pymc/model_graph.py @@ -463,10 +463,7 @@ def make_graph( node_formatters = node_formatters or {} node_formatters = update_node_formatters(node_formatters) - graph = graphviz.Digraph(name) - if graph_attrs is not None: - graph.attr(**graph_attrs) - + graph = graphviz.Digraph(name, graph_attr=graph_attrs) for plate in plates: if plate.dim_info: # must be preceded by 'cluster' to get a box around it From dcf2b76be088be1e51b9a0790958c7a3dd200ec6 Mon Sep 17 00:00:00 2001 From: Rob Zinkov Date: Sun, 30 Mar 2025 16:27:19 +0200 Subject: [PATCH 3/3] Simplify code --- pymc/model_graph.py | 13 +++++-------- 1 file changed, 5 insertions(+), 8 deletions(-) diff --git a/pymc/model_graph.py b/pymc/model_graph.py index 80c797c1f..3a241948a 100644 --- a/pymc/model_graph.py +++ b/pymc/model_graph.py @@ -430,9 +430,6 @@ def edges( ] -GraphAttrMapping = dict[Any, Any] - - def make_graph( name: str, plates: list[Plate], @@ -442,7 +439,7 @@ def make_graph( figsize=None, dpi=300, node_formatters: NodeTypeFormatterMapping | None = None, - graph_attrs: GraphAttrMapping | None = None, + graph_attr: dict[str, Any] | None = None, create_plate_label: PlateLabelFunc = create_plate_label_with_dim_length, ): """Make graphviz Digraph of PyMC model. @@ -463,7 +460,7 @@ def make_graph( node_formatters = node_formatters or {} node_formatters = update_node_formatters(node_formatters) - graph = graphviz.Digraph(name, graph_attr=graph_attrs) + graph = graphviz.Digraph(name, graph_attr=graph_attr) for plate in plates: if plate.dim_info: # must be preceded by 'cluster' to get a box around it @@ -680,7 +677,7 @@ def model_to_graphviz( figsize: tuple[int, int] | None = None, dpi: int = 300, node_formatters: NodeTypeFormatterMapping | None = None, - graph_attrs: GraphAttrMapping | None = None, + graph_attr: dict[str, Any] | None = None, include_dim_lengths: bool = True, ): """Produce a graphviz Digraph from a PyMC model. @@ -709,7 +706,7 @@ def model_to_graphviz( the size of the saved figure. dpi : int, optional Dots per inch. It only affects the resolution of the saved figure. The default is 300. - graph_attrs : dict, optional + graph_attr : dict, optional A dictionary of top-level layout attributes for graphviz Check out graphviz documentation for more information on available attributes https://graphviz.org/doc/info/attrs.html @@ -782,7 +779,7 @@ def model_to_graphviz( save=save, figsize=figsize, dpi=dpi, - graph_attrs=graph_attrs, + graph_attr=graph_attr, node_formatters=node_formatters, create_plate_label=create_plate_label_with_dim_length if include_dim_lengths