diff --git a/rustworkx-core/src/graph_ext/mod.rs b/rustworkx-core/src/graph_ext/mod.rs index 256d6ac0a5..00bd5f1565 100644 --- a/rustworkx-core/src/graph_ext/mod.rs +++ b/rustworkx-core/src/graph_ext/mod.rs @@ -71,12 +71,14 @@ use petgraph::{EdgeType, Graph}; pub mod contraction; pub mod multigraph; +pub mod substitution; pub use contraction::{ ContractNodesDirected, ContractNodesSimpleDirected, ContractNodesSimpleUndirected, ContractNodesUndirected, }; pub use multigraph::{HasParallelEdgesDirected, HasParallelEdgesUndirected}; +pub use substitution::SubstituteNodeWithGraph; /// A graph whose nodes may be removed. pub trait NodeRemovable: Data { diff --git a/rustworkx-core/src/graph_ext/substitution.rs b/rustworkx-core/src/graph_ext/substitution.rs new file mode 100644 index 0000000000..987de68fc0 --- /dev/null +++ b/rustworkx-core/src/graph_ext/substitution.rs @@ -0,0 +1,315 @@ +// Licensed under the Apache License, Version 2.0 (the "License"); you may +// not use this file except in compliance with the License. You may obtain +// a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, WITHOUT +// WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the +// License for the specific language governing permissions and limitations +// under the License. + +//! This module defines graph traits for node substitution. + +use crate::dictmap::{DictMap, InitWithHasher}; +use petgraph::data::DataMap; +use petgraph::stable_graph; +use petgraph::visit::{ + Data, EdgeRef, GraphBase, IntoEdgeReferences, IntoNodeReferences, NodeCount, NodeRef, +}; +use petgraph::{Directed, Direction}; +use std::convert::Infallible; +use std::error::Error; +use std::fmt::{Debug, Display, Formatter}; +use std::hash::Hash; + +#[derive(Debug)] +pub enum SubstituteNodeWithGraphError { + ReplacementGraphIndexError(N), + CallbackError(E), +} + +impl Display for SubstituteNodeWithGraphError { + fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { + match self { + SubstituteNodeWithGraphError::ReplacementGraphIndexError(n) => { + write!(f, "Node {:?} was not found in the replacement graph.", n) + } + SubstituteNodeWithGraphError::CallbackError(e) => { + write!(f, "Callback failed with: {}", e) + } + } + } +} + +impl Error for SubstituteNodeWithGraphError {} + +pub type SubstitutionResult = Result>; + +pub struct NoCallback; + +pub trait NodeFilter { + type CallbackError; + fn filter( + &mut self, + graph: &G, + node: G::NodeId, + ) -> SubstitutionResult; +} + +impl NodeFilter for NoCallback { + type CallbackError = Infallible; + #[inline] + fn filter( + &mut self, + _graph: &G, + _node: G::NodeId, + ) -> SubstitutionResult { + Ok(true) + } +} + +impl NodeFilter for F +where + G: GraphBase + DataMap, + F: FnMut(&G::NodeWeight) -> Result, +{ + type CallbackError = E; + #[inline] + fn filter( + &mut self, + graph: &G, + node: G::NodeId, + ) -> SubstitutionResult { + if let Some(x) = graph.node_weight(node) { + self(x).map_err(|e| SubstituteNodeWithGraphError::CallbackError(e)) + } else { + Ok(false) + } + } +} + +pub trait EdgeWeightMapper { + type CallbackError; + type MappedWeight; + fn map( + &mut self, + graph: &G, + edge: G::EdgeId, + ) -> SubstitutionResult; +} + +impl EdgeWeightMapper for NoCallback +where + G::EdgeWeight: Clone, +{ + type CallbackError = Infallible; + type MappedWeight = G::EdgeWeight; + #[inline] + fn map( + &mut self, + graph: &G, + edge: G::EdgeId, + ) -> SubstitutionResult { + Ok(graph.edge_weight(edge).unwrap().clone()) + } +} + +impl EdgeWeightMapper for F +where + G: GraphBase + DataMap, + F: FnMut(&G::EdgeWeight) -> Result, +{ + type CallbackError = E; + type MappedWeight = EW; + + #[inline] + fn map( + &mut self, + graph: &G, + edge: G::EdgeId, + ) -> SubstitutionResult { + if let Some(x) = graph.edge_weight(edge) { + self(x).map_err(|e| SubstituteNodeWithGraphError::CallbackError(e)) + } else { + panic!("Edge MUST exist in graph.") + } + } +} + +pub trait SubstituteNodeWithGraph: DataMap { + /// Substitute a node with a Graph. + /// + /// The nodes and edges of Graph `other` are cloned into this + /// graph and connected to its preexisting nodes using an edge mapping + /// function, `edge_map_fn`. + /// + /// The specified `edge_map_fn` is called for each of the edges between + /// the `node` being replaced and the rest of the graph and is expected + /// to return an index in `other` that the edge should be connected + /// to after the replacement, i.e. the node in `graph` that the edge + /// should be connected to once `node` is gone. It is also acceptable + /// for `edge_map_fn` to return `None`, in which case the edge is + /// ignored and will be dropped. + /// + /// It accepts the following three arguments: + /// - The [Direction], which designates whether the original edge was + /// incoming or outgoing to `node`. + /// - The [Self::NodeId] of the _other_ node of the original edge (i.e. the + /// one that isn't `node`). + /// - A reference to the edge weight of the original edge. + /// + /// An optional `node_filter` can be provided to ignore nodes in `other` that + /// should not be copied into this graph. This parameter accepts implementations + /// of the trait [NodeFilter], which has a blanket implementation for callables + /// which are `FnMut(&G1::NodeWeight) -> Result`, i.e. functions which + /// take a reference to a node weight in `other` and return a boolean to indicate + /// if the node corresponding to this weight should be included or not. To disable + /// filtering, simply provide [NoCallback]. + /// + /// A _sometimes_ optional `edge_weight_map` can be provided to transform edge weights from + /// the source graph `other` into weights of this graph. This parameter accepts + /// implementations of the trait [EdgeWeightMapper], which has a blanket + /// implementation for callables which are + /// `F: FnMut(&G1::EdgeWeight) -> Result`, + /// i.e. functions which take a reference to an edge weight in `graph` and return + /// an owned weight typed for this graph. An `edge_weight_map` must be provided + /// when `other` uses a different type for its edge weights, but can otherwise + /// be specified as [NoCallback] to disable mapping. + /// + /// This method returns a mapping of nodes in `other` to the copied node in + /// this graph. + #[allow(clippy::type_complexity)] + fn substitute_node_with_graph( + &mut self, + node: Self::NodeId, + other: &G, + edge_map_fn: EM, + node_filter: NF, + edge_weight_map: ET, + ) -> SubstitutionResult, G::NodeId, E> + where + G: Data + DataMap + NodeCount, + G::NodeId: Debug + Hash + Eq, + G::NodeWeight: Clone, + for<'a> &'a G: GraphBase + + Data + + IntoNodeReferences + + IntoEdgeReferences, + EM: FnMut(Direction, Self::NodeId, &Self::EdgeWeight) -> Result, E>, + NF: NodeFilter, + ET: EdgeWeightMapper; +} + +impl SubstituteNodeWithGraph for stable_graph::StableGraph +where + Ix: stable_graph::IndexType, + E: Clone, +{ + fn substitute_node_with_graph( + &mut self, + node: Self::NodeId, + other: &G, + mut edge_map_fn: EM, + mut node_filter: NF, + mut edge_weight_map: ET, + ) -> SubstitutionResult, G::NodeId, ER> + where + G: Data + DataMap + NodeCount, + G::NodeId: Debug + Hash + Eq, + G::NodeWeight: Clone, + for<'a> &'a G: GraphBase + + Data + + IntoNodeReferences + + IntoEdgeReferences, + EM: FnMut(Direction, Self::NodeId, &Self::EdgeWeight) -> Result, ER>, + NF: NodeFilter, + ET: EdgeWeightMapper, + { + let node_index = node; + if self.node_weight(node_index).is_none() { + panic!("Node `node` MUST be present in graph."); + } + // Copy nodes from other to self + let mut out_map: DictMap = + DictMap::with_capacity(other.node_count()); + for node in other.node_references() { + if !node_filter.filter(other, node.id())? { + continue; + } + let new_index = self.add_node(node.weight().clone()); + out_map.insert(node.id(), new_index); + } + // If no nodes are copied bail here since there is nothing left + // to do. + if out_map.is_empty() { + self.remove_node(node_index); + // Return a new empty map to clear allocation from out_map + return Ok(DictMap::new()); + } + // Copy edges from other to self + for edge in other.edge_references().filter(|edge| { + out_map.contains_key(&edge.target()) && out_map.contains_key(&edge.source()) + }) { + self.add_edge( + out_map[&edge.source()], + out_map[&edge.target()], + edge_weight_map.map(other, edge.id())?, + ); + } + // Add edges to/from node to nodes in other + let in_edges: Vec> = self + .edges_directed(node_index, petgraph::Direction::Incoming) + .map(|edge| { + let Some(target_in_other) = + edge_map_fn(Direction::Incoming, edge.source(), edge.weight()) + .map_err(|e| SubstituteNodeWithGraphError::CallbackError(e))? + else { + return Ok(None); + }; + let Some(target_in_self) = out_map.get(&target_in_other) else { + return Err(SubstituteNodeWithGraphError::ReplacementGraphIndexError( + target_in_other, + )); + }; + Ok(Some(( + edge.source(), + *target_in_self, + edge.weight().clone(), + ))) + }) + .collect::>()?; + let out_edges: Vec> = self + .edges_directed(node_index, petgraph::Direction::Outgoing) + .map(|edge| { + let Some(source_in_other) = + edge_map_fn(Direction::Outgoing, edge.target(), edge.weight()) + .map_err(|e| SubstituteNodeWithGraphError::CallbackError(e))? + else { + return Ok(None); + }; + let Some(source_in_self) = out_map.get(&source_in_other) else { + return Err(SubstituteNodeWithGraphError::ReplacementGraphIndexError( + source_in_other, + )); + }; + Ok(Some(( + *source_in_self, + edge.target(), + edge.weight().clone(), + ))) + }) + .collect::>()?; + for (source, target, weight) in in_edges + .into_iter() + .flatten() + .chain(out_edges.into_iter().flatten()) + { + self.add_edge(source, target, weight); + } + // Remove node + self.remove_node(node_index); + Ok(out_map) + } +} diff --git a/src/digraph.rs b/src/digraph.rs index b15b3dea06..88e771e2a1 100644 --- a/src/digraph.rs +++ b/src/digraph.rs @@ -2550,6 +2550,10 @@ impl PyDiGraph { /// when iterated over (although the same object will have a consistent /// order when iterated over multiple times). /// + /// If the replacement graph ``other`` contains cycles or is not a + /// multigraph, then this graph will also contain cycles or become + /// a multigraph after the substitution. + /// #[pyo3( text_signature = "(self, node, other, edge_map_fn, /, node_filter=None, edge_weight_map=None)" )] @@ -2561,109 +2565,61 @@ impl PyDiGraph { edge_map_fn: PyObject, node_filter: Option, edge_weight_map: Option, - ) -> PyResult { - let weight_map_fn = |obj: &PyObject, weight_fn: &Option| -> PyResult { - match weight_fn { - Some(weight_fn) => weight_fn.call1(py, (obj,)), - None => Ok(obj.clone_ref(py)), - } - }; - let map_fn = |source: usize, target: usize, weight: &PyObject| -> PyResult> { - let res = edge_map_fn.call1(py, (source, target, weight))?; - res.extract(py) + ) -> RxPyResult { + let node_index: NodeIndex = NodeIndex::new(node); + if self.graph.node_weight(node_index).is_none() { + return Err(PyIndexError::new_err(format!( + "Specified node {} is not in this graph", + node + )) + .into()); + } + + let edge_map_fn = |direction: Direction, + node: NodeIndex, + weight: &PyObject| + -> PyResult> { + let edge = match direction { + Direction::Incoming => (node.index(), node_index.index(), weight), + Direction::Outgoing => (node_index.index(), node.index(), weight), + }; + let res = edge_map_fn.call1(py, edge)?; + let index: Option = res.extract(py)?; + Ok(index.map(|i| NodeIndex::new(i))) }; - let filter_fn = |obj: &PyObject, filter_fn: &Option| -> PyResult { - match filter_fn { - Some(filter) => { + + let node_filter = move |obj: &PyObject| -> PyResult { + match node_filter { + Some(ref filter) => { let res = filter.call1(py, (obj,))?; res.extract(py) } None => Ok(true), } }; - let node_index: NodeIndex = NodeIndex::new(node); - if self.graph.node_weight(node_index).is_none() { - return Err(PyIndexError::new_err(format!( - "Specified node {} is not in this graph", - node - ))); - } - // Copy nodes from other to self - let mut out_map: DictMap = DictMap::with_capacity(other.node_count()); - for node in other.graph.node_indices() { - let node_weight = other.graph[node].clone_ref(py); - if !filter_fn(&node_weight, &node_filter)? { - continue; + + let weight_map_fn = move |obj: &PyObject| -> PyResult { + match edge_weight_map { + Some(ref weight_fn) => weight_fn.call1(py, (obj,)), + None => Ok(obj.clone_ref(py)), } - let new_index = self.graph.add_node(node_weight); - out_map.insert(node.index(), new_index.index()); - } - // If no nodes are copied bail here since there is nothing left - // to do. - if out_map.is_empty() { - self.remove_node(node_index.index())?; - // Return a new empty map to clear allocation from out_map - return Ok(NodeMap { - node_map: DictMap::new(), - }); - } - // Copy edges from other to self - for edge in other.graph.edge_references().filter(|edge| { - out_map.contains_key(&edge.target().index()) - && out_map.contains_key(&edge.source().index()) - }) { - self._add_edge( - NodeIndex::new(out_map[&edge.source().index()]), - NodeIndex::new(out_map[&edge.target().index()]), - weight_map_fn(edge.weight(), &edge_weight_map)?, - )?; - } - // Add edges to/from node to nodes in other - let in_edges: Vec<(NodeIndex, NodeIndex, PyObject)> = self - .graph - .edges_directed(node_index, petgraph::Direction::Incoming) - .map(|edge| (edge.source(), edge.target(), edge.weight().clone_ref(py))) - .collect(); - let out_edges: Vec<(NodeIndex, NodeIndex, PyObject)> = self - .graph - .edges_directed(node_index, petgraph::Direction::Outgoing) - .map(|edge| (edge.source(), edge.target(), edge.weight().clone_ref(py))) - .collect(); - for (source, target, weight) in in_edges { - let old_index = map_fn(source.index(), target.index(), &weight)?; - let target_out = match old_index { - Some(old_index) => match out_map.get(&old_index) { - Some(new_index) => NodeIndex::new(*new_index), - None => { - return Err(PyIndexError::new_err(format!( - "No mapped index {} found", - old_index - ))) - } - }, - None => continue, - }; - self._add_edge(source, target_out, weight)?; - } - for (source, target, weight) in out_edges { - let old_index = map_fn(source.index(), target.index(), &weight)?; - let source_out = match old_index { - Some(old_index) => match out_map.get(&old_index) { - Some(new_index) => NodeIndex::new(*new_index), - None => { - return Err(PyIndexError::new_err(format!( - "No mapped index {} found", - old_index - ))) - } - }, - None => continue, - }; - self._add_edge(source_out, target, weight)?; - } - // Remove node - self.remove_node(node_index.index())?; - Ok(NodeMap { node_map: out_map }) + }; + + let out_map = self.graph.substitute_node_with_graph( + node_index, + &other.graph, + edge_map_fn, + node_filter, + weight_map_fn, + )?; + + self.node_removed = true; + Ok(NodeMap { + node_map: out_map + .into_iter() + .map(|(k, v)| (k.index(), v.index())) + .collect(), + }) } /// Substitute a set of nodes with a single new node. diff --git a/src/lib.rs b/src/lib.rs index 79f183462f..8866523a7b 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -70,8 +70,8 @@ use hashbrown::HashMap; use numpy::Complex64; use pyo3::create_exception; -use pyo3::exceptions::PyException; use pyo3::exceptions::PyValueError; +use pyo3::exceptions::{PyException, PyIndexError}; use pyo3::import_exception; use pyo3::prelude::*; use pyo3::wrap_pyfunction; @@ -88,9 +88,11 @@ use petgraph::EdgeType; use rustworkx_core::dag_algo::TopologicalSortError; use std::convert::TryFrom; +use std::fmt::Debug; use rustworkx_core::dictmap::*; use rustworkx_core::err::{ContractError, ContractSimpleError}; +use rustworkx_core::graph_ext::substitution::SubstituteNodeWithGraphError; /// An ergonomic error type used to map Rustworkx core errors to /// [PyErr] automatically, via [From::from]. @@ -144,6 +146,19 @@ impl From> for RxPyErr { } } +impl From> for RxPyErr { + fn from(value: SubstituteNodeWithGraphError) -> Self { + RxPyErr { + pyerr: match value { + SubstituteNodeWithGraphError::CallbackError(e) => e, + SubstituteNodeWithGraphError::ReplacementGraphIndexError(_) => { + PyIndexError::new_err(format!("{}", value)) + } + }, + } + } +} + impl From> for RxPyErr { fn from(value: TopologicalSortError) -> Self { RxPyErr {