Skip to content

Commit d5304e6

Browse files
committed
modified the binder and API layer
1 parent 349e56a commit d5304e6

File tree

6 files changed

+73
-24
lines changed

6 files changed

+73
-24
lines changed

rustworkx/__init__.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -570,15 +570,14 @@ def minimum_cycle_basis(graph, edge_cost_fn):
570570
[2] de Pina, J. 1995. Applications of shortest path methods.
571571
Ph.D. thesis, University of Amsterdam, Netherlands
572572
573-
:param graph: The input graph to use. Can either be a
573+
:param graph: The input graph to use. Can be either a
574574
:class:`~rustworkx.PyGraph` or :class:`~rustworkx.PyDiGraph`
575575
:param edge_cost_fn: A callable object that acts as a weight function for
576576
an edge. It will accept a single positional argument, the edge's weight
577577
object and will return a float which will be used to represent the
578578
weight/cost of the edge
579579
580-
:return: A list of cycles where each cycle is a list of node indices
581-
580+
:returns: A list of cycles where each cycle is a list of node indices
582581
:rtype: list
583582
"""
584583
raise TypeError("Invalid Input Type %s for graph" % type(graph))

rustworkx/__init__.pyi

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -84,6 +84,7 @@ from .rustworkx import digraph_core_number as digraph_core_number
8484
from .rustworkx import graph_core_number as graph_core_number
8585
from .rustworkx import stoer_wagner_min_cut as stoer_wagner_min_cut
8686
from .rustworkx import graph_minimum_cycle_basis as graph_minimum_cycle_basis
87+
from .rustworkx import digraph_minimum_cycle_basis as digraph_minimum_cycle_basis
8788
from .rustworkx import simple_cycles as simple_cycles
8889
from .rustworkx import digraph_isolates as digraph_isolates
8990
from .rustworkx import graph_isolates as graph_isolates

rustworkx/rustworkx.pyi

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@ from typing import (
1515
Callable,
1616
Iterable,
1717
Iterator,
18+
Union,
1819
final,
1920
Sequence,
2021
Any,
@@ -248,8 +249,15 @@ def stoer_wagner_min_cut(
248249
weight_fn: Callable[[_T], float] | None = ...,
249250
) -> tuple[float, NodeIndices] | None: ...
250251
def graph_minimum_cycle_basis(
251-
graph: PyGraph[_S, _T], /, weight_fn: Callable[[_T], float] | None = ...
252-
) -> list[list[NodeIndices]] | None: ...
252+
graph: PyGraph[_S, _T],
253+
edge_cost: Callable[[_T], float],
254+
/,
255+
) -> list[list[NodeIndices]]: ...
256+
def digraph_minimum_cycle_basis(
257+
graph: PyDiGraph[_S, _T],
258+
edge_cost: Callable[[_T], float],
259+
/,
260+
) -> list[list[NodeIndices]]: ...
253261
def simple_cycles(graph: PyDiGraph, /) -> Iterator[NodeIndices]: ...
254262
def graph_isolates(graph: PyGraph) -> NodeIndices: ...
255263
def digraph_isolates(graph: PyDiGraph) -> NodeIndices: ...

src/connectivity/minimum_cycle_basis.rs renamed to src/connectivity/min_cycle_basis.rs

Lines changed: 17 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -4,18 +4,17 @@ use pyo3::exceptions::PyIndexError;
44
use pyo3::prelude::*;
55
use pyo3::Python;
66

7-
use petgraph::graph::NodeIndex;
7+
use crate::iterators::NodeIndices;
8+
use crate::{CostFn, StablePyGraph};
89
use petgraph::prelude::*;
910
use petgraph::visit::EdgeIndexable;
1011
use petgraph::EdgeType;
1112

12-
use crate::{CostFn, StablePyGraph};
13-
14-
pub fn minimum_cycle_basis_map<Ty: EdgeType + Sync>(
13+
pub fn minimum_cycle_basis<Ty: EdgeType + Sync>(
1514
py: Python,
1615
graph: &StablePyGraph<Ty>,
1716
edge_cost_fn: PyObject,
18-
) -> PyResult<Vec<Vec<NodeIndex>>> {
17+
) -> PyResult<Vec<Vec<NodeIndices>>> {
1918
if graph.node_count() == 0 || graph.edge_count() == 0 {
2019
return Ok(vec![]);
2120
}
@@ -35,5 +34,17 @@ pub fn minimum_cycle_basis_map<Ty: EdgeType + Sync>(
3534
}
3635
};
3736
let cycle_basis = minimal_cycle_basis(graph, |e| edge_cost(e.id())).unwrap();
38-
Ok(cycle_basis)
37+
// Convert the cycle basis to a list of lists of node indices
38+
let result: Vec<Vec<NodeIndices>> = cycle_basis
39+
.into_iter()
40+
.map(|cycle| {
41+
cycle
42+
.into_iter()
43+
.map(|node| NodeIndices {
44+
nodes: vec![node.index()],
45+
})
46+
.collect()
47+
})
48+
.collect();
49+
Ok(result)
3950
}

src/connectivity/mod.rs

Lines changed: 42 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@
1414

1515
mod all_pairs_all_simple_paths;
1616
mod johnson_simple_cycles;
17-
mod minimum_cycle_basis;
17+
mod min_cycle_basis;
1818
mod subgraphs;
1919

2020
use super::{
@@ -919,25 +919,54 @@ pub fn stoer_wagner_min_cut(
919919
}))
920920
}
921921

922+
/// Find a minimum cycle basis of an undirected graph.
923+
/// All weights must be nonnegative. If the input graph does not have
924+
/// any nodes or edges, this function returns ``None``.
925+
/// If the input graph does not any weight, this function will find the
926+
/// minimum cycle basis with the weight of 1.0 for all edges.
927+
///
928+
/// :param PyGraph: The undirected graph to be used
929+
/// :param Callable edge_cost_fn: An optional callable object (function, lambda, etc) which
930+
/// will be passed the edge object and expected to return a ``float``.
931+
/// Edges with ``NaN`` weights will be considered to have 1.0 weight.
932+
/// If ``edge_cost_fn`` is not specified a default value of ``1.0`` will be used for all edges.
933+
///
934+
/// :returns: A list of cycles, where each cycle is a list of node indices
935+
/// :rtype: list
922936
#[pyfunction]
923937
#[pyo3(text_signature = "(graph, edge_cost_fn, /)")]
924938
pub fn graph_minimum_cycle_basis(
925939
py: Python,
926940
graph: &graph::PyGraph,
927941
edge_cost_fn: PyObject,
928942
) -> PyResult<Vec<Vec<NodeIndices>>> {
929-
let basis = minimum_cycle_basis::minimum_cycle_basis_map(py, &graph.graph, edge_cost_fn);
930-
Ok(basis
931-
.into_iter()
932-
.map(|cycle| {
933-
cycle
934-
.into_iter()
935-
.map(|node| NodeIndices {
936-
nodes: node.iter().map(|nx| nx.index()).collect(),
937-
})
938-
.collect()
939-
})
940-
.collect())
943+
min_cycle_basis::minimum_cycle_basis(py, &graph.graph, edge_cost_fn)
944+
}
945+
946+
/// Find a minimum cycle basis of a directed graph (which is not of interest in the context
947+
/// of minimum cycle basis). This function will return the minimum cycle basis of the
948+
/// underlying undirected graph of the input directed graph.
949+
/// All weights must be nonnegative. If the input graph does not have
950+
/// any nodes or edges, this function returns ``None``.
951+
/// If the input graph does not any weight, this function will find the
952+
/// minimum cycle basis with the weight of 1.0 for all edges.
953+
///
954+
/// :param PyDiGraph: The directed graph to be used
955+
/// :param Callable edge_cost_fn: An optional callable object (function, lambda, etc) which
956+
/// will be passed the edge object and expected to return a ``float``.
957+
/// Edges with ``NaN`` weights will be considered to have 1.0 weight.
958+
/// If ``edge_cost_fn`` is not specified a default value of ``1.0`` will be used for all edges.
959+
///
960+
/// :returns: A list of cycles, where each cycle is a list of node indices
961+
/// :rtype: list
962+
#[pyfunction]
963+
#[pyo3(text_signature = "(graph, edge_cost_fn, /)")]
964+
pub fn digraph_minimum_cycle_basis(
965+
py: Python,
966+
graph: &digraph::PyDiGraph,
967+
edge_cost_fn: PyObject,
968+
) -> PyResult<Vec<Vec<NodeIndices>>> {
969+
min_cycle_basis::minimum_cycle_basis(py, &graph.graph, edge_cost_fn)
941970
}
942971

943972
/// Return the articulation points of an undirected graph.

src/lib.rs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -571,6 +571,7 @@ fn rustworkx(py: Python<'_>, m: &Bound<PyModule>) -> PyResult<()> {
571571
m.add_wrapped(wrap_pyfunction!(metric_closure))?;
572572
m.add_wrapped(wrap_pyfunction!(stoer_wagner_min_cut))?;
573573
m.add_wrapped(wrap_pyfunction!(graph_minimum_cycle_basis))?;
574+
m.add_wrapped(wrap_pyfunction!(digraph_minimum_cycle_basis))?;
574575
m.add_wrapped(wrap_pyfunction!(steiner_tree::steiner_tree))?;
575576
m.add_wrapped(wrap_pyfunction!(digraph_dfs_search))?;
576577
m.add_wrapped(wrap_pyfunction!(graph_dfs_search))?;

0 commit comments

Comments
 (0)