Skip to content
Merged
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 13 additions & 0 deletions pymc/model_graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -430,6 +430,9 @@
]


GraphAttrMapping = dict[Any, Any]


def make_graph(
name: str,
plates: list[Plate],
Expand All @@ -439,6 +442,7 @@
figsize=None,
dpi=300,
node_formatters: NodeTypeFormatterMapping | None = None,
graph_attrs: GraphAttrMapping | None = None,
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not sure I like the custom type, it makes it seem more complex than it really is. It's not much shorter than dict[str, Any]

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I have no strong feelings about this other than it was to be more consistent with the rest of the file

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

up to you; not a blocker

create_plate_label: PlateLabelFunc = create_plate_label_with_dim_length,
):
"""Make graphviz Digraph of PyMC model.
Expand All @@ -460,6 +464,9 @@
node_formatters = update_node_formatters(node_formatters)

graph = graphviz.Digraph(name)
if graph_attrs is not None:
graph.attr(**graph_attrs)

Check warning on line 468 in pymc/model_graph.py

View check run for this annotation

Codecov / codecov/patch

pymc/model_graph.py#L468

Added line #L468 was not covered by tests

for plate in plates:
if plate.dim_info:
# must be preceded by 'cluster' to get a box around it
Expand Down Expand Up @@ -676,6 +683,7 @@
figsize: tuple[int, int] | None = None,
dpi: int = 300,
node_formatters: NodeTypeFormatterMapping | None = None,
graph_attrs: GraphAttrMapping | None = None,
include_dim_lengths: bool = True,
):
"""Produce a graphviz Digraph from a PyMC model.
Expand Down Expand Up @@ -704,6 +712,10 @@
the size of the saved figure.
dpi : int, optional
Dots per inch. It only affects the resolution of the saved figure. The default is 300.
graph_attrs : dict, optional
A dictionary of top-level layout attributes for graphviz
Check out graphviz documentation for more information on available attributes
https://graphviz.org/doc/info/attrs.html
node_formatters : dict, optional
A dictionary mapping node types to functions that return a dictionary of node attributes.
Check out graphviz documentation for more information on available
Expand Down Expand Up @@ -773,6 +785,7 @@
save=save,
figsize=figsize,
dpi=dpi,
graph_attrs=graph_attrs,
node_formatters=node_formatters,
create_plate_label=create_plate_label_with_dim_length
if include_dim_lengths
Expand Down
Loading