Skip to content

Commit 44ed29e

Browse files
IvanIsCodingSILIZ4
authored andcommitted
Fix bfs_search and other search methods panicking with invalid sources (Qiskit#1388)
* Validate source nodes in search functions * Add tests for graph searches * Add tests for digraph searches * Black * Add release notes * Address clippy
1 parent f8b0e52 commit 44ed29e

File tree

8 files changed

+66
-1
lines changed

8 files changed

+66
-1
lines changed
Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,6 @@
1+
---
2+
fixes:
3+
- |
4+
Fixed a panic when passing invalid source nodes to search methods
5+
such as :func:`~rustworkx.bfs_search`. See
6+
`#1386 <https://github.com/Qiskit/rustworkx/issues/1386>`__ for more details.

src/traversal/mod.rs

Lines changed: 30 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -30,13 +30,30 @@ use std::convert::TryFrom;
3030

3131
use hashbrown::HashSet;
3232

33-
use pyo3::exceptions::PyTypeError;
33+
use pyo3::exceptions::{PyIndexError, PyTypeError};
3434
use pyo3::prelude::*;
3535
use pyo3::Python;
3636

3737
use petgraph::graph::NodeIndex;
38+
use petgraph::EdgeType;
3839

3940
use crate::iterators::EdgeList;
41+
use crate::StablePyGraph;
42+
43+
fn validate_source_nodes<Ty: EdgeType>(
44+
graph: &StablePyGraph<Ty>,
45+
starts: &[NodeIndex],
46+
) -> PyResult<()> {
47+
for index in starts.iter() {
48+
if !graph.contains_node(*index) {
49+
return Err(PyIndexError::new_err(format!(
50+
"Node source index \"{}\" out of graph bound",
51+
index.index()
52+
)));
53+
}
54+
}
55+
Ok(())
56+
}
4057

4158
/// Get an edge list of the tree edges from a depth-first traversal
4259
///
@@ -386,6 +403,8 @@ pub fn digraph_bfs_search(
386403
None => graph.graph.node_indices().collect(),
387404
};
388405

406+
validate_source_nodes(&graph.graph, &starts)?;
407+
389408
breadth_first_search(&graph.graph, starts, |event| {
390409
bfs_handler(py, &visitor, event)
391410
})?;
@@ -530,6 +549,8 @@ pub fn graph_bfs_search(
530549
None => graph.graph.node_indices().collect(),
531550
};
532551

552+
validate_source_nodes(&graph.graph, &starts)?;
553+
533554
breadth_first_search(&graph.graph, starts, |event| {
534555
bfs_handler(py, &visitor, event)
535556
})?;
@@ -644,6 +665,8 @@ pub fn digraph_dfs_search(
644665
None => graph.graph.node_indices().collect(),
645666
};
646667

668+
validate_source_nodes(&graph.graph, &starts)?;
669+
647670
depth_first_search(&graph.graph, starts, |event| {
648671
dfs_handler(py, &visitor, event)
649672
})?;
@@ -758,6 +781,8 @@ pub fn graph_dfs_search(
758781
None => graph.graph.node_indices().collect(),
759782
};
760783

784+
validate_source_nodes(&graph.graph, &starts)?;
785+
761786
depth_first_search(&graph.graph, starts, |event| {
762787
dfs_handler(py, &visitor, event)
763788
})?;
@@ -895,6 +920,8 @@ pub fn digraph_dijkstra_search(
895920
None => graph.graph.node_indices().collect(),
896921
};
897922

923+
validate_source_nodes(&graph.graph, &starts)?;
924+
898925
let edge_cost_fn = CostFn::try_from((weight_fn, 1.0))?;
899926
dijkstra_search(
900927
&graph.graph,
@@ -1036,6 +1063,8 @@ pub fn graph_dijkstra_search(
10361063
None => graph.graph.node_indices().collect(),
10371064
};
10381065

1066+
validate_source_nodes(&graph.graph, &starts)?;
1067+
10391068
let edge_cost_fn = CostFn::try_from((weight_fn, 1.0))?;
10401069
dijkstra_search(
10411070
&graph.graph,

tests/digraph/test_bfs_search.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -160,3 +160,8 @@ def gray_target_edge(self, _):
160160

161161
vis = PruneGrayTargetEdge()
162162
rustworkx.digraph_bfs_search(self.graph, [0], vis)
163+
164+
def test_invalid_source(self):
165+
graph = rustworkx.PyDiGraph()
166+
with self.assertRaises(IndexError):
167+
rustworkx.bfs_search(graph, [1], rustworkx.visit.BFSVisitor())

tests/digraph/test_dfs_search.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -104,3 +104,8 @@ def reconstruct_path(self):
104104
except rustworkx.visit.StopSearch:
105105
pass
106106
self.assertEqual(vis.reconstruct_path(), [0, 2, 5, 3])
107+
108+
def test_invalid_source(self):
109+
graph = rustworkx.PyDiGraph()
110+
with self.assertRaises(IndexError):
111+
rustworkx.dfs_search(graph, [1], rustworkx.visit.DFSVisitor())

tests/digraph/test_dijkstra_search.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -187,3 +187,8 @@ def edge_not_relaxed(self, _):
187187

188188
vis = PruneEdgeNotRelaxed()
189189
rustworkx.digraph_dijkstra_search(self.graph, [0], float, vis)
190+
191+
def test_invalid_source(self):
192+
graph = rustworkx.PyDiGraph()
193+
with self.assertRaises(IndexError):
194+
rustworkx.dijkstra_search(graph, [1], float, rustworkx.visit.DijkstraVisitor())

tests/graph/test_bfs_search.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -160,3 +160,8 @@ def gray_target_edge(self, _):
160160

161161
vis = PruneGrayTargetEdge()
162162
rustworkx.graph_bfs_search(self.graph, [0], vis)
163+
164+
def test_invalid_source(self):
165+
graph = rustworkx.PyGraph()
166+
with self.assertRaises(IndexError):
167+
rustworkx.bfs_search(graph, [1], rustworkx.visit.BFSVisitor())

tests/graph/test_dfs_search.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -104,3 +104,8 @@ def reconstruct_path(self):
104104
except rustworkx.visit.StopSearch:
105105
pass
106106
self.assertEqual(vis.reconstruct_path(), [0, 2, 5, 3])
107+
108+
def test_invalid_source(self):
109+
graph = rustworkx.PyGraph()
110+
with self.assertRaises(IndexError):
111+
rustworkx.dfs_search(graph, [1], rustworkx.visit.DFSVisitor())

tests/graph/test_dijkstra_search.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -187,3 +187,8 @@ def edge_not_relaxed(self, _):
187187

188188
vis = PruneEdgeNotRelaxed()
189189
rustworkx.graph_dijkstra_search(self.graph, [0], float, vis)
190+
191+
def test_invalid_source(self):
192+
graph = rustworkx.PyGraph()
193+
with self.assertRaises(IndexError):
194+
rustworkx.dijkstra_search(graph, [1], float, rustworkx.visit.DijkstraVisitor())

0 commit comments

Comments
 (0)