Skip to content

Commit 62181fb

Browse files
Implement condensation graph generation (#1337)
* Implement condensation tentatively * Update test_strongly_connected.py * Update src/connectivity/mod.rs thanks @IvanIsCoding Co-authored-by: Ivan Carvalho <[email protected]> * Update rustworkx/rustworkx.pyi thanks @IvanIsCoding Co-authored-by: Ivan Carvalho <[email protected]> * Update mod.rs Update mod.rs * Update mod.rs Update mod.rs Update mod.rs Update mod.rs * Update test_strongly_connected.py * Create condensation-undirected-support-apr2025.yaml * Use pyobject and bound Update mod.rs ok ok wip Update mod.rs Reformat Reformat * Update rustworkx.pyi * Update test_strongly_connected.py * Update test_strongly_connected.py * Replace MAX with None Co-authored-by: Ivan Carvalho <[email protected]> * Separate digraph and graph function Update lib.rs * Add stub definition Reformat * Update __init__.py * Replace MAX with None (2) Reformat Update __init__.py * Reformat --------- Co-authored-by: Ivan Carvalho <[email protected]>
1 parent 30f2907 commit 62181fb

File tree

7 files changed

+264
-4
lines changed

7 files changed

+264
-4
lines changed
Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,4 @@
1+
---
2+
features:
3+
- |
4+
Added a new condensation() function that works for both directed and undirected graphs. For directed graphs, it returns the condensation (quotient graph) where each node is a strongly connected component (SCC). For undirected graphs, each node is a connected component. The returned graph has a 'node_map' attribute mapping each original node index to the index of the condensed node it belongs to.

rustworkx/__init__.py

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2224,3 +2224,26 @@ def all_shortest_paths(
22242224
is provided.
22252225
"""
22262226
raise TypeError(f"Invalid Input Type {type(graph)} for graph")
2227+
2228+
2229+
@_rustworkx_dispatch
2230+
def condensation(graph, /, sccs=None):
2231+
"""Return the condensation of a directed or undirected graph
2232+
The condensation of a directed graph is a directed acyclic graph (DAG) in which
2233+
each node represents a strongly connected component (SCC) of the original graph.
2234+
The edges of the DAG represent the connections between these components.
2235+
The condensation of an undirected graph is a directed graph in which each node
2236+
represents a connected component of the original graph. The edges of the DAG
2237+
represent the connections between these components.
2238+
2239+
The condensation is computed using Tarjan's algorithm.
2240+
2241+
:param graph: The input graph to condense. This can be a
2242+
:class:`~rustworkx.PyGraph` or :class:`~rustworkx.PyDiGraph`.
2243+
:param sccs: An optional list of strongly connected components (SCCs) to use.
2244+
If not specified, the function will compute the SCCs internally.
2245+
If the input graph is undirected, this parameter is ignored.
2246+
:returns: A PyGraph or PyDiGraph object representing the condensation of the input graph.
2247+
:rtype: PyGraph or PyDiGraph
2248+
"""
2249+
raise TypeError(f"Invalid Input Type {type(graph)} for graph")

rustworkx/__init__.pyi

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -86,6 +86,8 @@ from .rustworkx import number_strongly_connected_components as number_strongly_c
8686
from .rustworkx import number_weakly_connected_components as number_weakly_connected_components
8787
from .rustworkx import node_connected_component as node_connected_component
8888
from .rustworkx import strongly_connected_components as strongly_connected_components
89+
from .rustworkx import digraph_condensation as digraph_condensation
90+
from .rustworkx import graph_condensation as graph_condensation
8991
from .rustworkx import weakly_connected_components as weakly_connected_components
9092
from .rustworkx import digraph_adjacency_matrix as digraph_adjacency_matrix
9193
from .rustworkx import graph_adjacency_matrix as graph_adjacency_matrix
@@ -644,3 +646,6 @@ def longest_simple_path(graph: PyGraph[_S, _T] | PyDiGraph[_S, _T]) -> NodeIndic
644646
def isolates(graph: PyGraph[_S, _T] | PyDiGraph[_S, _T]) -> NodeIndices: ...
645647
def two_color(graph: PyGraph[_S, _T] | PyDiGraph[_S, _T]) -> dict[int, int]: ...
646648
def is_bipartite(graph: PyGraph[_S, _T] | PyDiGraph[_S, _T]) -> bool: ...
649+
def condensation(
650+
graph: PyDiGraph | PyGraph, /, sccs: list[int] | None = ...
651+
) -> PyDiGraph | PyGraph: ...

rustworkx/rustworkx.pyi

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -219,6 +219,8 @@ def number_strongly_connected_components(graph: PyDiGraph, /) -> int: ...
219219
def number_weakly_connected_components(graph: PyDiGraph, /) -> int: ...
220220
def node_connected_component(graph: PyGraph, node: int, /) -> set[int]: ...
221221
def strongly_connected_components(graph: PyDiGraph, /) -> list[list[int]]: ...
222+
def digraph_condensation(graph: PyDiGraph, /, sccs: list[int] | None = ...) -> PyDiGraph: ...
223+
def graph_condensation(graph: PyDiGraph, /) -> PyGraph: ...
222224
def weakly_connected_components(graph: PyDiGraph, /) -> list[set[int]]: ...
223225
def digraph_adjacency_matrix(
224226
graph: PyDiGraph[_S, _T],

src/connectivity/mod.rs

Lines changed: 153 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -22,20 +22,22 @@ use super::{
2222

2323
use hashbrown::{HashMap, HashSet};
2424
use indexmap::IndexSet;
25-
use petgraph::algo;
26-
use petgraph::algo::condensation;
27-
use petgraph::graph::DiGraph;
25+
use petgraph::graph::{DiGraph, IndexType};
2826
use petgraph::stable_graph::NodeIndex;
2927
use petgraph::unionfind::UnionFind;
3028
use petgraph::visit::{EdgeRef, IntoEdgeReferences, NodeCount, NodeIndexable, Visitable};
29+
use petgraph::{algo, Graph};
3130
use pyo3::exceptions::PyValueError;
3231
use pyo3::prelude::*;
3332
use pyo3::types::PyDict;
33+
use pyo3::BoundObject;
34+
use pyo3::IntoPyObject;
3435
use pyo3::Python;
3536
use rayon::prelude::*;
3637

3738
use ndarray::prelude::*;
3839
use numpy::{IntoPyArray, PyArray2};
40+
use petgraph::prelude::StableGraph;
3941

4042
use crate::iterators::{
4143
AllPairsMultiplePathMapping, BiconnectedComponents, Chains, EdgeList, NodeIndices,
@@ -192,6 +194,153 @@ pub fn is_strongly_connected(graph: &digraph::PyDiGraph) -> PyResult<bool> {
192194
Ok(algo::kosaraju_scc(&graph.graph).len() == 1)
193195
}
194196

197+
/// Compute the condensation of a graph (directed or undirected).
198+
///
199+
/// For directed graphs, this returns the condensation (quotient graph) where each node
200+
/// represents a strongly connected component (SCC) of the input graph. For undirected graphs,
201+
/// each node represents a connected component.
202+
///
203+
/// The returned graph has a node attribute 'node_map' which is a list mapping each original
204+
/// node index to the index of the condensed node it belongs to.
205+
///
206+
/// :param graph: The input graph (PyDiGraph or PyGraph)
207+
/// :param sccs: (Optional, directed only) List of SCCs to use instead of computing them
208+
/// :returns: The condensed graph (PyDiGraph or PyGraph) with a 'node_map' attribute
209+
/// :rtype: PyDiGraph or PyGraph
210+
fn condensation_inner<'py, N, E, Ty, Ix>(
211+
py: Python<'py>,
212+
g: Graph<N, E, Ty, Ix>,
213+
make_acyclic: bool,
214+
sccs: Option<Vec<Vec<usize>>>,
215+
) -> PyResult<(StablePyGraph<Ty>, Vec<Option<usize>>)>
216+
where
217+
Ty: EdgeType,
218+
Ix: IndexType,
219+
N: IntoPyObject<'py, Target = PyAny> + Clone,
220+
E: IntoPyObject<'py, Target = PyAny> + Clone,
221+
{
222+
// For directed graphs, use SCCs; for undirected, use connected components
223+
let components: Vec<Vec<NodeIndex<Ix>>> = if Ty::is_directed() {
224+
if let Some(sccs) = sccs {
225+
sccs.into_iter()
226+
.map(|row| row.into_iter().map(NodeIndex::new).collect())
227+
.collect()
228+
} else {
229+
algo::kosaraju_scc(&g)
230+
}
231+
} else {
232+
connectivity::connected_components(&g)
233+
.into_iter()
234+
.map(|set| set.into_iter().collect())
235+
.collect()
236+
};
237+
238+
// Convert all NodeIndex<Ix> to NodeIndex<usize> for the output graph
239+
let components_usize: Vec<Vec<NodeIndex<usize>>> = components
240+
.iter()
241+
.map(|comp| comp.iter().map(|ix| NodeIndex::new(ix.index())).collect())
242+
.collect();
243+
244+
let mut condensed: StableGraph<Vec<N>, E, Ty, u32> =
245+
StableGraph::with_capacity(components_usize.len(), g.edge_count());
246+
247+
// Build a map from old indices to new ones.
248+
let mut node_map = vec![None; g.node_count()];
249+
for comp in components_usize.iter() {
250+
let new_nix = condensed.add_node(Vec::new());
251+
for nix in comp {
252+
node_map[nix.index()] = Some(new_nix.index());
253+
}
254+
}
255+
256+
// Consume nodes and edges of the old graph and insert them into the new one.
257+
let (nodes, edges) = g.into_nodes_edges();
258+
for (nix, node) in nodes.into_iter().enumerate() {
259+
if let Some(Some(idx)) = node_map.get(nix).copied() {
260+
condensed[NodeIndex::new(idx)].push(node.weight);
261+
}
262+
}
263+
for edge in edges {
264+
let (source, target) = match (
265+
node_map.get(edge.source().index()),
266+
node_map.get(edge.target().index()),
267+
) {
268+
(Some(Some(s)), Some(Some(t))) => (NodeIndex::new(*s), NodeIndex::new(*t)),
269+
_ => continue,
270+
};
271+
272+
if make_acyclic && Ty::is_directed() {
273+
if source != target {
274+
condensed.update_edge(source, target, edge.weight);
275+
}
276+
} else {
277+
condensed.add_edge(source, target, edge.weight);
278+
}
279+
}
280+
281+
let mapped = condensed.map(
282+
|_, w| match w.clone().into_pyobject(py) {
283+
Ok(bound) => bound.unbind(),
284+
Err(_) => PyValueError::new_err("Node conversion failed")
285+
.into_pyobject(py)
286+
.unwrap()
287+
.unbind()
288+
.into(),
289+
},
290+
|_, w| match w.clone().into_pyobject(py) {
291+
Ok(bound) => bound.unbind(),
292+
Err(_) => PyValueError::new_err("Edge conversion failed")
293+
.into_pyobject(py)
294+
.unwrap()
295+
.unbind()
296+
.into(),
297+
},
298+
);
299+
Ok((mapped, node_map))
300+
}
301+
302+
#[pyfunction]
303+
#[pyo3(text_signature = "(graph, /, sccs=None)", signature=(graph, sccs=None))]
304+
pub fn digraph_condensation(
305+
py: Python,
306+
graph: digraph::PyDiGraph,
307+
sccs: Option<Vec<Vec<usize>>>,
308+
) -> PyResult<digraph::PyDiGraph> {
309+
let g = graph.graph.clone();
310+
let (condensed, node_map) = condensation_inner(py, g.into(), true, sccs)?;
311+
312+
let mut attrs = HashMap::new();
313+
attrs.insert("node_map", node_map.clone());
314+
315+
let result = digraph::PyDiGraph {
316+
graph: condensed,
317+
cycle_state: algo::DfsSpace::default(),
318+
check_cycle: false,
319+
node_removed: false,
320+
multigraph: true,
321+
attrs: attrs.into_pyobject(py)?.into(),
322+
};
323+
Ok(result)
324+
}
325+
326+
#[pyfunction]
327+
#[pyo3(text_signature = "(graph, /)")]
328+
pub fn graph_condensation(py: Python, graph: graph::PyGraph) -> PyResult<graph::PyGraph> {
329+
let g = graph.graph.clone();
330+
let (condensed, node_map) = condensation_inner(py, g.into(), false, None)?;
331+
332+
let mut attrs = HashMap::new();
333+
attrs.insert("node_map", node_map.clone());
334+
335+
let result = graph::PyGraph {
336+
graph: condensed,
337+
node_removed: false,
338+
multigraph: graph.multigraph,
339+
attrs: attrs.into_pyobject(py)?.into(),
340+
};
341+
Ok(result)
342+
}
343+
195344
/// Return the first cycle encountered during DFS of a given PyDiGraph,
196345
/// empty list is returned if no cycle is found
197346
///
@@ -480,7 +629,7 @@ pub fn is_semi_connected(graph: &digraph::PyDiGraph) -> PyResult<bool> {
480629
temp_graph.add_edge(node_map[source.index()], node_map[target.index()], ());
481630
}
482631

483-
let condensed = condensation(temp_graph, true);
632+
let condensed = algo::condensation(temp_graph, true);
484633
let n = condensed.node_count();
485634
let weight_fn =
486635
|_: petgraph::graph::EdgeReference<()>| Ok::<usize, std::convert::Infallible>(1usize);

src/lib.rs

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -615,6 +615,8 @@ fn rustworkx(py: Python<'_>, m: &Bound<PyModule>) -> PyResult<()> {
615615
m.add_wrapped(wrap_pyfunction!(number_strongly_connected_components))?;
616616
m.add_wrapped(wrap_pyfunction!(strongly_connected_components))?;
617617
m.add_wrapped(wrap_pyfunction!(is_strongly_connected))?;
618+
m.add_wrapped(wrap_pyfunction!(digraph_condensation))?;
619+
m.add_wrapped(wrap_pyfunction!(graph_condensation))?;
618620
m.add_wrapped(wrap_pyfunction!(digraph_dfs_edges))?;
619621
m.add_wrapped(wrap_pyfunction!(graph_dfs_edges))?;
620622
m.add_wrapped(wrap_pyfunction!(digraph_find_cycle))?;

tests/digraph/test_strongly_connected.py

Lines changed: 75 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -100,3 +100,78 @@ def test_is_strongly_connected_null_graph(self):
100100
graph = rustworkx.PyDiGraph()
101101
with self.assertRaises(rustworkx.NullGraph):
102102
rustworkx.is_strongly_connected(graph)
103+
104+
105+
class TestCondensation(unittest.TestCase):
106+
def setUp(self):
107+
# Set up the graph
108+
self.graph = rustworkx.PyDiGraph()
109+
self.node_a = self.graph.add_node("a")
110+
self.node_b = self.graph.add_node("b")
111+
self.node_c = self.graph.add_node("c")
112+
self.node_d = self.graph.add_node("d")
113+
self.node_e = self.graph.add_node("e")
114+
self.node_f = self.graph.add_node("f")
115+
self.node_g = self.graph.add_node("g")
116+
self.node_h = self.graph.add_node("h")
117+
118+
# Add edges
119+
self.graph.add_edge(self.node_a, self.node_b, "a->b")
120+
self.graph.add_edge(self.node_b, self.node_c, "b->c")
121+
self.graph.add_edge(self.node_c, self.node_d, "c->d")
122+
self.graph.add_edge(self.node_d, self.node_a, "d->a") # Cycle: a -> b -> c -> d -> a
123+
124+
self.graph.add_edge(self.node_b, self.node_e, "b->e")
125+
126+
self.graph.add_edge(self.node_e, self.node_f, "e->f")
127+
self.graph.add_edge(self.node_f, self.node_g, "f->g")
128+
self.graph.add_edge(self.node_g, self.node_h, "g->h")
129+
self.graph.add_edge(self.node_h, self.node_e, "h->e") # Cycle: e -> f -> g -> h -> e
130+
131+
def test_condensation(self):
132+
# Call the condensation function
133+
condensed_graph = rustworkx.condensation(self.graph)
134+
135+
# Check the number of nodes (two cycles should be condensed into one node each)
136+
self.assertEqual(
137+
len(condensed_graph.node_indices()), 2
138+
) # [SCC(a, b, c, d), SCC(e, f, g, h)]
139+
140+
# Check the number of edges
141+
self.assertEqual(
142+
len(condensed_graph.edge_indices()), 1
143+
) # Edge: [SCC(a, b, c, d)] -> [SCC(e, f, g, h)]
144+
145+
# Check the contents of the condensed nodes
146+
nodes = list(condensed_graph.nodes())
147+
scc1 = nodes[0]
148+
scc2 = nodes[1]
149+
self.assertTrue(set(scc1) == {"a", "b", "c", "d"} or set(scc2) == {"a", "b", "c", "d"})
150+
self.assertTrue(set(scc1) == {"e", "f", "g", "h"} or set(scc2) == {"e", "f", "g", "h"})
151+
152+
# Check the contents of the edge
153+
weight = condensed_graph.edges()[0]
154+
self.assertIn("b->e", weight) # Ensure the correct edge remains in the condensed graph
155+
156+
def test_condensation_with_sccs_argument(self):
157+
# Compute SCCs manually
158+
sccs = rustworkx.strongly_connected_components(self.graph)
159+
# Call condensation with explicit sccs argument
160+
condensed_graph = rustworkx.condensation(self.graph, sccs=sccs)
161+
condensed_graph.attrs["node_map"]
162+
163+
# Check the number of nodes (should match SCC count)
164+
self.assertEqual(len(condensed_graph.node_indices()), len(sccs))
165+
166+
# Check the number of edges
167+
self.assertEqual(len(condensed_graph.edge_indices()), 1)
168+
169+
# Check the contents of the condensed nodes
170+
nodes = list(condensed_graph.nodes())
171+
scc_sets = [set(n) for n in nodes]
172+
self.assertIn(set(["a", "b", "c", "d"]), scc_sets)
173+
self.assertIn(set(["e", "f", "g", "h"]), scc_sets)
174+
175+
# Check the contents of the edge
176+
weight = condensed_graph.edges()[0]
177+
self.assertIn("b->e", weight)

0 commit comments

Comments
 (0)