Skip to content

Commit 2ccf786

Browse files
committed
Add a way to pass in graph level attributes to graphviz
1 parent 5db3779 commit 2ccf786

File tree

1 file changed

+10
-0
lines changed

1 file changed

+10
-0
lines changed

pymc/model_graph.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -439,6 +439,7 @@ def make_graph(
439439
figsize=None,
440440
dpi=300,
441441
node_formatters: NodeTypeFormatterMapping | None = None,
442+
graph_attrs : dict[str, Any] | None,
442443
create_plate_label: PlateLabelFunc = create_plate_label_with_dim_length,
443444
):
444445
"""Make graphviz Digraph of PyMC model.
@@ -460,6 +461,9 @@ def make_graph(
460461
node_formatters = update_node_formatters(node_formatters)
461462

462463
graph = graphviz.Digraph(name)
464+
if graph_attrs is not None:
465+
graph.attr(**graph_attrs)
466+
463467
for plate in plates:
464468
if plate.dim_info:
465469
# must be preceded by 'cluster' to get a box around it
@@ -676,6 +680,7 @@ def model_to_graphviz(
676680
figsize: tuple[int, int] | None = None,
677681
dpi: int = 300,
678682
node_formatters: NodeTypeFormatterMapping | None = None,
683+
graph_attrs: dict[str, Any] | None = None,
679684
include_dim_lengths: bool = True,
680685
):
681686
"""Produce a graphviz Digraph from a PyMC model.
@@ -704,6 +709,10 @@ def model_to_graphviz(
704709
the size of the saved figure.
705710
dpi : int, optional
706711
Dots per inch. It only affects the resolution of the saved figure. The default is 300.
712+
graph_attrs : dict, optional
713+
A dictionary of top-level layout attributes for graphviz
714+
Check out graphviz documentation for more information on available attributes
715+
https://graphviz.org/doc/info/attrs.html
707716
node_formatters : dict, optional
708717
A dictionary mapping node types to functions that return a dictionary of node attributes.
709718
Check out graphviz documentation for more information on available
@@ -773,6 +782,7 @@ def model_to_graphviz(
773782
save=save,
774783
figsize=figsize,
775784
dpi=dpi,
785+
graph_attrs=graph_attrs,
776786
node_formatters=node_formatters,
777787
create_plate_label=create_plate_label_with_dim_length
778788
if include_dim_lengths

0 commit comments

Comments
 (0)