Skip to content

Commit 8001f6e

Browse files
committed
modify the api layer, type annotations, the binder and the rust core
1 parent f2238fd commit 8001f6e

File tree

8 files changed

+183
-39
lines changed

8 files changed

+183
-39
lines changed

rustworkx-core/src/connectivity/minimum_cycle_basis.rs renamed to rustworkx-core/src/connectivity/minimal_cycle_basis.rs

Lines changed: 88 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,9 @@
11
use crate::connectivity::conn_components::connected_components;
22
use crate::dictmap::*;
3-
use crate::shortest_path::dijkstra;
3+
use crate::shortest_path::{astar, dijkstra};
44
use crate::Result;
55
use hashbrown::{HashMap, HashSet};
6-
use petgraph::algo::{astar, min_spanning_tree, Measure};
6+
use petgraph::algo::{min_spanning_tree, Measure};
77
use petgraph::csr::{DefaultIx, IndexType};
88
use petgraph::data::{DataMap, Element};
99
use petgraph::graph::Graph;
@@ -13,6 +13,7 @@ use petgraph::visit::{
1313
IntoNeighborsDirected, IntoNodeIdentifiers, IntoNodeReferences, NodeIndexable, Visitable,
1414
};
1515
use petgraph::Undirected;
16+
use std::cmp::Ordering;
1617
use std::convert::Infallible;
1718
use std::hash::Hash;
1819

@@ -39,7 +40,7 @@ where
3940
G::NodeId: Eq + Hash,
4041
G::EdgeWeight: Clone,
4142
F: FnMut(G::EdgeRef) -> Result<K, E>,
42-
K: Clone + PartialOrd + Copy + Measure + Default + Ord,
43+
K: Clone + PartialOrd + Copy + Measure + Default,
4344
{
4445
components
4546
.into_iter()
@@ -77,7 +78,7 @@ where
7778
})
7879
.collect()
7980
}
80-
pub fn minimum_cycle_basis<G, F, K, E>(graph: G, mut weight_fn: F) -> Result<Vec<Vec<NodeIndex>>, E>
81+
pub fn minimal_cycle_basis<G, F, K, E>(graph: G, mut weight_fn: F) -> Result<Vec<Vec<NodeIndex>>, E>
8182
where
8283
G: EdgeCount
8384
+ IntoNodeIdentifiers
@@ -88,10 +89,10 @@ where
8889
+ IntoNeighborsDirected
8990
+ Visitable
9091
+ IntoEdges,
91-
G::EdgeWeight: Clone + PartialOrd,
92+
G::EdgeWeight: Clone,
9293
G::NodeId: Eq + Hash,
9394
F: FnMut(G::EdgeRef) -> Result<K, E>,
94-
K: Clone + PartialOrd + Copy + Measure + Default + Ord,
95+
K: Clone + PartialOrd + Copy + Measure + Default,
9596
{
9697
let conn_components = connected_components(&graph);
9798
let mut min_cycle_basis = Vec::new();
@@ -136,7 +137,7 @@ where
136137
H::EdgeWeight: Clone + PartialOrd,
137138
H::NodeId: Eq + Hash,
138139
F: FnMut(H::EdgeRef) -> Result<K, E>,
139-
K: Clone + PartialOrd + Copy + Measure + Default + Ord,
140+
K: Clone + PartialOrd + Copy + Measure + Default,
140141
{
141142
let mut sub_cb: Vec<Vec<usize>> = Vec::new();
142143
let num_edges = subgraph.edge_count();
@@ -243,7 +244,7 @@ where
243244
H: IntoNodeReferences + IntoEdgeReferences + DataMap + NodeIndexable + EdgeIndexable,
244245
H::NodeId: Eq + Hash,
245246
F: FnMut(H::EdgeRef) -> Result<K, E>,
246-
K: Clone + PartialOrd + Copy + Measure + Default + Ord,
247+
K: Clone + PartialOrd + Copy + Measure + Default,
247248
{
248249
let mut gi = Graph::<_, _, petgraph::Undirected>::default();
249250
let mut subgraph_gi_map = HashMap::new();
@@ -290,31 +291,36 @@ where
290291
|edge| Ok(*edge.weight()),
291292
None,
292293
);
293-
// Find the shortest distance in the result and store it in the shortest_path_map
294294
let spl = result.unwrap()[&gi_lifted_nodeidx];
295295
shortest_path_map.insert(subnodeid, spl);
296296
}
297-
let min_start = shortest_path_map.iter().min_by_key(|x| x.1).unwrap().0;
297+
let min_start = shortest_path_map
298+
.iter()
299+
.min_by(|a, b| a.1.partial_cmp(b.1).unwrap_or(Ordering::Equal))
300+
.unwrap()
301+
.0;
298302
let min_start_node = subgraph_gi_map[min_start].0;
299303
let min_start_lifted_node = subgraph_gi_map[min_start].1;
300-
let result = astar(
304+
let result: Result<Option<(K, Vec<NodeIndex>)>> = astar(
301305
&gi,
302-
min_start_node,
303-
|finish| finish == min_start_lifted_node,
304-
|e| *e.weight(),
305-
|_| K::default(),
306+
min_start_node.clone(),
307+
|finish| Ok(finish == min_start_lifted_node.clone()),
308+
|e| Ok(*e.weight()),
309+
|_| Ok(K::default()),
306310
);
311+
307312
let mut min_path: Vec<usize> = Vec::new();
308313
match result {
309-
Some((_cost, path)) => {
314+
Ok(Some((_cost, path))) => {
310315
for node in path {
311316
if let Some(&subgraph_nodeid) = gi_subgraph_map.get(&node) {
312317
let subgraph_node = NodeIndexable::to_index(&subgraph, subgraph_nodeid);
313318
min_path.push(subgraph_node.index());
314319
}
315320
}
316321
}
317-
None => {}
322+
Ok(None) => {}
323+
Err(_) => {}
318324
}
319325
let edgelist = min_path
320326
.windows(2)
@@ -344,9 +350,9 @@ where
344350
}
345351

346352
#[cfg(test)]
347-
mod test_minimum_cycle_basis {
348-
use crate::connectivity::minimum_cycle_basis::minimum_cycle_basis;
349-
use petgraph::graph::Graph;
353+
mod test_minimal_cycle_basis {
354+
use crate::connectivity::minimal_cycle_basis::minimal_cycle_basis;
355+
use petgraph::graph::{Graph, NodeIndex};
350356
use petgraph::Undirected;
351357
use std::convert::Infallible;
352358

@@ -356,7 +362,7 @@ mod test_minimum_cycle_basis {
356362
let weight_fn = |edge: petgraph::graph::EdgeReference<i32>| -> Result<i32, Infallible> {
357363
Ok(*edge.weight())
358364
};
359-
let output = minimum_cycle_basis(&graph, weight_fn).unwrap();
365+
let output = minimal_cycle_basis(&graph, weight_fn).unwrap();
360366
assert_eq!(output.len(), 0);
361367
}
362368

@@ -372,8 +378,7 @@ mod test_minimum_cycle_basis {
372378
let weight_fn = |edge: petgraph::graph::EdgeReference<i32>| -> Result<i32, Infallible> {
373379
Ok(*edge.weight())
374380
};
375-
let cycles = minimum_cycle_basis(&graph, weight_fn);
376-
println!("Cycles {:?}", cycles.as_ref().unwrap());
381+
let cycles = minimal_cycle_basis(&graph, weight_fn);
377382
assert_eq!(cycles.unwrap().len(), 1);
378383
}
379384

@@ -393,10 +398,60 @@ mod test_minimum_cycle_basis {
393398
let weight_fn = |edge: petgraph::graph::EdgeReference<i32>| -> Result<i32, Infallible> {
394399
Ok(*edge.weight())
395400
};
396-
let cycles = minimum_cycle_basis(&graph, weight_fn);
401+
let cycles = minimal_cycle_basis(&graph, weight_fn);
397402
assert_eq!(cycles.unwrap().len(), 2);
398403
}
399404

405+
#[test]
406+
fn test_non_trivial_graph() {
407+
let mut g = Graph::<&str, i32, Undirected>::new_undirected();
408+
let a = g.add_node("A");
409+
let b = g.add_node("B");
410+
let c = g.add_node("C");
411+
let d = g.add_node("D");
412+
let e = g.add_node("E");
413+
let f = g.add_node("F");
414+
415+
g.add_edge(a, b, 7);
416+
g.add_edge(c, a, 9);
417+
g.add_edge(a, d, 11);
418+
g.add_edge(b, c, 10);
419+
g.add_edge(d, c, 2);
420+
g.add_edge(d, e, 9);
421+
g.add_edge(b, f, 15);
422+
g.add_edge(c, f, 11);
423+
g.add_edge(e, f, 6);
424+
425+
let weight_fn = |edge: petgraph::graph::EdgeReference<i32>| -> Result<i32, Infallible> {
426+
Ok(*edge.weight())
427+
};
428+
let output = minimal_cycle_basis(&g, weight_fn);
429+
let mut actual_output = output.unwrap();
430+
for cycle in &mut actual_output {
431+
cycle.sort();
432+
}
433+
actual_output.sort();
434+
435+
let expected_output: Vec<Vec<NodeIndex>> = vec![
436+
vec![
437+
NodeIndex::new(5),
438+
NodeIndex::new(2),
439+
NodeIndex::new(3),
440+
NodeIndex::new(4),
441+
],
442+
vec![NodeIndex::new(2), NodeIndex::new(5), NodeIndex::new(1)],
443+
vec![NodeIndex::new(0), NodeIndex::new(2), NodeIndex::new(1)],
444+
vec![NodeIndex::new(2), NodeIndex::new(3), NodeIndex::new(0)],
445+
];
446+
let mut sorted_expected_output = expected_output.clone();
447+
for cycle in &mut sorted_expected_output {
448+
cycle.sort();
449+
}
450+
sorted_expected_output.sort();
451+
452+
assert_eq!(actual_output, sorted_expected_output);
453+
}
454+
400455
#[test]
401456
fn test_weighted_diamond_graph() {
402457
let mut weighted_diamond = Graph::<(), i32, Undirected>::new_undirected();
@@ -412,20 +467,19 @@ mod test_minimum_cycle_basis {
412467
let weight_fn = |edge: petgraph::graph::EdgeReference<i32>| -> Result<i32, Infallible> {
413468
Ok(*edge.weight())
414469
};
415-
let output = minimum_cycle_basis(&weighted_diamond, weight_fn);
416-
let expected_output: Vec<Vec<usize>> = vec![vec![0, 1, 3], vec![0, 1, 2, 3]];
470+
let output = minimal_cycle_basis(&weighted_diamond, weight_fn);
471+
let expected_output1: Vec<Vec<usize>> = vec![vec![0, 1, 3], vec![0, 1, 2, 3]];
472+
let expected_output2: Vec<Vec<usize>> = vec![vec![1, 2, 3], vec![0, 1, 2, 3]];
417473
for cycle in output.unwrap().iter() {
418-
println!("{:?}", cycle);
419474
let mut node_indices: Vec<usize> = Vec::new();
420475
for node in cycle.iter() {
421476
node_indices.push(node.index());
422477
}
423478
node_indices.sort();
424-
println!("Node indices {:?}", node_indices);
425-
if expected_output.contains(&node_indices) {
426-
println!("Found cycle {:?}", node_indices);
427-
}
428-
assert!(expected_output.contains(&node_indices));
479+
assert!(
480+
expected_output1.contains(&node_indices)
481+
|| expected_output2.contains(&node_indices)
482+
);
429483
}
430484
}
431485

@@ -444,7 +498,7 @@ mod test_minimum_cycle_basis {
444498
let weight_fn =
445499
|_edge: petgraph::graph::EdgeReference<()>| -> Result<i32, Infallible> { Ok(1) };
446500

447-
let output = minimum_cycle_basis(&unweighted_diamond, weight_fn);
501+
let output = minimal_cycle_basis(&unweighted_diamond, weight_fn);
448502
let expected_output: Vec<Vec<usize>> = vec![vec![0, 1, 3], vec![1, 2, 3]];
449503
for cycle in output.unwrap().iter() {
450504
let mut node_indices: Vec<usize> = Vec::new();
@@ -476,7 +530,7 @@ mod test_minimum_cycle_basis {
476530
let weight_fn = |edge: petgraph::graph::EdgeReference<i32>| -> Result<i32, Infallible> {
477531
Ok(*edge.weight())
478532
};
479-
let output = minimum_cycle_basis(&complete_graph, weight_fn);
533+
let output = minimal_cycle_basis(&complete_graph, weight_fn);
480534
for cycle in output.unwrap().iter() {
481535
assert_eq!(cycle.len(), 3);
482536
}

rustworkx-core/src/connectivity/mod.rs

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@ mod cycle_basis;
2121
mod find_cycle;
2222
mod isolates;
2323
mod min_cut;
24-
mod minimum_cycle_basis;
24+
mod minimal_cycle_basis;
2525

2626
pub use all_simple_paths::{
2727
all_simple_paths_multiple_targets, longest_simple_path_multiple_targets,
@@ -37,4 +37,4 @@ pub use cycle_basis::cycle_basis;
3737
pub use find_cycle::find_cycle;
3838
pub use isolates::isolates;
3939
pub use min_cut::stoer_wagner_min_cut;
40-
pub use minimum_cycle_basis::minimum_cycle_basis;
40+
pub use minimal_cycle_basis::minimal_cycle_basis;

rustworkx/__init__.py

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -557,6 +557,32 @@ def all_pairs_dijkstra_path_lengths(graph, edge_cost_fn):
557557
raise TypeError("Invalid Input Type %s for graph" % type(graph))
558558

559559

560+
@_rustworkx_dispatch
561+
def minimum_cycle_basis(graph, edge_cost_fn):
562+
"""Find the minimum cycle basis of a graph.
563+
564+
This function will find the minimum cycle basis of a graph based on the
565+
following papers
566+
References:
567+
[1] Kavitha, Telikepalli, et al. "An O(m^2n) Algorithm for
568+
Minimum Cycle Basis of Graphs."
569+
http://link.springer.com/article/10.1007/s00453-007-9064-z
570+
[2] de Pina, J. 1995. Applications of shortest path methods.
571+
Ph.D. thesis, University of Amsterdam, Netherlands
572+
573+
:param graph: The input graph to use. Can either be a
574+
:class:`~rustworkx.PyGraph` or :class:`~rustworkx.PyDiGraph`
575+
:param edge_cost_fn: A callable object that acts as a weight function for
576+
an edge. It will accept a single positional argument, the edge's weight
577+
object and will return a float which will be used to represent the
578+
weight/cost of the edge
579+
580+
:return: A list of cycles where each cycle is a list of node indices
581+
582+
:rtype: list
583+
"""
584+
raise TypeError("Invalid Input Type %s for graph" % type(graph))
585+
560586
@_rustworkx_dispatch
561587
def dijkstra_shortest_path_lengths(graph, node, edge_cost_fn, goal=None):
562588
"""Compute the lengths of the shortest paths for a graph object using

rustworkx/__init__.pyi

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -83,7 +83,7 @@ from .rustworkx import graph_longest_simple_path as graph_longest_simple_path
8383
from .rustworkx import digraph_core_number as digraph_core_number
8484
from .rustworkx import graph_core_number as graph_core_number
8585
from .rustworkx import stoer_wagner_min_cut as stoer_wagner_min_cut
86-
from .rustworkx import minimum_cycle_basis as minimum_cycle_basis
86+
from .rustworkx import graph_minimum_cycle_basis as graph_minimum_cycle_basis
8787
from .rustworkx import simple_cycles as simple_cycles
8888
from .rustworkx import digraph_isolates as digraph_isolates
8989
from .rustworkx import graph_isolates as graph_isolates

rustworkx/rustworkx.pyi

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -247,7 +247,7 @@ def stoer_wagner_min_cut(
247247
/,
248248
weight_fn: Callable[[_T], float] | None = ...,
249249
) -> tuple[float, NodeIndices] | None: ...
250-
def minimum_cycle_basis(
250+
def graph_minimum_cycle_basis(
251251
graph: PyGraph[_S, _T],
252252
/,
253253
weight_fn: Callable[[_T], float] | None = ...
Lines changed: 41 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,41 @@
1+
use rustworkx_core::connectivity::minimal_cycle_basis;
2+
3+
use pyo3::exceptions::PyIndexError;
4+
use pyo3::prelude::*;
5+
use pyo3::Python;
6+
7+
use petgraph::graph::NodeIndex;
8+
use petgraph::prelude::*;
9+
use petgraph::visit::EdgeIndexable;
10+
use petgraph::EdgeType;
11+
12+
use crate::{CostFn, StablePyGraph};
13+
14+
pub fn minimum_cycle_basis_map<Ty: EdgeType + Sync>(
15+
py: Python,
16+
graph: &StablePyGraph<Ty>,
17+
edge_cost_fn: PyObject,
18+
) -> PyResult<Vec<Vec<NodeIndex>>> {
19+
if graph.node_count() == 0 {
20+
return Ok(vec![]);
21+
} else if graph.edge_count() == 0 {
22+
return Ok(vec![]);
23+
}
24+
let edge_cost_callable = CostFn::from(edge_cost_fn);
25+
let mut edge_weights: Vec<Option<f64>> = Vec::with_capacity(graph.edge_bound());
26+
for index in 0..=graph.edge_bound() {
27+
let raw_weight = graph.edge_weight(EdgeIndex::new(index));
28+
match raw_weight {
29+
Some(weight) => edge_weights.push(Some(edge_cost_callable.call(py, weight)?)),
30+
None => edge_weights.push(None),
31+
};
32+
}
33+
let edge_cost = |e: EdgeIndex| -> PyResult<f64> {
34+
match edge_weights[e.index()] {
35+
Some(weight) => Ok(weight),
36+
None => Err(PyIndexError::new_err("No edge found for index")),
37+
}
38+
};
39+
let cycle_basis = minimal_cycle_basis(graph, |e| edge_cost(e.id())).unwrap();
40+
Ok(cycle_basis)
41+
}

0 commit comments

Comments
 (0)