Skip to content

Commit 28b1e9c

Browse files
authored
Fix pickling of graphs with contracted nodes (#1505)
* set node_removed if nodes contracted * format
1 parent 7318a80 commit 28b1e9c

File tree

4 files changed

+68
-0
lines changed

4 files changed

+68
-0
lines changed

src/digraph.rs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3007,6 +3007,7 @@ impl PyDiGraph {
30073007
}
30083008
(None, true) => self.graph.contract_nodes(nodes, obj, check_cycle)?,
30093009
};
3010+
self.node_removed = true;
30103011
Ok(res.index())
30113012
}
30123013

src/graph.rs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1899,6 +1899,7 @@ impl PyGraph {
18991899
}
19001900
(None, true) => self.graph.contract_nodes(nodes, obj),
19011901
};
1902+
self.node_removed = true;
19021903
Ok(res.index())
19031904
}
19041905

tests/digraph/test_pickle.py

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -39,3 +39,26 @@ def test_weight_graph(self):
3939
self.assertEqual([1, 2, 3], gprime.node_indices())
4040
self.assertEqual(["B", "C", "D"], gprime.nodes())
4141
self.assertEqual({1: (1, 2, "B -> C"), 3: (3, 1, "D -> B")}, dict(gprime.edge_index_map()))
42+
43+
def test_contracted_nodes_pickle(self):
44+
"""Test pickle/unpickle of directed graphs with contracted nodes (issue #1503)"""
45+
g = rx.PyDiGraph()
46+
g.add_node("A") # Node 0
47+
g.add_node("B") # Node 1
48+
g.add_node("C") # Node 2
49+
50+
# Contract nodes 0 and 1 into a new node
51+
contracted_idx = g.contract_nodes([0, 1], "AB")
52+
g.add_edge(2, contracted_idx, "C -> AB")
53+
54+
# Verify initial state
55+
self.assertEqual([2, contracted_idx], g.node_indices())
56+
self.assertEqual([(2, contracted_idx)], g.edge_list())
57+
58+
# Test pickle/unpickle
59+
gprime = pickle.loads(pickle.dumps(g))
60+
61+
# Verify the unpickled graph matches
62+
self.assertEqual(g.node_indices(), gprime.node_indices())
63+
self.assertEqual(g.edge_list(), gprime.edge_list())
64+
self.assertEqual(g.nodes(), gprime.nodes())

tests/graph/test_pickle.py

Lines changed: 43 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -39,3 +39,46 @@ def test_weight_graph(self):
3939
self.assertEqual([1, 2, 3], gprime.node_indices())
4040
self.assertEqual(["B", "C", "D"], gprime.nodes())
4141
self.assertEqual({1: (1, 2, "B -> C"), 3: (3, 1, "D -> B")}, dict(gprime.edge_index_map()))
42+
43+
def test_contracted_nodes_pickle(self):
44+
"""Test pickle/unpickle of graphs with contracted nodes (issue #1503)"""
45+
g = rx.PyGraph()
46+
g.add_node("A") # Node 0
47+
g.add_node("B") # Node 1
48+
g.add_node("C") # Node 2
49+
50+
# Contract nodes 0 and 1 into a new node
51+
contracted_idx = g.contract_nodes([0, 1], "AB")
52+
g.add_edge(2, contracted_idx, "C -> AB")
53+
54+
# Verify initial state
55+
self.assertEqual([2, contracted_idx], g.node_indices())
56+
self.assertEqual([(2, contracted_idx)], g.edge_list())
57+
58+
# Test pickle/unpickle
59+
gprime = pickle.loads(pickle.dumps(g))
60+
61+
# Verify the unpickled graph matches
62+
self.assertEqual(g.node_indices(), gprime.node_indices())
63+
self.assertEqual(g.edge_list(), gprime.edge_list())
64+
self.assertEqual(g.nodes(), gprime.nodes())
65+
66+
def test_contracted_nodes_with_weights_pickle(self):
67+
"""Test pickle/unpickle of graphs with contracted nodes and edge weights"""
68+
g = rx.PyGraph()
69+
g.add_nodes_from(["Node0", "Node1", "Node2", "Node3"])
70+
g.add_edges_from([(0, 2, "edge_0_2"), (1, 3, "edge_1_3")])
71+
72+
# Contract multiple nodes
73+
contracted_idx = g.contract_nodes([0, 1], "Contracted_0_1")
74+
g.add_edge(contracted_idx, 2, "contracted_to_2")
75+
g.add_edge(3, contracted_idx, "3_to_contracted")
76+
77+
# Test pickle/unpickle
78+
gprime = pickle.loads(pickle.dumps(g))
79+
80+
# Verify complete graph state is preserved
81+
self.assertEqual(g.node_indices(), gprime.node_indices())
82+
self.assertEqual(g.edge_list(), gprime.edge_list())
83+
self.assertEqual(g.nodes(), gprime.nodes())
84+
self.assertEqual(dict(g.edge_index_map()), dict(gprime.edge_index_map()))

0 commit comments

Comments
 (0)