Skip to content

Commit 7598155

Browse files
authored
Merge branch 'main' into remove-39-retworkx
2 parents 1dd621c + 8145325 commit 7598155

File tree

10 files changed

+227
-4
lines changed

10 files changed

+227
-4
lines changed
Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,22 @@
1+
2+
---
3+
features:
4+
- |
5+
Added a new :func:`~rustworkx.bfs_layers` (and it's per type variants
6+
:func:`~rustworkx.graph_bfs_layers` and :func:`~rustworkx.digraph_bfs_layers`)
7+
that performs a breadth-first search traversal and returns the nodes organized
8+
by their BFS layers/levels. Each layer contains all nodes at the same distance
9+
from the source nodes. This is useful for analyzing graph structure and
10+
implementing algorithms that need to process nodes level by level.
11+
For example:
12+
13+
.. jupyter-execute::
14+
15+
import rustworkx
16+
17+
graph = rustworkx.PyDiGraph()
18+
graph.extend_from_edge_list([(0, 1), (0, 2), (1, 3), (2, 3), (3, 4)])
19+
20+
# Print the layers of the graph in BFS-order relative to 0 (source)
21+
layers = rustworkx.bfs_layers(graph, sources=[0])
22+
print('BFS layers:', layers)

rustworkx-core/src/traversal/bfs_visit.rs

Lines changed: 35 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -10,11 +10,10 @@
1010
// License for the specific language governing permissions and limitations
1111
// under the License.
1212

13+
use super::try_control;
1314
use petgraph::visit::{ControlFlow, EdgeRef, IntoEdges, VisitMap, Visitable};
1415
use std::collections::VecDeque;
1516

16-
use super::try_control;
17-
1817
/// A breadth first search (BFS) visitor event.
1918
#[derive(Copy, Clone, Debug)]
2019
pub enum BfsEvent<N, E> {
@@ -248,3 +247,37 @@ where
248247

249248
C::continuing()
250249
}
250+
251+
pub fn bfs_layers<G, I>(graph: G, sources: I) -> Vec<Vec<G::NodeId>>
252+
where
253+
G: IntoEdges + Visitable,
254+
I: IntoIterator<Item = G::NodeId>,
255+
G::NodeId: Copy + std::hash::Hash + Eq,
256+
{
257+
let mut visited = hashbrown::HashSet::new();
258+
let mut current_layer: Vec<G::NodeId> = sources.into_iter().collect();
259+
260+
for &node in &current_layer {
261+
visited.insert(node);
262+
}
263+
264+
let mut layers: Vec<Vec<G::NodeId>> = Vec::new();
265+
266+
while !current_layer.is_empty() {
267+
layers.push(current_layer.clone());
268+
269+
let mut next_layer = Vec::new();
270+
for &node in &current_layer {
271+
for edge in graph.edges(node) {
272+
let child = edge.target();
273+
if visited.insert(child) {
274+
next_layer.push(child);
275+
}
276+
}
277+
}
278+
279+
current_layer = next_layer;
280+
}
281+
282+
layers
283+
}

rustworkx-core/src/traversal/mod.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,7 @@ use petgraph::visit::Reversed;
2424
use petgraph::visit::VisitMap;
2525
use petgraph::visit::Visitable;
2626

27-
pub use bfs_visit::{breadth_first_search, BfsEvent};
27+
pub use bfs_visit::{bfs_layers, breadth_first_search, BfsEvent};
2828
pub use dfs_edges::dfs_edges;
2929
pub use dfs_visit::{depth_first_search, DfsEvent};
3030
pub use dijkstra_visit::{dijkstra_search, DijkstraEvent};

rustworkx/__init__.py

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2311,3 +2311,21 @@ def write_graphml(graph, path, /, keys=None, compression=None):
23112311
:raises RuntimeError: when an error is encountered while writing the GraphML file.
23122312
"""
23132313
raise TypeError(f"Invalid Input Type {type(graph)} for graph")
2314+
2315+
2316+
@_rustworkx_dispatch
2317+
def bfs_layers(graph, sources=None):
2318+
"""Return the BFS layers of a graph as a list of lists.
2319+
2320+
:param graph: The input graph to use. Can either be a
2321+
:class:`~rustworkx.PyGraph` or :class:`~rustworkx.PyDiGraph`
2322+
:param sources: An optional list of node indices to use as the starting
2323+
nodes for the BFS traversal. If not specified, all nodes in the graph
2324+
will be used as sources.
2325+
:type sources: list[int]
2326+
2327+
:returns: A list of lists where each inner list contains the node indices
2328+
at that BFS layer/level from the source nodes
2329+
:rtype: list[list[int]]
2330+
"""
2331+
raise TypeError(f"Invalid Input Type {type(graph)} for graph")

rustworkx/__init__.pyi

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -255,6 +255,8 @@ from .rustworkx import digraph_transitivity as digraph_transitivity
255255
from .rustworkx import graph_transitivity as graph_transitivity
256256
from .rustworkx import digraph_bfs_search as digraph_bfs_search
257257
from .rustworkx import graph_bfs_search as graph_bfs_search
258+
from .rustworkx import digraph_bfs_layers as digraph_bfs_layers
259+
from .rustworkx import graph_bfs_layers as graph_bfs_layers
258260
from .rustworkx import digraph_dfs_search as digraph_dfs_search
259261
from .rustworkx import graph_dfs_search as graph_dfs_search
260262
from .rustworkx import digraph_dijkstra_search as digraph_dijkstra_search
@@ -674,3 +676,7 @@ def write_graphml(
674676
keys: list[GraphMLKey] | None = ...,
675677
compression: str | None = ...,
676678
) -> None: ...
679+
def bfs_layers(
680+
graph: PyGraph | PyDiGraph,
681+
sources: Sequence[int] | None = ...,
682+
) -> list[list[int]]: ...

rustworkx/rustworkx.pyi

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1082,6 +1082,14 @@ def graph_dfs_search(
10821082
source: Sequence[int] | None = ...,
10831083
visitor: _DFSVisitor | None = ...,
10841084
) -> None: ...
1085+
def digraph_bfs_layers(
1086+
digraph: PyDiGraph,
1087+
sources: Sequence[int] | None = ...,
1088+
) -> list[list[int]]: ...
1089+
def graph_bfs_layers(
1090+
graph: PyGraph,
1091+
sources: Sequence[int] | None = ...,
1092+
) -> list[list[int]]: ...
10851093
def digraph_dijkstra_search(
10861094
graph: PyDiGraph,
10871095
source: Sequence[int] | None = ...,

src/lib.rs

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -664,6 +664,8 @@ fn rustworkx(py: Python<'_>, m: &Bound<PyModule>) -> PyResult<()> {
664664
m.add_wrapped(wrap_pyfunction!(steiner_tree::steiner_tree))?;
665665
m.add_wrapped(wrap_pyfunction!(digraph_dfs_search))?;
666666
m.add_wrapped(wrap_pyfunction!(graph_dfs_search))?;
667+
m.add_wrapped(wrap_pyfunction!(digraph_bfs_layers))?;
668+
m.add_wrapped(wrap_pyfunction!(graph_bfs_layers))?;
667669
m.add_wrapped(wrap_pyfunction!(articulation_points))?;
668670
m.add_wrapped(wrap_pyfunction!(bridges))?;
669671
m.add_wrapped(wrap_pyfunction!(biconnected_components))?;

src/traversal/mod.rs

Lines changed: 77 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@ use dfs_visit::{dfs_handler, PyDfsVisitor};
1919
use dijkstra_visit::{dijkstra_handler, PyDijkstraVisitor};
2020

2121
use rustworkx_core::traversal::{
22-
ancestors as core_ancestors, bfs_predecessors as core_bfs_predecessors,
22+
ancestors as core_ancestors, bfs_layers, bfs_predecessors as core_bfs_predecessors,
2323
bfs_successors as core_bfs_successors, breadth_first_search, depth_first_search,
2424
descendants as core_descendants, dfs_edges, dijkstra_search,
2525
};
@@ -32,6 +32,8 @@ use hashbrown::HashSet;
3232

3333
use pyo3::exceptions::{PyIndexError, PyTypeError};
3434
use pyo3::prelude::*;
35+
use pyo3::types::PyList;
36+
use pyo3::IntoPyObjectExt;
3537
use pyo3::Python;
3638

3739
use petgraph::graph::NodeIndex;
@@ -1116,3 +1118,77 @@ pub fn graph_dijkstra_search(
11161118

11171119
Ok(())
11181120
}
1121+
1122+
/// Return the BFS layers of a PyGraph as a list of lists.
1123+
///
1124+
/// :param graph: The input PyGraph to use for BFS traversal
1125+
/// :type graph: PyGraph
1126+
/// :param sources: An optional list of node indices to use as the starting
1127+
/// nodes for the BFS traversal. If not specified, all nodes in the graph
1128+
/// will be used as sources.
1129+
/// :type sources: list[int] or None
1130+
///
1131+
/// :returns: A list of lists where each inner list contains the node indices
1132+
/// at that BFS layer/level from the source nodes
1133+
/// :rtype: list[list[int]]
1134+
#[pyfunction]
1135+
#[pyo3(signature = (graph, sources=None))]
1136+
pub fn graph_bfs_layers(
1137+
py: Python,
1138+
graph: &graph::PyGraph,
1139+
sources: Option<Vec<usize>>,
1140+
) -> PyResult<PyObject> {
1141+
let starts: Vec<NodeIndex> = match sources {
1142+
Some(v) => v.into_iter().map(NodeIndex::new).collect(),
1143+
None => graph.graph.node_indices().collect(),
1144+
};
1145+
1146+
validate_source_nodes(&graph.graph, &starts)?;
1147+
1148+
let layers = bfs_layers(&graph.graph, starts);
1149+
1150+
let py_layers = PyList::empty(py);
1151+
for layer in layers {
1152+
let ids: Vec<usize> = layer.into_iter().map(|n| n.index()).collect();
1153+
let sublist = PyList::new(py, &ids)?;
1154+
py_layers.append(sublist)?;
1155+
}
1156+
py_layers.into_py_any(py)
1157+
}
1158+
1159+
/// Return the BFS layers of a PyDiGraph as a list of lists.
1160+
///
1161+
/// :param graph: The input PyDiGraph to use for BFS traversal
1162+
/// :type graph: PyDiGraph
1163+
/// :param sources: An optional list of node indices to use as the starting
1164+
/// nodes for the BFS traversal. If not specified, all nodes in the graph
1165+
/// will be used as sources.
1166+
/// :type sources: list[int] or None
1167+
///
1168+
/// :returns: A list of lists where each inner list contains the node indices
1169+
/// at that BFS layer/level from the source nodes
1170+
/// :rtype: list[list[int]]
1171+
#[pyfunction]
1172+
#[pyo3(signature = (digraph, sources=None))]
1173+
pub fn digraph_bfs_layers(
1174+
py: Python,
1175+
digraph: &digraph::PyDiGraph,
1176+
sources: Option<Vec<usize>>,
1177+
) -> PyResult<PyObject> {
1178+
let starts: Vec<NodeIndex> = match sources {
1179+
Some(v) => v.into_iter().map(NodeIndex::new).collect(),
1180+
None => digraph.graph.node_indices().collect(),
1181+
};
1182+
1183+
validate_source_nodes(&digraph.graph, &starts)?;
1184+
1185+
let layers = bfs_layers(&digraph.graph, starts);
1186+
1187+
let py_layers = PyList::empty(py);
1188+
for layer in layers {
1189+
let ids: Vec<usize> = layer.into_iter().map(|n| n.index()).collect();
1190+
let sublist = PyList::new(py, &ids)?;
1191+
py_layers.append(sublist)?;
1192+
}
1193+
py_layers.into_py_any(py)
1194+
}

tests/digraph/test_bfs_layer.py

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,29 @@
1+
import unittest
2+
import rustworkx
3+
4+
5+
class TestDiGraphBfsLayers(unittest.TestCase):
6+
def setUp(self):
7+
self.graph = rustworkx.generators.path_graph(6).to_directed()
8+
9+
def test_simple_chain(self):
10+
layers = rustworkx.bfs_layers(self.graph, [3])
11+
self.assertEqual([sorted(layer) for layer in layers], [[3], [2, 4], [1, 5], [0]])
12+
13+
def test_multiple_sources(self):
14+
layers = rustworkx.bfs_layers(self.graph, [0, 3])
15+
self.assertEqual(sorted(layers[0]), [0, 3])
16+
17+
def test_disconnected_digraph(self):
18+
g = rustworkx.PyDiGraph()
19+
g.extend_from_edge_list([(0, 1), (2, 3)])
20+
layers = rustworkx.bfs_layers(g, [2])
21+
self.assertEqual(layers, [[2], [3]])
22+
23+
def test_no_sources_defaults(self):
24+
layers = rustworkx.bfs_layers(self.graph, None)
25+
self.assertTrue(any(0 in layer for layer in layers))
26+
27+
def test_invalid_source(self):
28+
with self.assertRaises(IndexError):
29+
rustworkx.bfs_layers(self.graph, [42])

tests/graph/test_bfs_layer.py

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,29 @@
1+
import unittest
2+
import rustworkx
3+
4+
5+
class TestGraphBfsLayers(unittest.TestCase):
6+
def setUp(self):
7+
self.graph = rustworkx.generators.path_graph(5)
8+
9+
def test_simple_path(self):
10+
layers = rustworkx.bfs_layers(self.graph, [0])
11+
self.assertEqual(layers, [[0], [1], [2], [3], [4]])
12+
13+
def test_multiple_sources(self):
14+
layers = rustworkx.bfs_layers(self.graph, [0, 4])
15+
self.assertEqual(layers, [[0, 4], [1, 3], [2]])
16+
17+
def test_disconnected_graph(self):
18+
g = rustworkx.PyGraph()
19+
g.extend_from_edge_list([(0, 1), (2, 3)])
20+
layers = rustworkx.bfs_layers(g, [0])
21+
self.assertEqual(layers, [[0], [1]])
22+
23+
def test_no_sources_default_all_nodes(self):
24+
layers = rustworkx.bfs_layers(self.graph, None)
25+
self.assertTrue(all(isinstance(layer, list) for layer in layers))
26+
27+
def test_invalid_source(self):
28+
with self.assertRaises(IndexError):
29+
rustworkx.bfs_layers(self.graph, [99])

0 commit comments

Comments
 (0)