Skip to content

Commit dcf2b76

Browse files
committed
Simplify code
1 parent 574a486 commit dcf2b76

File tree

1 file changed

+5
-8
lines changed

1 file changed

+5
-8
lines changed

pymc/model_graph.py

Lines changed: 5 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -430,9 +430,6 @@ def edges(
430430
]
431431

432432

433-
GraphAttrMapping = dict[Any, Any]
434-
435-
436433
def make_graph(
437434
name: str,
438435
plates: list[Plate],
@@ -442,7 +439,7 @@ def make_graph(
442439
figsize=None,
443440
dpi=300,
444441
node_formatters: NodeTypeFormatterMapping | None = None,
445-
graph_attrs: GraphAttrMapping | None = None,
442+
graph_attr: dict[str, Any] | None = None,
446443
create_plate_label: PlateLabelFunc = create_plate_label_with_dim_length,
447444
):
448445
"""Make graphviz Digraph of PyMC model.
@@ -463,7 +460,7 @@ def make_graph(
463460
node_formatters = node_formatters or {}
464461
node_formatters = update_node_formatters(node_formatters)
465462

466-
graph = graphviz.Digraph(name, graph_attr=graph_attrs)
463+
graph = graphviz.Digraph(name, graph_attr=graph_attr)
467464
for plate in plates:
468465
if plate.dim_info:
469466
# must be preceded by 'cluster' to get a box around it
@@ -680,7 +677,7 @@ def model_to_graphviz(
680677
figsize: tuple[int, int] | None = None,
681678
dpi: int = 300,
682679
node_formatters: NodeTypeFormatterMapping | None = None,
683-
graph_attrs: GraphAttrMapping | None = None,
680+
graph_attr: dict[str, Any] | None = None,
684681
include_dim_lengths: bool = True,
685682
):
686683
"""Produce a graphviz Digraph from a PyMC model.
@@ -709,7 +706,7 @@ def model_to_graphviz(
709706
the size of the saved figure.
710707
dpi : int, optional
711708
Dots per inch. It only affects the resolution of the saved figure. The default is 300.
712-
graph_attrs : dict, optional
709+
graph_attr : dict, optional
713710
A dictionary of top-level layout attributes for graphviz
714711
Check out graphviz documentation for more information on available attributes
715712
https://graphviz.org/doc/info/attrs.html
@@ -782,7 +779,7 @@ def model_to_graphviz(
782779
save=save,
783780
figsize=figsize,
784781
dpi=dpi,
785-
graph_attrs=graph_attrs,
782+
graph_attr=graph_attr,
786783
node_formatters=node_formatters,
787784
create_plate_label=create_plate_label_with_dim_length
788785
if include_dim_lengths

0 commit comments

Comments
 (0)