Skip to content
61 changes: 56 additions & 5 deletions rustworkx/visualization/graphviz.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -10,17 +10,68 @@ import typing
from rustworkx.rustworkx import PyGraph, PyDiGraph

if typing.TYPE_CHECKING:
from PIL.Image import Image # type: ignore
from typing_extensions import TypeAlias
from PIL.Image import Image

_Method: TypeAlias = typing.Literal["twopi", "neato", "circo", "fdp", "sfdp", "dot"]
_ImageType: TypeAlias = typing.Literal[
"canon",
"cmap",
"cmapx",
"cmapx_np",
"dia",
"dot",
"fig",
"gd",
"gd2",
"gif",
"hpgl",
"imap",
"imap_np",
"ismap",
"jpe",
"jpeg",
"jpg",
"mif",
"mp",
"pcl",
"pdf",
"pic",
"plain",
"plain-ext",
"png",
"ps",
"ps2",
"svg",
"svgz",
"vml",
"vmlzvrml",
"vtx",
"wbmp",
"xdor",
"xlib",
]

_S = typing.TypeVar("_S")
_T = typing.TypeVar("_T")

@typing.overload
def graphviz_draw(
graph: PyDiGraph[_S, _T] | PyGraph[_S, _T],
node_attr_fn: typing.Callable[[_S], dict[str, str]] | None = ...,
edge_attr_fn: typing.Callable[[_T], dict[str, str]] | None = ...,
graph_attr: dict[str, str] | None = ...,
filename: None = ...,
image_type: _ImageType | None = ...,
method: _Method | None = ...,
) -> Image: ...
@typing.overload
def graphviz_draw(
graph: PyDiGraph[_S, _T] | PyGraph[_S, _T],
node_attr_fn: typing.Callable[[_S], dict[str, str]] | None = ...,
edge_attr_fn: typing.Callable[[_T], dict[str, str]] | None = ...,
graph_attr: dict[str, str] | None = ...,
filename: str | None = ...,
image_type: str | None = ...,
method: str | None = ...,
) -> Image | None: ...
filename: str = ...,
image_type: _ImageType | None = ...,
method: _Method | None = ...,
) -> None: ...
4 changes: 2 additions & 2 deletions rustworkx/visualization/matplotlib.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ if typing.TYPE_CHECKING:
_S = typing.TypeVar("_S")
_T = typing.TypeVar("_T")

class _DrawKwargs(typing.TypedDict, total=False):
class _DrawKwargs(typing.TypedDict, typing.Generic[_S, _T], total=False):
arrowstyle: str
arrow_size: int
node_list: list[int]
Expand Down Expand Up @@ -67,5 +67,5 @@ def mpl_draw(
ax: Axes | None = ...,
arrows: bool = ...,
with_labels: bool = ...,
**kwds: typing_extensions.Unpack[_DrawKwargs],
**kwds: typing_extensions.Unpack[_DrawKwargs[_S, _T]],
) -> Figure | None: ...