Skip to content
Open
4 changes: 4 additions & 0 deletions rustworkx/rustworkx.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -1074,6 +1074,8 @@ class _RustworkxCustomVecIter(Generic[_T_co], Sequence[_T_co], ABC):
def __getitem__(self, index: int) -> _T_co: ...
@overload
def __getitem__(self: Self, index: slice) -> Self: ...
@classmethod
def __class_getitem__(cls, key: Any, /) -> GenericAlias: ...
def __getstate__(self) -> Any: ...
def __hash__(self) -> int: ...
def __len__(self) -> int: ...
Expand All @@ -1091,6 +1093,8 @@ class _RustworkxCustomHashMapIter(Generic[_S, _T_co], Mapping[_S, _T_co], ABC):
def __contains__(self, other: object) -> bool: ...
def __eq__(self, other: object) -> bool: ...
def __getitem__(self, index: _S) -> _T_co: ...
@classmethod
def __class_getitem__(cls, key: Any, /) -> GenericAlias: ...
def __getstate__(self) -> Any: ...
def __hash__(self) -> int: ...
def __iter__(self) -> Iterator[_S]: ...
Expand Down
40 changes: 39 additions & 1 deletion src/iterators.rs
Original file line number Diff line number Diff line change
Expand Up @@ -41,12 +41,15 @@
use std::collections::hash_map::DefaultHasher;
use std::hash::Hasher;

use super::generic_class_getitem;
use num_bigint::BigUint;
use rustworkx_core::dictmap::*;

use ndarray::prelude::*;
use numpy::IntoPyArray;
use pyo3::exceptions::{PyIndexError, PyKeyError, PyNotImplementedError, PyValueError};
use pyo3::exceptions::{
PyIndexError, PyKeyError, PyNotImplementedError, PyTypeError, PyValueError,
};
use pyo3::gc::PyVisit;
use pyo3::prelude::*;
use pyo3::types::IntoPyDict;
Expand Down Expand Up @@ -618,6 +621,24 @@ macro_rules! custom_vec_iter_impl {
}
}

#[classmethod]
#[pyo3(signature = (key, /))]
pub fn __class_getitem__(
cls: &Bound<'_, pyo3::types::PyType>,
key: &Bound<'_, PyAny>,
) -> PyResult<PyObject> {
let class_string = stringify!($name);
const ALLOWED_GENERIC_CLASSES: [&str; 3] =
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Out of all changes this is the one I am least confident about because it feels like a hack

["BFSSuccessors", "BFSPredecessors", "WeightedEdgeList"];
if !ALLOWED_GENERIC_CLASSES.contains(&class_string) {
return Err(PyTypeError::new_err(format!(
"type 'rustworkx.{}' is not subscriptable",
class_string
)));
}
generic_class_getitem(cls, key)
}

#[pyo3(signature = (dtype=None, copy=None))]
fn __array__(
&self,
Expand Down Expand Up @@ -1290,6 +1311,23 @@ macro_rules! custom_hash_map_iter_impl {
}
}

#[classmethod]
#[pyo3(signature = (key, /))]
pub fn __class_getitem__(
cls: &Bound<'_, pyo3::types::PyType>,
key: &Bound<'_, PyAny>,
) -> PyResult<PyObject> {
let class_string = stringify!($name);
const ALLOWED_GENERIC_CLASSES: [&str; 1] = ["EdgeIndexMap"];
if !ALLOWED_GENERIC_CLASSES.contains(&class_string) {
return Err(PyTypeError::new_err(format!(
"type 'rustworkx.{}' is not subscriptable",
class_string
)));
}
generic_class_getitem(cls, key)
}

fn __iter__(slf: PyRef<Self>) -> $nameKeys {
$nameKeys {
$keys: slf.$data.keys().copied().collect(),
Expand Down
37 changes: 37 additions & 0 deletions tests/test_annotations.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
import types
import unittest

from typing import Optional

class TestAnnotationSubscriptions(unittest.TestCase):
def test_digraph(self):
Expand All @@ -36,3 +37,39 @@ def test_dag(self):
graph.__class_getitem__((int, int)),
types.GenericAlias,
)

def test_custom_vector_allowed(self):
graph: rx.PyGraph[Optional[int], Optional[int]] = rx.generators.path_graph(5)
we_list: rx.WeightedEdgeList[Optional[int]] = graph.weighted_edge_list()
self.assertIsInstance(
we_list.__class_getitem__((Optional[int],)),
types.GenericAlias,
)

def test_custom_vector_not_allowed(self):
graph: rx.PyGraph[Optional[int], Optional[int]] = rx.generators.path_graph(5)
edge_list: rx.EdgeList = graph.edge_list()
with self.assertRaises(TypeError):
self.assertIsInstance(
edge_list.__class_getitem__((Optional[int],)),
types.GenericAlias,
)

def test_custom_hashmap_allowed(self):
graph: rx.PyGraph[Optional[int], Optional[int]] = rx.generators.path_graph(5)
ei_map: rx.WeightedEdgeList[Optional[int]] = graph.edge_index_map()
self.assertIsInstance(
ei_map.__class_getitem__((Optional[int],)),
types.GenericAlias,
)

def test_custom_hashmap_not_allowed(self):
graph: rx.PyGraph[Optional[int], Optional[int]] = rx.generators.path_graph(5)
all_pairs_pm: rx.AllPairsPathMapping = rx.all_pairs_dijkstra_shortest_paths(
graph, lambda _: 1.0
)
with self.assertRaises(TypeError):
self.assertIsInstance(
all_pairs_pm.__class_getitem__((Optional[int],)),
types.GenericAlias,
)
Loading