diff --git a/pymc/model_graph.py b/pymc/model_graph.py index f31a33770..3a241948a 100644 --- a/pymc/model_graph.py +++ b/pymc/model_graph.py @@ -439,6 +439,7 @@ def make_graph( figsize=None, dpi=300, node_formatters: NodeTypeFormatterMapping | None = None, + graph_attr: dict[str, Any] | None = None, create_plate_label: PlateLabelFunc = create_plate_label_with_dim_length, ): """Make graphviz Digraph of PyMC model. @@ -459,7 +460,7 @@ def make_graph( node_formatters = node_formatters or {} node_formatters = update_node_formatters(node_formatters) - graph = graphviz.Digraph(name) + graph = graphviz.Digraph(name, graph_attr=graph_attr) for plate in plates: if plate.dim_info: # must be preceded by 'cluster' to get a box around it @@ -676,6 +677,7 @@ def model_to_graphviz( figsize: tuple[int, int] | None = None, dpi: int = 300, node_formatters: NodeTypeFormatterMapping | None = None, + graph_attr: dict[str, Any] | None = None, include_dim_lengths: bool = True, ): """Produce a graphviz Digraph from a PyMC model. @@ -704,6 +706,10 @@ def model_to_graphviz( the size of the saved figure. dpi : int, optional Dots per inch. It only affects the resolution of the saved figure. The default is 300. + graph_attr : dict, optional + A dictionary of top-level layout attributes for graphviz + Check out graphviz documentation for more information on available attributes + https://graphviz.org/doc/info/attrs.html node_formatters : dict, optional A dictionary mapping node types to functions that return a dictionary of node attributes. Check out graphviz documentation for more information on available @@ -773,6 +779,7 @@ def model_to_graphviz( save=save, figsize=figsize, dpi=dpi, + graph_attr=graph_attr, node_formatters=node_formatters, create_plate_label=create_plate_label_with_dim_length if include_dim_lengths