Skip to content

Commit fdc17eb

Browse files
danielleodigieraynelfss
authored andcommitted
Adding Transitive Reduction Function (#923)
* Implementing filter_nodes and filter_edges funcs * Running fmt and clippy * Fixed issue where errors were not being propagated up to Python. Created tests for filter_edges and filter_nodes for both PyGraph and PyDiGraph. Created release notes for the functions. * Ran fmt, clippy, and tox * Fixing release notes * Fixing release notes again * Fixing release notes again again * Fixed release notes * Fixed release notes. Changed Vec allocation. Expanded on documentation. * ran cargo fmt and clippy * working on adding different parallel edge behavior * Fixing docs for filter functions * Working on graph_adjacency_matrix * Implementing changes to graph_adjacency_matrix and digraph_adjacency_matrix * working on release notes * Fixed release notes and docs * Ran cargo fmt * Ran cargo clippy * Fixed digraph_adjacency_matrix, passes tests * Removed mpl_draw from r elease notes * Changed if-else blocks in adjacency_matrix functions to match blocks. Wrote tests. * Fixed tests to pass lint * Added transitive reduction function to dag algo module * Fixed issue with graph that have nodes removed. Function now returns index_map for cases where there were nodes removed. Added tests. * Changing graph.nodes_removed to be false again. Return graph does not have removed nodes * Adding requested changes: - Fixing Docs - Fixing Maps to only have capacity of node_count - Fixing tests * Adding function to DAG algorithsm index
1 parent f6666dd commit fdc17eb

File tree

5 files changed

+206
-1
lines changed

5 files changed

+206
-1
lines changed

docs/source/api/algorithm_functions/dag_algorithms.rst

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,3 +12,4 @@ DAG Algorithms
1212
rustworkx.dag_weighted_longest_path_length
1313
rustworkx.is_directed_acyclic_graph
1414
rustworkx.layers
15+
rustworkx.transitive_reduction
Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,36 @@
1+
---
2+
features:
3+
- |
4+
Added a new function, :func:`~.transitive_reduction` which returns the transtive reduction
5+
of a given :class:`~rustworkx.PyDiGraph` and a dictionary with the mapping of indices from the given graph to the returned graph.
6+
The given graph must be a Directed Acyclic Graph (DAG).
7+
For example:
8+
9+
.. jupyter-execute::
10+
11+
from rustworkx import PyDiGraph
12+
from rustworkx import transitive_reduction
13+
14+
graph = PyDiGraph()
15+
a = graph.add_node("a")
16+
b = graph.add_node("b")
17+
c = graph.add_node("c")
18+
d = graph.add_node("d")
19+
e = graph.add_node("e")
20+
21+
graph.add_edges_from([
22+
(a, b, 1),
23+
(a, d, 1),
24+
(a, c, 1),
25+
(a, e, 1),
26+
(b, d, 1),
27+
(c, d, 1),
28+
(c, e, 1),
29+
(d, e, 1)
30+
])
31+
32+
tr, _ = transitive_reduction(graph)
33+
list(tr.edge_list())
34+
35+
Ref: https://en.wikipedia.org/wiki/Transitive_reduction
36+

src/dag_algo/mod.rs

Lines changed: 92 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,12 +12,17 @@
1212

1313
mod longest_path;
1414

15+
use super::DictMap;
1516
use hashbrown::{HashMap, HashSet};
17+
use indexmap::IndexSet;
18+
use rustworkx_core::dictmap::InitWithHasher;
1619
use std::cmp::Ordering;
1720
use std::collections::BinaryHeap;
1821

1922
use super::iterators::NodeIndices;
20-
use crate::{digraph, DAGHasCycle, InvalidNode};
23+
use crate::{digraph, DAGHasCycle, InvalidNode, StablePyGraph};
24+
25+
use rustworkx_core::traversal::dfs_edges;
2126

2227
use pyo3::exceptions::PyValueError;
2328
use pyo3::prelude::*;
@@ -637,3 +642,89 @@ pub fn collect_bicolor_runs(
637642

638643
Ok(block_list)
639644
}
645+
646+
/// Returns the transitive reduction of a directed acyclic graph
647+
///
648+
/// The transitive reduction of :math:`G = (V,E)` is a graph :math:`G\prime = (V,E\prime)`
649+
/// such that for all :math:`v` and :math:`w` in :math:`V` there is an edge :math:`(v, w)` in
650+
/// :math:`E\prime` if and only if :math:`(v, w)` is in :math:`E`
651+
/// and there is no path from :math:`v` to :math:`w` in :math:`G` with length greater than 1.
652+
///
653+
/// :param PyDiGraph graph: A directed acyclic graph
654+
///
655+
/// :returns: a directed acyclic graph representing the transitive reduction, and
656+
/// a map containing the index of a node in the original graph mapped to its
657+
/// equivalent in the resulting graph.
658+
/// :rtype: Tuple[PyGraph, dict]
659+
///
660+
/// :raises PyValueError: if ``graph`` is not a DAG
661+
662+
#[pyfunction]
663+
#[pyo3(text_signature = "(graph, /)")]
664+
pub fn transitive_reduction(
665+
graph: &digraph::PyDiGraph,
666+
py: Python,
667+
) -> PyResult<(digraph::PyDiGraph, DictMap<usize, usize>)> {
668+
let g = &graph.graph;
669+
let mut index_map = DictMap::with_capacity(g.node_count());
670+
if !is_directed_acyclic_graph(graph) {
671+
return Err(PyValueError::new_err(
672+
"Directed Acyclic Graph required for transitive_reduction",
673+
));
674+
}
675+
let mut tr = StablePyGraph::<Directed>::with_capacity(g.node_count(), 0);
676+
let mut descendants = DictMap::new();
677+
let mut check_count = HashMap::with_capacity(g.node_count());
678+
679+
for node in g.node_indices() {
680+
let i = node.index();
681+
index_map.insert(
682+
node,
683+
tr.add_node(graph.get_node_data(i).unwrap().clone_ref(py)),
684+
);
685+
check_count.insert(node, graph.in_degree(i));
686+
}
687+
688+
for u in g.node_indices() {
689+
let mut u_nbrs: IndexSet<NodeIndex> = g.neighbors(u).collect();
690+
for v in g.neighbors(u) {
691+
if u_nbrs.contains(&v) {
692+
if !descendants.contains_key(&v) {
693+
let dfs = dfs_edges(&g, Some(v));
694+
descendants.insert(v, dfs);
695+
}
696+
for desc in &descendants[&v] {
697+
u_nbrs.remove(&NodeIndex::new(desc.1));
698+
}
699+
}
700+
*check_count.get_mut(&v).unwrap() -= 1;
701+
if check_count[&v] == 0 {
702+
descendants.remove(&v);
703+
}
704+
}
705+
for v in u_nbrs {
706+
tr.add_edge(
707+
*index_map.get(&u).unwrap(),
708+
*index_map.get(&v).unwrap(),
709+
graph
710+
.get_edge_data(u.index(), v.index())
711+
.unwrap()
712+
.clone_ref(py),
713+
);
714+
}
715+
}
716+
return Ok((
717+
digraph::PyDiGraph {
718+
graph: tr,
719+
node_removed: false,
720+
multigraph: graph.multigraph,
721+
attrs: py.None(),
722+
cycle_state: algo::DfsSpace::default(),
723+
check_cycle: graph.check_cycle,
724+
},
725+
index_map
726+
.iter()
727+
.map(|(k, v)| (k.index(), v.index()))
728+
.collect::<DictMap<usize, usize>>(),
729+
));
730+
}

src/lib.rs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -359,6 +359,7 @@ fn rustworkx(py: Python<'_>, m: &PyModule) -> PyResult<()> {
359359
m.add_wrapped(wrap_pyfunction!(dag_longest_path_length))?;
360360
m.add_wrapped(wrap_pyfunction!(dag_weighted_longest_path))?;
361361
m.add_wrapped(wrap_pyfunction!(dag_weighted_longest_path_length))?;
362+
m.add_wrapped(wrap_pyfunction!(transitive_reduction))?;
362363
m.add_wrapped(wrap_pyfunction!(number_connected_components))?;
363364
m.add_wrapped(wrap_pyfunction!(connected_components))?;
364365
m.add_wrapped(wrap_pyfunction!(is_connected))?;
Lines changed: 76 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,76 @@
1+
# Licensed under the Apache License, Version 2.0 (the "License"); you may
2+
# not use this file except in compliance with the License. You may obtain
3+
# a copy of the License at
4+
#
5+
# http://www.apache.org/licenses/LICENSE-2.0
6+
#
7+
# Unless required by applicable law or agreed to in writing, software
8+
# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
9+
# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
10+
# License for the specific language governing permissions and limitations
11+
# under the License.
12+
13+
import unittest
14+
15+
import rustworkx
16+
17+
18+
class TestTransitiveReduction(unittest.TestCase):
19+
def test_tr1(self):
20+
graph = rustworkx.PyDiGraph()
21+
a = graph.add_node("a")
22+
b = graph.add_node("b")
23+
c = graph.add_node("c")
24+
d = graph.add_node("d")
25+
e = graph.add_node("e")
26+
graph.add_edges_from(
27+
[(a, b, 1), (a, d, 1), (a, c, 1), (a, e, 1), (b, d, 1), (c, d, 1), (c, e, 1), (d, e, 1)]
28+
)
29+
tr, _ = rustworkx.transitive_reduction(graph)
30+
self.assertCountEqual(list(tr.edge_list()), [(0, 2), (0, 1), (1, 3), (2, 3), (3, 4)])
31+
32+
def test_tr2(self):
33+
graph2 = rustworkx.PyDiGraph()
34+
a = graph2.add_node("a")
35+
b = graph2.add_node("b")
36+
c = graph2.add_node("c")
37+
graph2.add_edges_from(
38+
[
39+
(a, b, 1),
40+
(b, c, 1),
41+
(a, c, 1),
42+
]
43+
)
44+
tr2, _ = rustworkx.transitive_reduction(graph2)
45+
self.assertCountEqual(list(tr2.edge_list()), [(0, 1), (1, 2)])
46+
47+
def test_tr3(self):
48+
graph3 = rustworkx.PyDiGraph()
49+
graph3.add_nodes_from([0, 1, 2, 3])
50+
graph3.add_edges_from([(0, 1, 1), (0, 2, 1), (0, 3, 1), (1, 2, 1), (1, 3, 1)])
51+
tr3, _ = rustworkx.transitive_reduction(graph3)
52+
self.assertCountEqual(list(tr3.edge_list()), [(0, 1), (1, 2), (1, 3)])
53+
54+
def test_tr_with_deletion(self):
55+
graph = rustworkx.PyDiGraph()
56+
a = graph.add_node("a")
57+
b = graph.add_node("b")
58+
c = graph.add_node("c")
59+
d = graph.add_node("d")
60+
e = graph.add_node("e")
61+
62+
graph.add_edges_from(
63+
[(a, b, 1), (a, d, 1), (a, c, 1), (a, e, 1), (b, d, 1), (c, d, 1), (c, e, 1), (d, e, 1)]
64+
)
65+
66+
graph.remove_node(3)
67+
68+
tr, index_map = rustworkx.transitive_reduction(graph)
69+
70+
self.assertCountEqual(list(tr.edge_list()), [(0, 1), (0, 2), (2, 3)])
71+
self.assertEqual(index_map[4], 3)
72+
73+
def test_tr_error(self):
74+
digraph = rustworkx.generators.directed_cycle_graph(1000)
75+
with self.assertRaises(ValueError):
76+
rustworkx.transitive_reduction(digraph)

0 commit comments

Comments
 (0)