diff --git a/rustworkx/visualization/graphviz.pyi b/rustworkx/visualization/graphviz.pyi index f1a9a41e3f..c5902b132d 100644 --- a/rustworkx/visualization/graphviz.pyi +++ b/rustworkx/visualization/graphviz.pyi @@ -10,17 +10,126 @@ import typing from rustworkx.rustworkx import PyGraph, PyDiGraph if typing.TYPE_CHECKING: - from PIL.Image import Image # type: ignore + from PIL.Image import Image _S = typing.TypeVar("_S") _T = typing.TypeVar("_T") +@typing.overload def graphviz_draw( graph: PyDiGraph[_S, _T] | PyGraph[_S, _T], node_attr_fn: typing.Callable[[_S], dict[str, str]] | None = ..., edge_attr_fn: typing.Callable[[_T], dict[str, str]] | None = ..., graph_attr: dict[str, str] | None = ..., - filename: str | None = ..., + filename: None = ..., + image_type: ( + typing.Literal[ + "canon", + "cmap", + "cmapx", + "cmapx_np", + "dia", + "dot", + "fig", + "gd", + "gd2", + "gif", + "hpgl", + "imap", + "imap_np", + "ismap", + "jpe", + "jpeg", + "jpg", + "mif", + "mp", + "pcl", + "pdf", + "pic", + "plain", + "plain-ext", + "png", + "ps", + "ps2", + "svg", + "svgz", + "vml", + "vmlzvrml", + "vtx", + "wbmp", + "xdor", + "xlib", + ] + | None + ) = ..., + method: typing.Literal["twopi", "neato", "circo", "fdp", "sfdp", "dot"] | None = ..., +) -> Image: ... +@typing.overload +def graphviz_draw( + graph: PyDiGraph[_S, _T] | PyGraph[_S, _T], + node_attr_fn: typing.Callable[[_S], dict[str, str]] | None = ..., + edge_attr_fn: typing.Callable[[_T], dict[str, str]] | None = ..., + graph_attr: dict[str, str] | None = ..., + filename: None = ..., + image_type: str | None = ..., + method: str | None = ..., +) -> Image: ... +@typing.overload +def graphviz_draw( + graph: PyDiGraph[_S, _T] | PyGraph[_S, _T], + node_attr_fn: typing.Callable[[_S], dict[str, str]] | None = ..., + edge_attr_fn: typing.Callable[[_T], dict[str, str]] | None = ..., + graph_attr: dict[str, str] | None = ..., + filename: str = ..., + image_type: ( + typing.Literal[ + "canon", + "cmap", + "cmapx", + "cmapx_np", + "dia", + "dot", + "fig", + "gd", + "gd2", + "gif", + "hpgl", + "imap", + "imap_np", + "ismap", + "jpe", + "jpeg", + "jpg", + "mif", + "mp", + "pcl", + "pdf", + "pic", + "plain", + "plain-ext", + "png", + "ps", + "ps2", + "svg", + "svgz", + "vml", + "vmlzvrml", + "vtx", + "wbmp", + "xdor", + "xlib", + ] + | None + ) = ..., + method: typing.Literal["twopi", "neato", "circo", "fdp", "sfdp", "dot"] | None = ..., +) -> None: ... +@typing.overload +def graphviz_draw( + graph: PyDiGraph[_S, _T] | PyGraph[_S, _T], + node_attr_fn: typing.Callable[[_S], dict[str, str]] | None = ..., + edge_attr_fn: typing.Callable[[_T], dict[str, str]] | None = ..., + graph_attr: dict[str, str] | None = ..., + filename: str = ..., image_type: str | None = ..., method: str | None = ..., -) -> Image | None: ... +) -> None: ... diff --git a/rustworkx/visualization/matplotlib.pyi b/rustworkx/visualization/matplotlib.pyi index 3f8da1fd17..38759fb6a1 100644 --- a/rustworkx/visualization/matplotlib.pyi +++ b/rustworkx/visualization/matplotlib.pyi @@ -20,7 +20,7 @@ if typing.TYPE_CHECKING: _S = typing.TypeVar("_S") _T = typing.TypeVar("_T") -class _DrawKwargs(typing.TypedDict, total=False): +class _DrawKwargs(typing.TypedDict, typing.Generic[_S, _T], total=False): arrowstyle: str arrow_size: int node_list: list[int] @@ -67,5 +67,5 @@ def mpl_draw( ax: Axes | None = ..., arrows: bool = ..., with_labels: bool = ..., - **kwds: typing_extensions.Unpack[_DrawKwargs], + **kwds: typing_extensions.Unpack[_DrawKwargs[_S, _T]], ) -> Figure | None: ...