@@ -439,6 +439,7 @@ def make_graph(
439439 figsize = None ,
440440 dpi = 300 ,
441441 node_formatters : NodeTypeFormatterMapping | None = None ,
442+ graph_attrs : dict [str , Any ] | None ,
442443 create_plate_label : PlateLabelFunc = create_plate_label_with_dim_length ,
443444):
444445 """Make graphviz Digraph of PyMC model.
@@ -460,6 +461,9 @@ def make_graph(
460461 node_formatters = update_node_formatters (node_formatters )
461462
462463 graph = graphviz .Digraph (name )
464+ if graph_attrs is not None :
465+ graph .attr (** graph_attrs )
466+
463467 for plate in plates :
464468 if plate .dim_info :
465469 # must be preceded by 'cluster' to get a box around it
@@ -676,6 +680,7 @@ def model_to_graphviz(
676680 figsize : tuple [int , int ] | None = None ,
677681 dpi : int = 300 ,
678682 node_formatters : NodeTypeFormatterMapping | None = None ,
683+ graph_attrs : dict [str , Any ] | None = None ,
679684 include_dim_lengths : bool = True ,
680685):
681686 """Produce a graphviz Digraph from a PyMC model.
@@ -704,6 +709,10 @@ def model_to_graphviz(
704709 the size of the saved figure.
705710 dpi : int, optional
706711 Dots per inch. It only affects the resolution of the saved figure. The default is 300.
712+ graph_attrs : dict, optional
713+ A dictionary of top-level layout attributes for graphviz
714+ Check out graphviz documentation for more information on available attributes
715+ https://graphviz.org/doc/info/attrs.html
707716 node_formatters : dict, optional
708717 A dictionary mapping node types to functions that return a dictionary of node attributes.
709718 Check out graphviz documentation for more information on available
@@ -773,6 +782,7 @@ def model_to_graphviz(
773782 save = save ,
774783 figsize = figsize ,
775784 dpi = dpi ,
785+ graph_attrs = graph_attrs ,
776786 node_formatters = node_formatters ,
777787 create_plate_label = create_plate_label_with_dim_length
778788 if include_dim_lengths
0 commit comments