Skip to content
Open
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
86 changes: 52 additions & 34 deletions rustworkx-core/src/token_swapper.rs
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,8 @@ use crate::traversal::dfs_edges;

type Swap = (NodeIndex, NodeIndex);
type Edge = (NodeIndex, NodeIndex);
type SwapList<N> = Vec<(N, N)>;
type SwapResult<N> = Result<SwapList<N>, MapNotPossible>;

/// Error returned by token swapper if the request mapping
/// is impossible
Expand All @@ -48,27 +50,43 @@ impl fmt::Display for MapNotPossible {
}
}

struct TokenSwapper<G: GraphBase>
// Define a trait for mapping providers
pub trait MappingProvider<G: GraphBase> {
fn get(&self, key: &G::NodeId) -> Option<&G::NodeId>;
fn iter(&self) -> Box<dyn Iterator<Item = (&G::NodeId, &G::NodeId)> + '_>;
}

// Implement the trait for HashMap
impl<G: GraphBase> MappingProvider<G> for HashMap<G::NodeId, G::NodeId>
where
G::NodeId: Eq + Hash,
{
fn get(&self, key: &G::NodeId) -> Option<&G::NodeId> {
HashMap::get(self, key)
}

fn iter(&self) -> Box<dyn Iterator<Item = (&G::NodeId, &G::NodeId)> + '_> {
Box::new(HashMap::iter(self))
}
}

struct TokenSwapper<G: GraphBase, M: MappingProvider<G>>
where
G::NodeId: Eq + Hash,
{
// The input graph
graph: G,
// The user-supplied mapping to use for swapping tokens
mapping: HashMap<G::NodeId, G::NodeId>,
mapping: M,
// Number of trials
trials: usize,
// Seed for random selection of a node for a trial
seed: Option<u64>,
// Threshold for how many nodes will trigger parallel iterator
parallel_threshold: usize,
// Map of NodeId to NodeIndex
node_map: HashMap<G::NodeId, NodeIndex>,
// Map of NodeIndex to NodeId
rev_node_map: HashMap<NodeIndex, G::NodeId>,
}

impl<G> TokenSwapper<G>
impl<G, M> TokenSwapper<G, M>
where
G: NodeCount
+ EdgeCount
Expand All @@ -80,10 +98,11 @@ where
+ Send
+ Sync,
G::NodeId: Hash + Eq + Send + Sync,
M: MappingProvider<G> + Send + Sync,
{
fn new(
graph: G,
mapping: HashMap<G::NodeId, G::NodeId>,
mapping: M,
trials: Option<usize>,
seed: Option<u64>,
parallel_threshold: Option<usize>,
Expand All @@ -94,12 +113,10 @@ where
trials: trials.unwrap_or(4),
seed,
parallel_threshold: parallel_threshold.unwrap_or(50),
node_map: HashMap::with_capacity(graph.node_count()),
rev_node_map: HashMap::with_capacity(graph.node_count()),
}
}

fn map(&mut self) -> Result<Vec<Swap>, MapNotPossible> {
fn map(&self) -> SwapResult<G::NodeId> {
let num_nodes = self.graph.node_bound();
let num_edges = self.graph.edge_count();

Expand All @@ -124,21 +141,17 @@ where
count += 1;
}

// Create maps between NodeId and NodeIndex
for node in self.graph.node_identifiers() {
self.node_map
.insert(node, NodeIndex::new(self.graph.to_index(node)));
self.rev_node_map
.insert(NodeIndex::new(self.graph.to_index(node)), node);
}
// sub will become same as digraph but with no self edges in add_token_edges
let mut sub_digraph = digraph.clone();

// The mapping in HashMap form using NodeIndex
let mut tokens: HashMap<NodeIndex, NodeIndex> = self
.mapping
.iter()
.map(|(k, v)| (self.node_map[k], self.node_map[v]))
.map(|(k, v)| {
let k_idx = NodeIndex::new(self.graph.to_index(*k));
let v_idx = NodeIndex::new(self.graph.to_index(*v));
(k_idx, v_idx)
})
.collect();

// todo_nodes are all the mapping entries where left != right
Expand All @@ -150,12 +163,8 @@ where

// Add initial edges to the digraph/sub_digraph
for node in self.graph.node_identifiers() {
self.add_token_edges(
self.node_map[&node],
&mut digraph,
&mut sub_digraph,
&mut tokens,
)?;
let node_idx = NodeIndex::new(self.graph.to_index(node));
self.add_token_edges(node_idx, &mut digraph, &mut sub_digraph, &mut tokens)?;
}
// First collect the self.trial number of random numbers
// into a Vec based on the given seed
Expand Down Expand Up @@ -205,15 +214,15 @@ where
digraph.update_edge(node, node, ());
return Ok(());
}
let id_node = self.rev_node_map[&node];
let id_token = self.rev_node_map[&tokens[&node]];
let id_node = self.graph.from_index(node.index());
let id_token = self.graph.from_index(tokens[&node].index());

if self.graph.neighbors(id_node).next().is_none() {
return Err(MapNotPossible {});
}

for id_neighbor in self.graph.neighbors(id_node) {
let neighbor = self.node_map[&id_neighbor];
let neighbor = NodeIndex::new(self.graph.to_index(id_neighbor));
let dist_neighbor: DictMap<G::NodeId, usize> = dijkstra(
&self.graph,
id_neighbor,
Expand Down Expand Up @@ -255,7 +264,7 @@ where
mut tokens: HashMap<NodeIndex, NodeIndex>,
mut todo_nodes: Vec<NodeIndex>,
trial_seed: u64,
) -> Result<Vec<Swap>, MapNotPossible> {
) -> SwapResult<G::NodeId> {
// Create a random trial list of swaps to move tokens to optimal positions
let mut steps = 0;
let mut swap_edges: Vec<Swap> = vec![];
Expand Down Expand Up @@ -338,7 +347,15 @@ where
todo_nodes.is_empty(),
"The output final swap map is incomplete, this points to a bug in rustworkx, please open an issue."
);
Ok(swap_edges)
let result: Vec<(G::NodeId, G::NodeId)> = swap_edges
.into_iter()
.map(|(ni1, ni2)| {
let id1 = self.graph.from_index(ni1.index());
let id2 = self.graph.from_index(ni2.index());
(id1, id2)
})
.collect();
Ok(result)
}

fn swap(
Expand Down Expand Up @@ -412,7 +429,8 @@ where
/// trigger the use of parallel threads. If the number of nodes in the graph is less than this value
/// it will run in a single thread. The default value is 50.
///
/// It returns a list of tuples representing the swaps to perform. The result will be an
/// It returns a list of tuples representing the swaps to perform, where each tuple contains
/// node identifiers of type `(G::NodeId, G::NodeId)`. The result will be an
/// `Err(MapNotPossible)` if the `token_swapper()` function can't find a mapping.
///
/// This function is multithreaded and will launch a thread pool with threads equal to
Expand Down Expand Up @@ -445,7 +463,7 @@ pub fn token_swapper<G>(
trials: Option<usize>,
seed: Option<u64>,
parallel_threshold: Option<usize>,
) -> Result<Vec<Swap>, MapNotPossible>
) -> SwapResult<G::NodeId>
where
G: NodeCount
+ EdgeCount
Expand All @@ -458,7 +476,7 @@ where
+ Sync,
G::NodeId: Hash + Eq + Send + Sync,
{
let mut swapper = TokenSwapper::new(graph, mapping, trials, seed, parallel_threshold);
let swapper = TokenSwapper::new(graph, mapping, trials, seed, parallel_threshold);
swapper.map()
}

Expand Down
7 changes: 4 additions & 3 deletions src/token_swapper.rs
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,8 @@ use rustworkx_core::token_swapper;
/// The inputs are a partial ``mapping`` to be implemented in swaps, and the number of ``trials``
/// to perform the mapping. It's minimized over the trials.
///
/// It returns a list of tuples representing the swaps to perform.
/// It returns a list of tuples representing the swaps to perform, where each tuple contains
/// the node identifiers (integers) of the nodes to swap.
///
/// :param PyGraph graph: The input graph
/// :param dict[int: int] mapping: Map of (node, token)
Expand All @@ -44,8 +45,8 @@ use rustworkx_core::token_swapper;
/// the ``RAYON_NUM_THREADS`` environment variable. For example, setting ``RAYON_NUM_THREADS=4``
/// would limit the thread pool to 4 threads.
///
/// :returns: A list of tuples which are the swaps to be applied to the mapping to rearrange
/// the tokens.
/// :returns: A list of tuples containing the node identifiers (integers) of the swaps to be
/// applied to the mapping to rearrange the tokens.
/// :rtype: EdgeList
#[pyfunction]
#[pyo3(
Expand Down