Skip to content

Commit 7318a80

Browse files
mtreinishjakelishmanIvanIsCoding
authored
Add distance_matrix to rustworkx-core (#1439)
* Add distance_matrix to rustworkx-core This commit moves the distance matrix functionality to rustworkx. This is mostly a straightforward migration as the functionality was written in a generic way already. The only difference is how node holes were handled, the this opted to split the functions into 2, a verison that assumes the graph has compact indices and one that doesn't. * Optimize the implementation of distance matrix This commit performs some optimizations on the internals of the distance_matrix() function. It avoids extra allocations and uses a fixedbitset for tracking instead of hashsets. Co-authored-by: Jake Lishman <[email protected]> * Deduplicate functions * Add release note * Remove stray debug print * Remove invalid classifier and capitalize keywords * Bump release version to 0.17.1 --------- Co-authored-by: Jake Lishman <[email protected]> Co-authored-by: Ivan Carvalho <[email protected]> Co-authored-by: Ivan Carvalho <[email protected]>
1 parent 3b01a17 commit 7318a80

File tree

9 files changed

+177
-98
lines changed

9 files changed

+177
-98
lines changed

Cargo.lock

Lines changed: 2 additions & 2 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

Cargo.toml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@ members = [
1717
]
1818

1919
[workspace.package]
20-
version = "0.17.0"
20+
version = "0.17.1"
2121
edition = "2021"
2222
rust-version = "1.79"
2323
authors = ["Matthew Treinish <[email protected]>"]
@@ -62,7 +62,7 @@ rayon.workspace = true
6262
serde = { version = "1.0", features = ["derive"] }
6363
serde_json = "1.0"
6464
smallvec = { version = "1.0", features = ["union"] }
65-
rustworkx-core = { path = "rustworkx-core", version = "=0.17.0" }
65+
rustworkx-core = { path = "rustworkx-core", version = "=0.17.1" }
6666
flate2 = "1.0.35"
6767

6868
[dependencies.pyo3]

docs/source/conf.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,7 @@
2626
# The short X.Y version.
2727
version = '0.17'
2828
# The full version, including alpha/beta/rc tags.
29-
release = '0.17.0'
29+
release = '0.17.1'
3030

3131
extensions = ['sphinx.ext.autodoc',
3232
'sphinx.ext.autosummary',

pyproject.toml

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
[project]
22
name = "rustworkx"
3-
version = "0.17.0"
3+
version = "0.17.1"
44
description = "A High-Performance Graph Library for Python"
55
requires-python = ">=3.9"
66
dependencies = [
@@ -22,9 +22,8 @@ classifiers=[
2222
"Operating System :: MacOS :: MacOS X",
2323
"Operating System :: Microsoft :: Windows",
2424
"Operating System :: POSIX :: Linux",
25-
"Development Status :: 5 - Production/Stable",
2625
]
27-
keywords = ["Networks", "network", "graph", "Graph Theory", "DAG"]
26+
keywords = ["Networks", "Network", "Graph", "Graph Theory", "DAG"]
2827

2928
[tool.setuptools]
3029
packages = ["rustworkx", "rustworkx.visualization"]
Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,6 @@
1+
---
2+
features:
3+
- |
4+
Added a new function ``rustworkx_core::shortest_path::distance_matrix``
5+
to rustworkx-core. This function is the equivalent of :func:`.distance_matrix`
6+
for the Python library, but as a generic Rust function for rustworkx-core.
Lines changed: 159 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,159 @@
1+
// Licensed under the Apache License, Version 2.0 (the "License"); you may
2+
// not use this file except in compliance with the License. You may obtain
3+
// a copy of the License at
4+
//
5+
// http://www.apache.org/licenses/LICENSE-2.0
6+
//
7+
// Unless required by applicable law or agreed to in writing, software
8+
// distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
9+
// WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
10+
// License for the specific language governing permissions and limitations
11+
// under the License.
12+
13+
use std::hash::Hash;
14+
15+
use hashbrown::HashMap;
16+
17+
use fixedbitset::FixedBitSet;
18+
use ndarray::prelude::*;
19+
use petgraph::visit::{
20+
GraphProp, IntoNeighborsDirected, IntoNodeIdentifiers, NodeCount, NodeIndexable,
21+
};
22+
use petgraph::{Incoming, Outgoing};
23+
use rayon::prelude::*;
24+
25+
/// Get the distance matrix for a graph
26+
///
27+
/// The generated distance matrix assumes the edge weight for all edges is
28+
/// 1.0 and returns a matrix.
29+
///
30+
/// This function is also multithreaded and will run in parallel if the number
31+
/// of nodes in the graph is above the value of `parallel_threshold`. If the function
32+
/// will be running in parallel the env var
33+
/// `RAYON_NUM_THREADS` can be used to adjust how many threads will be used.
34+
///
35+
/// # Arguments:
36+
///
37+
/// * graph - The graph object to compute the distance matrix for.
38+
/// * parallel_threshold - The threshold in number of nodes to run this function in parallel.
39+
/// If `graph` has fewer nodes than this the algorithm will run serially. A good default
40+
/// to use for this is 300.
41+
/// * as_undirected - If the input graph is directed and this is set to true the output
42+
/// matrix generated
43+
/// * null_value - The value to use for the absence of a path in the graph.
44+
///
45+
/// # Returns
46+
///
47+
/// A 2d ndarray [`Array`] of the distance matrix
48+
///
49+
/// # Example
50+
///
51+
/// ```rust
52+
/// use rustworkx_core::petgraph;
53+
/// use rustworkx_core::shortest_path::distance_matrix;
54+
/// use ndarray::{array, Array2};
55+
///
56+
/// let graph = petgraph::graph::UnGraph::<(), ()>::from_edges(&[
57+
/// (0, 1), (0, 6), (1, 2), (2, 3), (3, 4), (4, 5), (5, 6)
58+
/// ]);
59+
/// let distance_matrix = distance_matrix(&graph, 300, false, 0.);
60+
/// let expected: Array2<f64> = array![
61+
/// [0.0, 1.0, 2.0, 3.0, 3.0, 2.0, 1.0],
62+
/// [1.0, 0.0, 1.0, 2.0, 3.0, 3.0, 2.0],
63+
/// [2.0, 1.0, 0.0, 1.0, 2.0, 3.0, 3.0],
64+
/// [3.0, 2.0, 1.0, 0.0, 1.0, 2.0, 3.0],
65+
/// [3.0, 3.0, 2.0, 1.0, 0.0, 1.0, 2.0],
66+
/// [2.0, 3.0, 3.0, 2.0, 1.0, 0.0, 1.0],
67+
/// [1.0, 2.0, 3.0, 3.0, 2.0, 1.0, 0.0],
68+
/// ];
69+
/// assert_eq!(distance_matrix, expected)
70+
/// ```
71+
pub fn distance_matrix<G>(
72+
graph: G,
73+
parallel_threshold: usize,
74+
as_undirected: bool,
75+
null_value: f64,
76+
) -> Array2<f64>
77+
where
78+
G: Sync + IntoNeighborsDirected + NodeCount + NodeIndexable + IntoNodeIdentifiers + GraphProp,
79+
G::NodeId: Hash + Eq + Sync,
80+
{
81+
let n = graph.node_count();
82+
let node_map: HashMap<G::NodeId, usize> = if n != graph.node_bound() {
83+
graph
84+
.node_identifiers()
85+
.enumerate()
86+
.map(|(i, v)| (v, i))
87+
.collect()
88+
} else {
89+
HashMap::new()
90+
};
91+
let node_map_inv: Vec<G::NodeId> = if n != graph.node_bound() {
92+
graph.node_identifiers().collect()
93+
} else {
94+
Vec::new()
95+
};
96+
let mut node_map_fn: Box<dyn FnMut(G::NodeId) -> usize> = if n != graph.node_bound() {
97+
Box::new(|n: G::NodeId| -> usize { node_map[&n] })
98+
} else {
99+
Box::new(|n: G::NodeId| -> usize { graph.to_index(n) })
100+
};
101+
let mut reverse_node_map: Box<dyn FnMut(usize) -> G::NodeId> = if n != graph.node_bound() {
102+
Box::new(|n: usize| -> G::NodeId { node_map_inv[n] })
103+
} else {
104+
Box::new(|n: usize| -> G::NodeId { graph.from_index(n) })
105+
};
106+
let mut matrix = Array2::<f64>::from_elem((n, n), null_value);
107+
let neighbors = if as_undirected {
108+
(0..n)
109+
.map(|index| {
110+
graph
111+
.neighbors_directed(reverse_node_map(index), Incoming)
112+
.chain(graph.neighbors_directed(reverse_node_map(index), Outgoing))
113+
.map(&mut node_map_fn)
114+
.collect::<FixedBitSet>()
115+
})
116+
.collect::<Vec<_>>()
117+
} else {
118+
(0..n)
119+
.map(|index| {
120+
graph
121+
.neighbors(reverse_node_map(index))
122+
.map(&mut node_map_fn)
123+
.collect::<FixedBitSet>()
124+
})
125+
.collect::<Vec<_>>()
126+
};
127+
let bfs_traversal = |start: usize, mut row: ArrayViewMut1<f64>| {
128+
let mut distance = 0.0;
129+
let mut seen = FixedBitSet::with_capacity(n);
130+
let mut next = FixedBitSet::with_capacity(n);
131+
let mut cur = FixedBitSet::with_capacity(n);
132+
cur.put(start);
133+
while !cur.is_clear() {
134+
next.clear();
135+
for found in cur.ones() {
136+
row[[found]] = distance;
137+
next |= &neighbors[found];
138+
}
139+
seen.union_with(&cur);
140+
next.difference_with(&seen);
141+
distance += 1.0;
142+
::std::mem::swap(&mut cur, &mut next);
143+
}
144+
};
145+
if n < parallel_threshold {
146+
matrix
147+
.axis_iter_mut(Axis(0))
148+
.enumerate()
149+
.for_each(|(index, row)| bfs_traversal(index, row));
150+
} else {
151+
// Parallelize by row and iterate from each row index in BFS order
152+
matrix
153+
.axis_iter_mut(Axis(0))
154+
.into_par_iter()
155+
.enumerate()
156+
.for_each(|(index, row)| bfs_traversal(index, row));
157+
}
158+
matrix
159+
}

rustworkx-core/src/shortest_path/mod.rs

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,12 +19,14 @@ mod all_shortest_paths;
1919
mod astar;
2020
mod bellman_ford;
2121
mod dijkstra;
22+
mod distance_matrix;
2223
mod k_shortest_path;
2324
mod single_source_all_shortest_paths;
2425

2526
pub use all_shortest_paths::all_shortest_paths;
2627
pub use astar::astar;
2728
pub use bellman_ford::{bellman_ford, negative_cycle_finder};
2829
pub use dijkstra::dijkstra;
30+
pub use distance_matrix::distance_matrix;
2931
pub use k_shortest_path::k_shortest_path;
3032
pub use single_source_all_shortest_paths::single_source_all_shortest_paths;

src/shortest_path/distance_matrix.rs

Lines changed: 2 additions & 89 deletions
Original file line numberDiff line numberDiff line change
@@ -10,105 +10,18 @@
1010
// License for the specific language governing permissions and limitations
1111
// under the License.
1212

13-
use std::ops::Index;
14-
15-
use hashbrown::{HashMap, HashSet};
16-
1713
use ndarray::prelude::*;
18-
use petgraph::prelude::*;
1914
use petgraph::EdgeType;
20-
use rayon::prelude::*;
2115

22-
use crate::NodesRemoved;
2316
use crate::StablePyGraph;
2417

25-
#[inline]
26-
fn apply<I, M>(
27-
map_fn: &Option<M>,
28-
x: I,
29-
default: <M as Index<I>>::Output,
30-
) -> <M as Index<I>>::Output
31-
where
32-
M: Index<I>,
33-
<M as Index<I>>::Output: Sized + Copy,
34-
{
35-
match map_fn {
36-
Some(map) => map[x],
37-
None => default,
38-
}
39-
}
18+
use rustworkx_core::shortest_path;
4019

4120
pub fn compute_distance_matrix<Ty: EdgeType + Sync>(
4221
graph: &StablePyGraph<Ty>,
4322
parallel_threshold: usize,
4423
as_undirected: bool,
4524
null_value: f64,
4625
) -> Array2<f64> {
47-
let node_map: Option<HashMap<NodeIndex, usize>> = if graph.nodes_removed() {
48-
Some(
49-
graph
50-
.node_indices()
51-
.enumerate()
52-
.map(|(i, v)| (v, i))
53-
.collect(),
54-
)
55-
} else {
56-
None
57-
};
58-
59-
let node_map_inv: Option<Vec<NodeIndex>> = if graph.nodes_removed() {
60-
Some(graph.node_indices().collect())
61-
} else {
62-
None
63-
};
64-
65-
let n = graph.node_count();
66-
let mut matrix = Array2::<f64>::from_elem((n, n), null_value);
67-
let bfs_traversal = |index: usize, mut row: ArrayViewMut1<f64>| {
68-
let mut seen: HashMap<NodeIndex, usize> = HashMap::with_capacity(n);
69-
let start_index = apply(&node_map_inv, index, NodeIndex::new(index));
70-
let mut level = 0;
71-
let mut next_level: HashSet<NodeIndex> = HashSet::new();
72-
next_level.insert(start_index);
73-
while !next_level.is_empty() {
74-
let this_level = next_level;
75-
next_level = HashSet::new();
76-
let mut found: Vec<NodeIndex> = Vec::new();
77-
for v in this_level {
78-
if !seen.contains_key(&v) {
79-
seen.insert(v, level);
80-
found.push(v);
81-
row[[apply(&node_map, &v, v.index())]] = level as f64;
82-
}
83-
}
84-
if seen.len() == n {
85-
return;
86-
}
87-
for node in found {
88-
for v in graph.neighbors_directed(node, petgraph::Direction::Outgoing) {
89-
next_level.insert(v);
90-
}
91-
if graph.is_directed() && as_undirected {
92-
for v in graph.neighbors_directed(node, petgraph::Direction::Incoming) {
93-
next_level.insert(v);
94-
}
95-
}
96-
}
97-
level += 1
98-
}
99-
};
100-
if n < parallel_threshold {
101-
matrix
102-
.axis_iter_mut(Axis(0))
103-
.enumerate()
104-
.for_each(|(index, row)| bfs_traversal(index, row));
105-
} else {
106-
// Parallelize by row and iterate from each row index in BFS order
107-
matrix
108-
.axis_iter_mut(Axis(0))
109-
.into_par_iter()
110-
.enumerate()
111-
.for_each(|(index, row)| bfs_traversal(index, row));
112-
}
113-
matrix
26+
shortest_path::distance_matrix(graph, parallel_threshold, as_undirected, null_value)
11427
}

uv.lock

Lines changed: 1 addition & 1 deletion
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

0 commit comments

Comments
 (0)