Skip to content

Commit 8e81911

Browse files
Always return a cycle in digraph_find_cycle if no node is specified and a cycle exists (#1181)
* Handle find arbitrary cycle case * Find node in cycle more smartly * Implement find_node_in_arbitrary_cycle * Switch to Tarjan SCC for single pass DFS * Improve cycle checking logic * More assert_cycle! * Cargo fmt * Add test case for no cycle and no source * assertCycle for existing unit tests * Add more tests * Add self loop Python test * Add release notes and fix test * Address PR comments * Use less traits * Update release notes --------- Co-authored-by: mergify[bot] <37929162+mergify[bot]@users.noreply.github.com>
1 parent 12f8af5 commit 8e81911

File tree

3 files changed

+133
-36
lines changed

3 files changed

+133
-36
lines changed
Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,12 @@
1+
---
2+
fixes:
3+
- |
4+
Fixed the behavior of :func:`~rustworkx.digraph_find_cycle` when
5+
no source node was provided. Previously, the function would start looking
6+
for a cycle at an arbitrary node which was not guaranteed to return a cycle.
7+
Now, the function will smartly choose a source node to start the search from
8+
such that if a cycle exists, it will be found.
9+
other:
10+
- |
11+
The `rustworkx-core` function `rustworkx_core::connectivity::find_cycle` now
12+
requires the `petgraph::visit::Visitable` trait.

rustworkx-core/src/connectivity/find_cycle.rs

Lines changed: 82 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -11,8 +11,9 @@
1111
// under the License.
1212

1313
use hashbrown::{HashMap, HashSet};
14+
use petgraph::algo;
1415
use petgraph::visit::{
15-
EdgeCount, GraphBase, IntoNeighborsDirected, IntoNodeIdentifiers, NodeCount,
16+
EdgeCount, GraphBase, IntoNeighborsDirected, IntoNodeIdentifiers, NodeCount, Visitable,
1617
};
1718
use petgraph::Direction::Outgoing;
1819
use std::hash::Hash;
@@ -57,22 +58,22 @@ where
5758
G: GraphBase,
5859
G: NodeCount,
5960
G: EdgeCount,
60-
for<'b> &'b G: GraphBase<NodeId = G::NodeId> + IntoNodeIdentifiers + IntoNeighborsDirected,
61+
for<'b> &'b G:
62+
GraphBase<NodeId = G::NodeId> + IntoNodeIdentifiers + IntoNeighborsDirected + Visitable,
6163
G::NodeId: Eq + Hash,
6264
{
6365
// Find a cycle in the given graph and return it as a list of edges
64-
let mut graph_nodes: HashSet<G::NodeId> = graph.node_identifiers().collect();
6566
let mut cycle: Vec<(G::NodeId, G::NodeId)> = Vec::with_capacity(graph.edge_count());
66-
let temp_value: G::NodeId;
67-
// If source is not set get an arbitrary node from the set of graph
68-
// nodes we've not "examined"
67+
// If source is not set get a node in an arbitrary cycle if it exists,
68+
// otherwise return that there is no cycle
6969
let source_index = match source {
7070
Some(source_value) => source_value,
71-
None => {
72-
temp_value = *graph_nodes.iter().next().unwrap();
73-
graph_nodes.remove(&temp_value);
74-
temp_value
75-
}
71+
None => match find_node_in_arbitrary_cycle(&graph) {
72+
Some(node_in_cycle) => node_in_cycle,
73+
None => {
74+
return Vec::new();
75+
}
76+
},
7677
};
7778
// Stack (ie "pushdown list") of vertices already in the spanning tree
7879
let mut stack: Vec<G::NodeId> = vec![source_index];
@@ -119,11 +120,47 @@ where
119120
cycle
120121
}
121122

123+
fn find_node_in_arbitrary_cycle<G>(graph: &G) -> Option<G::NodeId>
124+
where
125+
G: GraphBase,
126+
G: NodeCount,
127+
G: EdgeCount,
128+
for<'b> &'b G:
129+
GraphBase<NodeId = G::NodeId> + IntoNodeIdentifiers + IntoNeighborsDirected + Visitable,
130+
G::NodeId: Eq + Hash,
131+
{
132+
for scc in algo::kosaraju_scc(&graph) {
133+
if scc.len() > 1 {
134+
return Some(scc[0]);
135+
}
136+
}
137+
for node in graph.node_identifiers() {
138+
for neighbor in graph.neighbors_directed(node, Outgoing) {
139+
if neighbor == node {
140+
return Some(node);
141+
}
142+
}
143+
}
144+
None
145+
}
146+
122147
#[cfg(test)]
123148
mod tests {
124149
use crate::connectivity::find_cycle;
125150
use petgraph::prelude::*;
126151

152+
// Utility to assert cycles in the response
153+
macro_rules! assert_cycle {
154+
($g: expr, $cycle: expr) => {{
155+
for i in 0..$cycle.len() {
156+
let (s, t) = $cycle[i];
157+
assert!($g.contains_edge(s, t));
158+
let (next_s, _) = $cycle[(i + 1) % $cycle.len()];
159+
assert_eq!(t, next_s);
160+
}
161+
}};
162+
}
163+
127164
#[test]
128165
fn test_find_cycle_source() {
129166
let edge_list = vec![
@@ -141,20 +178,13 @@ mod tests {
141178
(8, 9),
142179
];
143180
let graph = DiGraph::<i32, i32>::from_edges(edge_list);
144-
let mut res: Vec<(usize, usize)> = find_cycle(&graph, Some(NodeIndex::new(0)))
145-
.iter()
146-
.map(|(s, t)| (s.index(), t.index()))
147-
.collect();
148-
assert_eq!(res, [(0, 1), (1, 2), (2, 3), (3, 0)]);
149-
res = find_cycle(&graph, Some(NodeIndex::new(1)))
150-
.iter()
151-
.map(|(s, t)| (s.index(), t.index()))
152-
.collect();
153-
assert_eq!(res, [(1, 2), (2, 3), (3, 0), (0, 1)]);
154-
res = find_cycle(&graph, Some(NodeIndex::new(5)))
155-
.iter()
156-
.map(|(s, t)| (s.index(), t.index()))
157-
.collect();
181+
for i in [0, 1, 2, 3].iter() {
182+
let idx = NodeIndex::new(*i);
183+
let res = find_cycle(&graph, Some(idx));
184+
assert_cycle!(graph, res);
185+
assert_eq!(res[0].0, idx);
186+
}
187+
let res = find_cycle(&graph, Some(NodeIndex::new(5)));
158188
assert_eq!(res, []);
159189
}
160190

@@ -176,10 +206,32 @@ mod tests {
176206
];
177207
let mut graph = DiGraph::<i32, i32>::from_edges(edge_list);
178208
graph.add_edge(NodeIndex::new(1), NodeIndex::new(1), 0);
179-
let res: Vec<(usize, usize)> = find_cycle(&graph, Some(NodeIndex::new(0)))
180-
.iter()
181-
.map(|(s, t)| (s.index(), t.index()))
182-
.collect();
183-
assert_eq!(res, [(1, 1)]);
209+
let res = find_cycle(&graph, Some(NodeIndex::new(0)));
210+
assert_eq!(res[0].0, NodeIndex::new(1));
211+
assert_cycle!(graph, res);
212+
}
213+
214+
#[test]
215+
fn test_self_loop_no_source() {
216+
let edge_list = vec![(0, 1), (1, 2), (2, 3), (2, 2)];
217+
let graph = DiGraph::<i32, i32>::from_edges(edge_list);
218+
let res = find_cycle(&graph, None);
219+
assert_cycle!(graph, res);
220+
}
221+
222+
#[test]
223+
fn test_cycle_no_source() {
224+
let edge_list = vec![(0, 1), (1, 2), (2, 3), (3, 4), (4, 2)];
225+
let graph = DiGraph::<i32, i32>::from_edges(edge_list);
226+
let res = find_cycle(&graph, None);
227+
assert_cycle!(graph, res);
228+
}
229+
230+
#[test]
231+
fn test_no_cycle_no_source() {
232+
let edge_list = vec![(0, 1), (1, 2), (2, 3)];
233+
let graph = DiGraph::<i32, i32>::from_edges(edge_list);
234+
let res = find_cycle(&graph, None);
235+
assert_eq!(res, []);
184236
}
185237
}

tests/digraph/test_find_cycle.py

Lines changed: 39 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
import unittest
1414

1515
import rustworkx
16+
import rustworkx.generators
1617

1718

1819
class TestFindCycle(unittest.TestCase):
@@ -36,30 +37,38 @@ def setUp(self):
3637
]
3738
)
3839

40+
def assertCycle(self, first_node, graph, res):
41+
self.assertEqual(first_node, res[0][0])
42+
for i in range(len(res)):
43+
s, t = res[i]
44+
self.assertTrue(graph.has_edge(s, t))
45+
next_s, _ = res[(i + 1) % len(res)]
46+
self.assertEqual(t, next_s)
47+
3948
def test_find_cycle(self):
4049
graph = rustworkx.PyDiGraph()
4150
graph.add_nodes_from(list(range(6)))
4251
graph.add_edges_from_no_data(
4352
[(0, 1), (0, 3), (0, 5), (1, 2), (2, 3), (3, 4), (4, 5), (4, 0)]
4453
)
4554
res = rustworkx.digraph_find_cycle(graph, 0)
46-
self.assertEqual([(0, 1), (1, 2), (2, 3), (3, 4), (4, 0)], res)
55+
self.assertCycle(0, graph, res)
4756

4857
def test_find_cycle_multiple_roots_same_cycles(self):
4958
res = rustworkx.digraph_find_cycle(self.graph, 0)
50-
self.assertEqual(res, [(0, 1), (1, 2), (2, 3), (3, 0)])
59+
self.assertCycle(0, self.graph, res)
5160
res = rustworkx.digraph_find_cycle(self.graph, 1)
52-
self.assertEqual(res, [(1, 2), (2, 3), (3, 0), (0, 1)])
61+
self.assertCycle(1, self.graph, res)
5362
res = rustworkx.digraph_find_cycle(self.graph, 5)
5463
self.assertEqual(res, [])
5564

5665
def test_find_cycle_disconnected_graphs(self):
5766
self.graph.add_nodes_from(["A", "B", "C"])
5867
self.graph.add_edges_from_no_data([(10, 11), (12, 10), (11, 12)])
5968
res = rustworkx.digraph_find_cycle(self.graph, 0)
60-
self.assertEqual(res, [(0, 1), (1, 2), (2, 3), (3, 0)])
69+
self.assertCycle(0, self.graph, res)
6170
res = rustworkx.digraph_find_cycle(self.graph, 10)
62-
self.assertEqual(res, [(10, 11), (11, 12), (12, 10)])
71+
self.assertCycle(10, self.graph, res)
6372

6473
def test_invalid_types(self):
6574
graph = rustworkx.PyGraph()
@@ -69,4 +78,28 @@ def test_invalid_types(self):
6978
def test_self_loop(self):
7079
self.graph.add_edge(1, 1, None)
7180
res = rustworkx.digraph_find_cycle(self.graph, 0)
72-
self.assertEqual([(1, 1)], res)
81+
self.assertCycle(1, self.graph, res)
82+
83+
def test_no_cycle_no_source(self):
84+
g = rustworkx.generators.directed_grid_graph(10, 10)
85+
res = rustworkx.digraph_find_cycle(g)
86+
self.assertEqual(res, [])
87+
88+
def test_cycle_no_source(self):
89+
g = rustworkx.generators.directed_path_graph(1000)
90+
a = g.add_node(1000)
91+
b = g.node_indices()[-2]
92+
g.add_edge(b, a, None)
93+
g.add_edge(a, b, None)
94+
res = rustworkx.digraph_find_cycle(g)
95+
self.assertEqual(len(res), 2)
96+
self.assertTrue(res[0] == res[1][::-1])
97+
98+
def test_cycle_self_loop(self):
99+
g = rustworkx.generators.directed_path_graph(1000)
100+
a = g.add_node(1000)
101+
b = g.node_indices()[-1]
102+
g.add_edge(b, a, None)
103+
g.add_edge(a, a, None)
104+
res = rustworkx.digraph_find_cycle(g)
105+
self.assertEqual(res, [(a, a)])

0 commit comments

Comments
 (0)