@@ -138,7 +138,7 @@ def test_edge_subgraph_non_edge(self):
138138 self .assertEqual ([(0 , 1 , 2 ), (0 , 1 , 3 ), (1 , 2 , 4 )], subgraph .weighted_edge_list ())
139139
140140 def test_preserve_attrs (self ):
141- graph = rustworkx .PyGraph (attrs = "My attribute" )
141+ graph = rustworkx .PyDiGraph (attrs = "My attribute" )
142142 graph .add_node ("a" )
143143 graph .add_node ("b" )
144144 graph .add_node ("c" )
@@ -148,3 +148,62 @@ def test_preserve_attrs(self):
148148 self .assertEqual ([(0 , 1 , 4 )], subgraph .weighted_edge_list ())
149149 self .assertEqual (["b" , "d" ], subgraph .nodes ())
150150 self .assertEqual (graph .attrs , subgraph .attrs )
151+
152+ def test_subgraph_with_nodemap (self ):
153+ graph = rustworkx .PyDiGraph ()
154+ graph .add_nodes_from (list (range (6 )))
155+ graph .add_edges_from ([(0 , 1 , 1 ), (1 , 2 , 2 ), (2 , 3 , 3 ), (3 , 4 , 4 ), (4 , 5 , 5 )])
156+
157+ # Test basic subgraph with node mapping
158+ subgraph , node_map = graph .subgraph_with_nodemap ([0 , 2 , 4 ])
159+ self .assertEqual ([], subgraph .weighted_edge_list ()) # No edges between disconnected nodes
160+ self .assertEqual ([0 , 2 , 4 ], subgraph .nodes ())
161+ self .assertEqual (dict (node_map ), {0 : 0 , 1 : 2 , 2 : 4 })
162+
163+ # Test with connected nodes
164+ subgraph2 , node_map2 = graph .subgraph_with_nodemap ([1 , 2 , 3 ])
165+ self .assertEqual ([(0 , 1 , 2 ), (1 , 2 , 3 )], subgraph2 .weighted_edge_list ())
166+ self .assertEqual ([1 , 2 , 3 ], subgraph2 .nodes ())
167+ self .assertEqual (dict (node_map2 ), {0 : 1 , 1 : 2 , 2 : 3 })
168+
169+ def test_subgraph_with_nodemap_edge_cases (self ):
170+ graph = rustworkx .PyDiGraph ()
171+ graph .add_nodes_from (["a" , "b" , "c" ])
172+ graph .add_edges_from ([(0 , 1 , 1 ), (1 , 2 , 2 )])
173+
174+ # Test empty node list
175+ subgraph , node_map = graph .subgraph_with_nodemap ([])
176+ self .assertEqual ([], subgraph .weighted_edge_list ())
177+ self .assertEqual (0 , len (subgraph ))
178+ self .assertEqual (dict (node_map ), {})
179+
180+ # Test invalid node indices (should be silently ignored)
181+ subgraph , node_map = graph .subgraph_with_nodemap ([42 , 100 ])
182+ self .assertEqual ([], subgraph .weighted_edge_list ())
183+ self .assertEqual (0 , len (subgraph ))
184+ self .assertEqual (dict (node_map ), {})
185+
186+ # Test single node (no edges in subgraph)
187+ subgraph , node_map = graph .subgraph_with_nodemap ([1 ])
188+ self .assertEqual ([], subgraph .weighted_edge_list ())
189+ self .assertEqual (["b" ], subgraph .nodes ())
190+ self .assertEqual (dict (node_map ), {0 : 1 })
191+
192+ # Test all nodes
193+ subgraph , node_map = graph .subgraph_with_nodemap ([0 , 1 , 2 ])
194+ self .assertEqual ([(0 , 1 , 1 ), (1 , 2 , 2 )], subgraph .weighted_edge_list ())
195+ self .assertEqual (["a" , "b" , "c" ], subgraph .nodes ())
196+ self .assertEqual (dict (node_map ), {0 : 0 , 1 : 1 , 2 : 2 })
197+
198+ def test_subgraph_with_nodemap_preserve_attrs (self ):
199+ graph = rustworkx .PyDiGraph (attrs = "test_attrs" )
200+ graph .add_nodes_from (["a" , "b" , "c" ])
201+ graph .add_edges_from ([(0 , 1 , 1 ), (1 , 2 , 2 )])
202+
203+ # Test preserve_attrs=False (default)
204+ subgraph , node_map = graph .subgraph_with_nodemap ([0 , 1 ])
205+ self .assertIsNone (subgraph .attrs )
206+
207+ # Test preserve_attrs=True
208+ subgraph2 , node_map2 = graph .subgraph_with_nodemap ([0 , 1 ], preserve_attrs = True )
209+ self .assertEqual (graph .attrs , subgraph2 .attrs )
0 commit comments