Skip to content

Commit 5829af6

Browse files
raynelfssmtreinish
andauthored
Add layers function to rustworkx-core (#1194)
* Initial: Add `layers` function to `rustworkx-core` - Modify the `layers` python interface to use the `rustworkx-core` equivalent. * Docs: Add proper docstring to `layers` * Fix: Wrong import in docstring * Fix: Return an `Iterator` instance from `layers` in `rustworkx-core`. * Test: Add tests to `layers` - Move `layers` to `dag_algo.rs`. - Add check for cycles, if a cycle is found throw an error. - Refactor `LayersIndexError` to `LayersError`. - Move `LayersError` to `err.rs`. - Other small tweaks and fixes. * Format: Fix lint. * Docs: Fix docs test. * Docs: Add release note * Docs: Fix release note * Fix: Return NodeId instead of usize * Docs: Add suggestions for release note. * Fix: Return true Iterator for `layers` - Use panic exceptions for specific cases. - Other tweaks and fixes. * Fix: Node check only in the first layer * Fix: Remove result handling for layers - Use `panic!` when a programming error is made. - Verify cycles by checkng repeating layers, call `panic!` if one is found. - Adapt python side function to use check for nodes to avoid panic. - Adapt tests. * Remove: `LayersError` as it will no longer be needed. - Small fix in docstring. * Fix: Revert result handling in `layers` - Add result handling in the python version of the function. - Use indices to keep track of cycles. - Revert deletion of `LayersError`. - Update tests. - Other tweaks and fixes. * Docs: Fix release note and docstring - Fix docstring test and regular test. - Add extra check for missing nodes. * Fix: Explicit warning for invalid first index - Remove calls to `to_owned()`. --------- Co-authored-by: Matthew Treinish <mtreinish@kortar.org>
1 parent ed3cb8f commit 5829af6

File tree

4 files changed

+312
-87
lines changed

4 files changed

+312
-87
lines changed
Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,11 @@
1+
---
2+
features:
3+
- |
4+
Added a function ``~rustworkx_core.dag_algo::layers`` in rustworkx-core to
5+
get the layers of a directed acyclic graph. This is equivalent to the
6+
:func:`.layers` function that existed in the Python API but now exposes it
7+
for Rust users too.
8+
fix:
9+
- |
10+
When calling ``~rustworkx_core.dag_algo::layers``if the provided graph has
11+
a cycle, the function will throw a ``DAGHasCycle`` error instance.

rustworkx-core/src/dag_algo.rs

Lines changed: 249 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,17 +17,21 @@ use std::fmt::{Display, Formatter};
1717
use std::hash::Hash;
1818

1919
use hashbrown::{HashMap, HashSet};
20+
use std::fmt::Debug;
21+
use std::mem::swap;
2022

2123
use petgraph::algo;
2224
use petgraph::data::DataMap;
2325
use petgraph::visit::{
2426
EdgeRef, GraphBase, GraphProp, IntoEdgesDirected, IntoNeighborsDirected, IntoNodeIdentifiers,
25-
NodeCount, Visitable,
27+
NodeCount, NodeIndexable, Visitable,
2628
};
2729
use petgraph::Directed;
2830

2931
use num_traits::{Num, Zero};
3032

33+
use crate::err::LayersError;
34+
3135
/// Return a pair of [`petgraph::Direction`] values corresponding to the "forwards" and "backwards"
3236
/// direction of graph traversal, based on whether the graph is being traved forwards (following
3337
/// the edges) or backward (reversing along edges). The order of returns is (forwards, backwards).
@@ -333,6 +337,155 @@ where
333337
Ok(Some((path, path_weight)))
334338
}
335339

340+
/// Return an iterator of graph layers
341+
///
342+
/// A layer is a subgraph whose nodes are disjoint, i.e.,
343+
/// a layer has depth 1. The layers are constructed using a greedy algorithm.
344+
///
345+
/// Arguments:
346+
///
347+
/// * `graph` - The graph to get the layers from
348+
/// * `first_layer` - A list of node ids for the first layer. This
349+
/// will be the first layer in the output
350+
///
351+
/// Will `panic!` if a provided node is not in the graph.
352+
/// ```
353+
/// use rustworkx_core::petgraph::prelude::*;
354+
/// use rustworkx_core::dag_algo::layers;
355+
/// use rustworkx_core::dictmap::*;
356+
///
357+
/// let edge_list = vec![
358+
/// (0, 1),
359+
/// (1, 2),
360+
/// (2, 3),
361+
/// (3, 4),
362+
/// ];
363+
///
364+
/// let graph = DiGraph::<u32, u32>::from_edges(&edge_list);
365+
/// let layers: Vec<Vec<NodeIndex>> = layers(&graph, vec![0.into(),]).map(|layer| layer.unwrap()).collect();
366+
/// let expected_layers: Vec<Vec<NodeIndex>> = vec![
367+
/// vec![0.into(),],
368+
/// vec![1.into(),],
369+
/// vec![2.into(),],
370+
/// vec![3.into(),],
371+
/// vec![4.into()]
372+
/// ];
373+
/// assert_eq!(layers, expected_layers)
374+
/// ```
375+
pub fn layers<G>(
376+
graph: G,
377+
first_layer: Vec<G::NodeId>,
378+
) -> impl Iterator<Item = Result<Vec<G::NodeId>, LayersError>>
379+
where
380+
G: NodeIndexable // Used in from_index and to_index.
381+
+ IntoNodeIdentifiers // Used for .node_identifiers
382+
+ IntoNeighborsDirected // Used for .neighbors_directed
383+
+ IntoEdgesDirected, // Used for .edged_directed
384+
<G as GraphBase>::NodeId: Debug + Copy + Eq + Hash,
385+
{
386+
LayersIter {
387+
graph,
388+
cur_layer: first_layer,
389+
next_layer: vec![],
390+
predecessor_count: HashMap::new(),
391+
first_iter: true,
392+
cycle_check: HashSet::default(),
393+
}
394+
}
395+
396+
#[derive(Debug, Clone)]
397+
struct LayersIter<G, N> {
398+
graph: G,
399+
cur_layer: Vec<N>,
400+
next_layer: Vec<N>,
401+
predecessor_count: HashMap<N, usize>,
402+
first_iter: bool,
403+
cycle_check: HashSet<N>, // TODO: Figure out why some cycles cannot be detected
404+
}
405+
406+
impl<G, N> Iterator for LayersIter<G, N>
407+
where
408+
G: NodeIndexable // Used in from_index and to_index.
409+
+ IntoNodeIdentifiers // Used for .node_identifiers
410+
+ IntoNeighborsDirected // Used for .neighbors_directed
411+
+ IntoEdgesDirected // Used for .edged_directed
412+
+ GraphBase<NodeId = N>,
413+
N: Debug + Copy + Eq + Hash,
414+
{
415+
type Item = Result<Vec<N>, LayersError>;
416+
fn next(&mut self) -> Option<Self::Item> {
417+
if self.first_iter {
418+
self.first_iter = false;
419+
for node in &self.cur_layer {
420+
if self.graph.to_index(*node) >= self.graph.node_bound() {
421+
panic!("Node {:#?} is not present in the graph.", node);
422+
}
423+
if self.cycle_check.contains(node) {
424+
return Some(Err(LayersError(format!(
425+
"An invalid first layer was provided: {:#?} appears more than once.",
426+
node
427+
))));
428+
}
429+
self.cycle_check.insert(*node);
430+
}
431+
Some(Ok(self.cur_layer.clone()))
432+
} else if self.cur_layer.is_empty() {
433+
None
434+
} else {
435+
for node in &self.cur_layer {
436+
if self.graph.to_index(*node) >= self.graph.node_bound() {
437+
panic!("Node {:#?} is not present in the graph.", node);
438+
}
439+
let children = self
440+
.graph
441+
.neighbors_directed(*node, petgraph::Direction::Outgoing);
442+
let mut used_indices: HashSet<G::NodeId> = HashSet::new();
443+
for succ in children {
444+
// Skip duplicate successors
445+
if used_indices.contains(&succ) {
446+
continue;
447+
}
448+
used_indices.insert(succ);
449+
let mut multiplicity: usize = 0;
450+
let raw_edges: G::EdgesDirected = self
451+
.graph
452+
.edges_directed(*node, petgraph::Direction::Outgoing);
453+
for edge in raw_edges {
454+
if edge.target() == succ {
455+
multiplicity += 1;
456+
}
457+
}
458+
self.predecessor_count
459+
.entry(succ)
460+
.and_modify(|e| *e -= multiplicity)
461+
.or_insert(
462+
// Get the number of incoming edges to the successor
463+
self.graph
464+
.edges_directed(succ, petgraph::Direction::Incoming)
465+
.count()
466+
- multiplicity,
467+
);
468+
if *self.predecessor_count.get(&succ).unwrap() == 0 {
469+
if self.cycle_check.contains(&succ) {
470+
return Some(Err(LayersError("The provided graph contains a cycle or an invalid first layer was provided.".to_string())));
471+
}
472+
self.next_layer.push(succ);
473+
self.cycle_check.insert(succ);
474+
self.predecessor_count.remove(&succ);
475+
}
476+
}
477+
}
478+
swap(&mut self.cur_layer, &mut self.next_layer);
479+
self.next_layer.clear();
480+
if self.cur_layer.is_empty() {
481+
None
482+
} else {
483+
Some(Ok(self.cur_layer.clone()))
484+
}
485+
}
486+
}
487+
}
488+
336489
/// Collect runs that match a filter function given edge colors.
337490
///
338491
/// A bicolor run is a list of groups of nodes connected by edges of exactly
@@ -931,6 +1084,101 @@ mod test_lexicographical_topological_sort {
9311084
}
9321085
}
9331086

1087+
#[cfg(test)]
1088+
mod test_layers {
1089+
use super::*;
1090+
use petgraph::{
1091+
graph::{DiGraph, NodeIndex},
1092+
stable_graph::StableDiGraph,
1093+
};
1094+
1095+
#[test]
1096+
fn test_empty_graph() {
1097+
let graph: DiGraph<(), ()> = DiGraph::new();
1098+
let result: Vec<Vec<NodeIndex>> = layers(&graph, vec![]).flatten().collect();
1099+
assert_eq!(result, vec![vec![]]);
1100+
}
1101+
1102+
#[test]
1103+
fn test_empty_stable_graph() {
1104+
let graph: StableDiGraph<(), ()> = StableDiGraph::new();
1105+
let result: Vec<Vec<NodeIndex>> = layers(&graph, vec![]).flatten().collect();
1106+
assert_eq!(result, vec![vec![]]);
1107+
}
1108+
1109+
#[test]
1110+
fn test_simple_layer() {
1111+
let mut graph: DiGraph<String, ()> = DiGraph::new();
1112+
let mut nodes: Vec<NodeIndex> = Vec::new();
1113+
nodes.push(graph.add_node("a".to_string()));
1114+
for i in 0..5 {
1115+
nodes.push(graph.add_node(i.to_string()));
1116+
}
1117+
nodes.push(graph.add_node("A parent".to_string()));
1118+
for (source, target) in [(0, 1), (0, 2), (0, 3), (0, 4), (0, 5), (6, 3)] {
1119+
graph.add_edge(nodes[source], nodes[target], ());
1120+
}
1121+
let expected: Vec<Vec<NodeIndex>> = vec![
1122+
vec![0.into(), 6.into()],
1123+
vec![5.into(), 4.into(), 2.into(), 1.into(), 3.into()],
1124+
];
1125+
let result: Vec<Vec<NodeIndex>> =
1126+
layers(&graph, vec![0.into(), 6.into()]).flatten().collect();
1127+
assert_eq!(result, expected);
1128+
}
1129+
1130+
#[test]
1131+
#[should_panic]
1132+
fn test_missing_node() {
1133+
let edge_list = vec![(0, 1), (1, 2), (2, 3), (3, 4)];
1134+
let graph = DiGraph::<u32, u32>::from_edges(&edge_list);
1135+
layers(&graph, vec![4.into(), 5.into()]).for_each(|layer| match layer {
1136+
Err(e) => panic!("{}", e.0),
1137+
Ok(layer) => drop(layer),
1138+
});
1139+
}
1140+
1141+
#[test]
1142+
fn test_dag_with_multiple_paths() {
1143+
let mut graph: DiGraph<(), ()> = DiGraph::new();
1144+
let n0 = graph.add_node(());
1145+
let n1 = graph.add_node(());
1146+
let n2 = graph.add_node(());
1147+
let n3 = graph.add_node(());
1148+
let n4 = graph.add_node(());
1149+
let n5 = graph.add_node(());
1150+
graph.add_edge(n0, n1, ());
1151+
graph.add_edge(n0, n2, ());
1152+
graph.add_edge(n1, n2, ());
1153+
graph.add_edge(n1, n3, ());
1154+
graph.add_edge(n2, n3, ());
1155+
graph.add_edge(n3, n4, ());
1156+
graph.add_edge(n2, n5, ());
1157+
graph.add_edge(n4, n5, ());
1158+
1159+
let result: Vec<Vec<NodeIndex>> = layers(&graph, vec![0.into()]).flatten().collect();
1160+
assert_eq!(
1161+
result,
1162+
vec![vec![n0], vec![n1], vec![n2], vec![n3], vec![n4], vec![n5]]
1163+
);
1164+
}
1165+
1166+
#[test]
1167+
#[should_panic]
1168+
fn test_graph_with_cycle() {
1169+
let mut graph: DiGraph<(), i32> = DiGraph::new();
1170+
let n0 = graph.add_node(());
1171+
let n1 = graph.add_node(());
1172+
graph.add_edge(n0, n1, 1);
1173+
graph.add_edge(n1, n0, 1);
1174+
1175+
layers(&graph, vec![0.into()]).for_each(|layer| match layer {
1176+
Err(e) => panic!("{}", e.0),
1177+
Ok(layer) => drop(layer),
1178+
});
1179+
}
1180+
}
1181+
9341182
// Tests for collect_bicolor_runs
9351183
#[cfg(test)]
9361184
mod test_collect_bicolor_runs {

rustworkx-core/src/err.rs

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -54,3 +54,15 @@ fn fmt_dag_would_cycle(f: &mut Formatter<'_>) -> std::fmt::Result {
5454
fn fmt_merge_error<E: Error>(f: &mut Formatter<'_>, inner: &E) -> std::fmt::Result {
5555
write!(f, "The merge callback failed with: {:?}", inner)
5656
}
57+
58+
/// Error returned by Layers function when an index is not part of the graph.
59+
#[derive(Debug, PartialEq, Eq)]
60+
pub struct LayersError(pub String);
61+
62+
impl Error for LayersError {}
63+
64+
impl Display for LayersError {
65+
fn fmt(&self, f: &mut Formatter) -> std::fmt::Result {
66+
write!(f, "{}", self.0)
67+
}
68+
}

0 commit comments

Comments
 (0)