@@ -430,6 +430,9 @@ def edges(
430430 ]
431431
432432
433+ GraphAttrMapping = dict [Any , Any ]
434+
435+
433436def make_graph (
434437 name : str ,
435438 plates : list [Plate ],
@@ -439,6 +442,7 @@ def make_graph(
439442 figsize = None ,
440443 dpi = 300 ,
441444 node_formatters : NodeTypeFormatterMapping | None = None ,
445+ graph_attrs : GraphAttrMapping | None = None ,
442446 create_plate_label : PlateLabelFunc = create_plate_label_with_dim_length ,
443447):
444448 """Make graphviz Digraph of PyMC model.
@@ -460,6 +464,9 @@ def make_graph(
460464 node_formatters = update_node_formatters (node_formatters )
461465
462466 graph = graphviz .Digraph (name )
467+ if graph_attrs is not None :
468+ graph .attr (** graph_attrs )
469+
463470 for plate in plates :
464471 if plate .dim_info :
465472 # must be preceded by 'cluster' to get a box around it
@@ -676,6 +683,7 @@ def model_to_graphviz(
676683 figsize : tuple [int , int ] | None = None ,
677684 dpi : int = 300 ,
678685 node_formatters : NodeTypeFormatterMapping | None = None ,
686+ graph_attrs : GraphAttrMapping | None = None ,
679687 include_dim_lengths : bool = True ,
680688):
681689 """Produce a graphviz Digraph from a PyMC model.
@@ -704,6 +712,10 @@ def model_to_graphviz(
704712 the size of the saved figure.
705713 dpi : int, optional
706714 Dots per inch. It only affects the resolution of the saved figure. The default is 300.
715+ graph_attrs : dict, optional
716+ A dictionary of top-level layout attributes for graphviz
717+ Check out graphviz documentation for more information on available attributes
718+ https://graphviz.org/doc/info/attrs.html
707719 node_formatters : dict, optional
708720 A dictionary mapping node types to functions that return a dictionary of node attributes.
709721 Check out graphviz documentation for more information on available
@@ -773,6 +785,7 @@ def model_to_graphviz(
773785 save = save ,
774786 figsize = figsize ,
775787 dpi = dpi ,
788+ graph_attrs = graph_attrs ,
776789 node_formatters = node_formatters ,
777790 create_plate_label = create_plate_label_with_dim_length
778791 if include_dim_lengths
0 commit comments