@@ -950,13 +950,37 @@ pub fn local_complement(
950950 Ok ( complement_graph)
951951}
952952
953+ /// Represents target nodes for path-finding operations.
954+ ///
955+ /// This enum allows functions to accept either a single target node
956+ /// or multiple target nodes, providing flexibility in path-finding algorithms.
957+ pub enum TargetNodes {
958+ Single ( NodeIndex ) ,
959+ Multiple ( HashSet < NodeIndex > ) ,
960+ }
961+
962+ impl < ' py > FromPyObject < ' py > for TargetNodes {
963+ fn extract_bound ( ob : & Bound < ' py , PyAny > ) -> PyResult < Self > {
964+ if let Ok ( int) = ob. extract :: < usize > ( ) {
965+ Ok ( Self :: Single ( NodeIndex :: new ( int) ) )
966+ } else {
967+ let mut target_set: HashSet < NodeIndex > = HashSet :: new ( ) ;
968+ for target in ob. try_iter ( ) ? {
969+ let target_index = NodeIndex :: new ( target?. extract :: < usize > ( ) ?) ;
970+ target_set. insert ( target_index) ;
971+ }
972+ Ok ( Self :: Multiple ( target_set) )
973+ }
974+ }
975+ }
976+
953977/// Return all simple paths between 2 nodes in a PyGraph object
954978///
955979/// A simple path is a path with no repeated nodes.
956980///
957981/// :param PyGraph graph: The graph to find the path in
958982/// :param int origin: The node index to find the paths from
959- /// :param int to: The node index to find the paths to
983+ /// :param int | iterable[int] to: The node index(es) to find the paths to
960984/// :param int min_depth: The minimum depth of the path to include in the output
961985/// list of paths. By default all paths are included regardless of depth,
962986/// setting to 0 will behave like the default.
@@ -971,7 +995,7 @@ pub fn local_complement(
971995pub fn graph_all_simple_paths (
972996 graph : & graph:: PyGraph ,
973997 origin : usize ,
974- to : usize ,
998+ to : TargetNodes ,
975999 min_depth : Option < usize > ,
9761000 cutoff : Option < usize > ,
9771001) -> PyResult < Vec < Vec < usize > > > {
@@ -981,27 +1005,48 @@ pub fn graph_all_simple_paths(
9811005 "The input index for 'from' is not a valid node index" ,
9821006 ) ) ;
9831007 }
984- let to_index = NodeIndex :: new ( to) ;
985- if !graph. graph . contains_node ( to_index) {
986- return Err ( InvalidNode :: new_err (
987- "The input index for 'to' is not a valid node index" ,
988- ) ) ;
989- }
9901008 let min_intermediate_nodes: usize = match min_depth {
9911009 Some ( 0 ) | None => 0 ,
9921010 Some ( depth) => depth - 2 ,
9931011 } ;
9941012 let cutoff_petgraph: Option < usize > = cutoff. map ( |depth| depth - 2 ) ;
995- let result: Vec < Vec < usize > > = algo:: all_simple_paths :: < Vec < _ > , _ , foldhash:: fast:: RandomState > (
996- & graph. graph ,
997- from_index,
998- to_index,
999- min_intermediate_nodes,
1000- cutoff_petgraph,
1001- )
1002- . map ( |v : Vec < NodeIndex > | v. into_iter ( ) . map ( |i| i. index ( ) ) . collect ( ) )
1003- . collect ( ) ;
1004- Ok ( result)
1013+
1014+ match to {
1015+ TargetNodes :: Single ( to_index) => {
1016+ if !graph. graph . contains_node ( to_index) {
1017+ return Err ( InvalidNode :: new_err (
1018+ "The input index for 'to' is not a valid node index" ,
1019+ ) ) ;
1020+ }
1021+
1022+ let result: Vec < Vec < usize > > =
1023+ algo:: all_simple_paths :: < Vec < _ > , _ , foldhash:: fast:: RandomState > (
1024+ & graph. graph ,
1025+ from_index,
1026+ to_index,
1027+ min_intermediate_nodes,
1028+ cutoff_petgraph,
1029+ )
1030+ . map ( |v : Vec < NodeIndex > | v. into_iter ( ) . map ( |i| i. index ( ) ) . collect ( ) )
1031+ . collect ( ) ;
1032+ Ok ( result)
1033+ }
1034+ TargetNodes :: Multiple ( target_set) => {
1035+ let result = connectivity:: all_simple_paths_multiple_targets (
1036+ & graph. graph ,
1037+ from_index,
1038+ & target_set,
1039+ min_intermediate_nodes,
1040+ cutoff_petgraph,
1041+ ) ;
1042+
1043+ Ok ( result
1044+ . into_values ( )
1045+ . flatten ( )
1046+ . map ( |path| path. into_iter ( ) . map ( |node| node. index ( ) ) . collect ( ) )
1047+ . collect ( ) )
1048+ }
1049+ }
10051050}
10061051
10071052/// Return all simple paths between 2 nodes in a PyDiGraph object
@@ -1010,7 +1055,7 @@ pub fn graph_all_simple_paths(
10101055///
10111056/// :param PyDiGraph graph: The graph to find the path in
10121057/// :param int origin: The node index to find the paths from
1013- /// :param int to: The node index to find the paths to
1058+ /// :param int | iterable[int] to: The node index(es) to find the paths to
10141059/// :param int min_depth: The minimum depth of the path to include in the output
10151060/// list of paths. By default all paths are included regardless of depth,
10161061/// setting to 0 will behave like the default.
@@ -1025,7 +1070,7 @@ pub fn graph_all_simple_paths(
10251070pub fn digraph_all_simple_paths (
10261071 graph : & digraph:: PyDiGraph ,
10271072 origin : usize ,
1028- to : usize ,
1073+ to : TargetNodes ,
10291074 min_depth : Option < usize > ,
10301075 cutoff : Option < usize > ,
10311076) -> PyResult < Vec < Vec < usize > > > {
@@ -1035,27 +1080,48 @@ pub fn digraph_all_simple_paths(
10351080 "The input index for 'from' is not a valid node index" ,
10361081 ) ) ;
10371082 }
1038- let to_index = NodeIndex :: new ( to) ;
1039- if !graph. graph . contains_node ( to_index) {
1040- return Err ( InvalidNode :: new_err (
1041- "The input index for 'to' is not a valid node index" ,
1042- ) ) ;
1043- }
10441083 let min_intermediate_nodes: usize = match min_depth {
10451084 Some ( 0 ) | None => 0 ,
10461085 Some ( depth) => depth - 2 ,
10471086 } ;
10481087 let cutoff_petgraph: Option < usize > = cutoff. map ( |depth| depth - 2 ) ;
1049- let result: Vec < Vec < usize > > = algo:: all_simple_paths :: < Vec < _ > , _ , foldhash:: fast:: RandomState > (
1050- & graph. graph ,
1051- from_index,
1052- to_index,
1053- min_intermediate_nodes,
1054- cutoff_petgraph,
1055- )
1056- . map ( |v : Vec < NodeIndex > | v. into_iter ( ) . map ( |i| i. index ( ) ) . collect ( ) )
1057- . collect ( ) ;
1058- Ok ( result)
1088+
1089+ match to {
1090+ TargetNodes :: Single ( to_index) => {
1091+ if !graph. graph . contains_node ( to_index) {
1092+ return Err ( InvalidNode :: new_err (
1093+ "The input index for 'to' is not a valid node index" ,
1094+ ) ) ;
1095+ }
1096+
1097+ let result: Vec < Vec < usize > > =
1098+ algo:: all_simple_paths :: < Vec < _ > , _ , foldhash:: fast:: RandomState > (
1099+ & graph. graph ,
1100+ from_index,
1101+ to_index,
1102+ min_intermediate_nodes,
1103+ cutoff_petgraph,
1104+ )
1105+ . map ( |v : Vec < NodeIndex > | v. into_iter ( ) . map ( |i| i. index ( ) ) . collect ( ) )
1106+ . collect ( ) ;
1107+ Ok ( result)
1108+ }
1109+ TargetNodes :: Multiple ( target_set) => {
1110+ let result = connectivity:: all_simple_paths_multiple_targets (
1111+ & graph. graph ,
1112+ from_index,
1113+ & target_set,
1114+ min_intermediate_nodes,
1115+ cutoff_petgraph,
1116+ ) ;
1117+
1118+ Ok ( result
1119+ . into_values ( )
1120+ . flatten ( )
1121+ . map ( |path| path. into_iter ( ) . map ( |node| node. index ( ) ) . collect ( ) )
1122+ . collect ( ) )
1123+ }
1124+ }
10591125}
10601126
10611127/// Return all the simple paths between all pairs of nodes in the graph
0 commit comments