diff --git a/rustworkx-core/src/token_swapper.rs b/rustworkx-core/src/token_swapper.rs index 8e43a11257..84bbdfc43a 100644 --- a/rustworkx-core/src/token_swapper.rs +++ b/rustworkx-core/src/token_swapper.rs @@ -35,6 +35,8 @@ use crate::traversal::dfs_edges; type Swap = (NodeIndex, NodeIndex); type Edge = (NodeIndex, NodeIndex); +type SwapList = Vec<(N, N)>; +type SwapResult = Result, MapNotPossible>; /// Error returned by token swapper if the request mapping /// is impossible @@ -48,27 +50,43 @@ impl fmt::Display for MapNotPossible { } } -struct TokenSwapper +// Define a trait for mapping providers +pub trait MappingProvider { + fn get(&self, key: &G::NodeId) -> Option<&G::NodeId>; + fn iter(&self) -> Box + '_>; +} + +// Implement the trait for HashMap +impl MappingProvider for HashMap +where + G::NodeId: Eq + Hash, +{ + fn get(&self, key: &G::NodeId) -> Option<&G::NodeId> { + HashMap::get(self, key) + } + + fn iter(&self) -> Box + '_> { + Box::new(HashMap::iter(self)) + } +} + +struct TokenSwapper> where G::NodeId: Eq + Hash, { // The input graph graph: G, // The user-supplied mapping to use for swapping tokens - mapping: HashMap, + mapping: M, // Number of trials trials: usize, // Seed for random selection of a node for a trial seed: Option, // Threshold for how many nodes will trigger parallel iterator parallel_threshold: usize, - // Map of NodeId to NodeIndex - node_map: HashMap, - // Map of NodeIndex to NodeId - rev_node_map: HashMap, } -impl TokenSwapper +impl TokenSwapper where G: NodeCount + EdgeCount @@ -80,10 +98,11 @@ where + Send + Sync, G::NodeId: Hash + Eq + Send + Sync, + M: MappingProvider + Send + Sync, { fn new( graph: G, - mapping: HashMap, + mapping: M, trials: Option, seed: Option, parallel_threshold: Option, @@ -94,12 +113,10 @@ where trials: trials.unwrap_or(4), seed, parallel_threshold: parallel_threshold.unwrap_or(50), - node_map: HashMap::with_capacity(graph.node_count()), - rev_node_map: HashMap::with_capacity(graph.node_count()), } } - fn map(&mut self) -> Result, MapNotPossible> { + fn map(&self) -> SwapResult { let num_nodes = self.graph.node_bound(); let num_edges = self.graph.edge_count(); @@ -124,21 +141,17 @@ where count += 1; } - // Create maps between NodeId and NodeIndex - for node in self.graph.node_identifiers() { - self.node_map - .insert(node, NodeIndex::new(self.graph.to_index(node))); - self.rev_node_map - .insert(NodeIndex::new(self.graph.to_index(node)), node); - } // sub will become same as digraph but with no self edges in add_token_edges let mut sub_digraph = digraph.clone(); - // The mapping in HashMap form using NodeIndex let mut tokens: HashMap = self .mapping .iter() - .map(|(k, v)| (self.node_map[k], self.node_map[v])) + .map(|(k, v)| { + let k_idx = NodeIndex::new(self.graph.to_index(*k)); + let v_idx = NodeIndex::new(self.graph.to_index(*v)); + (k_idx, v_idx) + }) .collect(); // todo_nodes are all the mapping entries where left != right @@ -150,12 +163,8 @@ where // Add initial edges to the digraph/sub_digraph for node in self.graph.node_identifiers() { - self.add_token_edges( - self.node_map[&node], - &mut digraph, - &mut sub_digraph, - &mut tokens, - )?; + let node_idx = NodeIndex::new(self.graph.to_index(node)); + self.add_token_edges(node_idx, &mut digraph, &mut sub_digraph, &mut tokens)?; } // First collect the self.trial number of random numbers // into a Vec based on the given seed @@ -205,15 +214,15 @@ where digraph.update_edge(node, node, ()); return Ok(()); } - let id_node = self.rev_node_map[&node]; - let id_token = self.rev_node_map[&tokens[&node]]; + let id_node = self.graph.from_index(node.index()); + let id_token = self.graph.from_index(tokens[&node].index()); if self.graph.neighbors(id_node).next().is_none() { return Err(MapNotPossible {}); } for id_neighbor in self.graph.neighbors(id_node) { - let neighbor = self.node_map[&id_neighbor]; + let neighbor = NodeIndex::new(self.graph.to_index(id_neighbor)); let dist_neighbor: DictMap = dijkstra( &self.graph, id_neighbor, @@ -255,7 +264,7 @@ where mut tokens: HashMap, mut todo_nodes: Vec, trial_seed: u64, - ) -> Result, MapNotPossible> { + ) -> SwapResult { // Create a random trial list of swaps to move tokens to optimal positions let mut steps = 0; let mut swap_edges: Vec = vec![]; @@ -338,7 +347,15 @@ where todo_nodes.is_empty(), "The output final swap map is incomplete, this points to a bug in rustworkx, please open an issue." ); - Ok(swap_edges) + let result: Vec<(G::NodeId, G::NodeId)> = swap_edges + .into_iter() + .map(|(ni1, ni2)| { + let id1 = self.graph.from_index(ni1.index()); + let id2 = self.graph.from_index(ni2.index()); + (id1, id2) + }) + .collect(); + Ok(result) } fn swap( @@ -412,7 +429,8 @@ where /// trigger the use of parallel threads. If the number of nodes in the graph is less than this value /// it will run in a single thread. The default value is 50. /// -/// It returns a list of tuples representing the swaps to perform. The result will be an +/// It returns a list of tuples representing the swaps to perform, where each tuple contains +/// node identifiers of type `(G::NodeId, G::NodeId)`. The result will be an /// `Err(MapNotPossible)` if the `token_swapper()` function can't find a mapping. /// /// This function is multithreaded and will launch a thread pool with threads equal to @@ -439,13 +457,13 @@ where /// assert_eq!(3, output.len()); /// /// ``` -pub fn token_swapper( +pub fn token_swapper( graph: G, - mapping: HashMap, + mapping: M, trials: Option, seed: Option, parallel_threshold: Option, -) -> Result, MapNotPossible> +) -> SwapResult where G: NodeCount + EdgeCount @@ -457,8 +475,9 @@ where + Send + Sync, G::NodeId: Hash + Eq + Send + Sync, + M: MappingProvider + Send + Sync, { - let mut swapper = TokenSwapper::new(graph, mapping, trials, seed, parallel_threshold); + let swapper = TokenSwapper::new(graph, mapping, trials, seed, parallel_threshold); swapper.map() } @@ -466,9 +485,11 @@ where mod test_token_swapper { use crate::petgraph; - use crate::token_swapper::token_swapper; + use crate::token_swapper::{token_swapper, MappingProvider}; use hashbrown::HashMap; use petgraph::graph::NodeIndex; + use petgraph::visit::GraphBase; + use std::hash::Hash; fn do_swap(mapping: &mut HashMap, swaps: &Vec<(NodeIndex, NodeIndex)>) { // Apply the swaps to the mapping to get final result @@ -493,6 +514,80 @@ mod test_token_swapper { } } + struct VecMappingProvider { + pairs: Vec<(N, N)>, + } + + impl VecMappingProvider { + fn new(pairs: Vec<(N, N)>) -> Self { + VecMappingProvider { pairs } + } + } + + // Implement MappingProvider for any G where G::NodeId = N + impl MappingProvider for VecMappingProvider + where + G: GraphBase, + N: Eq + Hash + 'static, // 'static bound for Box + { + fn get(&self, key: &N) -> Option<&N> { + self.pairs.iter().find(|(k, _)| k == key).map(|(_, v)| v) + } + + fn iter(&self) -> Box + '_> { + Box::new(self.pairs.iter().map(|(k, v)| (k, v))) + } + } + + #[test] + fn test_vec_mapping_provider() { + let g = petgraph::graph::UnGraph::<(), ()>::from_edges([(0, 1), (1, 2), (2, 3)]); + let pairs = vec![ + (NodeIndex::new(0), NodeIndex::new(0)), + (NodeIndex::new(1), NodeIndex::new(3)), + (NodeIndex::new(3), NodeIndex::new(1)), + (NodeIndex::new(2), NodeIndex::new(2)), + ]; + let mapping_provider = VecMappingProvider::new(pairs.clone()); + let swaps = token_swapper(&g, mapping_provider, Some(4), Some(4), Some(50)) + .expect("swap mapping errored"); + assert_eq!(3, swaps.len()); + + let mut applied_map: HashMap<_, _> = pairs.into_iter().collect(); + do_swap(&mut applied_map, &swaps); + let expected: HashMap<_, _> = vec![ + (NodeIndex::new(0), NodeIndex::new(0)), + (NodeIndex::new(3), NodeIndex::new(3)), + (NodeIndex::new(1), NodeIndex::new(1)), + (NodeIndex::new(2), NodeIndex::new(2)), + ] + .into_iter() + .collect(); + assert_eq!(expected, applied_map); + } + + #[test] + fn test_return_type_is_node_id() { + let g = petgraph::graph::UnGraph::<(), ()>::from_edges([(0, 1), (1, 2)]); + let mapping = HashMap::from([ + (NodeIndex::new(0), NodeIndex::new(2)), + (NodeIndex::new(2), NodeIndex::new(0)), + ]); + let swaps = token_swapper(&g, mapping.clone(), Some(4), Some(4), Some(50)) + .expect("swap mapping errored"); + assert_eq!(swaps.len(), 3); // Adjusted to 3 steps for path graph + let first_swap = swaps[0]; + // Explicitly check that swaps are (NodeIndex, NodeIndex), which is G::NodeId + let _: (NodeIndex, NodeIndex) = first_swap; // Type assertion + let mut new_map = mapping; + do_swap(&mut new_map, &swaps); + let expected = HashMap::from([ + (NodeIndex::new(0), NodeIndex::new(0)), + (NodeIndex::new(2), NodeIndex::new(2)), + ]); + assert_eq!(new_map, expected); + } + #[test] fn test_simple_swap() { // Simple arbitrary swap diff --git a/rustworkx/rustworkx.pyi b/rustworkx/rustworkx.pyi index 32fb3c1a77..18f8f20187 100644 --- a/rustworkx/rustworkx.pyi +++ b/rustworkx/rustworkx.pyi @@ -1003,7 +1003,7 @@ def graph_tensor_product( def graph_token_swapper( graph: PyGraph, - mapping: dict[int, int], + mapping: Mapping[int, int], /, trials: int | None = ..., seed: int | None = ..., diff --git a/src/token_swapper.rs b/src/token_swapper.rs index a2ee68e21d..70a84fc96e 100644 --- a/src/token_swapper.rs +++ b/src/token_swapper.rs @@ -17,7 +17,6 @@ use crate::InvalidMapping; use hashbrown::HashMap; use petgraph::graph::NodeIndex; use pyo3::prelude::*; -use rustworkx_core::token_swapper; /// This module performs an approximately optimal Token Swapping algorithm /// Supports partial mappings (i.e. not-permutations) for graphs with missing tokens. @@ -28,10 +27,12 @@ use rustworkx_core::token_swapper; /// The inputs are a partial ``mapping`` to be implemented in swaps, and the number of ``trials`` /// to perform the mapping. It's minimized over the trials. /// -/// It returns a list of tuples representing the swaps to perform. +/// It returns a list of tuples representing the swaps to perform, where each tuple contains +/// the node identifiers (integers) of the nodes to swap. /// /// :param PyGraph graph: The input graph -/// :param dict[int: int] mapping: Map of (node, token) +/// :param Mapping[int, int] mapping: Map of (node, token). Can be any mapping-like object +/// that associates integer node indices to integer token positions (e.g., dict, or custom Mapping types). /// :param int trials: The number of trials to run /// :param int seed: The random seed to be used in producing random ints for selecting /// which nodes to process next @@ -44,8 +45,8 @@ use rustworkx_core::token_swapper; /// the ``RAYON_NUM_THREADS`` environment variable. For example, setting ``RAYON_NUM_THREADS=4`` /// would limit the thread pool to 4 threads. /// -/// :returns: A list of tuples which are the swaps to be applied to the mapping to rearrange -/// the tokens. +/// :returns: A list of tuples containing the node identifiers (integers) of the swaps to be +/// applied to the mapping to rearrange the tokens. /// :rtype: EdgeList #[pyfunction] #[pyo3( @@ -53,25 +54,37 @@ use rustworkx_core::token_swapper; signature = (graph, mapping, trials=None, seed=None, parallel_threshold=None) )] pub fn graph_token_swapper( + py: Python<'_>, graph: &graph::PyGraph, - mapping: HashMap, + mapping: Py, trials: Option, seed: Option, parallel_threshold: Option, ) -> PyResult { - let map: HashMap = mapping - .iter() - .map(|(s, t)| (NodeIndex::new(*s), NodeIndex::new(*t))) - .collect(); - let swaps = - match token_swapper::token_swapper(&graph.graph, map, trials, seed, parallel_threshold) { - Ok(swaps) => swaps, - Err(_) => { - return Err(InvalidMapping::new_err( - "Specified mapping could not be made on the given graph", - )) - } - }; + let items = mapping.getattr(py, "items")?.call0(py)?; + let mut map: HashMap = HashMap::new(); + + for item_result in items.bind(py).try_iter()? { + let item = item_result?; + let (key, value): (usize, usize) = item.extract()?; + map.insert(NodeIndex::new(key), NodeIndex::new(value)); + } + + let swaps = match rustworkx_core::token_swapper::token_swapper( + &graph.graph, + map, + trials, + seed, + parallel_threshold, + ) { + Ok(swaps) => swaps, + Err(_) => { + return Err(InvalidMapping::new_err( + "Specified mapping could not be made on the given graph", + )) + } + }; + Ok(EdgeList { edges: swaps .into_iter() diff --git a/tests/test_token_swapper.py b/tests/test_token_swapper.py index 17232498c2..78eb5e2a6f 100644 --- a/tests/test_token_swapper.py +++ b/tests/test_token_swapper.py @@ -13,7 +13,7 @@ import unittest import itertools import rustworkx as rx - +from collections.abc import Mapping from numpy import random @@ -39,6 +39,52 @@ def setUp(self) -> None: super().setUp() random.seed(0) + def test_simple_dict(self) -> None: + """Test a simple permutation on a path graph using a dictionary.""" + graph = rx.generators.path_graph(4) + mapping = {0: 0, 1: 3, 3: 1, 2: 2} + swaps = rx.graph_token_swapper(graph, mapping, 4, 4, 1) + self.assertIsInstance(swaps, rx.EdgeList) + self.assertTrue( + all( + isinstance(s, tuple) and len(s) == 2 and all(isinstance(n, int) for n in s) + for s in swaps + ) + ) + swap_permutation(mapping, swaps) + self.assertEqual(3, len(swaps)) + self.assertEqual({i: i for i in range(4)}, mapping) + + def test_custom_mapping(self) -> None: + """Test a permutation using a custom mapping class.""" + + class CustomMapping(Mapping): + def __init__(self, data): + self._data = data + + def __getitem__(self, key): + return self._data[key] + + def __iter__(self): + return iter(self._data) + + def __len__(self): + return len(self._data) + + graph = rx.generators.path_graph(4) + mapping = CustomMapping({0: 0, 1: 3, 3: 1, 2: 2}) + swaps = rx.graph_token_swapper(graph, mapping, 4, 4, 1) + self.assertIsInstance(swaps, rx.EdgeList) + self.assertTrue( + all( + isinstance(s, tuple) and len(s) == 2 and all(isinstance(n, int) for n in s) + for s in swaps + ) + ) + swap_permutation(mapping._data, swaps) + self.assertEqual(3, len(swaps)) + self.assertEqual({i: i for i in range(4)}, mapping._data) + def test_simple(self) -> None: """Test a simple permutation on a path graph of size 4.""" graph = rx.generators.path_graph(4)