Skip to content

Commit 696f123

Browse files
Make cell connections public and named (#2296)
make connections public and use a dict with relative coordinates as key --------- Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
1 parent b95dd19 commit 696f123

File tree

6 files changed

+102
-90
lines changed

6 files changed

+102
-90
lines changed

mesa/experimental/cell_space/cell.py

Lines changed: 28 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -9,8 +9,11 @@
99
from mesa.experimental.cell_space.cell_collection import CellCollection
1010

1111
if TYPE_CHECKING:
12+
from mesa.agent import Agent
1213
from mesa.experimental.cell_space.cell_agent import CellAgent
1314

15+
Coordinate = tuple[int, ...]
16+
1417

1518
class Cell:
1619
"""The cell represents a position in a discrete space.
@@ -26,7 +29,7 @@ class Cell:
2629

2730
__slots__ = [
2831
"coordinate",
29-
"_connections",
32+
"connections",
3033
"agents",
3134
"capacity",
3235
"properties",
@@ -44,7 +47,7 @@ class Cell:
4447

4548
def __init__(
4649
self,
47-
coordinate: tuple[int, ...],
50+
coordinate: Coordinate,
4851
capacity: float | None = None,
4952
random: Random | None = None,
5053
) -> None:
@@ -58,20 +61,25 @@ def __init__(
5861
"""
5962
super().__init__()
6063
self.coordinate = coordinate
61-
self._connections: list[Cell] = [] # TODO: change to CellCollection?
62-
self.agents = [] # TODO:: change to AgentSet or weakrefs? (neither is very performant, )
64+
self.connections: dict[Coordinate, Cell] = {}
65+
self.agents: list[
66+
Agent
67+
] = [] # TODO:: change to AgentSet or weakrefs? (neither is very performant, )
6368
self.capacity = capacity
64-
self.properties: dict[str, object] = {}
69+
self.properties: dict[Coordinate, object] = {}
6570
self.random = random
6671

67-
def connect(self, other: Cell) -> None:
72+
def connect(self, other: Cell, key: Coordinate | None = None) -> None:
6873
"""Connects this cell to another cell.
6974
7075
Args:
7176
other (Cell): other cell to connect to
77+
key (Tuple[int, ...]): key for the connection. Should resemble a relative coordinate
7278
7379
"""
74-
self._connections.append(other)
80+
if key is None:
81+
key = other.coordinate
82+
self.connections[key] = other
7583

7684
def disconnect(self, other: Cell) -> None:
7785
"""Disconnects this cell from another cell.
@@ -80,7 +88,9 @@ def disconnect(self, other: Cell) -> None:
8088
other (Cell): other cell to remove from connections
8189
8290
"""
83-
self._connections.remove(other)
91+
keys_to_remove = [k for k, v in self.connections.items() if v == other]
92+
for key in keys_to_remove:
93+
del self.connections[key]
8494

8595
def add_agent(self, agent: CellAgent) -> None:
8696
"""Adds an agent to the cell.
@@ -123,30 +133,34 @@ def __repr__(self): # noqa
123133

124134
# FIXME: Revisit caching strategy on methods
125135
@cache # noqa: B019
126-
def neighborhood(self, radius=1, include_center=False):
136+
def neighborhood(self, radius: int = 1, include_center: bool = False):
127137
"""Returns a list of all neighboring cells."""
128-
return CellCollection(
138+
return CellCollection[Cell](
129139
self._neighborhood(radius=radius, include_center=include_center),
130140
random=self.random,
131141
)
132142

133143
# FIXME: Revisit caching strategy on methods
134144
@cache # noqa: B019
135-
def _neighborhood(self, radius=1, include_center=False):
145+
def _neighborhood(
146+
self, radius: int = 1, include_center: bool = False
147+
) -> dict[Cell, list[Agent]]:
136148
# if radius == 0:
137149
# return {self: self.agents}
138150
if radius < 1:
139151
raise ValueError("radius must be larger than one")
140152
if radius == 1:
141-
neighborhood = {neighbor: neighbor.agents for neighbor in self._connections}
153+
neighborhood = {
154+
neighbor: neighbor.agents for neighbor in self.connections.values()
155+
}
142156
if not include_center:
143157
return neighborhood
144158
else:
145159
neighborhood[self] = self.agents
146160
return neighborhood
147161
else:
148-
neighborhood = {}
149-
for neighbor in self._connections:
162+
neighborhood: dict[Cell, list[Agent]] = {}
163+
for neighbor in self.connections.values():
150164
neighborhood.update(
151165
neighbor._neighborhood(radius - 1, include_center=True)
152166
)

mesa/experimental/cell_space/discrete_space.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -62,11 +62,11 @@ def all_cells(self):
6262
def __iter__(self): # noqa
6363
return iter(self._cells.values())
6464

65-
def __getitem__(self, key): # noqa
65+
def __getitem__(self, key: tuple[int, ...]) -> T: # noqa: D105
6666
return self._cells[key]
6767

6868
@property
69-
def empties(self) -> CellCollection:
69+
def empties(self) -> CellCollection[T]:
7070
"""Return all empty in spaces."""
7171
return self.all_cells.select(lambda cell: cell.is_empty)
7272

mesa/experimental/cell_space/grid.py

Lines changed: 4 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@
1212
T = TypeVar("T", bound=Cell)
1313

1414

15-
class Grid(DiscreteSpace, Generic[T]):
15+
class Grid(DiscreteSpace[T], Generic[T]):
1616
"""Base class for all grid classes.
1717
1818
Attributes:
@@ -100,7 +100,7 @@ def _connect_single_cell_nd(self, cell: T, offsets: list[tuple[int, ...]]) -> No
100100
if self.torus:
101101
n_coord = tuple(nc % d for nc, d in zip(n_coord, self.dimensions))
102102
if all(0 <= nc < d for nc, d in zip(n_coord, self.dimensions)):
103-
cell.connect(self._cells[n_coord])
103+
cell.connect(self._cells[n_coord], d_coord)
104104

105105
def _connect_single_cell_2d(self, cell: T, offsets: list[tuple[int, int]]) -> None:
106106
i, j = cell.coordinate
@@ -111,7 +111,7 @@ def _connect_single_cell_2d(self, cell: T, offsets: list[tuple[int, int]]) -> No
111111
if self.torus:
112112
ni, nj = ni % height, nj % width
113113
if 0 <= ni < height and 0 <= nj < width:
114-
cell.connect(self._cells[ni, nj])
114+
cell.connect(self._cells[ni, nj], (di, dj))
115115

116116

117117
class OrthogonalMooreGrid(Grid[T]):
@@ -133,7 +133,6 @@ def _connect_cells_2d(self) -> None:
133133
( 1, -1), ( 1, 0), ( 1, 1),
134134
]
135135
# fmt: on
136-
height, width = self.dimensions
137136

138137
for cell in self.all_cells:
139138
self._connect_single_cell_2d(cell, offsets)
@@ -165,13 +164,12 @@ def _connect_cells_2d(self) -> None:
165164
( 1, 0),
166165
]
167166
# fmt: on
168-
height, width = self.dimensions
169167

170168
for cell in self.all_cells:
171169
self._connect_single_cell_2d(cell, offsets)
172170

173171
def _connect_cells_nd(self) -> None:
174-
offsets = []
172+
offsets: list[tuple[int, ...]] = []
175173
dimensions = len(self.dimensions)
176174
for dim in range(dimensions):
177175
for delta in [

mesa/experimental/cell_space/network.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77
from mesa.experimental.cell_space.discrete_space import DiscreteSpace
88

99

10-
class Network(DiscreteSpace):
10+
class Network(DiscreteSpace[Cell]):
1111
"""A networked discrete space."""
1212

1313
def __init__(
@@ -37,6 +37,6 @@ def __init__(
3737
for cell in self.all_cells:
3838
self._connect_single_cell(cell)
3939

40-
def _connect_single_cell(self, cell):
40+
def _connect_single_cell(self, cell: Cell):
4141
for node_id in self.G.neighbors(cell.coordinate):
42-
cell.connect(self._cells[node_id])
42+
cell.connect(self._cells[node_id], node_id)

mesa/experimental/cell_space/voronoi.py

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -216,8 +216,8 @@ def _connect_cells(self) -> None:
216216

217217
for point in self.triangulation.export_triangles():
218218
for i, j in combinations(point, 2):
219-
self._cells[i].connect(self._cells[j])
220-
self._cells[j].connect(self._cells[i])
219+
self._cells[i].connect(self._cells[j], (i, j))
220+
self._cells[j].connect(self._cells[i], (j, i))
221221

222222
def _validate_parameters(self) -> None:
223223
if self.capacity is not None and not isinstance(self.capacity, float | int):
@@ -233,9 +233,10 @@ def _validate_parameters(self) -> None:
233233

234234
def _get_voronoi_regions(self) -> tuple:
235235
if self.voronoi_coordinates is None or self.regions is None:
236-
self.voronoi_coordinates, self.regions = (
237-
self.triangulation.export_voronoi_regions()
238-
)
236+
(
237+
self.voronoi_coordinates,
238+
self.regions,
239+
) = self.triangulation.export_voronoi_regions()
239240
return self.voronoi_coordinates, self.regions
240241

241242
@staticmethod

0 commit comments

Comments
 (0)