Skip to content

Commit b19c8ba

Browse files
authored
handle list of attr_name in apply/get raster data (#303)
1 parent 77200e8 commit b19c8ba

File tree

2 files changed

+356
-37
lines changed

2 files changed

+356
-37
lines changed

mesa_geo/raster_layers.py

Lines changed: 87 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -194,7 +194,7 @@ def __init__(
194194
Origin is at upper left corner of the grid. Use rowcol instead.
195195
:param rowcol: Indices of the cell in (row, col) format.
196196
Origin is at upper left corner of the grid
197-
:param xy: Cell center coordinates in the CRS.
197+
:param xy: Geographic/projected (x, y) coordinates of the cell center in the CRS.
198198
"""
199199

200200
super().__init__(model)
@@ -263,7 +263,7 @@ def rowcol(self) -> Coordinate | None:
263263
@property
264264
def xy(self) -> FloatCoordinate | None:
265265
"""
266-
Cell center coordinates in the CRS.
266+
Geographic/projected (x, y) coordinates of the cell center in the CRS.
267267
"""
268268
return self._xy
269269

@@ -426,53 +426,104 @@ def coord_iter(self) -> Iterator[tuple[Cell, int, int]]:
426426
for col in range(self.height):
427427
yield self.cells[row][col], row, col # cell, x, y
428428

429-
def apply_raster(self, data: np.ndarray, attr_name: str | None = None) -> None:
429+
def apply_raster(
430+
self, data: np.ndarray, attr_name: str | Sequence[str] | None = None
431+
) -> None:
430432
"""
431433
Apply raster data to the cells.
432434
433-
:param np.ndarray data: 2D numpy array with shape (1, height, width).
434-
:param str | None attr_name: Name of the attribute to be added to the cells.
435-
If None, a random name will be generated. Default is None.
436-
:raises ValueError: If the shape of the data is not (1, height, width).
435+
:param np.ndarray data: 3D numpy array with shape (bands, height, width).
436+
:param str | Sequence[str] | None attr_name: Attribute name(s) to be added to the
437+
cells. For multi-band rasters, pass a list of names with length equal to
438+
the number of bands, or a single base name to be suffixed per band. If None,
439+
names are generated. Default is None.
440+
:raises ValueError: If the shape of the data does not match the raster.
437441
"""
438442

439-
if data.shape != (1, self.height, self.width):
443+
if data.ndim != 3 or data.shape[1:] != (self.height, self.width):
440444
raise ValueError(
441445
f"Data shape does not match raster shape. "
442-
f"Expected {(1, self.height, self.width)}, received {data.shape}."
446+
f"Expected (*, {self.height}, {self.width}), received {data.shape}."
443447
)
444-
if attr_name is None:
445-
attr_name = f"attribute_{len(self.cell_cls.__dict__)}"
446-
self._attributes.add(attr_name)
447-
for grid_x in range(self.width):
448-
for grid_y in range(self.height):
449-
setattr(
450-
self.cells[grid_x][grid_y],
451-
attr_name,
452-
data[0, self.height - grid_y - 1, grid_x],
453-
)
448+
num_bands = data.shape[0]
449+
450+
if num_bands == 1:
451+
if isinstance(attr_name, Sequence) and not isinstance(attr_name, str):
452+
if len(attr_name) != 1:
453+
raise ValueError(
454+
"attr_name sequence length must match the number of raster bands; "
455+
f"expected {num_bands} band names, got {len(attr_name)}."
456+
)
457+
names = [attr_name[0]]
458+
else:
459+
names = [cast(str | None, attr_name)]
460+
else:
461+
if isinstance(attr_name, Sequence) and not isinstance(attr_name, str):
462+
if len(attr_name) != num_bands:
463+
raise ValueError(
464+
"attr_name sequence length must match the number of raster bands; "
465+
f"expected {num_bands} band names, got {len(attr_name)}."
466+
)
467+
names = list(attr_name)
468+
elif isinstance(attr_name, str):
469+
names = [f"{attr_name}_{band_idx + 1}" for band_idx in range(num_bands)]
470+
else:
471+
names = [None] * num_bands
472+
473+
def _default_attr_name() -> str:
474+
base = f"attribute_{len(self.cell_cls.__dict__)}"
475+
if base not in self._attributes:
476+
return base
477+
suffix = 1
478+
candidate = f"{base}_{suffix}"
479+
while candidate in self._attributes:
480+
suffix += 1
481+
candidate = f"{base}_{suffix}"
482+
return candidate
483+
484+
for band_idx, name in enumerate(names):
485+
attr = _default_attr_name() if name is None else name
486+
self._attributes.add(attr)
487+
for grid_x in range(self.width):
488+
for grid_y in range(self.height):
489+
setattr(
490+
self.cells[grid_x][grid_y],
491+
attr,
492+
data[band_idx, self.height - grid_y - 1, grid_x],
493+
)
454494

455-
def get_raster(self, attr_name: str | None = None) -> np.ndarray:
495+
def get_raster(self, attr_name: str | Sequence[str] | None = None) -> np.ndarray:
456496
"""
457497
Return the values of given attribute.
458498
459-
:param str | None attr_name: Name of the attribute to be returned. If None,
460-
returns all attributes. Default is None.
461-
:return: The values of given attribute as a 2D numpy array with shape (1, height, width).
499+
:param str | Sequence[str] | None attr_name: Name(s) of attributes to be returned.
500+
If None, returns all attributes. Default is None.
501+
:return: The values of given attribute(s) as a numpy array with shape
502+
(bands, height, width).
462503
:rtype: np.ndarray
463504
"""
464505

465-
if attr_name is not None and attr_name not in self.attributes:
506+
if isinstance(attr_name, str) and attr_name not in self.attributes:
466507
raise ValueError(
467508
f"Attribute {attr_name} does not exist. "
468509
f"Choose from {self.attributes}, or set `attr_name` to `None` to retrieve all."
469510
)
511+
if isinstance(attr_name, Sequence) and not isinstance(attr_name, str):
512+
missing = [name for name in attr_name if name not in self.attributes]
513+
if missing:
514+
raise ValueError(
515+
f"Attribute {missing[0]} does not exist. "
516+
f"Choose from {self.attributes}, or set `attr_name` to `None` to retrieve all."
517+
)
470518
if attr_name is None:
471519
num_bands = len(self.attributes)
472520
attr_names = self.attributes
521+
elif isinstance(attr_name, Sequence) and not isinstance(attr_name, str):
522+
num_bands = len(attr_name)
523+
attr_names = list(attr_name)
473524
else:
474525
num_bands = 1
475-
attr_names = {attr_name}
526+
attr_names = [attr_name]
476527
data = np.empty((num_bands, self.height, self.width))
477528
for ind, name in enumerate(attr_names):
478529
for grid_x in range(self.width):
@@ -684,16 +735,18 @@ def from_file(
684735
raster_file: str,
685736
model: Model,
686737
cell_cls: type[Cell] = Cell,
687-
attr_name: str | None = None,
738+
attr_name: str | Sequence[str] | None = None,
688739
rio_opener: Callable | None = None,
689740
) -> RasterLayer:
690741
"""
691742
Creates a RasterLayer from a raster file.
692743
693744
:param str raster_file: Path to the raster file.
694745
:param Type[Cell] cell_cls: The class of the cells in the layer.
695-
:param str | None attr_name: The name of the attribute to use for the cell values.
696-
If None, a random name will be generated. Default is None.
746+
:param str | Sequence[str] | None attr_name: Attribute name(s) to use for the cell
747+
values. For multi-band rasters, pass a list of names with length equal to
748+
the number of bands, or a single base name to be suffixed per band. If None,
749+
names are generated. Default is None.
697750
:param Callable | None rio_opener: A callable passed to Rasterio open() function.
698751
"""
699752

@@ -713,14 +766,17 @@ def from_file(
713766
return obj
714767

715768
def to_file(
716-
self, raster_file: str, attr_name: str | None = None, driver: str = "GTiff"
769+
self,
770+
raster_file: str,
771+
attr_name: str | Sequence[str] | None = None,
772+
driver: str = "GTiff",
717773
) -> None:
718774
"""
719775
Writes a raster layer to a file.
720776
721777
:param str raster_file: The path to the raster file to write to.
722-
:param str | None attr_name: The name of the attribute to write to the raster.
723-
If None, all attributes are written. Default is None.
778+
:param str | Sequence[str] | None attr_name: The name(s) of attributes to write
779+
to the raster. If None, all attributes are written. Default is None.
724780
:param str driver: The GDAL driver to use for writing the raster file.
725781
Default is 'GTiff'. See GDAL docs at https://gdal.org/drivers/raster/index.html.
726782
"""

0 commit comments

Comments
 (0)