@@ -22,20 +22,22 @@ use super::{
2222
2323use hashbrown:: { HashMap , HashSet } ;
2424use indexmap:: IndexSet ;
25- use petgraph:: algo;
26- use petgraph:: algo:: condensation;
27- use petgraph:: graph:: DiGraph ;
25+ use petgraph:: graph:: { DiGraph , IndexType } ;
2826use petgraph:: stable_graph:: NodeIndex ;
2927use petgraph:: unionfind:: UnionFind ;
3028use petgraph:: visit:: { EdgeRef , IntoEdgeReferences , NodeCount , NodeIndexable , Visitable } ;
29+ use petgraph:: { algo, Graph } ;
3130use pyo3:: exceptions:: PyValueError ;
3231use pyo3:: prelude:: * ;
3332use pyo3:: types:: PyDict ;
33+ use pyo3:: BoundObject ;
34+ use pyo3:: IntoPyObject ;
3435use pyo3:: Python ;
3536use rayon:: prelude:: * ;
3637
3738use ndarray:: prelude:: * ;
3839use numpy:: { IntoPyArray , PyArray2 } ;
40+ use petgraph:: prelude:: StableGraph ;
3941
4042use crate :: iterators:: {
4143 AllPairsMultiplePathMapping , BiconnectedComponents , Chains , EdgeList , NodeIndices ,
@@ -192,6 +194,153 @@ pub fn is_strongly_connected(graph: &digraph::PyDiGraph) -> PyResult<bool> {
192194 Ok ( algo:: kosaraju_scc ( & graph. graph ) . len ( ) == 1 )
193195}
194196
197+ /// Compute the condensation of a graph (directed or undirected).
198+ ///
199+ /// For directed graphs, this returns the condensation (quotient graph) where each node
200+ /// represents a strongly connected component (SCC) of the input graph. For undirected graphs,
201+ /// each node represents a connected component.
202+ ///
203+ /// The returned graph has a node attribute 'node_map' which is a list mapping each original
204+ /// node index to the index of the condensed node it belongs to.
205+ ///
206+ /// :param graph: The input graph (PyDiGraph or PyGraph)
207+ /// :param sccs: (Optional, directed only) List of SCCs to use instead of computing them
208+ /// :returns: The condensed graph (PyDiGraph or PyGraph) with a 'node_map' attribute
209+ /// :rtype: PyDiGraph or PyGraph
210+ fn condensation_inner < ' py , N , E , Ty , Ix > (
211+ py : Python < ' py > ,
212+ g : Graph < N , E , Ty , Ix > ,
213+ make_acyclic : bool ,
214+ sccs : Option < Vec < Vec < usize > > > ,
215+ ) -> PyResult < ( StablePyGraph < Ty > , Vec < Option < usize > > ) >
216+ where
217+ Ty : EdgeType ,
218+ Ix : IndexType ,
219+ N : IntoPyObject < ' py , Target = PyAny > + Clone ,
220+ E : IntoPyObject < ' py , Target = PyAny > + Clone ,
221+ {
222+ // For directed graphs, use SCCs; for undirected, use connected components
223+ let components: Vec < Vec < NodeIndex < Ix > > > = if Ty :: is_directed ( ) {
224+ if let Some ( sccs) = sccs {
225+ sccs. into_iter ( )
226+ . map ( |row| row. into_iter ( ) . map ( NodeIndex :: new) . collect ( ) )
227+ . collect ( )
228+ } else {
229+ algo:: kosaraju_scc ( & g)
230+ }
231+ } else {
232+ connectivity:: connected_components ( & g)
233+ . into_iter ( )
234+ . map ( |set| set. into_iter ( ) . collect ( ) )
235+ . collect ( )
236+ } ;
237+
238+ // Convert all NodeIndex<Ix> to NodeIndex<usize> for the output graph
239+ let components_usize: Vec < Vec < NodeIndex < usize > > > = components
240+ . iter ( )
241+ . map ( |comp| comp. iter ( ) . map ( |ix| NodeIndex :: new ( ix. index ( ) ) ) . collect ( ) )
242+ . collect ( ) ;
243+
244+ let mut condensed: StableGraph < Vec < N > , E , Ty , u32 > =
245+ StableGraph :: with_capacity ( components_usize. len ( ) , g. edge_count ( ) ) ;
246+
247+ // Build a map from old indices to new ones.
248+ let mut node_map = vec ! [ None ; g. node_count( ) ] ;
249+ for comp in components_usize. iter ( ) {
250+ let new_nix = condensed. add_node ( Vec :: new ( ) ) ;
251+ for nix in comp {
252+ node_map[ nix. index ( ) ] = Some ( new_nix. index ( ) ) ;
253+ }
254+ }
255+
256+ // Consume nodes and edges of the old graph and insert them into the new one.
257+ let ( nodes, edges) = g. into_nodes_edges ( ) ;
258+ for ( nix, node) in nodes. into_iter ( ) . enumerate ( ) {
259+ if let Some ( Some ( idx) ) = node_map. get ( nix) . copied ( ) {
260+ condensed[ NodeIndex :: new ( idx) ] . push ( node. weight ) ;
261+ }
262+ }
263+ for edge in edges {
264+ let ( source, target) = match (
265+ node_map. get ( edge. source ( ) . index ( ) ) ,
266+ node_map. get ( edge. target ( ) . index ( ) ) ,
267+ ) {
268+ ( Some ( Some ( s) ) , Some ( Some ( t) ) ) => ( NodeIndex :: new ( * s) , NodeIndex :: new ( * t) ) ,
269+ _ => continue ,
270+ } ;
271+
272+ if make_acyclic && Ty :: is_directed ( ) {
273+ if source != target {
274+ condensed. update_edge ( source, target, edge. weight ) ;
275+ }
276+ } else {
277+ condensed. add_edge ( source, target, edge. weight ) ;
278+ }
279+ }
280+
281+ let mapped = condensed. map (
282+ |_, w| match w. clone ( ) . into_pyobject ( py) {
283+ Ok ( bound) => bound. unbind ( ) ,
284+ Err ( _) => PyValueError :: new_err ( "Node conversion failed" )
285+ . into_pyobject ( py)
286+ . unwrap ( )
287+ . unbind ( )
288+ . into ( ) ,
289+ } ,
290+ |_, w| match w. clone ( ) . into_pyobject ( py) {
291+ Ok ( bound) => bound. unbind ( ) ,
292+ Err ( _) => PyValueError :: new_err ( "Edge conversion failed" )
293+ . into_pyobject ( py)
294+ . unwrap ( )
295+ . unbind ( )
296+ . into ( ) ,
297+ } ,
298+ ) ;
299+ Ok ( ( mapped, node_map) )
300+ }
301+
302+ #[ pyfunction]
303+ #[ pyo3( text_signature = "(graph, /, sccs=None)" , signature=( graph, sccs=None ) ) ]
304+ pub fn digraph_condensation (
305+ py : Python ,
306+ graph : digraph:: PyDiGraph ,
307+ sccs : Option < Vec < Vec < usize > > > ,
308+ ) -> PyResult < digraph:: PyDiGraph > {
309+ let g = graph. graph . clone ( ) ;
310+ let ( condensed, node_map) = condensation_inner ( py, g. into ( ) , true , sccs) ?;
311+
312+ let mut attrs = HashMap :: new ( ) ;
313+ attrs. insert ( "node_map" , node_map. clone ( ) ) ;
314+
315+ let result = digraph:: PyDiGraph {
316+ graph : condensed,
317+ cycle_state : algo:: DfsSpace :: default ( ) ,
318+ check_cycle : false ,
319+ node_removed : false ,
320+ multigraph : true ,
321+ attrs : attrs. into_pyobject ( py) ?. into ( ) ,
322+ } ;
323+ Ok ( result)
324+ }
325+
326+ #[ pyfunction]
327+ #[ pyo3( text_signature = "(graph, /)" ) ]
328+ pub fn graph_condensation ( py : Python , graph : graph:: PyGraph ) -> PyResult < graph:: PyGraph > {
329+ let g = graph. graph . clone ( ) ;
330+ let ( condensed, node_map) = condensation_inner ( py, g. into ( ) , false , None ) ?;
331+
332+ let mut attrs = HashMap :: new ( ) ;
333+ attrs. insert ( "node_map" , node_map. clone ( ) ) ;
334+
335+ let result = graph:: PyGraph {
336+ graph : condensed,
337+ node_removed : false ,
338+ multigraph : graph. multigraph ,
339+ attrs : attrs. into_pyobject ( py) ?. into ( ) ,
340+ } ;
341+ Ok ( result)
342+ }
343+
195344/// Return the first cycle encountered during DFS of a given PyDiGraph,
196345/// empty list is returned if no cycle is found
197346///
@@ -480,7 +629,7 @@ pub fn is_semi_connected(graph: &digraph::PyDiGraph) -> PyResult<bool> {
480629 temp_graph. add_edge ( node_map[ source. index ( ) ] , node_map[ target. index ( ) ] , ( ) ) ;
481630 }
482631
483- let condensed = condensation ( temp_graph, true ) ;
632+ let condensed = algo :: condensation ( temp_graph, true ) ;
484633 let n = condensed. node_count ( ) ;
485634 let weight_fn =
486635 |_: petgraph:: graph:: EdgeReference < ( ) > | Ok :: < usize , std:: convert:: Infallible > ( 1usize ) ;
0 commit comments