Skip to content

Commit f182dc2

Browse files
authored
Fix issues with rustworkx.visit annotations (#1353)
1 parent 830668b commit f182dc2

File tree

4 files changed

+33
-22
lines changed

4 files changed

+33
-22
lines changed
Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,7 @@
1+
---
2+
fixes:
3+
- |
4+
Fixed a bug in the discoverability of the type hints for the `rustworkx.visit` module.
5+
Classes declared in the module are also now properly annotated as accepting generic types.
6+
Refer to `issue 1352 <https://github.com/Qiskit/rustworkx/issues/1352>`__ for
7+
more information.

rustworkx/__init__.pyi

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -25,8 +25,8 @@ else:
2525
# rustworkx module we need to explicitly re-export every inner function from
2626
# rustworkx.rustworkx (the root rust module) in the form:
2727
# `from .rustworkx import foo as foo` so that mypy will treat `rustworkx.foo`
28-
# as a valid path
29-
import rustworkx.visit as visit
28+
# as a valid path.
29+
from . import visit as visit
3030

3131
from .rustworkx import DAGHasCycle as DAGHasCycle
3232
from .rustworkx import DAGWouldCycle as DAGWouldCycle

rustworkx/visit.py

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,10 @@
66
# copyright notice, and modified files need to carry a notice indicating
77
# that they have been altered from the originals.
88

9+
from typing import TypeVar, Generic
10+
11+
_T = TypeVar("_T")
12+
913

1014
class StopSearch(Exception):
1115
"""Stop graph traversal"""
@@ -19,7 +23,7 @@ class PruneSearch(Exception):
1923
pass
2024

2125

22-
class BFSVisitor:
26+
class BFSVisitor(Generic[_T]):
2327
"""A visitor object that is invoked at the event-points inside the
2428
:func:`~rustworkx.bfs_search` algorithm. By default, it performs no
2529
action, and should be used as a base class in order to be useful.
@@ -68,7 +72,7 @@ def black_target_edge(self, e):
6872
return
6973

7074

71-
class DFSVisitor:
75+
class DFSVisitor(Generic[_T]):
7276
"""A visitor object that is invoked at the event-points inside the
7377
:func:`~rustworkx.dfs_search` algorithm. By default, it performs no
7478
action, and should be used as a base class in order to be useful.
@@ -119,7 +123,7 @@ def forward_or_cross_edge(self, e):
119123
return
120124

121125

122-
class DijkstraVisitor:
126+
class DijkstraVisitor(Generic[_T]):
123127
"""A visitor object that is invoked at the event-points inside the
124128
:func:`~rustworkx.dijkstra_search` algorithm. By default, it performs no
125129
action, and should be used as a base class in order to be useful.

rustworkx/visit.pyi

Lines changed: 17 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -9,31 +9,31 @@
99
# This file contains only type annotations for PyO3 functions and classes
1010
# For implementation details, see visit.py
1111

12-
from typing import Generic, TypeVar
12+
from typing import Any, Generic, TypeVar
1313

1414
class StopSearch(Exception): ...
1515
class PruneSearch(Exception): ...
1616

1717
_T = TypeVar("_T")
1818

1919
class BFSVisitor(Generic[_T]):
20-
def discover_vertex(self, v: int): ...
21-
def finish_vertex(self, v: int): ...
22-
def tree_edge(self, e: tuple[int, int, _T]): ...
23-
def non_tree_edge(self, e: tuple[int, int, _T]): ...
24-
def gray_target_edge(self, e: tuple[int, int, _T]): ...
25-
def black_target_edge(self, e: tuple[int, int, _T]): ...
20+
def discover_vertex(self, v: int) -> Any: ...
21+
def finish_vertex(self, v: int) -> Any: ...
22+
def tree_edge(self, e: tuple[int, int, _T]) -> Any: ...
23+
def non_tree_edge(self, e: tuple[int, int, _T]) -> Any: ...
24+
def gray_target_edge(self, e: tuple[int, int, _T]) -> Any: ...
25+
def black_target_edge(self, e: tuple[int, int, _T]) -> Any: ...
2626

2727
class DFSVisitor(Generic[_T]):
28-
def discover_vertex(self, v: int, t: int): ...
29-
def finish_vertex(self, v: int, t: int): ...
30-
def tree_edge(self, e: tuple[int, int, _T]): ...
31-
def back_edge(self, e: tuple[int, int, _T]): ...
32-
def forward_or_cross_edge(self, e: tuple[int, int, _T]): ...
28+
def discover_vertex(self, v: int, t: int) -> Any: ...
29+
def finish_vertex(self, v: int, t: int) -> Any: ...
30+
def tree_edge(self, e: tuple[int, int, _T]) -> Any: ...
31+
def back_edge(self, e: tuple[int, int, _T]) -> Any: ...
32+
def forward_or_cross_edge(self, e: tuple[int, int, _T]) -> Any: ...
3333

3434
class DijkstraVisitor(Generic[_T]):
35-
def discover_vertex(self, v: int, score: float): ...
36-
def finish_vertex(self, v: int): ...
37-
def examine_edge(self, edge: tuple[int, int, _T]): ...
38-
def edge_relaxed(self, edge: tuple[int, int, _T]): ...
39-
def edge_not_relaxed(self, edge: tuple[int, int, _T]): ...
35+
def discover_vertex(self, v: int, score: float) -> Any: ...
36+
def finish_vertex(self, v: int) -> Any: ...
37+
def examine_edge(self, edge: tuple[int, int, _T]) -> Any: ...
38+
def edge_relaxed(self, edge: tuple[int, int, _T]) -> Any: ...
39+
def edge_not_relaxed(self, edge: tuple[int, int, _T]) -> Any: ...

0 commit comments

Comments
 (0)