diff --git a/mapchete_eo/array/convert.py b/mapchete_eo/array/convert.py index d6942e8d..26c76616 100644 --- a/mapchete_eo/array/convert.py +++ b/mapchete_eo/array/convert.py @@ -2,6 +2,7 @@ import numpy as np import numpy.ma as ma +from numpy.typing import DTypeLike import xarray as xr from mapchete.types import NodataVal @@ -19,7 +20,9 @@ def to_masked_array( - xarr: Union[xr.Dataset, xr.DataArray], copy: bool = False + xarr: Union[xr.Dataset, xr.DataArray], + copy: bool = False, + out_dtype: Optional[DTypeLike] = None, ) -> ma.MaskedArray: """Convert xr.DataArray to ma.MaskedArray.""" if isinstance(xarr, xr.Dataset): @@ -31,6 +34,9 @@ def to_masked_array( "Cannot create masked_array because DataArray fill value is None" ) + if out_dtype: + xarr = xarr.astype(out_dtype, copy=False) + if xarr.dtype in _NUMPY_FLOAT_DTYPES: return ma.masked_values(xarr, fill_value, copy=copy, shrink=False) else: diff --git a/mapchete_eo/base.py b/mapchete_eo/base.py index aa3bb599..90100ea3 100644 --- a/mapchete_eo/base.py +++ b/mapchete_eo/base.py @@ -6,6 +6,7 @@ import croniter from mapchete import Bounds +import numpy as np import numpy.ma as ma import xarray as xr from dateutil.tz import tzutc @@ -18,6 +19,8 @@ from mapchete.types import MPathLike, NodataVal, NodataVals from pydantic import BaseModel from rasterio.enums import Resampling +from rasterio.features import geometry_mask +from shapely.geometry import mapping from shapely.geometry.base import BaseGeometry from mapchete_eo.archives.base import Archive @@ -62,6 +65,7 @@ class EODataCube(base.InputTile): eo_bands: dict time: List[TimeRange] area: BaseGeometry + area_pixelbuffer: int = 0 def __init__( self, @@ -367,6 +371,29 @@ def default_read_values( nodatavals=nodatavals, merge_products_by=merge_products_by, merge_method=merge_method, + read_mask=self.get_read_mask(), + ) + + def get_read_mask(self) -> np.ndarray: + """ + Determine read mask according to input area. + + This will generate a numpy array where pixel overlapping the input area + are set True and thus will get filled by the read function. Pixel outside + of the area are not considered for reading. + + On staged reading, i.e. first checking the product masks to assess valid + pixels, this will avoid reading product bands in cases the product only covers + pixels outside of the intended reading area. + """ + area = self.area.buffer(self.area_pixelbuffer * self.tile.pixel_x_size) + if area.is_empty: + return np.zeros((self.tile.shape), dtype=bool) + return geometry_mask( + geometries=[mapping(area)], + out_shape=self.tile.shape, + transform=self.tile.transform, + invert=True, ) @@ -443,8 +470,9 @@ def _init_area(self, input_params: dict) -> BaseGeometry: input_params.get("delimiters", {}).get("bounds"), crs=getattr(input_params.get("pyramid"), "crs"), ), + raise_if_empty=False, ) - return process_area.intersection( + process_area = process_area.intersection( reproject_geometry( configured_area, src_crs=configured_area_crs or self.crs, diff --git a/mapchete_eo/io/items.py b/mapchete_eo/io/items.py index 4f976cf0..aa71ae96 100644 --- a/mapchete_eo/io/items.py +++ b/mapchete_eo/io/items.py @@ -56,7 +56,7 @@ def item_to_np_array( return out -def expand_params(param, length): +def expand_params(param: Any, length: int) -> List[Any]: """ Expand parameters if they are not a list. """ @@ -104,8 +104,10 @@ def get_item_property( | ``collection`` | The collection ID of an Item's collection. | +--------------------+--------------------------------------------------------+ """ - if property in ["year", "month", "day", "date", "datetime"]: - if item.datetime is None: + if property == "id": + return item.id + elif property in ["year", "month", "day", "date", "datetime"]: + if item.datetime is None: # pragma: no cover raise ValueError( f"STAC item has no datetime attached, thus cannot get property {property}" ) diff --git a/mapchete_eo/io/levelled_cubes.py b/mapchete_eo/io/levelled_cubes.py index 3c107c6a..082b99b8 100644 --- a/mapchete_eo/io/levelled_cubes.py +++ b/mapchete_eo/io/levelled_cubes.py @@ -40,27 +40,50 @@ def read_levelled_cube_to_np_array( raise_empty: bool = True, out_dtype: DTypeLike = np.uint16, out_fill_value: NodataVal = 0, + read_mask: Optional[np.ndarray] = None, ) -> ma.MaskedArray: """ Read products as slices into a cube by filling up nodata gaps with next slice. + + If a read_mask is provided, only the pixels marked True are considered to be read. """ - if len(products) == 0: + if len(products) == 0: # pragma: no cover raise NoSourceProducts("no products to read") - bands = assets or eo_bands - if bands is None: + if bands is None: # pragma: no cover raise ValueError("either assets or eo_bands have to be set") - out_shape = (target_height, len(bands), *grid.shape) + + # 2D read_mask shape + if read_mask is None: + read_mask = np.ones(grid.shape, dtype=bool) + elif read_mask.ndim != 2: # pragma: no cover + raise ValueError( + "read_mask must be 2-dimensional, not %s-dimensional", + read_mask.ndim, + ) out: ma.MaskedArray = ma.masked_array( - data=np.zeros(out_shape, dtype=out_dtype), - mask=np.ones(out_shape, dtype=out_dtype), + data=np.full(out_shape, out_fill_value, dtype=out_dtype), + mask=np.ones(out_shape, dtype=bool), fill_value=out_fill_value, ) + + if not read_mask.any(): + logger.debug("nothing to read") + return out + + # extrude mask to match each layer + layer_read_mask = np.stack([read_mask for _ in bands]) + + def _cube_read_mask() -> np.ndarray: + # This is only needed for debug output, thus there is no need to materialize always + return np.stack([layer_read_mask for _ in range(target_height)]) + logger.debug( - "empty cube with shape %s has %s", + "empty cube with shape %s has %s and %s pixels to be filled", out.shape, pretty_bytes(out.size * out.itemsize), + _cube_read_mask().sum(), ) logger.debug("sort products into slices ...") @@ -76,25 +99,25 @@ def read_levelled_cube_to_np_array( slices_read_count, slices_skip_count = 0, 0 # pick slices one by one - for slice_count, slice in enumerate(slices, 1): + for slice_count, slice_ in enumerate(slices, 1): # all filled up? let's get outta here! if not out.mask.any(): - logger.debug("cube is full, quitting!") + logger.debug("cube has no pixels to be filled, quitting!") break # generate 2D mask of holes to be filled in output cube - cube_nodata_mask = out.mask.any(axis=0).any(axis=0) + cube_nodata_mask = np.logical_and(out.mask.any(axis=0).any(axis=0), read_mask) # read slice try: logger.debug( "see if slice %s %s has some of the %s unmasked pixels for cube", slice_count, - slice, + slice_, cube_nodata_mask.sum(), ) - with slice.cached(): - slice_array = slice.read( + with slice_.cached(): + slice_array = slice_.read( merge_method=merge_method, product_read_kwargs=dict( product_read_kwargs, @@ -104,17 +127,18 @@ def read_levelled_cube_to_np_array( resampling=resampling, nodatavals=nodatavals, raise_empty=raise_empty, - target_mask=~cube_nodata_mask.copy(), + read_mask=cube_nodata_mask.copy(), + out_dtype=out_dtype, ), ) slices_read_count += 1 except (EmptySliceException, CorruptedSlice) as exc: - logger.debug("skipped slice %s: %s", slice, str(exc)) + logger.debug("skipped slice %s: %s", slice_, str(exc)) slices_skip_count += 1 continue # if slice was not empty, fill pixels into cube - logger.debug("add slice %s array to cube", slice) + logger.debug("add slice %s array to cube", slice_) # iterate through layers of cube for layer_index in range(target_height): @@ -124,34 +148,35 @@ def read_levelled_cube_to_np_array( continue # determine empty patches of current layer - empty_patches = out[layer_index].mask.copy() - pixels_for_layer = (~slice_array[empty_patches].mask).sum() + empty_patches = np.logical_and(out[layer_index].mask, layer_read_mask) + remaining_pixels_for_layer = (~slice_array[empty_patches].mask).sum() # when slice has nothing to offer for this layer, skip - if pixels_for_layer == 0: + if remaining_pixels_for_layer == 0: logger.debug( "layer %s: slice has no pixels for this layer, jump to next", layer_index, ) continue + # insert slice data into empty patches of layer logger.debug( "layer %s: fill with %s pixels ...", layer_index, - pixels_for_layer, + remaining_pixels_for_layer, ) - # insert slice data into empty patches of layer out[layer_index][empty_patches] = slice_array[empty_patches] - masked_pixels = out[layer_index].mask.sum() - total_pixels = out[layer_index].size - percent_full = round( - 100 * ((total_pixels - masked_pixels) / total_pixels), 2 - ) + + # report on layer fill status logger.debug( - "layer %s: %s%% filled (%s empty pixels remaining)", + "layer %s: %s", layer_index, - percent_full, - out[layer_index].mask.sum(), + _percent_full( + remaining=np.logical_and( + out[layer_index].mask, layer_read_mask + ).sum(), + total=layer_read_mask.sum(), + ), ) # remove slice values which were just inserted for next layer @@ -161,13 +186,13 @@ def read_levelled_cube_to_np_array( logger.debug("slice fully inserted into cube, skipping") break - masked_pixels = out.mask.sum() - total_pixels = out.size - percent_full = round(100 * ((total_pixels - masked_pixels) / total_pixels), 2) + # report on layer fill status logger.debug( - "cube is %s%% filled (%s empty pixels remaining)", - percent_full, - masked_pixels, + "cube is %s", + _percent_full( + remaining=np.logical_and(out.mask, _cube_read_mask()).sum(), + total=_cube_read_mask().sum(), + ), ) logger.debug( @@ -197,6 +222,7 @@ def read_levelled_cube_to_xarray( band_axis_name: str = "bands", x_axis_name: str = "x", y_axis_name: str = "y", + read_mask: Optional[np.ndarray] = None, ) -> xr.Dataset: """ Read products as slices into a cube by filling up nodata gaps with next slice. @@ -218,6 +244,7 @@ def read_levelled_cube_to_xarray( sort=sort, product_read_kwargs=product_read_kwargs, raise_empty=raise_empty, + read_mask=read_mask, ), slice_names=[f"layer-{ii}" for ii in range(target_height)], band_names=variables, @@ -226,3 +253,7 @@ def read_levelled_cube_to_xarray( x_axis_name=x_axis_name, y_axis_name=y_axis_name, ) + + +def _percent_full(remaining: int, total: int, ndigits: int = 2) -> str: + return f"{round(100 * (total - remaining) / total, ndigits=ndigits)}% full ({remaining} remaining emtpy pixels)" diff --git a/mapchete_eo/io/products.py b/mapchete_eo/io/products.py index cb8855c3..524b01e7 100644 --- a/mapchete_eo/io/products.py +++ b/mapchete_eo/io/products.py @@ -10,6 +10,7 @@ from mapchete import Timer import numpy as np import numpy.ma as ma +from numpy.typing import DTypeLike import xarray as xr from mapchete.config import get_hash from mapchete.geometry import to_shape @@ -49,11 +50,13 @@ def products_to_np_array( sort: Optional[SortMethodConfig] = None, product_read_kwargs: dict = {}, raise_empty: bool = True, + out_dtype: Optional[DTypeLike] = None, + read_mask: Optional[np.ndarray] = None, ) -> ma.MaskedArray: """Read grid window of EOProducts and merge into a 4D xarray.""" return ma.stack( [ - to_masked_array(s) + to_masked_array(s, out_dtype=out_dtype) for s in generate_slice_dataarrays( products=products, assets=assets, @@ -66,6 +69,7 @@ def products_to_np_array( sort=sort, product_read_kwargs=product_read_kwargs, raise_empty=raise_empty, + read_mask=read_mask, ) ] ) @@ -87,6 +91,7 @@ def products_to_xarray( sort: Optional[SortMethodConfig] = None, raise_empty: bool = True, product_read_kwargs: dict = {}, + read_mask: Optional[np.ndarray] = None, ) -> xr.Dataset: """Read grid window of EOProducts and merge into a 4D xarray.""" data_vars = [ @@ -103,6 +108,7 @@ def products_to_xarray( sort=sort, product_read_kwargs=product_read_kwargs, raise_empty=raise_empty, + read_mask=read_mask, ) ] if merge_products_by and merge_products_by not in ["date", "datetime"]: @@ -322,8 +328,11 @@ def _generate_arrays( valid_arrays = [a for a in arrays if not ma.getmaskarray(a).all()] if valid_arrays: - stacked = ma.stack(valid_arrays, dtype=out.dtype) - out = stacked.mean(axis=0, dtype=out.dtype) + out_dtype = out.dtype + out_fill_value = out.fill_value + stacked = ma.stack(valid_arrays, dtype=out_dtype) + out = stacked.mean(axis=0, dtype=out_dtype).astype(out_dtype, copy=False) + out.set_fill_value(out_fill_value) else: # All arrays were fully masked — return fully masked output out = ma.masked_all(out.shape, dtype=out.dtype) @@ -351,10 +360,12 @@ def generate_slice_dataarrays( sort: Optional[SortMethodConfig] = None, product_read_kwargs: dict = {}, raise_empty: bool = True, + read_mask: Optional[np.ndarray] = None, ) -> Iterator[xr.DataArray]: """ Yield products or merged products into slices as DataArrays. """ + if len(products) == 0: raise NoSourceProducts("no products to read") @@ -396,6 +407,7 @@ def generate_slice_dataarrays( resampling=resampling, nodatavals=nodatavals, raise_empty=raise_empty, + read_mask=read_mask, ), raise_empty=raise_empty, ), diff --git a/mapchete_eo/platforms/sentinel2/metadata_parser.py b/mapchete_eo/platforms/sentinel2/metadata_parser.py index 5a21198f..bf591fe6 100644 --- a/mapchete_eo/platforms/sentinel2/metadata_parser.py +++ b/mapchete_eo/platforms/sentinel2/metadata_parser.py @@ -161,10 +161,8 @@ def __repr__(self): return f"" def clear_cached_data(self): - logger.debug("clear S2Metadata internal caches") self._cache = dict(viewing_incidence_angles=dict(), detector_footprints=dict()) if self._cached_xml_root is not None: - logger.debug("clear S2Metadata xml cache") self._cached_xml_root.clear() self._cached_xml_root = None self.path_mapper.clear_cached_data() diff --git a/mapchete_eo/platforms/sentinel2/product.py b/mapchete_eo/platforms/sentinel2/product.py index f586d154..f9f88ba4 100644 --- a/mapchete_eo/platforms/sentinel2/product.py +++ b/mapchete_eo/platforms/sentinel2/product.py @@ -195,7 +195,6 @@ def __repr__(self): return f"" def clear_cached_data(self): - logger.debug("clear S2Product caches") if self._metadata is not None: self._metadata.clear_cached_data() self._metadata = None @@ -215,7 +214,7 @@ def read_np_array( mask_config: MaskConfig = MaskConfig(), brdf_config: Optional[BRDFConfig] = None, fill_value: int = 0, - target_mask: Optional[np.ndarray] = None, + read_mask: Optional[np.ndarray] = None, **kwargs, ) -> ma.MaskedArray: assets = assets or [] @@ -228,7 +227,9 @@ def read_np_array( count = len(assets) if isinstance(grid, Resolution): grid = self.metadata.grid(grid) - mask = self.get_mask(grid, mask_config, target_mask=target_mask).data + mask = self.get_mask( + grid, mask_config, target_mask=None if read_mask is None else ~read_mask + ).data if nodatavals is None: nodatavals = fill_value elif fill_value is None and nodatavals is not None: @@ -464,13 +465,12 @@ def get_mask( if isinstance(grid, Resolution) else Grid.from_obj(grid) ) - if target_mask is None: target_mask = np.zeros(shape=grid.shape, dtype=bool) else: if target_mask.shape != grid.shape: raise ValueError("a target mask must have the same shape as the grid") - logger.debug("got custom target mask to start with: %s", target_mask) + logger.debug("got custom target mask to start with: %s", target_mask.shape) def _check_full(arr): # ATTENTION: target_mask and out have to be combined *after* mask was buffered! diff --git a/mapchete_eo/processes/merge_rasters.py b/mapchete_eo/processes/merge_rasters.py index 92e08481..35dfbbf0 100644 --- a/mapchete_eo/processes/merge_rasters.py +++ b/mapchete_eo/processes/merge_rasters.py @@ -181,15 +181,19 @@ def gradient_merge( # footprint coverage) # set 1 to 0: gradient_1band[gradient_1band == 1] = 0 - logger.debug(f"gradient_1band: {gradient_1band}") + logger.debug( + f"gradient_1band; min: {np.min(gradient_1band)}, max: {np.max(gradient_1band)}" + ) # extrude array to match number of raster bands gradient_8bit = np.stack([gradient_1band for _ in range(raster.shape[0])]) - logger.debug(f"gradient_8bit: {gradient_8bit}") + logger.debug( + f"gradient_8bit; min: {np.min(gradient_8bit)}, max: {np.max(gradient_8bit)}" + ) # scale gradient from 0 to 1 gradient = gradient_8bit / 255 - logger.debug(f"gradient: {gradient}") + logger.debug(f"gradient; min: {np.min(gradient)} , max: {np.max(gradient)}") # now only apply the gradient where out and raster have values # otherwise pick the remaining existing value or keep a masked diff --git a/mapchete_eo/product.py b/mapchete_eo/product.py index d43ebf37..5391cccd 100644 --- a/mapchete_eo/product.py +++ b/mapchete_eo/product.py @@ -113,7 +113,6 @@ def read_np_array( nodatavals: NodataVals = None, raise_empty: bool = True, apply_offset: bool = True, - apply_scale: bool = False, **kwargs, ) -> ma.MaskedArray: assets = assets or [] diff --git a/mapchete_eo/search/base.py b/mapchete_eo/search/base.py index 2378fbbb..65a93b12 100644 --- a/mapchete_eo/search/base.py +++ b/mapchete_eo/search/base.py @@ -1,3 +1,4 @@ +from functools import cached_property import json import logging from abc import ABC, abstractmethod @@ -48,13 +49,25 @@ class CatalogSearcher(ABC): This class serves as a bridge between an Archive and a catalog implementation. """ - eo_bands: List[str] - id: str - description: str - stac_extensions: List[str] collections: List[str] config_cls: Type[BaseModel] + @abstractmethod + @cached_property + def eo_bands(self) -> List[str]: ... + + @abstractmethod + @cached_property + def id(self) -> str: ... + + @abstractmethod + @cached_property + def description(self) -> str: ... + + @abstractmethod + @cached_property + def stac_extensions(self) -> List[str]: ... + @abstractmethod def search( self, @@ -66,10 +79,10 @@ def search( class StaticCatalogWriterMixin(CatalogSearcher): - client: Client - id: str - description: str - stac_extensions: List[str] + # client: Client + # id: str + # description: str + # stac_extensions: List[str] @abstractmethod def get_collections(self) -> List[Collection]: # pragma: no cover diff --git a/mapchete_eo/search/stac_search.py b/mapchete_eo/search/stac_search.py index 4e6877e5..8b782bf7 100644 --- a/mapchete_eo/search/stac_search.py +++ b/mapchete_eo/search/stac_search.py @@ -36,17 +36,34 @@ def __init__( stac_item_modifiers: Optional[List[Callable[[Item], Item]]] = None, endpoint: Optional[MPathLike] = None, ): + if endpoint is not None: + self.endpoint = endpoint if collections: self.collections = collections else: # pragma: no cover raise ValueError("collections must be given") - self.client = Client.open(endpoint or self.endpoint) - self.id = self.client.id - self.description = self.client.description - self.stac_extensions = self.client.stac_extensions - self.eo_bands = self._eo_bands() self.stac_item_modifiers = stac_item_modifiers + @cached_property + def client(self) -> Client: + return Client.open(self.endpoint) + + @cached_property + def eo_bands(self) -> List[str]: + return self._eo_bands() + + @cached_property + def id(self) -> str: + return self.client.id + + @cached_property + def description(self) -> str: + return self.client.description + + @cached_property + def stac_extensions(self) -> List[str]: + return self.client.stac_extensions + def search( self, time: Optional[Union[TimeRange, List[TimeRange]]] = None, diff --git a/mapchete_eo/search/stac_static.py b/mapchete_eo/search/stac_static.py index 32f97842..3fbc6ac0 100644 --- a/mapchete_eo/search/stac_static.py +++ b/mapchete_eo/search/stac_static.py @@ -1,3 +1,4 @@ +from functools import cached_property import logging import warnings from typing import Any, Callable, Dict, Generator, List, Optional, Union @@ -37,13 +38,25 @@ def __init__( stac_item_modifiers: Optional[List[Callable[[Item], Item]]] = None, ): self.client = Client.from_file(str(baseurl), stac_io=FSSpecStacIO()) - self.id = self.client.id - self.description = self.client.description - self.stac_extensions = self.client.stac_extensions self.collections = [c.id for c in self.client.get_children()] - self.eo_bands = self._eo_bands() self.stac_item_modifiers = stac_item_modifiers + @cached_property + def eo_bands(self) -> List[str]: + return self._eo_bands() + + @cached_property + def id(self) -> str: + return self.client.id + + @cached_property + def description(self) -> str: + return self.client.description + + @cached_property + def stac_extensions(self) -> List[str]: + return self.client.stac_extensions + def search( self, time: Optional[Union[TimeRange, List[TimeRange]]] = None, diff --git a/mapchete_eo/search/utm_search.py b/mapchete_eo/search/utm_search.py index 8b0b462d..69796b6f 100644 --- a/mapchete_eo/search/utm_search.py +++ b/mapchete_eo/search/utm_search.py @@ -1,4 +1,5 @@ import datetime +from functools import cached_property import logging from typing import Any, Callable, Dict, Generator, List, Optional, Set, Union @@ -51,9 +52,12 @@ def __init__( if len(collections) == 0: # pragma: no cover raise ValueError("no collections provided") self.collections = collections - self.eo_bands = self._eo_bands() self.stac_item_modifiers = stac_item_modifiers + @cached_property + def eo_bands(self) -> List[str]: # pragma: no cover + return self._eo_bands() + def search( self, time: Optional[Union[TimeRange, List[TimeRange]]] = None, diff --git a/mapchete_eo/sort.py b/mapchete_eo/sort.py index 23eba217..f2a0e17c 100644 --- a/mapchete_eo/sort.py +++ b/mapchete_eo/sort.py @@ -5,7 +5,9 @@ from typing import Callable, List, Optional from pydantic import BaseModel +from pystac import Item +from mapchete_eo.io.items import get_item_property from mapchete_eo.protocols import DateTimeProtocol from mapchete_eo.time import timedelta, to_datetime from mapchete_eo.types import DateTimeLike @@ -22,7 +24,7 @@ def sort_objects_by_target_date( **kwargs, ) -> List[DateTimeProtocol]: """ - Return sorted list of onjects according to their distance to the target_date. + Return sorted list of objects according to their distance to the target_date. Default for target date is the middle between the objects start date and end date. """ @@ -46,3 +48,17 @@ class TargetDateSort(SortMethodConfig): func: Callable = sort_objects_by_target_date target_date: Optional[DateTimeLike] = None reverse: bool = False + + +def sort_objects_by_cloud_cover( + objects: List[Item], reverse: bool = False +) -> List[Item]: + if len(objects) == 0: # pragma: no cover + return objects + objects.sort(key=lambda x: get_item_property(x, "eo:cloud_cover"), reverse=reverse) + return objects + + +class CloudCoverSort(SortMethodConfig): + func: Callable = sort_objects_by_cloud_cover + reverse: bool = False diff --git a/tests/conftest.py b/tests/conftest.py index d21ace05..dc619b62 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -763,7 +763,7 @@ def set_cdse_test_env(monkeypatch, request): monkeypatch.setenv("AWS_DEFAULT_REGION", "default") monkeypatch.delenv("AWS_REQUEST_PAYER", raising=False) else: - pytest.fail("CDSE AWS credentials not found in environment") + pytest.skip("CDSE AWS credentials not found in environment") @pytest.fixture diff --git a/tests/platforms/sentinel2/test_product.py b/tests/platforms/sentinel2/test_product.py index c6a9d730..7eb8f541 100644 --- a/tests/platforms/sentinel2/test_product.py +++ b/tests/platforms/sentinel2/test_product.py @@ -615,7 +615,19 @@ def test_read_levelled_cube_xarray(s2_stac_items, test_tile): assert isinstance(xarr, xr.Dataset) -def test_read_levelled_cube_np_array(s2_stac_items, test_tile): +@pytest.mark.parametrize( + "read_mask", + [ + None, + np.zeros((256, 256), dtype=bool), + np.ones((256, 256), dtype=bool), + np.concatenate( + [np.zeros((128, 256), dtype=bool), np.ones((128, 256), dtype=bool)], + dtype=bool, + ), + ], +) +def test_read_levelled_cube_np_array(s2_stac_items, test_tile, read_mask): assets = ["red"] target_height = 5 arr = read_levelled_cube_to_np_array( @@ -630,12 +642,29 @@ def test_read_levelled_cube_np_array(s2_stac_items, test_tile): cloud_probability_threshold=50, ) ), + read_mask=read_mask, ) assert isinstance(arr, ma.MaskedArray) - assert arr.any() - assert not arr.mask.all() assert arr.shape[0] == target_height + # no read_mask given or fully set to True + if read_mask is None or read_mask.all(): + assert arr.any() + assert not arr.mask.all() + + # read_mask full of False + elif not read_mask.any(): + assert not arr.any() + assert arr.mask.all() + + # mixed read_mask + else: + assert arr.any() + # flatten cube to 2D indicating where data was inserted + flattened_arr = arr.any(axis=0).any(axis=0) + # any area *not* indicated by read_mask *must* be empty + assert not flattened_arr[~read_mask].any() + # not much a better way of testing it than to make sure, cube is filled from the bottom layers = list(range(target_height)) for lower, higher in zip(layers[:-1], layers[1:]): diff --git a/tests/test_array.py b/tests/test_array.py index 463efcb0..314d709d 100644 --- a/tests/test_array.py +++ b/tests/test_array.py @@ -101,10 +101,17 @@ def test_to_dataset_4d(test_4d_array): lazy_fixture("test_3d_array"), ], ) -def test_dataarray_to_masked_array(masked_array): - converted = to_masked_array(to_dataarray(masked_array)) +@pytest.mark.parametrize( + "out_dtype", + [None, "int8", "int16"], +) +def test_dataarray_to_masked_array(masked_array, out_dtype): + out_dtype = out_dtype or masked_array.dtype + + converted = to_masked_array(to_dataarray(masked_array), out_dtype=out_dtype) + assert converted.shape == masked_array.shape - assert converted.dtype == masked_array.dtype + assert converted.dtype == out_dtype @pytest.mark.parametrize( diff --git a/tests/test_io.py b/tests/test_io.py index ea4c98c3..ced03fbb 100644 --- a/tests/test_io.py +++ b/tests/test_io.py @@ -14,10 +14,10 @@ ) from mapchete_eo.io.products import Slice from mapchete_eo.product import EOProduct -from mapchete_eo.sort import TargetDateSort def test_get_item_property_date(s2_stac_item): + assert get_item_property(s2_stac_item, "id") == s2_stac_item.id assert get_item_property(s2_stac_item, "day") == s2_stac_item.datetime.day assert get_item_property(s2_stac_item, "month") == s2_stac_item.datetime.month assert get_item_property(s2_stac_item, "year") == s2_stac_item.datetime.year @@ -108,11 +108,8 @@ def test_products_to_slices(s2_stac_items): def test_products_to_slices_empty(): - slices = products_to_slices([], group_by_property="day", sort=TargetDateSort()) - for slice_ in slices: - assert len(slice_.products) > 1 - for product in slice_.products: - assert slice_.name == product.item.datetime.day + slices = products_to_slices([]) + assert not slices @pytest.mark.parametrize( diff --git a/tests/test_sort.py b/tests/test_sort.py new file mode 100644 index 00000000..eabc67df --- /dev/null +++ b/tests/test_sort.py @@ -0,0 +1,24 @@ +import pytest + +from mapchete_eo.io import products_to_slices +from mapchete_eo.product import EOProduct +from mapchete_eo.sort import TargetDateSort, CloudCoverSort + + +@pytest.mark.parametrize("sort_by", ["date", "cloud_cover"]) +@pytest.mark.parametrize("reverse", [True, False]) +def test_sort(s2_stac_items, sort_by, reverse): + if sort_by == "date": + sort_method = TargetDateSort(target_date="1970-01-01", reverse=reverse) + sort_property = "datetime" + elif sort_by == "cloud_cover": + sort_method = CloudCoverSort(reverse=reverse) + sort_property = "eo:cloud_cover" + slices = products_to_slices( + [EOProduct.from_stac_item(item) for item in s2_stac_items], + group_by_property="id", + sort=sort_method, + ) + assert slices + sorted_properties = [slice_.get_property(sort_property) for slice_ in slices] + assert sorted_properties == sorted(sorted_properties, reverse=reverse) diff --git a/tests/testdata/sentinel2_stac.mapchete b/tests/testdata/sentinel2_stac.mapchete index 4531728c..29b5d11c 100644 --- a/tests/testdata/sentinel2_stac.mapchete +++ b/tests/testdata/sentinel2_stac.mapchete @@ -16,4 +16,3 @@ output: pyramid: grid: geodetic zoom_levels: 13 -bounds: [16, 46, 16.1, 46.1] \ No newline at end of file