Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
169 changes: 132 additions & 37 deletions rustworkx-core/src/token_swapper.rs
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,8 @@ use crate::traversal::dfs_edges;

type Swap = (NodeIndex, NodeIndex);
type Edge = (NodeIndex, NodeIndex);
type SwapList<N> = Vec<(N, N)>;
type SwapResult<N> = Result<SwapList<N>, MapNotPossible>;

/// Error returned by token swapper if the request mapping
/// is impossible
Expand All @@ -48,27 +50,43 @@ impl fmt::Display for MapNotPossible {
}
}

struct TokenSwapper<G: GraphBase>
// Define a trait for mapping providers
pub trait MappingProvider<G: GraphBase> {
fn get(&self, key: &G::NodeId) -> Option<&G::NodeId>;
fn iter(&self) -> Box<dyn Iterator<Item = (&G::NodeId, &G::NodeId)> + '_>;
}

// Implement the trait for HashMap
impl<G: GraphBase> MappingProvider<G> for HashMap<G::NodeId, G::NodeId>
where
G::NodeId: Eq + Hash,
{
fn get(&self, key: &G::NodeId) -> Option<&G::NodeId> {
HashMap::get(self, key)
}

fn iter(&self) -> Box<dyn Iterator<Item = (&G::NodeId, &G::NodeId)> + '_> {
Box::new(HashMap::iter(self))
}
}

struct TokenSwapper<G: GraphBase, M: MappingProvider<G>>
where
G::NodeId: Eq + Hash,
{
// The input graph
graph: G,
// The user-supplied mapping to use for swapping tokens
mapping: HashMap<G::NodeId, G::NodeId>,
mapping: M,
// Number of trials
trials: usize,
// Seed for random selection of a node for a trial
seed: Option<u64>,
// Threshold for how many nodes will trigger parallel iterator
parallel_threshold: usize,
// Map of NodeId to NodeIndex
node_map: HashMap<G::NodeId, NodeIndex>,
// Map of NodeIndex to NodeId
rev_node_map: HashMap<NodeIndex, G::NodeId>,
}

impl<G> TokenSwapper<G>
impl<G, M> TokenSwapper<G, M>
where
G: NodeCount
+ EdgeCount
Expand All @@ -80,10 +98,11 @@ where
+ Send
+ Sync,
G::NodeId: Hash + Eq + Send + Sync,
M: MappingProvider<G> + Send + Sync,
{
fn new(
graph: G,
mapping: HashMap<G::NodeId, G::NodeId>,
mapping: M,
trials: Option<usize>,
seed: Option<u64>,
parallel_threshold: Option<usize>,
Expand All @@ -94,12 +113,10 @@ where
trials: trials.unwrap_or(4),
seed,
parallel_threshold: parallel_threshold.unwrap_or(50),
node_map: HashMap::with_capacity(graph.node_count()),
rev_node_map: HashMap::with_capacity(graph.node_count()),
}
}

fn map(&mut self) -> Result<Vec<Swap>, MapNotPossible> {
fn map(&self) -> SwapResult<G::NodeId> {
let num_nodes = self.graph.node_bound();
let num_edges = self.graph.edge_count();

Expand All @@ -124,21 +141,17 @@ where
count += 1;
}

// Create maps between NodeId and NodeIndex
for node in self.graph.node_identifiers() {
self.node_map
.insert(node, NodeIndex::new(self.graph.to_index(node)));
self.rev_node_map
.insert(NodeIndex::new(self.graph.to_index(node)), node);
}
// sub will become same as digraph but with no self edges in add_token_edges
let mut sub_digraph = digraph.clone();

// The mapping in HashMap form using NodeIndex
let mut tokens: HashMap<NodeIndex, NodeIndex> = self
.mapping
.iter()
.map(|(k, v)| (self.node_map[k], self.node_map[v]))
.map(|(k, v)| {
let k_idx = NodeIndex::new(self.graph.to_index(*k));
let v_idx = NodeIndex::new(self.graph.to_index(*v));
(k_idx, v_idx)
})
.collect();

// todo_nodes are all the mapping entries where left != right
Expand All @@ -150,12 +163,8 @@ where

// Add initial edges to the digraph/sub_digraph
for node in self.graph.node_identifiers() {
self.add_token_edges(
self.node_map[&node],
&mut digraph,
&mut sub_digraph,
&mut tokens,
)?;
let node_idx = NodeIndex::new(self.graph.to_index(node));
self.add_token_edges(node_idx, &mut digraph, &mut sub_digraph, &mut tokens)?;
}
// First collect the self.trial number of random numbers
// into a Vec based on the given seed
Expand Down Expand Up @@ -205,15 +214,15 @@ where
digraph.update_edge(node, node, ());
return Ok(());
}
let id_node = self.rev_node_map[&node];
let id_token = self.rev_node_map[&tokens[&node]];
let id_node = self.graph.from_index(node.index());
let id_token = self.graph.from_index(tokens[&node].index());

if self.graph.neighbors(id_node).next().is_none() {
return Err(MapNotPossible {});
}

for id_neighbor in self.graph.neighbors(id_node) {
let neighbor = self.node_map[&id_neighbor];
let neighbor = NodeIndex::new(self.graph.to_index(id_neighbor));
let dist_neighbor: DictMap<G::NodeId, usize> = dijkstra(
&self.graph,
id_neighbor,
Expand Down Expand Up @@ -255,7 +264,7 @@ where
mut tokens: HashMap<NodeIndex, NodeIndex>,
mut todo_nodes: Vec<NodeIndex>,
trial_seed: u64,
) -> Result<Vec<Swap>, MapNotPossible> {
) -> SwapResult<G::NodeId> {
// Create a random trial list of swaps to move tokens to optimal positions
let mut steps = 0;
let mut swap_edges: Vec<Swap> = vec![];
Expand Down Expand Up @@ -338,7 +347,15 @@ where
todo_nodes.is_empty(),
"The output final swap map is incomplete, this points to a bug in rustworkx, please open an issue."
);
Ok(swap_edges)
let result: Vec<(G::NodeId, G::NodeId)> = swap_edges
.into_iter()
.map(|(ni1, ni2)| {
let id1 = self.graph.from_index(ni1.index());
let id2 = self.graph.from_index(ni2.index());
(id1, id2)
})
.collect();
Ok(result)
}

fn swap(
Expand Down Expand Up @@ -412,7 +429,8 @@ where
/// trigger the use of parallel threads. If the number of nodes in the graph is less than this value
/// it will run in a single thread. The default value is 50.
///
/// It returns a list of tuples representing the swaps to perform. The result will be an
/// It returns a list of tuples representing the swaps to perform, where each tuple contains
/// node identifiers of type `(G::NodeId, G::NodeId)`. The result will be an
/// `Err(MapNotPossible)` if the `token_swapper()` function can't find a mapping.
///
/// This function is multithreaded and will launch a thread pool with threads equal to
Expand All @@ -439,13 +457,13 @@ where
/// assert_eq!(3, output.len());
///
/// ```
pub fn token_swapper<G>(
pub fn token_swapper<G, M>(
graph: G,
mapping: HashMap<G::NodeId, G::NodeId>,
mapping: M,
trials: Option<usize>,
seed: Option<u64>,
parallel_threshold: Option<usize>,
) -> Result<Vec<Swap>, MapNotPossible>
) -> SwapResult<G::NodeId>
where
G: NodeCount
+ EdgeCount
Expand All @@ -457,18 +475,21 @@ where
+ Send
+ Sync,
G::NodeId: Hash + Eq + Send + Sync,
M: MappingProvider<G> + Send + Sync,
{
let mut swapper = TokenSwapper::new(graph, mapping, trials, seed, parallel_threshold);
let swapper = TokenSwapper::new(graph, mapping, trials, seed, parallel_threshold);
swapper.map()
}

#[cfg(test)]
mod test_token_swapper {

use crate::petgraph;
use crate::token_swapper::token_swapper;
use crate::token_swapper::{token_swapper, MappingProvider};
use hashbrown::HashMap;
use petgraph::graph::NodeIndex;
use petgraph::visit::GraphBase;
use std::hash::Hash;

fn do_swap(mapping: &mut HashMap<NodeIndex, NodeIndex>, swaps: &Vec<(NodeIndex, NodeIndex)>) {
// Apply the swaps to the mapping to get final result
Expand All @@ -493,6 +514,80 @@ mod test_token_swapper {
}
}

struct VecMappingProvider<N> {
pairs: Vec<(N, N)>,
}

impl<N> VecMappingProvider<N> {
fn new(pairs: Vec<(N, N)>) -> Self {
VecMappingProvider { pairs }
}
}

// Implement MappingProvider for any G where G::NodeId = N
impl<G, N> MappingProvider<G> for VecMappingProvider<N>
where
G: GraphBase<NodeId = N>,
N: Eq + Hash + 'static, // 'static bound for Box<dyn Iterator>
{
fn get(&self, key: &N) -> Option<&N> {
self.pairs.iter().find(|(k, _)| k == key).map(|(_, v)| v)
}

fn iter(&self) -> Box<dyn Iterator<Item = (&N, &N)> + '_> {
Box::new(self.pairs.iter().map(|(k, v)| (k, v)))
}
}

#[test]
fn test_vec_mapping_provider() {
let g = petgraph::graph::UnGraph::<(), ()>::from_edges([(0, 1), (1, 2), (2, 3)]);
let pairs = vec![
(NodeIndex::new(0), NodeIndex::new(0)),
(NodeIndex::new(1), NodeIndex::new(3)),
(NodeIndex::new(3), NodeIndex::new(1)),
(NodeIndex::new(2), NodeIndex::new(2)),
];
let mapping_provider = VecMappingProvider::new(pairs.clone());
let swaps = token_swapper(&g, mapping_provider, Some(4), Some(4), Some(50))
.expect("swap mapping errored");
assert_eq!(3, swaps.len());

let mut applied_map: HashMap<_, _> = pairs.into_iter().collect();
do_swap(&mut applied_map, &swaps);
let expected: HashMap<_, _> = vec![
(NodeIndex::new(0), NodeIndex::new(0)),
(NodeIndex::new(3), NodeIndex::new(3)),
(NodeIndex::new(1), NodeIndex::new(1)),
(NodeIndex::new(2), NodeIndex::new(2)),
]
.into_iter()
.collect();
assert_eq!(expected, applied_map);
}

#[test]
fn test_return_type_is_node_id() {
let g = petgraph::graph::UnGraph::<(), ()>::from_edges([(0, 1), (1, 2)]);
let mapping = HashMap::from([
(NodeIndex::new(0), NodeIndex::new(2)),
(NodeIndex::new(2), NodeIndex::new(0)),
]);
let swaps = token_swapper(&g, mapping.clone(), Some(4), Some(4), Some(50))
.expect("swap mapping errored");
assert_eq!(swaps.len(), 3); // Adjusted to 3 steps for path graph
let first_swap = swaps[0];
// Explicitly check that swaps are (NodeIndex, NodeIndex), which is G::NodeId
let _: (NodeIndex, NodeIndex) = first_swap; // Type assertion
let mut new_map = mapping;
do_swap(&mut new_map, &swaps);
let expected = HashMap::from([
(NodeIndex::new(0), NodeIndex::new(0)),
(NodeIndex::new(2), NodeIndex::new(2)),
]);
assert_eq!(new_map, expected);
}

#[test]
fn test_simple_swap() {
// Simple arbitrary swap
Expand Down
2 changes: 1 addition & 1 deletion rustworkx/rustworkx.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -1003,7 +1003,7 @@ def graph_tensor_product(

def graph_token_swapper(
graph: PyGraph,
mapping: dict[int, int],
mapping: Mapping[int, int],
/,
trials: int | None = ...,
seed: int | None = ...,
Expand Down
Loading