Skip to content

Commit 3054bac

Browse files
Bugfix for deepcopy / pickling discrete spaces (#2378)
* bugfix * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * fix for 3.10 * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --------- Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
1 parent 045904a commit 3054bac

File tree

4 files changed

+58
-1
lines changed

4 files changed

+58
-1
lines changed

mesa/experimental/cell_space/cell.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -211,3 +211,12 @@ def modify_property(
211211
self._mesa_property_layers[property_name].modify_cell(
212212
self.coordinate, operation, value
213213
)
214+
215+
def __getstate__(self):
216+
"""Return state of the Cell with connections set to empty."""
217+
# fixme, once we shift to 3.11, replace this with super. __getstate__
218+
state = (self.__dict__, {k: getattr(self, k) for k in self.__slots__})
219+
state[1][
220+
"connections"
221+
] = {} # replace this with empty connections to avoid infinite recursion error in pickle/deepcopy
222+
return state

mesa/experimental/cell_space/discrete_space.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@ class DiscreteSpace(Generic[T]):
2222
all_cells (CellCollection): The cells composing the discrete space
2323
random (Random): The random number generator
2424
cell_klass (Type) : the type of cell class
25-
empties (CellCollection) : collecction of all cells that are empty
25+
empties (CellCollection) : collection of all cells that are empty
2626
property_layers (dict[str, PropertyLayer]): the property layers of the discrete space
2727
"""
2828

@@ -55,6 +55,7 @@ def __init__(
5555
def cutoff_empties(self): # noqa
5656
return 7.953 * len(self._cells) ** 0.384
5757

58+
def _connect_cells(self): ...
5859
def _connect_single_cell(self, cell: T): ...
5960

6061
@cached_property
@@ -134,3 +135,8 @@ def modify_properties(
134135
condition: a function that takes a cell and returns a boolean (used to filter cells)
135136
"""
136137
self.property_layers[property_name].modify_cells(operation, value, condition)
138+
139+
def __setstate__(self, state):
140+
"""Set the state of the discrete space and rebuild the connections."""
141+
self.__dict__ = state
142+
self._connect_cells()

mesa/experimental/cell_space/network.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,9 @@ def __init__(
3434
node_id, capacity, random=self.random
3535
)
3636

37+
self._connect_cells()
38+
39+
def _connect_cells(self) -> None:
3740
for cell in self.all_cells:
3841
self._connect_single_cell(cell)
3942

tests/test_cell_space.py

Lines changed: 39 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -691,3 +691,42 @@ def test_patch(): # noqa: D103
691691

692692
agent.remove()
693693
assert agent not in model._agents
694+
695+
696+
def test_copying_discrete_spaces(): # noqa: D103
697+
# inspired by #2373
698+
# we use deepcopy but this also applies to pickle
699+
import copy
700+
701+
import networkx as nx
702+
703+
grid = OrthogonalMooreGrid((100, 100))
704+
grid_copy = copy.deepcopy(grid)
705+
706+
c1 = grid[(5, 5)].connections
707+
c2 = grid_copy[(5, 5)].connections
708+
709+
for c1, c2 in zip(grid.all_cells, grid_copy.all_cells):
710+
for k, v in c1.connections.items():
711+
assert v.coordinate == c2.connections[k].coordinate
712+
713+
n = 10
714+
m = 20
715+
seed = 42
716+
G = nx.gnm_random_graph(n, m, seed=seed) # noqa: N806
717+
grid = Network(G)
718+
grid_copy = copy.deepcopy(grid)
719+
720+
for c1, c2 in zip(grid.all_cells, grid_copy.all_cells):
721+
for k, v in c1.connections.items():
722+
assert v.coordinate == c2.connections[k].coordinate
723+
724+
grid = HexGrid((100, 100))
725+
grid_copy = copy.deepcopy(grid)
726+
727+
c1 = grid[(5, 5)].connections
728+
c2 = grid_copy[(5, 5)].connections
729+
730+
for c1, c2 in zip(grid.all_cells, grid_copy.all_cells):
731+
for k, v in c1.connections.items():
732+
assert v.coordinate == c2.connections[k].coordinate

0 commit comments

Comments
 (0)