Skip to content

Commit 830668b

Browse files
authored
Allow PyGraph and PyDiGraph to be annotated as generic classes at runtime (#1348)
* Allow PyGraph and PyDiGraph to be annotated as generic classes at runtime * Black * Add release notes * Add __class_getitem__ to stubs
1 parent abef7e2 commit 830668b

File tree

6 files changed

+101
-5
lines changed

6 files changed

+101
-5
lines changed
Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,21 @@
1+
---
2+
features:
3+
- |
4+
The :class:`PyDiGraph() <rustworkx.PyDiGraph>` and :class:`PyGraph() <rustworkx.PyGraph>`
5+
classes now have better support for
6+
`PEP 560 <https://peps.python.org/pep-0560/>`__. Building off of the previous
7+
releases which introduced type annotations, the following code snippet is now valid:
8+
9+
.. jupyter-execute::
10+
11+
import rustworkx as rx
12+
13+
graph: rx.PyGraph[int, int] = rx.PyGraph()
14+
15+
16+
Previously, users had to rely on post-poned evaluation of type annotations from
17+
`PEP 563 <https://peps.python.org/pep-0563/>`__ for annotations to work.
18+
19+
Refer to `issue 1345 <https://github.com/Qiskit/rustworkx/issues/1345>`__ for
20+
more information.
21+

rustworkx/rustworkx.pyi

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
# For implementation details, see __init__.py and src/lib.rs
1111

1212
from .visit import BFSVisitor, DFSVisitor, DijkstraVisitor
13+
from types import GenericAlias
1314
from typing import (
1415
Callable,
1516
final,
@@ -1308,6 +1309,8 @@ class PyGraph(Generic[_S, _T]):
13081309
) -> None: ...
13091310
def __delitem__(self, idx: int, /) -> None: ...
13101311
def __getitem__(self, idx: int, /) -> _S: ...
1312+
@classmethod
1313+
def __class_getitem__(cls, key: Any, /) -> GenericAlias: ...
13111314
def __getnewargs_ex__(self) -> tuple[tuple[Any, ...], dict[str, Any]]: ...
13121315
def __getstate__(self) -> Any: ...
13131316
def __len__(self) -> int: ...
@@ -1509,6 +1512,8 @@ class PyDiGraph(Generic[_S, _T]):
15091512
def reverse(self) -> None: ...
15101513
def __delitem__(self, idx: int, /) -> None: ...
15111514
def __getitem__(self, idx: int, /) -> _S: ...
1515+
@classmethod
1516+
def __class_getitem__(cls, key: Any, /) -> GenericAlias: ...
15121517
def __getnewargs_ex__(self) -> tuple[tuple[Any, ...], dict[str, Any]]: ...
15131518
def __getstate__(self) -> Any: ...
15141519
def __len__(self) -> int: ...

src/digraph.rs

Lines changed: 12 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,7 @@ use smallvec::SmallVec;
3131
use pyo3::exceptions::PyIndexError;
3232
use pyo3::gc::PyVisit;
3333
use pyo3::prelude::*;
34-
use pyo3::types::{IntoPyDict, PyBool, PyDict, PyList, PyString, PyTuple};
34+
use pyo3::types::{IntoPyDict, PyBool, PyDict, PyList, PyString, PyTuple, PyType};
3535
use pyo3::PyTraverseError;
3636
use pyo3::Python;
3737

@@ -55,8 +55,8 @@ use super::iterators::{
5555
EdgeIndexMap, EdgeIndices, EdgeList, NodeIndices, NodeMap, WeightedEdgeList,
5656
};
5757
use super::{
58-
find_node_by_weight, weight_callable, DAGHasCycle, DAGWouldCycle, IsNan, NoEdgeBetweenNodes,
59-
NoSuitableNeighbors, NodesRemoved, StablePyGraph,
58+
find_node_by_weight, generic_class_getitem, weight_callable, DAGHasCycle, DAGWouldCycle, IsNan,
59+
NoEdgeBetweenNodes, NoSuitableNeighbors, NodesRemoved, StablePyGraph,
6060
};
6161

6262
use super::dag_algo::is_directed_acyclic_graph;
@@ -3143,6 +3143,15 @@ impl PyDiGraph {
31433143
}
31443144
}
31453145

3146+
#[classmethod]
3147+
#[pyo3(signature = (key, /))]
3148+
pub fn __class_getitem__(
3149+
cls: &Bound<'_, PyType>,
3150+
key: &Bound<'_, PyAny>,
3151+
) -> PyResult<PyObject> {
3152+
generic_class_getitem(cls, key)
3153+
}
3154+
31463155
// Functions to enable Python Garbage Collection
31473156

31483157
// Function for PyTypeObject.tp_traverse [1][2] used to tell Python what

src/graph.rs

Lines changed: 12 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,7 @@ use rustworkx_core::graph_ext::*;
2626
use pyo3::exceptions::PyIndexError;
2727
use pyo3::gc::PyVisit;
2828
use pyo3::prelude::*;
29-
use pyo3::types::{IntoPyDict, PyBool, PyDict, PyList, PyString, PyTuple};
29+
use pyo3::types::{IntoPyDict, PyBool, PyDict, PyList, PyString, PyTuple, PyType};
3030
use pyo3::PyTraverseError;
3131
use pyo3::Python;
3232

@@ -40,7 +40,8 @@ use crate::iterators::NodeMap;
4040
use super::dot_utils::build_dot;
4141
use super::iterators::{EdgeIndexMap, EdgeIndices, EdgeList, NodeIndices, WeightedEdgeList};
4242
use super::{
43-
find_node_by_weight, weight_callable, IsNan, NoEdgeBetweenNodes, NodesRemoved, StablePyGraph,
43+
find_node_by_weight, generic_class_getitem, weight_callable, IsNan, NoEdgeBetweenNodes,
44+
NodesRemoved, StablePyGraph,
4445
};
4546

4647
use crate::RxPyResult;
@@ -1994,6 +1995,15 @@ impl PyGraph {
19941995
}
19951996
}
19961997

1998+
#[classmethod]
1999+
#[pyo3(signature = (key, /))]
2000+
pub fn __class_getitem__(
2001+
cls: &Bound<'_, PyType>,
2002+
key: &Bound<'_, PyAny>,
2003+
) -> PyResult<PyObject> {
2004+
generic_class_getitem(cls, key)
2005+
}
2006+
19972007
// Functions to enable Python Garbage Collection
19982008

19992009
// Function for PyTypeObject.tp_traverse [1][2] used to tell Python what

src/lib.rs

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -376,6 +376,19 @@ fn find_node_by_weight<Ty: EdgeType>(
376376
Ok(index)
377377
}
378378

379+
fn generic_class_getitem(
380+
cls: &Bound<'_, pyo3::types::PyType>,
381+
key: &Bound<'_, PyAny>,
382+
) -> PyResult<PyObject> {
383+
Python::with_gil(|py| -> PyResult<PyObject> {
384+
let types_mod = py.import_bound("types")?;
385+
let types_generic_alias = types_mod.getattr("GenericAlias")?;
386+
let args = (cls, key);
387+
let generic_alias = types_generic_alias.call1(args)?;
388+
Ok(generic_alias.into())
389+
})
390+
}
391+
379392
// The provided node is invalid.
380393
create_exception!(rustworkx, InvalidNode, PyException);
381394
// Performing this operation would result in trying to add a cycle to a DAG.

tests/test_annotations.py

Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,38 @@
1+
# Licensed under the Apache License, Version 2.0 (the "License"); you may
2+
# not use this file except in compliance with the License. You may obtain
3+
# a copy of the License at
4+
#
5+
# http://www.apache.org/licenses/LICENSE-2.0
6+
#
7+
# Unless required by applicable law or agreed to in writing, software
8+
# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
9+
# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
10+
# License for the specific language governing permissions and limitations
11+
# under the License.
12+
13+
import rustworkx as rx
14+
import types
15+
import unittest
16+
17+
18+
class TestAnnotationSubscriptions(unittest.TestCase):
19+
def test_digraph(self):
20+
graph: rx.PyDiGraph[int, int] = rx.PyDiGraph()
21+
self.assertIsInstance(
22+
graph.__class_getitem__((int, int)),
23+
types.GenericAlias,
24+
)
25+
26+
def test_graph(self):
27+
graph: rx.PyGraph[int, int] = rx.PyGraph()
28+
self.assertIsInstance(
29+
graph.__class_getitem__((int, int)),
30+
types.GenericAlias,
31+
)
32+
33+
def test_dag(self):
34+
graph: rx.PyDAG[int, int] = rx.PyDAG()
35+
self.assertIsInstance(
36+
graph.__class_getitem__((int, int)),
37+
types.GenericAlias,
38+
)

0 commit comments

Comments
 (0)