Skip to content

Commit 4327583

Browse files
Move ancestors and descendants to rustworkx-core (#1208)
This commit adds an implementation of the ancestors and descendants functions to the rustworkx-core crate exposing the functions to rust users. The existing implementation in the rustworkx crate is removed and it is updated to call the rustworkx-core functions. These new functions will be more efficient as they're not using dijkstra's algorithm to find a path from nodes now and instead are just doing a BFS. The rustwork-core functions also return an iterator of nodes. Co-authored-by: mergify[bot] <37929162+mergify[bot]@users.noreply.github.com>
1 parent c40b169 commit 4327583

File tree

3 files changed

+159
-20
lines changed

3 files changed

+159
-20
lines changed
Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,12 @@
1+
---
2+
features:
3+
- |
4+
Added a new function ``ancestors()`` to the
5+
``rustworkx_core::traversal`` module. That is a generic Rust implementation
6+
for the core rust library that provides the
7+
:func:`.ancestors` function to Rust users.
8+
- |
9+
Added a new function ``descendants()`` to the
10+
``rustworkx_core::traversal`` module. That is a generic Rust implementation
11+
for the core rust library that provides the
12+
:func:`.descendants` function to Rust users.

rustworkx-core/src/traversal/mod.rs

Lines changed: 137 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,13 @@ mod dfs_edges;
1717
mod dfs_visit;
1818
mod dijkstra_visit;
1919

20+
use petgraph::prelude::*;
21+
use petgraph::visit::GraphRef;
22+
use petgraph::visit::IntoNeighborsDirected;
23+
use petgraph::visit::Reversed;
24+
use petgraph::visit::VisitMap;
25+
use petgraph::visit::Visitable;
26+
2027
pub use bfs_visit::{breadth_first_search, BfsEvent};
2128
pub use dfs_edges::dfs_edges;
2229
pub use dfs_visit::{depth_first_search, DfsEvent};
@@ -45,3 +52,133 @@ macro_rules! try_control {
4552
}
4653

4754
use try_control;
55+
56+
struct AncestryWalker<G, N, VM> {
57+
graph: G,
58+
walker: Bfs<N, VM>,
59+
}
60+
61+
impl<
62+
G: GraphRef + Visitable + IntoNeighborsDirected<NodeId = N>,
63+
N: Copy + Clone + PartialEq,
64+
VM: VisitMap<N>,
65+
> Iterator for AncestryWalker<G, N, VM>
66+
{
67+
type Item = N;
68+
fn next(&mut self) -> Option<Self::Item> {
69+
self.walker.next(self.graph)
70+
}
71+
}
72+
73+
/// Return the ancestors of a node in a graph.
74+
///
75+
/// `node` is included in the output
76+
///
77+
/// # Arguments:
78+
///
79+
/// * `node` - The node to find the ancestors of
80+
///
81+
/// # Returns
82+
///
83+
/// An iterator where each item is a node id for an ancestor of ``node``.
84+
/// This includes ``node`` in the returned ids.
85+
///
86+
/// # Example
87+
///
88+
/// ```rust
89+
/// use rustworkx_core::traversal::ancestors;
90+
/// use rustworkx_core::petgraph::stable_graph::{StableDiGraph, NodeIndex};
91+
///
92+
/// let graph: StableDiGraph<(), ()> = StableDiGraph::from_edges(&[
93+
/// (0, 1), (1, 2), (1, 3), (2, 4), (3, 4), (4, 5)
94+
/// ]);
95+
/// let ancestors: Vec<usize> = ancestors(&graph, NodeIndex::new(3)).map(|x| x.index()).collect();
96+
/// assert_eq!(vec![3_usize, 1, 0], ancestors);
97+
/// ```
98+
pub fn ancestors<G>(graph: G, node: G::NodeId) -> impl Iterator<Item = G::NodeId>
99+
where
100+
G: GraphRef + Visitable + IntoNeighborsDirected,
101+
{
102+
let reversed = Reversed(graph);
103+
AncestryWalker {
104+
graph: reversed,
105+
walker: Bfs::new(reversed, node),
106+
}
107+
}
108+
109+
/// Return the descendants of a node in a graph.
110+
///
111+
/// `node` is included in the output.
112+
/// # Arguments:
113+
///
114+
/// * `node` - The node to find the ancestors of
115+
///
116+
/// # Returns
117+
///
118+
/// An iterator where each item is a node id for an ancestor of ``node``.
119+
/// This includes ``node`` in the returned ids.
120+
///
121+
/// # Example
122+
///
123+
/// ```rust
124+
/// use rustworkx_core::traversal::descendants;
125+
/// use rustworkx_core::petgraph::stable_graph::{StableDiGraph, NodeIndex};
126+
///
127+
/// let graph: StableDiGraph<(), ()> = StableDiGraph::from_edges(&[
128+
/// (0, 1), (1, 2), (1, 3), (2, 4), (3, 4), (4, 5)
129+
/// ]);
130+
/// let descendants: Vec<usize> = descendants(&graph, NodeIndex::new(3)).map(|x| x.index()).collect();
131+
/// assert_eq!(vec![3_usize, 4, 5], descendants);
132+
/// ```
133+
pub fn descendants<G>(graph: G, node: G::NodeId) -> impl Iterator<Item = G::NodeId>
134+
where
135+
G: GraphRef + Visitable + IntoNeighborsDirected,
136+
{
137+
AncestryWalker {
138+
graph,
139+
walker: Bfs::new(graph, node),
140+
}
141+
}
142+
143+
#[cfg(test)]
144+
mod test_ancestry {
145+
use super::{ancestors, descendants};
146+
use crate::petgraph::graph::DiGraph;
147+
use crate::petgraph::stable_graph::{NodeIndex, StableDiGraph};
148+
149+
#[test]
150+
fn test_ancestors_digraph() {
151+
let graph: DiGraph<(), ()> =
152+
DiGraph::from_edges(&[(0, 1), (1, 2), (1, 3), (2, 4), (3, 4), (4, 5)]);
153+
let ancestors: Vec<usize> = ancestors(&graph, NodeIndex::new(3))
154+
.map(|x| x.index())
155+
.collect();
156+
assert_eq!(vec![3_usize, 1, 0], ancestors);
157+
}
158+
159+
#[test]
160+
fn test_descendants() {
161+
let graph: DiGraph<(), ()> =
162+
DiGraph::from_edges(&[(0, 1), (1, 2), (1, 3), (2, 4), (3, 4), (4, 5)]);
163+
let descendants: Vec<usize> = descendants(&graph, NodeIndex::new(3))
164+
.map(|x| x.index())
165+
.collect();
166+
assert_eq!(vec![3_usize, 4, 5], descendants);
167+
}
168+
169+
#[test]
170+
fn test_no_ancestors() {
171+
let mut graph: StableDiGraph<(), ()> = StableDiGraph::new();
172+
let index = graph.add_node(());
173+
let res = ancestors(&graph, index);
174+
assert_eq!(vec![index], res.collect::<Vec<NodeIndex>>())
175+
}
176+
177+
#[test]
178+
fn test_no_descendants() {
179+
let mut graph: StableDiGraph<(), ()> = StableDiGraph::new();
180+
let index = graph.add_node(());
181+
let res = descendants(&graph, index);
182+
assert_eq!(vec![index], res.collect::<Vec<NodeIndex>>())
183+
}
184+
}

src/traversal/mod.rs

Lines changed: 10 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,8 @@ use dfs_visit::{dfs_handler, PyDfsVisitor};
1919
use dijkstra_visit::{dijkstra_handler, PyDijkstraVisitor};
2020

2121
use rustworkx_core::traversal::{
22-
breadth_first_search, depth_first_search, dfs_edges, dijkstra_search,
22+
ancestors as core_ancestors, breadth_first_search, depth_first_search,
23+
descendants as core_descendants, dfs_edges, dijkstra_search,
2324
};
2425

2526
use super::{digraph, graph, iterators, CostFn};
@@ -32,7 +33,6 @@ use pyo3::exceptions::PyTypeError;
3233
use pyo3::prelude::*;
3334
use pyo3::Python;
3435

35-
use petgraph::algo;
3636
use petgraph::graph::NodeIndex;
3737
use petgraph::visit::{Bfs, NodeCount, Reversed};
3838

@@ -221,16 +221,10 @@ pub fn bfs_predecessors(
221221
#[pyfunction]
222222
#[pyo3(text_signature = "(graph, node, /)")]
223223
pub fn ancestors(graph: &digraph::PyDiGraph, node: usize) -> HashSet<usize> {
224-
let index = NodeIndex::new(node);
225-
let mut out_set: HashSet<usize> = HashSet::new();
226-
let reverse_graph = Reversed(&graph.graph);
227-
let res = algo::dijkstra(reverse_graph, index, None, |_| 1);
228-
for n in res.keys() {
229-
let n_int = n.index();
230-
out_set.insert(n_int);
231-
}
232-
out_set.remove(&node);
233-
out_set
224+
core_ancestors(&graph.graph, NodeIndex::new(node))
225+
.map(|x| x.index())
226+
.filter(|x| *x != node)
227+
.collect()
234228
}
235229

236230
/// Return the descendants of a node in a graph.
@@ -249,14 +243,10 @@ pub fn ancestors(graph: &digraph::PyDiGraph, node: usize) -> HashSet<usize> {
249243
#[pyo3(text_signature = "(graph, node, /)")]
250244
pub fn descendants(graph: &digraph::PyDiGraph, node: usize) -> HashSet<usize> {
251245
let index = NodeIndex::new(node);
252-
let mut out_set: HashSet<usize> = HashSet::new();
253-
let res = algo::dijkstra(&graph.graph, index, None, |_| 1);
254-
for n in res.keys() {
255-
let n_int = n.index();
256-
out_set.insert(n_int);
257-
}
258-
out_set.remove(&node);
259-
out_set
246+
core_descendants(&graph.graph, index)
247+
.map(|x| x.index())
248+
.filter(|x| *x != node)
249+
.collect()
260250
}
261251

262252
/// Breadth-first traversal of a directed graph.

0 commit comments

Comments
 (0)