diff --git a/rustworkx/rustworkx.pyi b/rustworkx/rustworkx.pyi index 2bc3c8a770..1304a700c7 100644 --- a/rustworkx/rustworkx.pyi +++ b/rustworkx/rustworkx.pyi @@ -1077,6 +1077,8 @@ class _RustworkxCustomVecIter(Generic[_T_co], Sequence[_T_co], ABC): def __getitem__(self, index: int) -> _T_co: ... @overload def __getitem__(self: Self, index: slice) -> Self: ... + @classmethod + def __class_getitem__(cls, key: Any, /) -> GenericAlias: ... def __getstate__(self) -> Any: ... def __hash__(self) -> int: ... def __len__(self) -> int: ... @@ -1096,6 +1098,8 @@ class _RustworkxCustomHashMapIter(Generic[_S, _T_co], Mapping[_S, _T_co], ABC): def __contains__(self, other: object) -> bool: ... def __eq__(self, other: object) -> bool: ... def __getitem__(self, index: _S) -> _T_co: ... + @classmethod + def __class_getitem__(cls, key: Any, /) -> GenericAlias: ... def __getstate__(self) -> Any: ... def __hash__(self) -> int: ... def __iter__(self) -> Iterator[_S]: ... diff --git a/src/iterators.rs b/src/iterators.rs index a15cef94c4..c5efa44c78 100644 --- a/src/iterators.rs +++ b/src/iterators.rs @@ -41,12 +41,15 @@ use std::collections::hash_map::DefaultHasher; use std::hash::Hasher; +use super::generic_class_getitem; use num_bigint::BigUint; use rustworkx_core::dictmap::*; use ndarray::prelude::*; use numpy::IntoPyArray; -use pyo3::exceptions::{PyIndexError, PyKeyError, PyNotImplementedError, PyValueError}; +use pyo3::exceptions::{ + PyIndexError, PyKeyError, PyNotImplementedError, PyTypeError, PyValueError, +}; use pyo3::gc::PyVisit; use pyo3::prelude::*; use pyo3::types::IntoPyDict; @@ -622,6 +625,24 @@ macro_rules! custom_vec_iter_impl { } } + #[classmethod] + #[pyo3(signature = (key, /))] + pub fn __class_getitem__( + cls: &Bound<'_, pyo3::types::PyType>, + key: &Bound<'_, PyAny>, + ) -> PyResult { + let class_string = stringify!($name); + const ALLOWED_GENERIC_CLASSES: [&str; 3] = + ["BFSSuccessors", "BFSPredecessors", "WeightedEdgeList"]; + if !ALLOWED_GENERIC_CLASSES.contains(&class_string) { + return Err(PyTypeError::new_err(format!( + "type 'rustworkx.{}' is not subscriptable", + class_string + ))); + } + generic_class_getitem(cls, key) + } + #[pyo3(signature = (dtype=None, copy=None))] fn __array__<'py>( &self, @@ -1292,6 +1313,23 @@ macro_rules! custom_hash_map_iter_impl { } } + #[classmethod] + #[pyo3(signature = (key, /))] + pub fn __class_getitem__( + cls: &Bound<'_, pyo3::types::PyType>, + key: &Bound<'_, PyAny>, + ) -> PyResult { + let class_string = stringify!($name); + const ALLOWED_GENERIC_CLASSES: [&str; 1] = ["EdgeIndexMap"]; + if !ALLOWED_GENERIC_CLASSES.contains(&class_string) { + return Err(PyTypeError::new_err(format!( + "type 'rustworkx.{}' is not subscriptable", + class_string + ))); + } + generic_class_getitem(cls, key) + } + fn __iter__(slf: PyRef) -> $nameKeys { $nameKeys { $keys: slf.$data.keys().copied().collect(), diff --git a/tests/test_annotations.py b/tests/test_annotations.py index 4542c5eb85..ed4e393800 100644 --- a/tests/test_annotations.py +++ b/tests/test_annotations.py @@ -14,6 +14,8 @@ import types import unittest +from typing import Optional + class TestAnnotationSubscriptions(unittest.TestCase): def test_digraph(self): @@ -36,3 +38,39 @@ def test_dag(self): graph.__class_getitem__((int, int)), types.GenericAlias, ) + + def test_custom_vector_allowed(self): + graph: rx.PyGraph[Optional[int], Optional[int]] = rx.generators.path_graph(5) + we_list: rx.WeightedEdgeList[Optional[int]] = graph.weighted_edge_list() + self.assertIsInstance( + we_list.__class_getitem__((Optional[int],)), + types.GenericAlias, + ) + + def test_custom_vector_not_allowed(self): + graph: rx.PyGraph[Optional[int], Optional[int]] = rx.generators.path_graph(5) + edge_list: rx.EdgeList = graph.edge_list() + with self.assertRaises(TypeError): + self.assertIsInstance( + edge_list.__class_getitem__((Optional[int],)), + types.GenericAlias, + ) + + def test_custom_hashmap_allowed(self): + graph: rx.PyGraph[Optional[int], Optional[int]] = rx.generators.path_graph(5) + ei_map: rx.WeightedEdgeList[Optional[int]] = graph.edge_index_map() + self.assertIsInstance( + ei_map.__class_getitem__((Optional[int],)), + types.GenericAlias, + ) + + def test_custom_hashmap_not_allowed(self): + graph: rx.PyGraph[Optional[int], Optional[int]] = rx.generators.path_graph(5) + all_pairs_pm: rx.AllPairsPathMapping = rx.all_pairs_dijkstra_shortest_paths( + graph, lambda _: 1.0 + ) + with self.assertRaises(TypeError): + self.assertIsInstance( + all_pairs_pm.__class_getitem__((Optional[int],)), + types.GenericAlias, + )