Skip to content

Commit 30897c5

Browse files
authored
enhance visualization func type hint (#1443)
* enhance graphviz_draw type hint * enhance mpl_draw type hint * fix black lint * fix black lint Signed-off-by: ZhengYu, Xu <[email protected]> * make Method and ImageType private to satisfy mypy.stubtest * allow image_type and method to accept str * remove type alias * format --------- Signed-off-by: ZhengYu, Xu <[email protected]>
1 parent 3fbc10d commit 30897c5

File tree

2 files changed

+114
-5
lines changed

2 files changed

+114
-5
lines changed

rustworkx/visualization/graphviz.pyi

Lines changed: 112 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -10,17 +10,126 @@ import typing
1010
from rustworkx.rustworkx import PyGraph, PyDiGraph
1111

1212
if typing.TYPE_CHECKING:
13-
from PIL.Image import Image # type: ignore
13+
from PIL.Image import Image
1414

1515
_S = typing.TypeVar("_S")
1616
_T = typing.TypeVar("_T")
1717

18+
@typing.overload
1819
def graphviz_draw(
1920
graph: PyDiGraph[_S, _T] | PyGraph[_S, _T],
2021
node_attr_fn: typing.Callable[[_S], dict[str, str]] | None = ...,
2122
edge_attr_fn: typing.Callable[[_T], dict[str, str]] | None = ...,
2223
graph_attr: dict[str, str] | None = ...,
23-
filename: str | None = ...,
24+
filename: None = ...,
25+
image_type: (
26+
typing.Literal[
27+
"canon",
28+
"cmap",
29+
"cmapx",
30+
"cmapx_np",
31+
"dia",
32+
"dot",
33+
"fig",
34+
"gd",
35+
"gd2",
36+
"gif",
37+
"hpgl",
38+
"imap",
39+
"imap_np",
40+
"ismap",
41+
"jpe",
42+
"jpeg",
43+
"jpg",
44+
"mif",
45+
"mp",
46+
"pcl",
47+
"pdf",
48+
"pic",
49+
"plain",
50+
"plain-ext",
51+
"png",
52+
"ps",
53+
"ps2",
54+
"svg",
55+
"svgz",
56+
"vml",
57+
"vmlzvrml",
58+
"vtx",
59+
"wbmp",
60+
"xdor",
61+
"xlib",
62+
]
63+
| None
64+
) = ...,
65+
method: typing.Literal["twopi", "neato", "circo", "fdp", "sfdp", "dot"] | None = ...,
66+
) -> Image: ...
67+
@typing.overload
68+
def graphviz_draw(
69+
graph: PyDiGraph[_S, _T] | PyGraph[_S, _T],
70+
node_attr_fn: typing.Callable[[_S], dict[str, str]] | None = ...,
71+
edge_attr_fn: typing.Callable[[_T], dict[str, str]] | None = ...,
72+
graph_attr: dict[str, str] | None = ...,
73+
filename: None = ...,
74+
image_type: str | None = ...,
75+
method: str | None = ...,
76+
) -> Image: ...
77+
@typing.overload
78+
def graphviz_draw(
79+
graph: PyDiGraph[_S, _T] | PyGraph[_S, _T],
80+
node_attr_fn: typing.Callable[[_S], dict[str, str]] | None = ...,
81+
edge_attr_fn: typing.Callable[[_T], dict[str, str]] | None = ...,
82+
graph_attr: dict[str, str] | None = ...,
83+
filename: str = ...,
84+
image_type: (
85+
typing.Literal[
86+
"canon",
87+
"cmap",
88+
"cmapx",
89+
"cmapx_np",
90+
"dia",
91+
"dot",
92+
"fig",
93+
"gd",
94+
"gd2",
95+
"gif",
96+
"hpgl",
97+
"imap",
98+
"imap_np",
99+
"ismap",
100+
"jpe",
101+
"jpeg",
102+
"jpg",
103+
"mif",
104+
"mp",
105+
"pcl",
106+
"pdf",
107+
"pic",
108+
"plain",
109+
"plain-ext",
110+
"png",
111+
"ps",
112+
"ps2",
113+
"svg",
114+
"svgz",
115+
"vml",
116+
"vmlzvrml",
117+
"vtx",
118+
"wbmp",
119+
"xdor",
120+
"xlib",
121+
]
122+
| None
123+
) = ...,
124+
method: typing.Literal["twopi", "neato", "circo", "fdp", "sfdp", "dot"] | None = ...,
125+
) -> None: ...
126+
@typing.overload
127+
def graphviz_draw(
128+
graph: PyDiGraph[_S, _T] | PyGraph[_S, _T],
129+
node_attr_fn: typing.Callable[[_S], dict[str, str]] | None = ...,
130+
edge_attr_fn: typing.Callable[[_T], dict[str, str]] | None = ...,
131+
graph_attr: dict[str, str] | None = ...,
132+
filename: str = ...,
24133
image_type: str | None = ...,
25134
method: str | None = ...,
26-
) -> Image | None: ...
135+
) -> None: ...

rustworkx/visualization/matplotlib.pyi

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@ if typing.TYPE_CHECKING:
2020
_S = typing.TypeVar("_S")
2121
_T = typing.TypeVar("_T")
2222

23-
class _DrawKwargs(typing.TypedDict, total=False):
23+
class _DrawKwargs(typing.TypedDict, typing.Generic[_S, _T], total=False):
2424
arrowstyle: str
2525
arrow_size: int
2626
node_list: list[int]
@@ -67,5 +67,5 @@ def mpl_draw(
6767
ax: Axes | None = ...,
6868
arrows: bool = ...,
6969
with_labels: bool = ...,
70-
**kwds: typing_extensions.Unpack[_DrawKwargs],
70+
**kwds: typing_extensions.Unpack[_DrawKwargs[_S, _T]],
7171
) -> Figure | None: ...

0 commit comments

Comments
 (0)