diff --git a/mesa_geo/raster_layers.py b/mesa_geo/raster_layers.py index 30f96299..7601f755 100644 --- a/mesa_geo/raster_layers.py +++ b/mesa_geo/raster_layers.py @@ -194,7 +194,7 @@ def __init__( Origin is at upper left corner of the grid. Use rowcol instead. :param rowcol: Indices of the cell in (row, col) format. Origin is at upper left corner of the grid - :param xy: Cell center coordinates in the CRS. + :param xy: Geographic/projected (x, y) coordinates of the cell center in the CRS. """ super().__init__(model) @@ -263,7 +263,7 @@ def rowcol(self) -> Coordinate | None: @property def xy(self) -> FloatCoordinate | None: """ - Cell center coordinates in the CRS. + Geographic/projected (x, y) coordinates of the cell center in the CRS. """ return self._xy @@ -426,53 +426,104 @@ def coord_iter(self) -> Iterator[tuple[Cell, int, int]]: for col in range(self.height): yield self.cells[row][col], row, col # cell, x, y - def apply_raster(self, data: np.ndarray, attr_name: str | None = None) -> None: + def apply_raster( + self, data: np.ndarray, attr_name: str | Sequence[str] | None = None + ) -> None: """ Apply raster data to the cells. - :param np.ndarray data: 2D numpy array with shape (1, height, width). - :param str | None attr_name: Name of the attribute to be added to the cells. - If None, a random name will be generated. Default is None. - :raises ValueError: If the shape of the data is not (1, height, width). + :param np.ndarray data: 3D numpy array with shape (bands, height, width). + :param str | Sequence[str] | None attr_name: Attribute name(s) to be added to the + cells. For multi-band rasters, pass a list of names with length equal to + the number of bands, or a single base name to be suffixed per band. If None, + names are generated. Default is None. + :raises ValueError: If the shape of the data does not match the raster. """ - if data.shape != (1, self.height, self.width): + if data.ndim != 3 or data.shape[1:] != (self.height, self.width): raise ValueError( f"Data shape does not match raster shape. " - f"Expected {(1, self.height, self.width)}, received {data.shape}." + f"Expected (*, {self.height}, {self.width}), received {data.shape}." ) - if attr_name is None: - attr_name = f"attribute_{len(self.cell_cls.__dict__)}" - self._attributes.add(attr_name) - for grid_x in range(self.width): - for grid_y in range(self.height): - setattr( - self.cells[grid_x][grid_y], - attr_name, - data[0, self.height - grid_y - 1, grid_x], - ) + num_bands = data.shape[0] + + if num_bands == 1: + if isinstance(attr_name, Sequence) and not isinstance(attr_name, str): + if len(attr_name) != 1: + raise ValueError( + "attr_name sequence length must match the number of raster bands; " + f"expected {num_bands} band names, got {len(attr_name)}." + ) + names = [attr_name[0]] + else: + names = [cast(str | None, attr_name)] + else: + if isinstance(attr_name, Sequence) and not isinstance(attr_name, str): + if len(attr_name) != num_bands: + raise ValueError( + "attr_name sequence length must match the number of raster bands; " + f"expected {num_bands} band names, got {len(attr_name)}." + ) + names = list(attr_name) + elif isinstance(attr_name, str): + names = [f"{attr_name}_{band_idx + 1}" for band_idx in range(num_bands)] + else: + names = [None] * num_bands + + def _default_attr_name() -> str: + base = f"attribute_{len(self.cell_cls.__dict__)}" + if base not in self._attributes: + return base + suffix = 1 + candidate = f"{base}_{suffix}" + while candidate in self._attributes: + suffix += 1 + candidate = f"{base}_{suffix}" + return candidate + + for band_idx, name in enumerate(names): + attr = _default_attr_name() if name is None else name + self._attributes.add(attr) + for grid_x in range(self.width): + for grid_y in range(self.height): + setattr( + self.cells[grid_x][grid_y], + attr, + data[band_idx, self.height - grid_y - 1, grid_x], + ) - def get_raster(self, attr_name: str | None = None) -> np.ndarray: + def get_raster(self, attr_name: str | Sequence[str] | None = None) -> np.ndarray: """ Return the values of given attribute. - :param str | None attr_name: Name of the attribute to be returned. If None, - returns all attributes. Default is None. - :return: The values of given attribute as a 2D numpy array with shape (1, height, width). + :param str | Sequence[str] | None attr_name: Name(s) of attributes to be returned. + If None, returns all attributes. Default is None. + :return: The values of given attribute(s) as a numpy array with shape + (bands, height, width). :rtype: np.ndarray """ - if attr_name is not None and attr_name not in self.attributes: + if isinstance(attr_name, str) and attr_name not in self.attributes: raise ValueError( f"Attribute {attr_name} does not exist. " f"Choose from {self.attributes}, or set `attr_name` to `None` to retrieve all." ) + if isinstance(attr_name, Sequence) and not isinstance(attr_name, str): + missing = [name for name in attr_name if name not in self.attributes] + if missing: + raise ValueError( + f"Attribute {missing[0]} does not exist. " + f"Choose from {self.attributes}, or set `attr_name` to `None` to retrieve all." + ) if attr_name is None: num_bands = len(self.attributes) attr_names = self.attributes + elif isinstance(attr_name, Sequence) and not isinstance(attr_name, str): + num_bands = len(attr_name) + attr_names = list(attr_name) else: num_bands = 1 - attr_names = {attr_name} + attr_names = [attr_name] data = np.empty((num_bands, self.height, self.width)) for ind, name in enumerate(attr_names): for grid_x in range(self.width): @@ -684,7 +735,7 @@ def from_file( raster_file: str, model: Model, cell_cls: type[Cell] = Cell, - attr_name: str | None = None, + attr_name: str | Sequence[str] | None = None, rio_opener: Callable | None = None, ) -> RasterLayer: """ @@ -692,8 +743,10 @@ def from_file( :param str raster_file: Path to the raster file. :param Type[Cell] cell_cls: The class of the cells in the layer. - :param str | None attr_name: The name of the attribute to use for the cell values. - If None, a random name will be generated. Default is None. + :param str | Sequence[str] | None attr_name: Attribute name(s) to use for the cell + values. For multi-band rasters, pass a list of names with length equal to + the number of bands, or a single base name to be suffixed per band. If None, + names are generated. Default is None. :param Callable | None rio_opener: A callable passed to Rasterio open() function. """ @@ -713,14 +766,17 @@ def from_file( return obj def to_file( - self, raster_file: str, attr_name: str | None = None, driver: str = "GTiff" + self, + raster_file: str, + attr_name: str | Sequence[str] | None = None, + driver: str = "GTiff", ) -> None: """ Writes a raster layer to a file. :param str raster_file: The path to the raster file to write to. - :param str | None attr_name: The name of the attribute to write to the raster. - If None, all attributes are written. Default is None. + :param str | Sequence[str] | None attr_name: The name(s) of attributes to write + to the raster. If None, all attributes are written. Default is None. :param str driver: The GDAL driver to use for writing the raster file. Default is 'GTiff'. See GDAL docs at https://gdal.org/drivers/raster/index.html. """ diff --git a/tests/test_RasterLayer.py b/tests/test_RasterLayer.py index 4e8e4cfb..78f5499b 100644 --- a/tests/test_RasterLayer.py +++ b/tests/test_RasterLayer.py @@ -1,3 +1,5 @@ +import os +import tempfile import unittest import warnings @@ -23,9 +25,50 @@ def setUp(self) -> None: ], model=self.model, ) + self._tmpdir = tempfile.TemporaryDirectory() + self.tmpdir = self._tmpdir.name + self._setup_raster_files() def tearDown(self) -> None: - pass + self._tmpdir.cleanup() + + def _setup_raster_files(self) -> None: + self.multi_band_values = np.array( + [ + [[1, 2], [3, 4]], + [[10, 20], [30, 40]], + ] + ) + self.single_band_values = np.array([[[9, 8], [7, 6]]]) + transform = rio.transform.from_bounds(-1, -1, 1, 1, 2, 2) + + self.multi_band_path = os.path.join(self.tmpdir, "multi_band.tif") + with rio.open( + self.multi_band_path, + "w", + driver="GTiff", + width=2, + height=2, + count=2, + dtype=self.multi_band_values.dtype, + crs="epsg:4326", + transform=transform, + ) as dataset: + dataset.write(self.multi_band_values) + + self.single_band_path = os.path.join(self.tmpdir, "single_band.tif") + with rio.open( + self.single_band_path, + "w", + driver="GTiff", + width=2, + height=2, + count=1, + dtype=self.single_band_values.dtype, + crs="epsg:4326", + transform=transform, + ) as dataset: + dataset.write(self.single_band_values) def test_apply_raster(self): raster_data = np.array([[[1, 2], [3, 4], [5, 6]]]) @@ -51,6 +94,98 @@ def test_apply_raster(self): with self.assertRaises(ValueError): self.raster_layer.apply_raster(np.empty((1, 100, 100))) + def test_apply_raster_single_band_attr_name_none(self): + raster_data = np.array([[[7, 8], [9, 10], [11, 12]]]) + self.raster_layer.apply_raster(raster_data) + + self.assertEqual(len(self.raster_layer.attributes), 1) + np.testing.assert_array_equal(self.raster_layer.get_raster(), raster_data) + + def test_apply_raster_single_band_attr_name_list(self): + raster_data = np.array([[[7, 8], [9, 10], [11, 12]]]) + self.raster_layer.apply_raster(raster_data, attr_name=["elevation"]) + + self.assertEqual(self.raster_layer.attributes, {"elevation"}) + np.testing.assert_array_equal( + self.raster_layer.get_raster(attr_name="elevation"), raster_data + ) + + def test_apply_raster_single_band_attr_name_list_mismatch(self): + raster_data = np.array([[[7, 8], [9, 10], [11, 12]]]) + with self.assertRaises(ValueError): + self.raster_layer.apply_raster( + raster_data, attr_name=["elevation", "water_level"] + ) + + def test_apply_raster_multiband_attr_name_list(self): + raster_data = np.array( + [ + [[1, 2], [3, 4], [5, 6]], + [[10, 20], [30, 40], [50, 60]], + ] + ) + self.raster_layer.apply_raster( + raster_data, attr_name=["elevation", "water_level"] + ) + + np.testing.assert_array_equal( + self.raster_layer.get_raster(attr_name="elevation"), raster_data[0:1] + ) + np.testing.assert_array_equal( + self.raster_layer.get_raster(attr_name="water_level"), raster_data[1:2] + ) + + def test_apply_raster_multiband_attr_name_none(self): + raster_data = np.array( + [ + [[1, 2], [3, 4], [5, 6]], + [[10, 20], [30, 40], [50, 60]], + ] + ) + self.raster_layer.apply_raster(raster_data) + + data = self.raster_layer.get_raster() + self.assertEqual(data.shape, raster_data.shape) + self.assertTrue( + any( + np.array_equal(data[idx], raster_data[0]) + for idx in range(data.shape[0]) + ) + ) + self.assertTrue( + any( + np.array_equal(data[idx], raster_data[1]) + for idx in range(data.shape[0]) + ) + ) + + def test_apply_raster_multiband_attr_name_string(self): + raster_data = np.array( + [ + [[1, 2], [3, 4], [5, 6]], + [[10, 20], [30, 40], [50, 60]], + ] + ) + self.raster_layer.apply_raster(raster_data, attr_name="band") + + self.assertEqual(self.raster_layer.attributes, {"band_1", "band_2"}) + np.testing.assert_array_equal( + self.raster_layer.get_raster(attr_name="band_1"), raster_data[0:1] + ) + np.testing.assert_array_equal( + self.raster_layer.get_raster(attr_name="band_2"), raster_data[1:2] + ) + + def test_apply_raster_multiband_attr_name_list_mismatch(self): + raster_data = np.array( + [ + [[1, 2], [3, 4], [5, 6]], + [[10, 20], [30, 40], [50, 60]], + ] + ) + with self.assertRaises(ValueError): + self.raster_layer.apply_raster(raster_data, attr_name=["only_one"]) + def test_get_raster(self): raster_data = np.array([[[1, 2], [3, 4], [5, 6]]]) self.raster_layer.apply_raster(raster_data, attr_name="val") @@ -71,14 +206,60 @@ def test_get_raster(self): ) self.raster_layer.apply_raster(raster_data) - # We expect 3 layers: val, elevation, and the new unnamed one. - # Since they are all identical raster_data, the order doesn't matter for equality check. + data = self.raster_layer.get_raster() + self.assertEqual(data.shape, (3, 3, 2)) + for band in data: + np.testing.assert_array_equal(band, raster_data[0]) + with self.assertRaises(ValueError): + self.raster_layer.get_raster("not_existing_attr") + + def test_get_raster_attr_name_list(self): + raster_data = np.array( + [ + [[1, 2], [3, 4], [5, 6]], + [[10, 20], [30, 40], [50, 60]], + ] + ) + self.raster_layer.apply_raster( + raster_data, attr_name=["elevation", "water_level"] + ) np.testing.assert_array_equal( - self.raster_layer.get_raster(), - np.concatenate((raster_data, raster_data, raster_data)), + self.raster_layer.get_raster(attr_name=["water_level", "elevation"]), + np.array([raster_data[1], raster_data[0]]), + ) + + def test_get_raster_attr_name_list_missing(self): + raster_data = np.array( + [ + [[1, 2], [3, 4], [5, 6]], + [[10, 20], [30, 40], [50, 60]], + ] + ) + self.raster_layer.apply_raster( + raster_data, attr_name=["elevation", "water_level"] ) with self.assertRaises(ValueError): - self.raster_layer.get_raster("not_existing_attr") + self.raster_layer.get_raster(attr_name=["elevation", "missing"]) + + def test_to_file_attr_name_list(self): + raster_data = np.array( + [ + [[1, 2], [3, 4], [5, 6]], + [[10, 20], [30, 40], [50, 60]], + ] + ) + self.raster_layer.apply_raster( + raster_data, attr_name=["elevation", "water_level"] + ) + + path = os.path.join(self.tmpdir, "selected_bands.tif") + self.raster_layer.to_file(path, attr_name=["water_level", "elevation"]) + + with rio.open(path, "r") as dataset: + values = dataset.read() + + np.testing.assert_array_equal(values[0], raster_data[1]) + np.testing.assert_array_equal(values[1], raster_data[0]) def test_get_min_cell(self): self.raster_layer.apply_raster( @@ -171,3 +352,85 @@ def test_cell_xy_updates_after_to_crs(self): self.assertEqual(transformed_cell.xy, expected_xy) self.assertEqual(self.raster_layer.cells[0][0].xy, original_xy) self.assertNotEqual(transformed_cell.xy, original_xy) + + def test_from_file_multiband_attr_name_list(self): + layer = mg.RasterLayer.from_file( + self.multi_band_path, self.model, attr_name=["band_1", "band_2"] + ) + + self.assertEqual(layer.attributes, {"band_1", "band_2"}) + np.testing.assert_array_equal( + layer.get_raster(attr_name="band_1"), self.multi_band_values[0:1] + ) + np.testing.assert_array_equal( + layer.get_raster(attr_name="band_2"), self.multi_band_values[1:2] + ) + + def test_from_file_multiband_attr_name_base(self): + layer = mg.RasterLayer.from_file( + self.multi_band_path, self.model, attr_name="band" + ) + + self.assertEqual(layer.attributes, {"band_1", "band_2"}) + np.testing.assert_array_equal( + layer.get_raster(attr_name="band_1"), self.multi_band_values[0:1] + ) + np.testing.assert_array_equal( + layer.get_raster(attr_name="band_2"), self.multi_band_values[1:2] + ) + + def test_from_file_multiband_attr_name_length_mismatch(self): + with self.assertRaises(ValueError): + mg.RasterLayer.from_file( + self.multi_band_path, self.model, attr_name=["only_one"] + ) + + def test_from_file_single_band_attr_name_list(self): + layer = mg.RasterLayer.from_file( + self.single_band_path, self.model, attr_name=["elevation"] + ) + + self.assertEqual(layer.attributes, {"elevation"}) + np.testing.assert_array_equal( + layer.get_raster(attr_name="elevation"), self.single_band_values + ) + + def test_from_file_single_band_attr_name_string(self): + layer = mg.RasterLayer.from_file( + self.single_band_path, self.model, attr_name="elevation" + ) + + self.assertEqual(layer.attributes, {"elevation"}) + np.testing.assert_array_equal( + layer.get_raster(attr_name="elevation"), self.single_band_values + ) + + def test_from_file_single_band_attr_name_none(self): + layer = mg.RasterLayer.from_file(self.single_band_path, self.model) + + self.assertEqual(len(layer.attributes), 1) + np.testing.assert_array_equal(layer.get_raster(), self.single_band_values) + + def test_from_file_single_band_attr_name_length_mismatch(self): + with self.assertRaises(ValueError): + mg.RasterLayer.from_file( + self.single_band_path, self.model, attr_name=["a", "b"] + ) + + def test_from_file_multiband_attr_name_none(self): + layer = mg.RasterLayer.from_file(self.multi_band_path, self.model) + + data = layer.get_raster() + self.assertEqual(data.shape, self.multi_band_values.shape) + self.assertTrue( + any( + np.array_equal(data[idx], self.multi_band_values[0]) + for idx in range(data.shape[0]) + ) + ) + self.assertTrue( + any( + np.array_equal(data[idx], self.multi_band_values[1]) + for idx in range(data.shape[0]) + ) + )