Skip to content

Commit 77200e8

Browse files
codebreaker32pre-commit-ci[bot]wang-boyu
authored
Add Cell.xy for RasterLayer cell coordinates and clarify pos/indices semantics (#299)
* Add Cell.xy for RasterLayer cell coordinates and clarify pos/indices semantics * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Update test_RasterLayer.py * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Update test_RasterLayer.py * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * update cell coordinates and deprecation accessors * update raster layer docstring * remove cell.grid_pos and recover cell.pos --------- Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: Wang Boyu <boyu.wby@gmail.com>
1 parent b67da65 commit 77200e8

File tree

2 files changed

+242
-39
lines changed

2 files changed

+242
-39
lines changed

mesa_geo/raster_layers.py

Lines changed: 185 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
import copy
99
import itertools
1010
import math
11+
import warnings
1112
from collections.abc import Callable, Iterable, Iterator, Sequence
1213
from typing import Any, cast, overload
1314

@@ -16,7 +17,7 @@
1617
from affine import Affine
1718
from mesa import Model
1819
from mesa.agent import Agent
19-
from mesa.space import Coordinate, accept_tuple_argument
20+
from mesa.space import Coordinate, FloatCoordinate, accept_tuple_argument
2021
from rasterio.warp import (
2122
Resampling,
2223
calculate_default_transform,
@@ -99,6 +100,13 @@ def height(self, height: int) -> None:
99100

100101
@property
101102
def total_bounds(self) -> np.ndarray | None:
103+
"""
104+
Return the bounds of the raster layer in [min_x, min_y, max_x, max_y] format.
105+
106+
:return: Bounds of the raster layer in [min_x, min_y, max_x, max_y] format.
107+
:rtype: np.ndarray | None
108+
"""
109+
102110
return self._total_bounds
103111

104112
@total_bounds.setter
@@ -159,24 +167,105 @@ def out_of_bounds(self, pos: Coordinate) -> bool:
159167
class Cell(Agent):
160168
"""
161169
Cells are containers of raster attributes, and are building blocks of `RasterLayer`.
170+
171+
Deprecated:
172+
`Cell.indices` is deprecated. Use `Cell.rowcol` instead.
162173
"""
163174

164-
pos: Coordinate | None
165-
indices: Coordinate | None
175+
_pos: Coordinate | None
176+
_rowcol: Coordinate | None
177+
_xy: FloatCoordinate | None
166178

167-
def __init__(self, model, pos=None, indices=None):
179+
def __init__(
180+
self,
181+
model,
182+
pos=None,
183+
indices=None,
184+
*,
185+
rowcol=None,
186+
xy=None,
187+
):
168188
"""
169189
Initialize a cell.
170190
171-
:param pos: Position of the cell in (x, y) format.
191+
:param pos: Grid position of the cell in (grid_x, grid_y) format.
172192
Origin is at lower left corner of the grid
173-
:param indices: Indices of the cell in (row, col) format.
193+
:param indices: (Deprecated) Indices of the cell in (row, col) format.
194+
Origin is at upper left corner of the grid. Use rowcol instead.
195+
:param rowcol: Indices of the cell in (row, col) format.
174196
Origin is at upper left corner of the grid
197+
:param xy: Cell center coordinates in the CRS.
175198
"""
176199

177200
super().__init__(model)
178-
self.pos = pos
179-
self.indices = indices
201+
self._pos = pos
202+
self._rowcol = indices if rowcol is None else rowcol
203+
self._xy = xy
204+
205+
@property
206+
def pos(self) -> Coordinate | None:
207+
"""
208+
Grid position in (grid_x, grid_y) format with origin at lower left of the grid.
209+
"""
210+
return self._pos
211+
212+
@pos.setter
213+
def pos(self, pos: Coordinate | None) -> None:
214+
"""
215+
Deprecated setter for `pos`.
216+
"""
217+
warnings.warn(
218+
"Cell.pos setter is deprecated and will be read-only in a future release.",
219+
DeprecationWarning,
220+
stacklevel=2,
221+
)
222+
# set the pos for backward compatibility
223+
# in the future, this will be removed and raise an AttributeError,
224+
# because pos is read-only
225+
self._pos = pos
226+
227+
@property
228+
def indices(self) -> Coordinate | None:
229+
"""
230+
Deprecated alias of `rowcol`.
231+
"""
232+
warnings.warn(
233+
"Cell.indices is deprecated and will be removed in a future release. "
234+
"Use Cell.rowcol instead.",
235+
DeprecationWarning,
236+
stacklevel=2,
237+
)
238+
return self._rowcol
239+
240+
@indices.setter
241+
def indices(self, indices: Coordinate | None) -> None:
242+
"""
243+
Deprecated setter for `rowcol`.
244+
"""
245+
warnings.warn(
246+
"Cell.indices is deprecated and will be removed in a future release. "
247+
"Use Cell.rowcol instead.",
248+
DeprecationWarning,
249+
stacklevel=2,
250+
)
251+
# for backward compatibility, set the rowcol to the indices
252+
# in the future, this will be removed
253+
# and raise an AttributeError, because indices is read-only
254+
self._rowcol = indices
255+
256+
@property
257+
def rowcol(self) -> Coordinate | None:
258+
"""
259+
Raster indices in (row, col) format with origin at upper left of the grid.
260+
"""
261+
return self._rowcol
262+
263+
@property
264+
def xy(self) -> FloatCoordinate | None:
265+
"""
266+
Cell center coordinates in the CRS.
267+
"""
268+
return self._xy
180269

181270
def step(self):
182271
pass
@@ -228,13 +317,31 @@ def __init__(
228317
self._attributes = set()
229318
self._neighborhood_cache = {}
230319

320+
def _update_transform(self) -> None:
321+
super()._update_transform()
322+
if getattr(self, "cells", None):
323+
self._sync_cell_xy()
324+
325+
def _sync_cell_xy(self) -> None:
326+
for column in self.cells:
327+
for cell in column:
328+
row, col = cell.rowcol
329+
cell._xy = rio.transform.xy(self.transform, row, col, offset="center")
330+
231331
def _initialize_cells(self, model: Model, cell_cls: type[Cell]):
232332
self.cells = []
233-
for x in range(self.width):
333+
for grid_x in range(self.width):
234334
col: list[cell_cls] = []
235-
for y in range(self.height):
236-
row_idx, col_idx = self.height - y - 1, x
237-
col.append(self.cell_cls(model, pos=(x, y), indices=(row_idx, col_idx)))
335+
for grid_y in range(self.height):
336+
row_idx, col_idx = self.height - grid_y - 1, grid_x
337+
xy = rio.transform.xy(self.transform, row_idx, col_idx, offset="center")
338+
cell = self.cell_cls(
339+
model,
340+
pos=(grid_x, grid_y),
341+
rowcol=(row_idx, col_idx),
342+
xy=xy,
343+
)
344+
col.append(cell)
238345
self.cells.append(col)
239346

240347
@property
@@ -337,9 +444,13 @@ def apply_raster(self, data: np.ndarray, attr_name: str | None = None) -> None:
337444
if attr_name is None:
338445
attr_name = f"attribute_{len(self.cell_cls.__dict__)}"
339446
self._attributes.add(attr_name)
340-
for x in range(self.width):
341-
for y in range(self.height):
342-
setattr(self.cells[x][y], attr_name, data[0, self.height - y - 1, x])
447+
for grid_x in range(self.width):
448+
for grid_y in range(self.height):
449+
setattr(
450+
self.cells[grid_x][grid_y],
451+
attr_name,
452+
data[0, self.height - grid_y - 1, grid_x],
453+
)
343454

344455
def get_raster(self, attr_name: str | None = None) -> np.ndarray:
345456
"""
@@ -364,9 +475,11 @@ def get_raster(self, attr_name: str | None = None) -> np.ndarray:
364475
attr_names = {attr_name}
365476
data = np.empty((num_bands, self.height, self.width))
366477
for ind, name in enumerate(attr_names):
367-
for x in range(self.width):
368-
for y in range(self.height):
369-
data[ind, self.height - y - 1, x] = getattr(self.cells[x][y], name)
478+
for grid_x in range(self.width):
479+
for grid_y in range(self.height):
480+
data[ind, self.height - grid_y - 1, grid_x] = getattr(
481+
self.cells[grid_x][grid_y], name
482+
)
370483
return data
371484

372485
def iter_neighborhood(
@@ -380,12 +493,13 @@ def iter_neighborhood(
380493
Return an iterator over cell coordinates that are in the
381494
neighborhood of a certain point.
382495
383-
:param Coordinate pos: Coordinate tuple for the neighborhood to get.
496+
:param Coordinate pos: Grid coordinate tuple (grid_x, grid_y) for the
497+
neighborhood to get. Origin is at lower left corner of the grid.
384498
:param bool moore: Whether to use Moore neighborhood or not. If True,
385499
return Moore neighborhood (including diagonals). If False, return
386500
Von Neumann neighborhood (exclude diagonals).
387-
:param bool include_center: If True, return the (x, y) cell as well.
388-
Otherwise, return surrounding cells only. Default is False.
501+
:param bool include_center: If True, return the (grid_x, grid_y) cell as
502+
well. Otherwise, return surrounding cells only. Default is False.
389503
:param int radius: Radius, in cells, of the neighborhood. Default is 1.
390504
:return: An iterator over cell coordinates that are in the neighborhood.
391505
For example with radius 1, it will return list with number of elements
@@ -406,12 +520,13 @@ def iter_neighbors(
406520
"""
407521
Return an iterator over neighbors to a certain point.
408522
409-
:param Coordinate pos: Coordinate tuple for the neighborhood to get.
523+
:param Coordinate pos: Grid coordinate tuple (grid_x, grid_y) for the
524+
neighborhood to get. Origin is at lower left corner of the grid.
410525
:param bool moore: Whether to use Moore neighborhood or not. If True,
411526
return Moore neighborhood (including diagonals). If False, return
412527
Von Neumann neighborhood (exclude diagonals).
413-
:param bool include_center: If True, return the (x, y) cell as well.
414-
Otherwise, return surrounding cells only. Default is False.
528+
:param bool include_center: If True, return the (grid_x, grid_y) cell
529+
as well. Otherwise, return surrounding cells only. Default is False.
415530
:param int radius: Radius, in cells, of the neighborhood. Default is 1.
416531
:return: An iterator of cells that are in the neighborhood; at most 9 (8)
417532
if Moore, 5 (4) if Von Neumann (if not including the center).
@@ -429,8 +544,8 @@ def iter_cell_list_contents(
429544
Returns an iterator of the contents of the cells
430545
identified in cell_list.
431546
432-
:param Iterable[Coordinate] cell_list: Array-like of (x, y) tuples,
433-
or single tuple.
547+
:param Iterable[Coordinate] cell_list: Array-like of grid (grid_x, grid_y) tuples,
548+
or single tuple (grid_x, grid_y). Origin is at lower left corner of the grid.
434549
:return: An iterator of the contents of the cells identified in cell_list.
435550
:rtype: Iterator[Cell]
436551
"""
@@ -448,8 +563,8 @@ def get_cell_list_contents(self, cell_list: Iterable[Coordinate]) -> list[Cell]:
448563
449564
Note: this method returns a list of cells.
450565
451-
:param Iterable[Coordinate] cell_list: Array-like of (x, y) tuples,
452-
or single tuple.
566+
:param Iterable[Coordinate] cell_list: Array-like of grid (grid_x, grid_y) tuples,
567+
or single tuple (grid_x, grid_y). Origin is at lower left corner of the grid.
453568
:return: A list of the contents of the cells identified in cell_list.
454569
:rtype: List[Cell]
455570
"""
@@ -463,6 +578,24 @@ def get_neighborhood(
463578
include_center: bool = False,
464579
radius: int = 1,
465580
) -> list[Coordinate]:
581+
"""
582+
Return a list of cell coordinates that are in the
583+
neighborhood of a certain point.
584+
585+
:param Coordinate pos: Grid coordinate tuple (grid_x, grid_y) for the
586+
neighborhood to get. Origin is at lower left corner of the grid.
587+
:param bool moore: Whether to use Moore neighborhood or not. If True,
588+
return Moore neighborhood (including diagonals). If False, return
589+
Von Neumann neighborhood (exclude diagonals).
590+
:param bool include_center: If True, return the (grid_x, grid_y) cell as
591+
well. Otherwise, return surrounding cells only. Default is False.
592+
:param int radius: Radius, in cells, of the neighborhood. Default is 1.
593+
:return: A list of cell coordinates that are in the neighborhood.
594+
For example with radius 1, it will return list with number of elements
595+
equals at most 9 (8) if Moore, 5 (4) if Von Neumann (if not including
596+
the center).
597+
:rtype: List[Coordinate]
598+
"""
466599
cache_key = (pos, moore, include_center, radius)
467600
neighborhood = self._neighborhood_cache.get(cache_key, None)
468601

@@ -500,8 +633,18 @@ def get_neighboring_cells(
500633
return [self.cells[idx[0]][idx[1]] for idx in neighboring_cell_idx]
501634

502635
def to_crs(self, crs, inplace=False) -> RasterLayer | None:
636+
"""
637+
Transform the raster layer to a new coordinate reference system.
638+
639+
:param crs: The coordinate reference system to transform to.
640+
:param inplace: Whether to transform the raster layer in place or
641+
return a new raster layer. Defaults to False.
642+
:return: The transformed raster layer if not inplace.
643+
:rtype: RasterLayer | None
644+
"""
645+
503646
super()._to_crs_check(crs)
504-
layer = self if inplace else copy.copy(self)
647+
layer = self if inplace else copy.deepcopy(self)
505648

506649
src_crs = rio.crs.CRS.from_user_input(layer.crs)
507650
dst_crs = rio.crs.CRS.from_user_input(crs)
@@ -518,6 +661,8 @@ def to_crs(self, crs, inplace=False) -> RasterLayer | None:
518661
]
519662
layer.crs = crs
520663
layer._transform = transform
664+
if getattr(layer, "cells", None):
665+
layer._sync_cell_xy()
521666

522667
if not inplace:
523668
return layer
@@ -529,7 +674,7 @@ def to_image(self, colormap) -> ImageLayer:
529674

530675
values = np.empty(shape=(4, self.height, self.width))
531676
for cell in self:
532-
row, col = cell.indices
677+
row, col = cell.rowcol
533678
values[:, row, col] = colormap(cell)
534679
return ImageLayer(values=values, crs=self.crs, total_bounds=self.total_bounds)
535680

@@ -563,6 +708,7 @@ def from_file(
563708
]
564709
obj = cls(width, height, dataset.crs, total_bounds, model, cell_cls)
565710
obj._transform = dataset.transform
711+
obj._sync_cell_xy()
566712
obj.apply_raster(values, attr_name=attr_name)
567713
return obj
568714

@@ -639,6 +785,15 @@ def values(self, values: np.ndarray) -> None:
639785
self._update_transform()
640786

641787
def to_crs(self, crs, inplace=False) -> ImageLayer | None:
788+
"""
789+
Transform the image layer to a new coordinate reference system.
790+
791+
:param crs: The coordinate reference system to transform to.
792+
:param inplace: Whether to transform the image layer in place or
793+
return a new image layer. Defaults to False.
794+
:return: The transformed image layer if not inplace.
795+
:rtype: ImageLayer | None
796+
"""
642797
super()._to_crs_check(crs)
643798
layer = self if inplace else copy.copy(self)
644799

0 commit comments

Comments
 (0)