Skip to content

Commit 9c03098

Browse files
IvanIsCodingSILIZ4
authored andcommitted
Fix panic for ancestors and descendants when the source node is invalid (Qiskit#1389)
* Raise IndexError instead of panicking for ancestors and descendants * Add tests * Black * Add release notes
1 parent 44ed29e commit 9c03098

File tree

3 files changed

+34
-6
lines changed

3 files changed

+34
-6
lines changed
Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,5 @@
1+
fixes:
2+
- |
3+
Fixed a panic when passing an invalid source node to
4+
:func:`~rustworkx.ancenstors` and :func:`~rustworkx.descendants`. See
5+
`#1381 <https://github.com/Qiskit/rustworkx/issues/1381>`__ for more information.

src/traversal/mod.rs

Lines changed: 19 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -236,11 +236,18 @@ pub fn bfs_predecessors(
236236
/// :rtype: set
237237
#[pyfunction]
238238
#[pyo3(text_signature = "(graph, node, /)")]
239-
pub fn ancestors(graph: &digraph::PyDiGraph, node: usize) -> HashSet<usize> {
240-
core_ancestors(&graph.graph, NodeIndex::new(node))
239+
pub fn ancestors(graph: &digraph::PyDiGraph, node: usize) -> PyResult<HashSet<usize>> {
240+
let index = NodeIndex::new(node);
241+
if !graph.graph.contains_node(index) {
242+
return Err(PyIndexError::new_err(format!(
243+
"Node source index \"{}\" out of graph bound",
244+
node
245+
)));
246+
}
247+
Ok(core_ancestors(&graph.graph, index)
241248
.map(|x| x.index())
242249
.filter(|x| *x != node)
243-
.collect()
250+
.collect())
244251
}
245252

246253
/// Return the descendants of a node in a graph.
@@ -257,12 +264,18 @@ pub fn ancestors(graph: &digraph::PyDiGraph, node: usize) -> HashSet<usize> {
257264
/// :rtype: set
258265
#[pyfunction]
259266
#[pyo3(text_signature = "(graph, node, /)")]
260-
pub fn descendants(graph: &digraph::PyDiGraph, node: usize) -> HashSet<usize> {
267+
pub fn descendants(graph: &digraph::PyDiGraph, node: usize) -> PyResult<HashSet<usize>> {
261268
let index = NodeIndex::new(node);
262-
core_descendants(&graph.graph, index)
269+
if !graph.graph.contains_node(index) {
270+
return Err(PyIndexError::new_err(format!(
271+
"Node source index \"{}\" out of graph bound",
272+
node
273+
)));
274+
}
275+
Ok(core_descendants(&graph.graph, index)
263276
.map(|x| x.index())
264277
.filter(|x| *x != node)
265-
.collect()
278+
.collect())
266279
}
267280

268281
/// Breadth-first traversal of a directed graph with several source vertices.

tests/digraph/test_ancestors_descendants.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,11 @@ def test_ancestors_no_descendants(self):
3939
res = rustworkx.ancestors(dag, node_b)
4040
self.assertEqual({node_a}, res)
4141

42+
def test_invalid_source(self):
43+
graph = rustworkx.generators.directed_path_graph(5)
44+
with self.assertRaises(IndexError):
45+
rustworkx.ancestors(graph, 10)
46+
4247

4348
class TestDescendants(unittest.TestCase):
4449
def test_descendants(self):
@@ -62,3 +67,8 @@ def test_descendants_no_ancestors(self):
6267
node_c = dag.add_child(node_b, "c", {"b": 1})
6368
res = rustworkx.descendants(dag, node_b)
6469
self.assertEqual({node_c}, res)
70+
71+
def test_invalid_source(self):
72+
graph = rustworkx.generators.directed_path_graph(5)
73+
with self.assertRaises(IndexError):
74+
rustworkx.descendants(graph, 10)

0 commit comments

Comments
 (0)