Skip to content

Commit 3d6a01d

Browse files
JoOkumaSILIZ4
authored andcommitted
New method subgraph_with_nodemap (Qiskit#1461)
* adding initial subgraph_with_nodemap implementation * adding digraph with nodemap * fix style * fixing return type annotation * adding subgraph reference sphinx note * add release note * trying to fix release notes
1 parent de87d83 commit 3d6a01d

File tree

6 files changed

+241
-16
lines changed

6 files changed

+241
-16
lines changed
Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,7 @@
1+
---
2+
features:
3+
- |
4+
Added new :meth:`~.PyGraph.subgraph_with_nodemap` and :meth:`~.PyDiGraph.subgraph_with_nodemap`
5+
methods to the :class:`~.PyGraph` and :class:`~.PyDiGraph` classes. These methods extend the
6+
existing :meth:`~.PyGraph.subgraph` method by returning a :class:`~.NodeMap` object that maps
7+
the nodes of the subgraph to the nodes of the original graph.

rustworkx/rustworkx.pyi

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1320,6 +1320,9 @@ class PyGraph(Generic[_S, _T]):
13201320
def remove_node(self, node: int, /) -> None: ...
13211321
def remove_nodes_from(self, index_list: Iterable[int], /) -> None: ...
13221322
def subgraph(self, nodes: Sequence[int], /, preserve_attrs: bool = ...) -> PyGraph[_S, _T]: ...
1323+
def subgraph_with_nodemap(
1324+
self, nodes: Sequence[int], /, preserve_attrs: bool = ...
1325+
) -> tuple[PyGraph[_S, _T], NodeMap]: ...
13231326
def substitute_node_with_subgraph(
13241327
self,
13251328
node: int,
@@ -1528,6 +1531,9 @@ class PyDiGraph(Generic[_S, _T]):
15281531
def subgraph(
15291532
self, nodes: Sequence[int], /, preserve_attrs: bool = ...
15301533
) -> PyDiGraph[_S, _T]: ...
1534+
def subgraph_with_nodemap(
1535+
self, nodes: Sequence[int], /, preserve_attrs: bool = ...
1536+
) -> tuple[PyDiGraph[_S, _T], NodeMap]: ...
15311537
def substitute_node_with_subgraph(
15321538
self,
15331539
node: int,

src/digraph.rs

Lines changed: 55 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -3013,7 +3013,13 @@ impl PyDiGraph {
30133013
Ok(res.index())
30143014
}
30153015

3016-
/// Return a new PyDiGraph object for a subgraph of this graph
3016+
/// Return a new PyDiGraph object for a subgraph of this graph and a NodeMap
3017+
/// object that maps the nodes of the subgraph to the nodes of the original graph.
3018+
///
3019+
/// .. note::
3020+
/// This method is identical to :meth:`.subgraph()` but includes a
3021+
/// NodeMap object that maps the nodes of the subgraph to the nodes of
3022+
/// the original graph.
30173023
///
30183024
/// :param list[int] nodes: A list of node indices to generate the subgraph
30193025
/// from. If a node index is included that is not present in the graph
@@ -3023,24 +3029,33 @@ impl PyDiGraph {
30233029
/// subgraph. By default this is set to ``False`` and the :attr:`~.PyDiGraph.attrs`
30243030
/// attribute will be ``None`` in the subgraph.
30253031
///
3026-
/// :returns: A new PyDiGraph object representing a subgraph of this graph.
3032+
/// :returns: A tuple containing a new PyDiGraph object representing a subgraph of this graph
3033+
/// and a NodeMap object that maps the nodes of the subgraph to the nodes of the original graph.
30273034
/// It is worth noting that node and edge weight/data payloads are
30283035
/// passed by reference so if you update (not replace) an object used
30293036
/// as the weight in graph or the subgraph it will also be updated in
3030-
/// the other. Node and edge the indices will be recreated for the subgraph for compactness.
3031-
/// Therefore, do not access data using the original graph's indices.
3032-
/// :rtype: PyGraph
3037+
/// the other.
3038+
/// :rtype: tuple[PyDiGraph, NodeMap]
30333039
///
3034-
#[pyo3(signature=(nodes, preserve_attrs=false),text_signature = "(self, nodes, /, preserve_attrs=False)")]
3035-
pub fn subgraph(&self, py: Python, nodes: Vec<usize>, preserve_attrs: bool) -> PyDiGraph {
3040+
#[pyo3(signature=(nodes, preserve_attrs=false), text_signature = "(self, nodes, /, preserve_attrs=False)")]
3041+
pub fn subgraph_with_nodemap(
3042+
&self,
3043+
py: Python,
3044+
nodes: Vec<usize>,
3045+
preserve_attrs: bool,
3046+
) -> (PyDiGraph, NodeMap) {
30363047
let node_set: HashSet<usize> = nodes.iter().cloned().collect();
3048+
// mapping from original node index to new node index
30373049
let mut node_map: HashMap<NodeIndex, NodeIndex> = HashMap::with_capacity(nodes.len());
3050+
// mapping from new node index to original node index
3051+
let mut node_dict: DictMap<usize, usize> = DictMap::with_capacity(nodes.len());
30383052
let node_filter = |node: NodeIndex| -> bool { node_set.contains(&node.index()) };
30393053
let mut out_graph = StablePyGraph::<Directed>::new();
30403054
let filtered = NodeFiltered(&self.graph, node_filter);
30413055
for node in filtered.node_references() {
30423056
let new_node = out_graph.add_node(node.1.clone_ref(py));
30433057
node_map.insert(node.0, new_node);
3058+
node_dict.insert(new_node.index(), node.0.index());
30443059
}
30453060
for edge in filtered.edge_references() {
30463061
let new_source = *node_map.get(&edge.source()).unwrap();
@@ -3052,14 +3067,45 @@ impl PyDiGraph {
30523067
} else {
30533068
py.None()
30543069
};
3055-
PyDiGraph {
3070+
let node_map = NodeMap {
3071+
node_map: node_dict,
3072+
};
3073+
let subgraph = PyDiGraph {
30563074
graph: out_graph,
30573075
node_removed: false,
30583076
cycle_state: algo::DfsSpace::default(),
30593077
check_cycle: self.check_cycle,
30603078
multigraph: self.multigraph,
30613079
attrs,
3062-
}
3080+
};
3081+
(subgraph, node_map)
3082+
}
3083+
3084+
/// Return a new PyDiGraph object for a subgraph of this graph.
3085+
///
3086+
/// .. note::
3087+
/// To return a NodeMap object that maps the nodes of the subgraph to
3088+
/// the nodes of the original graph, use :meth:`.subgraph_with_nodemap()`.
3089+
///
3090+
/// :param list[int] nodes: A list of node indices to generate the subgraph
3091+
/// from. If a node index is included that is not present in the graph
3092+
/// it will silently be ignored.
3093+
/// :param bool preserve_attrs: If set to the True the attributes of the PyDiGraph
3094+
/// will be copied by reference to be the attributes of the output
3095+
/// subgraph. By default this is set to False and the :attr:`~.PyDiGraph.attrs`
3096+
/// attribute will be ``None`` in the subgraph.
3097+
///
3098+
/// :returns: A new PyDiGraph object representing a subgraph of this graph.
3099+
/// It is worth noting that node and edge weight/data payloads are
3100+
/// passed by reference so if you update (not replace) an object used
3101+
/// as the weight in graph or the subgraph it will also be updated in
3102+
/// the other.
3103+
/// :rtype: PyDiGraph
3104+
///
3105+
#[pyo3(signature=(nodes, preserve_attrs=false), text_signature = "(self, nodes, /, preserve_attrs=False)")]
3106+
pub fn subgraph(&self, py: Python, nodes: Vec<usize>, preserve_attrs: bool) -> PyDiGraph {
3107+
let (subgraph, _) = self.subgraph_with_nodemap(py, nodes, preserve_attrs);
3108+
subgraph
30633109
}
30643110

30653111
/// Return a new PyDiGraph object for an edge induced subgraph of this graph

src/graph.rs

Lines changed: 53 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1907,7 +1907,13 @@ impl PyGraph {
19071907
Ok(res.index())
19081908
}
19091909

1910-
/// Return a new PyGraph object for a subgraph of this graph
1910+
/// Return a new PyGraph object for a subgraph of this graph and a NodeMap
1911+
/// object that maps the nodes of the subgraph to the nodes of the original graph.
1912+
///
1913+
/// .. note::
1914+
/// This method is identical to :meth:`.subgraph()` but includes a
1915+
/// NodeMap object that maps the nodes of the subgraph to the nodes of
1916+
/// the original graph.
19111917
///
19121918
/// :param list[int] nodes: A list of node indices to generate the subgraph
19131919
/// from. If a node index is included that is not present in the graph
@@ -1917,23 +1923,33 @@ impl PyGraph {
19171923
/// subgraph. By default this is set to False and the :attr:`~.PyGraph.attrs`
19181924
/// attribute will be ``None`` in the subgraph.
19191925
///
1920-
/// :returns: A new PyGraph object representing a subgraph of this graph.
1926+
/// :returns: A tuple containing a new PyGraph object representing a subgraph of this graph
1927+
/// and a NodeMap object that maps the nodes of the subgraph to the nodes of the original graph.
19211928
/// It is worth noting that node and edge weight/data payloads are
19221929
/// passed by reference so if you update (not replace) an object used
19231930
/// as the weight in graph or the subgraph it will also be updated in
19241931
/// the other.
1925-
/// :rtype: PyGraph
1932+
/// :rtype: tuple[PyGraph, NodeMap]
19261933
///
19271934
#[pyo3(signature=(nodes, preserve_attrs=false), text_signature = "(self, nodes, /, preserve_attrs=False)")]
1928-
pub fn subgraph(&self, py: Python, nodes: Vec<usize>, preserve_attrs: bool) -> PyGraph {
1935+
pub fn subgraph_with_nodemap(
1936+
&self,
1937+
py: Python,
1938+
nodes: Vec<usize>,
1939+
preserve_attrs: bool,
1940+
) -> (PyGraph, NodeMap) {
19291941
let node_set: HashSet<usize> = nodes.iter().cloned().collect();
1942+
// mapping from original node index to new node index
19301943
let mut node_map: HashMap<NodeIndex, NodeIndex> = HashMap::with_capacity(nodes.len());
1944+
// mapping from new node index to original node index
1945+
let mut node_dict: DictMap<usize, usize> = DictMap::with_capacity(nodes.len());
19311946
let node_filter = |node: NodeIndex| -> bool { node_set.contains(&node.index()) };
19321947
let mut out_graph = StablePyGraph::<Undirected>::default();
19331948
let filtered = NodeFiltered(&self.graph, node_filter);
19341949
for node in filtered.node_references() {
19351950
let new_node = out_graph.add_node(node.1.clone_ref(py));
19361951
node_map.insert(node.0, new_node);
1952+
node_dict.insert(new_node.index(), node.0.index());
19371953
}
19381954
for edge in filtered.edge_references() {
19391955
let new_source = *node_map.get(&edge.source()).unwrap();
@@ -1945,12 +1961,43 @@ impl PyGraph {
19451961
} else {
19461962
py.None()
19471963
};
1948-
PyGraph {
1964+
let node_map = NodeMap {
1965+
node_map: node_dict,
1966+
};
1967+
let subgraph = PyGraph {
19491968
graph: out_graph,
19501969
node_removed: false,
19511970
multigraph: self.multigraph,
19521971
attrs,
1953-
}
1972+
};
1973+
(subgraph, node_map)
1974+
}
1975+
1976+
/// Return a new PyGraph object for a subgraph of this graph.
1977+
///
1978+
/// .. note::
1979+
/// To return a NodeMap object that maps the nodes of the subgraph to
1980+
/// the nodes of the original graph, use :meth:`.subgraph_with_nodemap()`.
1981+
///
1982+
/// :param list[int] nodes: A list of node indices to generate the subgraph
1983+
/// from. If a node index is included that is not present in the graph
1984+
/// it will silently be ignored.
1985+
/// :param bool preserve_attrs: If set to the True the attributes of the PyGraph
1986+
/// will be copied by reference to be the attributes of the output
1987+
/// subgraph. By default this is set to False and the :attr:`~.PyGraph.attrs`
1988+
/// attribute will be ``None`` in the subgraph.
1989+
///
1990+
/// :returns: A new PyGraph object representing a subgraph of this graph.
1991+
/// It is worth noting that node and edge weight/data payloads are
1992+
/// passed by reference so if you update (not replace) an object used
1993+
/// as the weight in graph or the subgraph it will also be updated in
1994+
/// the other.
1995+
/// :rtype: PyGraph
1996+
///
1997+
#[pyo3(signature=(nodes, preserve_attrs=false), text_signature = "(self, nodes, /, preserve_attrs=False)")]
1998+
pub fn subgraph(&self, py: Python, nodes: Vec<usize>, preserve_attrs: bool) -> PyGraph {
1999+
let (subgraph, _) = self.subgraph_with_nodemap(py, nodes, preserve_attrs);
2000+
subgraph
19542001
}
19552002

19562003
/// Return a new PyGraph object for an edge induced subgraph of this graph

tests/digraph/test_subgraph.py

Lines changed: 60 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -138,7 +138,7 @@ def test_edge_subgraph_non_edge(self):
138138
self.assertEqual([(0, 1, 2), (0, 1, 3), (1, 2, 4)], subgraph.weighted_edge_list())
139139

140140
def test_preserve_attrs(self):
141-
graph = rustworkx.PyGraph(attrs="My attribute")
141+
graph = rustworkx.PyDiGraph(attrs="My attribute")
142142
graph.add_node("a")
143143
graph.add_node("b")
144144
graph.add_node("c")
@@ -148,3 +148,62 @@ def test_preserve_attrs(self):
148148
self.assertEqual([(0, 1, 4)], subgraph.weighted_edge_list())
149149
self.assertEqual(["b", "d"], subgraph.nodes())
150150
self.assertEqual(graph.attrs, subgraph.attrs)
151+
152+
def test_subgraph_with_nodemap(self):
153+
graph = rustworkx.PyDiGraph()
154+
graph.add_nodes_from(list(range(6)))
155+
graph.add_edges_from([(0, 1, 1), (1, 2, 2), (2, 3, 3), (3, 4, 4), (4, 5, 5)])
156+
157+
# Test basic subgraph with node mapping
158+
subgraph, node_map = graph.subgraph_with_nodemap([0, 2, 4])
159+
self.assertEqual([], subgraph.weighted_edge_list()) # No edges between disconnected nodes
160+
self.assertEqual([0, 2, 4], subgraph.nodes())
161+
self.assertEqual(dict(node_map), {0: 0, 1: 2, 2: 4})
162+
163+
# Test with connected nodes
164+
subgraph2, node_map2 = graph.subgraph_with_nodemap([1, 2, 3])
165+
self.assertEqual([(0, 1, 2), (1, 2, 3)], subgraph2.weighted_edge_list())
166+
self.assertEqual([1, 2, 3], subgraph2.nodes())
167+
self.assertEqual(dict(node_map2), {0: 1, 1: 2, 2: 3})
168+
169+
def test_subgraph_with_nodemap_edge_cases(self):
170+
graph = rustworkx.PyDiGraph()
171+
graph.add_nodes_from(["a", "b", "c"])
172+
graph.add_edges_from([(0, 1, 1), (1, 2, 2)])
173+
174+
# Test empty node list
175+
subgraph, node_map = graph.subgraph_with_nodemap([])
176+
self.assertEqual([], subgraph.weighted_edge_list())
177+
self.assertEqual(0, len(subgraph))
178+
self.assertEqual(dict(node_map), {})
179+
180+
# Test invalid node indices (should be silently ignored)
181+
subgraph, node_map = graph.subgraph_with_nodemap([42, 100])
182+
self.assertEqual([], subgraph.weighted_edge_list())
183+
self.assertEqual(0, len(subgraph))
184+
self.assertEqual(dict(node_map), {})
185+
186+
# Test single node (no edges in subgraph)
187+
subgraph, node_map = graph.subgraph_with_nodemap([1])
188+
self.assertEqual([], subgraph.weighted_edge_list())
189+
self.assertEqual(["b"], subgraph.nodes())
190+
self.assertEqual(dict(node_map), {0: 1})
191+
192+
# Test all nodes
193+
subgraph, node_map = graph.subgraph_with_nodemap([0, 1, 2])
194+
self.assertEqual([(0, 1, 1), (1, 2, 2)], subgraph.weighted_edge_list())
195+
self.assertEqual(["a", "b", "c"], subgraph.nodes())
196+
self.assertEqual(dict(node_map), {0: 0, 1: 1, 2: 2})
197+
198+
def test_subgraph_with_nodemap_preserve_attrs(self):
199+
graph = rustworkx.PyDiGraph(attrs="test_attrs")
200+
graph.add_nodes_from(["a", "b", "c"])
201+
graph.add_edges_from([(0, 1, 1), (1, 2, 2)])
202+
203+
# Test preserve_attrs=False (default)
204+
subgraph, node_map = graph.subgraph_with_nodemap([0, 1])
205+
self.assertIsNone(subgraph.attrs)
206+
207+
# Test preserve_attrs=True
208+
subgraph2, node_map2 = graph.subgraph_with_nodemap([0, 1], preserve_attrs=True)
209+
self.assertEqual(graph.attrs, subgraph2.attrs)

tests/graph/test_subgraph.py

Lines changed: 60 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -148,3 +148,63 @@ def test_preserve_attrs(self):
148148
self.assertEqual([(0, 1, 4)], subgraph.weighted_edge_list())
149149
self.assertEqual(["b", "d"], subgraph.nodes())
150150
self.assertEqual(graph.attrs, subgraph.attrs)
151+
152+
def test_subgraph_with_nodemap(self):
153+
graph = rustworkx.PyGraph()
154+
graph.add_nodes_from(list(range(6)))
155+
# Create a more complex graph: 0-1-2-3-4-5 with additional edges 1-4, 2-5
156+
graph.extend_from_edge_list([(0, 1), (1, 2), (2, 3), (3, 4), (4, 5), (1, 4), (2, 5)])
157+
158+
# Test subset with multiple edges
159+
subgraph, node_map = graph.subgraph_with_nodemap([1, 2, 4])
160+
self.assertEqual(set(subgraph.node_indices()), {0, 1, 2})
161+
162+
# Check that we have the correct edges: 1-2 and 1-4 from original
163+
edge_list = list(subgraph.edge_list())
164+
self.assertEqual(len(edge_list), 2)
165+
self.assertIn((0, 1), edge_list) # 1-2 in original
166+
self.assertIn((0, 2), edge_list) # 1-4 in original
167+
168+
# Verify node mapping
169+
node_map_dict = dict(node_map)
170+
self.assertEqual(len(node_map_dict), 3)
171+
self.assertEqual(set(node_map_dict.values()), {1, 2, 4})
172+
173+
# Test disconnected nodes (nodes with no edges between them)
174+
graph2 = rustworkx.PyGraph()
175+
graph2.add_nodes_from(["a", "b", "c", "d", "e"])
176+
graph2.add_edges_from([(0, 1, 1), (2, 3, 2)]) # Two separate components
177+
subgraph, node_map = graph2.subgraph_with_nodemap([0, 2, 4])
178+
179+
self.assertEqual([], subgraph.weighted_edge_list()) # No edges between selected nodes
180+
self.assertEqual(["a", "c", "e"], subgraph.nodes())
181+
self.assertEqual(dict(node_map), {0: 0, 1: 2, 2: 4})
182+
183+
def test_subgraph_with_nodemap_edge_cases(self):
184+
graph = rustworkx.PyGraph()
185+
graph.add_nodes_from(["a", "b", "c"])
186+
graph.add_edges_from([(0, 1, 1), (1, 2, 2)])
187+
188+
# Test empty node list
189+
subgraph, node_map = graph.subgraph_with_nodemap([])
190+
self.assertEqual([], subgraph.weighted_edge_list())
191+
self.assertEqual(0, len(subgraph))
192+
self.assertEqual(dict(node_map), {})
193+
194+
# Test invalid node indices (should be silently ignored)
195+
subgraph, node_map = graph.subgraph_with_nodemap([42, 100])
196+
self.assertEqual([], subgraph.weighted_edge_list())
197+
self.assertEqual(0, len(subgraph))
198+
self.assertEqual(dict(node_map), {})
199+
200+
# Test single node (no edges in subgraph)
201+
subgraph, node_map = graph.subgraph_with_nodemap([1])
202+
self.assertEqual([], subgraph.weighted_edge_list())
203+
self.assertEqual(["b"], subgraph.nodes())
204+
self.assertEqual(dict(node_map), {0: 1})
205+
206+
# Test all nodes
207+
subgraph, node_map = graph.subgraph_with_nodemap([0, 1, 2])
208+
self.assertEqual([(0, 1, 1), (1, 2, 2)], subgraph.weighted_edge_list())
209+
self.assertEqual(["a", "b", "c"], subgraph.nodes())
210+
self.assertEqual(dict(node_map), {0: 0, 1: 1, 2: 2})

0 commit comments

Comments
 (0)