diff --git a/.gitignore b/.gitignore index 1ec3a75..eb4ff94 100644 --- a/.gitignore +++ b/.gitignore @@ -49,3 +49,5 @@ venv.bak/ .mypy_cache/ .dmypy.json dmypy.json + +.vscode/ diff --git a/CHANGELOG.md b/CHANGELOG.md index 9fd8647..84d87f6 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -4,6 +4,13 @@ HyBIG follows semantic versioning. All notable changes to this project will be documented in this file. The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). +## [unreleased] + +### Changed + +* Refactored code be more memory efficient, reducing peak memory usage when processing large granules by about 90% at no performance hit. This is accomplished by fundamentally changing the way data is read, doing away with rioxarray and instead using rasterio directly. The DatasetReader allows for windowed reads into the source data, at the expense of having to reimplement the mask_and_scale functionality that is provided by rioxarray. +* Minor changes related to resolving tech debt and code quality improvements throughout. + ## [v2.5.1] - 2026-01-06 ### Changed diff --git a/docker/service_version.txt b/docker/service_version.txt index 73462a5..e70b452 100644 --- a/docker/service_version.txt +++ b/docker/service_version.txt @@ -1 +1 @@ -2.5.1 +2.6.0 diff --git a/hybig/browse.py b/hybig/browse.py index f596dd3..5f949ec 100644 --- a/hybig/browse.py +++ b/hybig/browse.py @@ -5,20 +5,17 @@ from logging import Logger, getLogger from pathlib import Path -import matplotlib import numpy as np import rasterio from affine import dumpsw from harmony_service_lib.message import Message as HarmonyMessage from harmony_service_lib.message import Source as HarmonySource -from matplotlib.colors import Normalize -from numpy import ndarray, uint8 +from matplotlib.colors import BoundaryNorm, Normalize +from numpy.typing import NDArray from osgeo_utils.auxiliary.color_palette import ColorPalette -from PIL import Image from rasterio.io import DatasetReader -from rasterio.warp import Resampling, reproject -from rioxarray import open_rasterio -from xarray import DataArray +from rasterio.warp import Resampling, reproject, transform_bounds +from rasterio.windows import Window from hybig.browse_utility import get_harmony_message_from_params from hybig.color_utility import ( @@ -26,7 +23,6 @@ OPAQUE, TRANSPARENT, ColorMap, - all_black_color_map, colormap_from_colors, get_color_palette, greyscale_colormap, @@ -44,9 +40,9 @@ def create_browse( source_tiff: str, - params: dict = None, + params: dict | None = None, palette: str | ColorPalette | None = None, - logger: Logger = None, + logger: Logger | None = None, ) -> list[tuple[Path, Path, Path]]: """Create browse imagery from an input geotiff. @@ -156,50 +152,31 @@ def create_browse_imagery( xml file. """ - output_driver = image_driver(message.format.mime) + output_driver = image_driver(message.format.mime) # type: ignore out_image_file = output_image_file(Path(input_file_path), driver=output_driver) out_world_file = output_world_file(Path(input_file_path), driver=output_driver) try: - with open_rasterio( - input_file_path, mode='r', mask_and_scale=True - ) as rio_in_array: - in_dataset = rio_in_array.rio._manager.acquire() - validate_file_type(in_dataset) - validate_file_crs(rio_in_array) - - if rio_in_array.rio.count == 1: - color_palette = get_color_palette( - in_dataset, source, item_color_palette - ) - if output_driver == 'JPEG': - # For JPEG output, convert to RGB - # color_map will be None - raster, color_map = convert_singleband_to_rgb( - rio_in_array, color_palette - ) - else: - # For PNG output, use palettized approach - raster, color_map = convert_singleband_to_raster( - rio_in_array, color_palette - ) - elif rio_in_array.rio.count in (3, 4): - raster = convert_mulitband_to_raster(rio_in_array) - color_map = None - if output_driver == 'JPEG': - raster = raster[0:3, :, :] - else: - raise HyBIGError( - f'incorrect number of bands for image: {rio_in_array.rio.count}' - ) + with rasterio.open(input_file_path) as src_ds: + validate_file_type(src_ds) + validate_file_crs(src_ds) + + band_count = src_ds.count + color_palette = None - grid_parameters = get_target_grid_parameters(message, rio_in_array) + if band_count == 1: + color_palette = get_color_palette(src_ds, source, item_color_palette) + elif band_count not in (3, 4): + raise HyBIGError(f'incorrect number of bands for image: {src_ds.count}') + + grid_parameters = get_target_grid_parameters(message, src_ds) grid_parameter_list, tile_locators = create_tiled_output_parameters( grid_parameters ) - processed_files = [] - for grid_parameters, tile_location in zip_longest( + # A list of (image_path, world_file_path, aux_xml_path) + processed_files: list[tuple[Path, Path, Path]] = [] + for grid_params, tile_location in zip_longest( grid_parameter_list, tile_locators ): tiled_out_image_file = get_tiled_filename(out_image_file, tile_location) @@ -207,15 +184,14 @@ def create_browse_imagery( tiled_out_aux_xml_file = get_aux_xml_filename(tiled_out_image_file) logger.info(f'out image file: {tiled_out_image_file}: {tile_location}') - write_georaster_as_browse( - rio_in_array, - raster, - color_map, - grid_parameters, - logger=logger, - driver=output_driver, - out_file_name=tiled_out_image_file, - out_world_name=tiled_out_world_file, + process_tile( + src_ds, + grid_params, + color_palette, + output_driver, + tiled_out_image_file, + tiled_out_world_file, + logger, ) processed_files.append( (tiled_out_image_file, tiled_out_world_file, tiled_out_aux_xml_file) @@ -227,7 +203,197 @@ def create_browse_imagery( return processed_files -def convert_mulitband_to_raster(data_array: DataArray) -> ndarray[uint8]: +def process_tile( + src_ds: DatasetReader, + grid_params: GridParams, + color_palette: ColorPalette | None, + output_driver: str, + out_file_name: Path, + out_world_name: Path, + logger: Logger, +) -> None: + """Read a region from the source dataset, convert raster, and write output.""" + band_count = src_ds.count + + src_window = calculate_source_window(src_ds, grid_params) + + # Tile is outside source bounds + if src_window is None: + return + + # Explicitly load a subset of ds + tile_source = read_window_with_mask_and_scale(src_ds, src_window) + src_crs = src_ds.crs + src_transform = src_ds.window_transform(src_window) + + dst_nodata: int | np.uint8 + + if band_count == 1: + if output_driver == 'JPEG': + raster, color_map = convert_singleband_to_rgb(tile_source, color_palette) + dst_nodata = TRANSPARENT # Not really used for JPEG + else: + raster, color_map, dst_nodata = convert_singleband_to_raster( + tile_source, color_palette + ) + else: + raster = convert_multiband_to_raster(tile_source) + color_map = None + dst_nodata = TRANSPARENT + if output_driver == 'JPEG': + raster = raster[0:3, :, :] + + write_georaster_as_browse( + raster, + src_crs, + src_transform, + color_map, + dst_nodata, + grid_params, + logger, + driver=output_driver, + out_file_name=out_file_name, + out_world_name=out_world_name, + ) + + # Explicit cleanup + del raster + del tile_source + + +def calculate_source_window( + src_ds: DatasetReader, + grid_params: GridParams, +) -> Window | None: + """Calculate the source window needed to cover the target tile. + + Returns a Window defining which portion of the source to read, + with some buffer for reprojection edge effects. + """ + try: + # Get target tile bounds in target CRS + dst_height = grid_params['height'] + dst_width = grid_params['width'] + dst_crs = grid_params['crs'] + dst_transform = grid_params['transform'] + + # Calculate tile bounds in destination CRS + dst_left = dst_transform.c + dst_top = dst_transform.f + dst_right = dst_left + dst_width * dst_transform.a + dst_bottom = dst_top + dst_height * dst_transform.e + + dst_bounds = ( + min(dst_left, dst_right), + min(dst_top, dst_bottom), + max(dst_left, dst_right), + max(dst_top, dst_bottom), + ) + + # Transform bounds to source CRS + src_crs = src_ds.crs + src_bounds = transform_bounds(dst_crs, src_crs, *dst_bounds) + + # Add buffer for reprojection (10% on each side) + width = src_bounds[2] - src_bounds[0] + height = src_bounds[3] - src_bounds[1] + buffer_x = width * 0.1 + buffer_y = height * 0.1 + + buffered_bounds = ( + src_bounds[0] - buffer_x, + src_bounds[1] - buffer_y, + src_bounds[2] + buffer_x, + src_bounds[3] + buffer_y, + ) + + # Convert to window in source pixel coordinates + src_transform = src_ds.transform + if len(src_ds.shape) == 3: + src_height, src_width = src_ds.shape[1], src_ds.shape[2] + else: + # Single band + src_height, src_width = src_ds.shape[0], src_ds.shape[1] + + # Inverse transform to get pixel coordinates from geographic coordinates + # For a point (x, y), the pixel coordinate is: + # col = (x - transform.c) / transform.a + # row = (y - transform.f) / transform.e + # Works like rasterio.windows.from_bounds but also handles positive y pixel size + # which is an edge case for some datasets like PODAAC's GHRSST MUR. + left, bottom, right, top = buffered_bounds + col_left = (left - src_transform.c) / src_transform.a + col_right = (right - src_transform.c) / src_transform.a + + row_top = (top - src_transform.f) / src_transform.e + row_bottom = (bottom - src_transform.f) / src_transform.e + + # Handle both positive and negative y scales + col_min = min(col_left, col_right) + col_max = max(col_left, col_right) + row_min = min(row_top, row_bottom) + row_max = max(row_top, row_bottom) + + # Convert to integer pixel bounds and clip to image extent + col_off = max(0, int(np.floor(col_min))) + row_off = max(0, int(np.floor(row_min))) + col_end = min(src_width, int(np.ceil(col_max))) + row_end = min(src_height, int(np.ceil(row_max))) + + win_width = col_end - col_off + win_height = row_end - row_off + + if win_width <= 0 or win_height <= 0: + return None + + return Window(col_off, row_off, win_width, win_height) # type: ignore + + except Exception: + # If calculation fails, return None + return None + + +def read_window_with_mask_and_scale( + src_ds: DatasetReader, + window: Window, + bands: list[int] | None = None, +) -> NDArray: + """Read a window from a rasterio dataset with masking and scaling applied. + + Replicates the behavior of rioxarray's mask_and_scale=True option. + """ + if bands is None: + bands = list(range(1, src_ds.count + 1)) + + data = src_ds.read(bands, window=window) + + # Convert to float for NaN support + data = data.astype('float64') + + # Apply masking and scaling per band + for i, band_idx in enumerate(bands): + band_data = data[i] # note that this passes by reference + + # Get nodata value for this band + nodata = src_ds.nodatavals[band_idx - 1] # nodatavals is 0-indexed + + # Mask nodata values + if nodata is not None: + mask = np.isnan(band_data) if np.isnan(nodata) else (band_data == nodata) + band_data[mask] = np.nan + + scale = (src_ds.scales or [None])[band_idx - 1] or 1.0 + offset = (src_ds.offsets or [None])[band_idx - 1] or 0.0 + + # Apply scale/offset to non-NaN values + if scale != 1.0 or offset != 0.0: + valid_mask = ~np.isnan(band_data) + band_data[valid_mask] = band_data[valid_mask] * scale + offset + + return data + + +def convert_multiband_to_raster(data_array: NDArray) -> NDArray[np.uint8]: """Convert multiband to a raster image. Return a 4-band raster, where the alpha layer is presumed to be the missing @@ -237,63 +403,43 @@ def convert_mulitband_to_raster(data_array: DataArray) -> ndarray[uint8]: any missing data in the RGB bands. """ - if data_array.rio.count not in [3, 4]: + if data_array.shape[0] not in [3, 4]: raise HyBIGError( - f'Cannot create image from {data_array.rio.count} band image. ' + f'Cannot create image from {data_array.shape[0]} band image. ' 'Expecting 3 or 4 bands.' ) - bands = data_array.to_numpy() - - if data_array.rio.count == 4: - return convert_to_uint8(bands, original_dtype(data_array)) + if data_array.shape[0] == 4: + return convert_to_uint8(data_array, str(data_array.dtype)) # Input NaNs in any of the RGB bands are made transparent. - nan_mask = np.isnan(bands).any(axis=0) + nan_mask = np.isnan(data_array).any(axis=0) nan_alpha = np.where(nan_mask, TRANSPARENT, OPAQUE) - raster = convert_to_uint8(bands, original_dtype(data_array)) + raster = convert_to_uint8(data_array, str(data_array.dtype)) return np.concatenate((raster, nan_alpha[None, ...]), axis=0) -def convert_to_uint8(bands: ndarray, dtype: str | None) -> ndarray[uint8]: - """Convert Banded data with NaNs (missing) into a uint8 data cube. - - Nearly all of the time this will simply pass through the data coercing it - back into unsigned ints and setting the missing values to 0 that will be - masked as transparent in the output png. +def convert_to_uint8(bands: NDArray, dtype: str | None) -> NDArray[np.uint8]: + """Convert banded data with NaNs (missing) into a uint8 data cube.""" + max_val = np.nanmax(bands) - There is a some small non-zero chance that the input RGB image was 16-bit - and if any of the values exceed 255, we must normalize all of input data to - the range 0-255. + # previously this used scaled.filled(0) which only works on masked arrays + if dtype != 'uint8' and max_val > 255: + min_val = np.nanmin(bands) + # Normalize to 0-255 range + with np.errstate(invalid='ignore'): # Suppress NaN warnings + scaled = (bands - min_val) / (max_val - min_val) * 255.0 + return np.nan_to_num(np.around(scaled), nan=0).astype('uint8') - """ - - if dtype != 'uint8' and np.nanmax(bands) > 255: - norm = Normalize(vmin=np.nanmin(bands), vmax=np.nanmax(bands)) - scaled = np.around(norm(bands) * 255.0) - raster = scaled.filled(0).astype('uint8') - else: - raster = np.nan_to_num(bands).astype('uint8') - - return raster - - -def original_dtype(data_array: DataArray) -> str | None: - """Return the original input data's type. - - rastero_open retains the input dtype in the encoding dictionary and is used - to understand what kind of casts are safe. - - """ - return data_array.encoding.get('dtype') or data_array.encoding.get('rasterio_dtype') + return np.nan_to_num(bands, nan=0).astype('uint8') def convert_singleband_to_raster( - data_array: DataArray, + data_array: NDArray, color_palette: ColorPalette | None = None, -) -> tuple[ndarray, ColorMap]: +) -> tuple[NDArray, ColorMap, np.uint8]: """Convert input dataset to a 1-band palettized image with colormap. Uses a palette if provided otherwise returns a greyscale image. @@ -303,26 +449,29 @@ def convert_singleband_to_raster( return scale_paletted_1band(data_array, color_palette) -def scale_grey_1band(data_array: DataArray) -> tuple[ndarray, ColorMap]: +def scale_grey_1band(data_array: NDArray) -> tuple[NDArray, ColorMap, np.uint8]: """Normalize input array and return scaled data with greyscale ColorMap.""" band = data_array[0, :, :] - norm = Normalize(vmin=np.nanmin(band), vmax=np.nanmax(band)) # Scale input data from 0 to 254 + norm = Normalize(vmin=np.nanmin(band), vmax=np.nanmax(band)) normalized_data = norm(band) * 254.0 - # Set any missing to missing - normalized_data[np.isnan(band)] = NODATA_IDX + # Set any missing (nan) to palette's NODATA_IDX + result = np.round(normalized_data) + result[np.isnan(band)] = NODATA_IDX - grey_colormap = greyscale_colormap() - raster_data = np.expand_dims(np.round(normalized_data).data, 0) - return np.array(raster_data, dtype='uint8'), grey_colormap + return ( + result.astype('uint8')[np.newaxis, :, :], + greyscale_colormap(), + np.uint8(NODATA_IDX), + ) def convert_singleband_to_rgb( - data_array: DataArray, + data_array: NDArray, color_palette: ColorPalette | None = None, -) -> tuple[ndarray, None]: +) -> tuple[NDArray, None]: """Convert input 1-band dataset to RGB image for JPEG output. Uses a palette if provided, otherwise returns a greyscale RGB image. @@ -333,39 +482,56 @@ def convert_singleband_to_rgb( return scale_paletted_1band_to_rgb(data_array, color_palette) -def scale_grey_1band_to_rgb(data_array: DataArray) -> tuple[ndarray, None]: +def scale_grey_1band_to_rgb(data_array: NDArray) -> tuple[NDArray, None]: """Normalize input array and return as 3-band RGB grayscale image.""" band = data_array[0, :, :] - norm = Normalize(vmin=np.nanmin(band), vmax=np.nanmax(band)) - # Scale input data from 0 to 254 (palettized data is 254-level quantized) + # Scale input data from 0 to 254. Note that this means nodata and + # the valid data min will occupy the same color level + norm = Normalize(vmin=np.nanmin(band), vmax=np.nanmax(band)) normalized_data = norm(band) * 254.0 - # Set any missing to 0 (black), no transparency + # Set any missing (nan) to 0 black normalized_data[np.isnan(band)] = 0 grey_data = np.round(normalized_data).astype('uint8') - rgb_data = np.stack([grey_data, grey_data, grey_data], axis=0) + return np.stack([grey_data, grey_data, grey_data], axis=0), None - return rgb_data, None +def prepare_palette_colors( + palette: ColorPalette, with_alpha: bool = True +) -> tuple[list[tuple], tuple, int | None]: + """Extract colors and nodata handling from a palette. -def scale_paletted_1band_to_rgb( - data_array: DataArray, palette: ColorPalette -) -> tuple[ndarray, None]: - """Scale a 1-band image with palette into RGB image for JPEG output.""" - band = data_array[0, :, :] - levels = list(palette.pal.keys()) + Returns: + Tuple of (colors_list, nodata_color, nodata_index_or_none) + """ colors = [ - palette.color_to_color_entry(value, with_alpha=True) + palette.color_to_color_entry(value, with_alpha=with_alpha) for value in palette.pal.values() ] - norm = matplotlib.colors.BoundaryNorm(levels, len(levels) - 1) - # handle palette no data value - nodata_color = (0, 0, 0) + nodata_color = (0, 0, 0, 0) if with_alpha else (0, 0, 0) + nodata_index = None + if palette.ndv is not None: - nodata_color = palette.color_to_color_entry(palette.ndv, with_alpha=True) + nodata_color = palette.color_to_color_entry(palette.ndv, with_alpha=with_alpha) + if palette.ndv in palette.pal.values(): + nodata_index = list(palette.pal.values()).index(palette.ndv) + + return colors, nodata_color, nodata_index + + +def scale_paletted_1band_to_rgb( + data_array: NDArray, palette: ColorPalette +) -> tuple[NDArray, None]: + """Scale a 1-band image with palette into RGB image for JPEG output.""" + band = data_array[0, :, :] + levels = list(palette.pal.keys()) + colors, nodata_color, _ = prepare_palette_colors(palette, with_alpha=False) + colors_array = np.array(colors, dtype='uint8') + + norm = BoundaryNorm(levels, len(levels) - 1) # Store NaN mask before normalization nan_mask = np.isnan(band) @@ -373,94 +539,85 @@ def scale_paletted_1band_to_rgb( # Replace NaN with first level to avoid issues during normalization band_clean = np.where(nan_mask, levels[0], band) - # Apply normalization to get palette indices - indexed_band = norm(band_clean) - - # Clip indices to valid range [0, len(colors)-1] - indexed_band = np.clip(indexed_band, 0, len(colors) - 1) + # Get palette indices and clip to valid range + indexed_band = np.clip(norm(band_clean), 0, len(colors) - 1).astype(int) - # Create RGB output array - height, width = band.shape - rgb_array = np.zeros((3, height, width), dtype='uint8') + # Vectorized color lookup + rgb_array = colors_array[indexed_band].transpose(2, 0, 1) - # Apply colors based on palette indices - for i, color in enumerate(colors): - mask = indexed_band == i - rgb_array[0, mask] = color[0] # Red - rgb_array[1, mask] = color[1] # Green - rgb_array[2, mask] = color[2] # Blue - - # Handle NaN/nodata values (overwrite any color assignment) + # Handle nodata (overwrite any color assignment) if nan_mask.any(): rgb_array[0, nan_mask] = nodata_color[0] rgb_array[1, nan_mask] = nodata_color[1] rgb_array[2, nan_mask] = nodata_color[2] - return rgb_array, None + return np.ascontiguousarray(rgb_array), None def scale_paletted_1band( - data_array: DataArray, palette: ColorPalette -) -> tuple[ndarray, ColorMap]: + data_array: NDArray, palette: ColorPalette +) -> tuple[NDArray, ColorMap, np.uint8]: """Scale a 1-band image with palette into modified image and associated color_map. Use the palette's levels and values, transform the input data_array into - the correct levels indexed from 0-255 return the scaled array along side of + the correct levels indexed from 0-255 return the scaled array alongside a colormap corresponding to the new levels. Values below the minimum palette level are clipped to the lowest color. Values above the maximum palette level are clipped to the highest color. Only NaN values are mapped to the nodata index. + + Returns: + Tuple of (raster_data, color_map, nodata_index) """ - global DST_NODATA band = data_array[0, :, :] levels = list(palette.pal.keys()) - colors = [ - palette.color_to_color_entry(value, with_alpha=True) - for value in palette.pal.values() - ] - norm = matplotlib.colors.BoundaryNorm(levels, len(levels) - 1) + colors, nodata_color, existing_nodata_idx = prepare_palette_colors( + palette, with_alpha=True + ) - # handle palette no data value - nodata_color = (0, 0, 0, 0) - if palette.ndv is not None: - nodata_color = palette.color_to_color_entry(palette.ndv, with_alpha=True) - # Check if nodata color already exists in palette - if palette.ndv in palette.pal.values(): - DST_NODATA = list(palette.pal.values()).index(palette.ndv) - # Don't add nodata_color; it's already in colors - else: - # Nodata not in palette, add it at the beginning - DST_NODATA = 0 - colors = [nodata_color, *colors] + # Determine where nodata sits in the final colormap + if existing_nodata_idx is not None: + # Don't add nodata_color; it's already in colors + dst_nodata = np.uint8(existing_nodata_idx) + elif palette.ndv is not None: + # Nodata not in palette, add it at the beginning + dst_nodata = np.uint8(0) + colors = [nodata_color, *colors] else: # if there is no ndv, add one to the end of the colormap - DST_NODATA = len(colors) + dst_nodata = np.uint8(len(colors)) colors = [*colors, nodata_color] + norm = BoundaryNorm(levels, len(levels) - 1) + nan_mask = np.isnan(band) - band_clean = np.where(nan_mask, levels[0], band) + if band.flags.writeable: + band[nan_mask] = levels[0] + band_clean = band + else: + band_clean = np.where(nan_mask, levels[0], band) scaled_band = norm(band_clean) - if DST_NODATA == 0: + # Apply offset and clip to valid palette range + if dst_nodata == 0: # boundary norm indexes [0, levels) by default, so if the NODATA index is 0, # all the palette indices need to be incremented by 1. scaled_band = scaled_band + 1 - - # Clip to valid palette range (excluding nodata index) - if DST_NODATA == 0: - # Palette occupies indices 1 to len(colors)-1 - scaled_band = np.clip(scaled_band, 1, len(colors) - 1) + np.clip(scaled_band, 1, len(colors) - 1, out=scaled_band) else: - # Palette occupies indices 0 to DST_NODATA-1 - scaled_band = np.clip(scaled_band, 0, DST_NODATA - 1) + # Palette occupies indices 0 to dst_nodata-1 + np.clip(scaled_band, 0, dst_nodata - 1, out=scaled_band) # Only set NaN values to nodata index - scaled_band[nan_mask] = DST_NODATA + scaled_band[nan_mask] = dst_nodata + + del nan_mask + del band_clean color_map = colormap_from_colors(colors) - raster_data = np.expand_dims(scaled_band.data, 0) - return np.array(raster_data, dtype='uint8'), color_map + raster = scaled_band.data.astype('uint8')[np.newaxis, :, :] + return raster, color_map, dst_nodata def image_driver(mime: str) -> str: @@ -470,26 +627,12 @@ def image_driver(mime: str) -> str: return 'PNG' -def get_color_map_from_image(image: Image) -> dict: - """Get a writable color map - - Read the RGBA palette from a PIL Image and covert into a dictionary - that can be written by rasterio. - - """ - color_tuples = np.array(image.getpalette(rawmode='RGBA')).reshape(-1, 4) - color_map = all_black_color_map() - for idx, color_tuple in enumerate(color_tuples): - color_map[idx] = tuple(color_tuple) - return color_map - - def get_aux_xml_filename(image_filename: Path) -> Path: """Get aux.xml filenames.""" return image_filename.with_suffix(image_filename.suffix + '.aux.xml') -def get_tiled_filename(input_file: Path, locator: dict | None = None) -> Path: +def get_tiled_filename(input_file: Path, locator: dict[str, int] | None = None) -> Path: """Add a column, row identifier to the output files. Only update if there is a valid locator dict. @@ -519,26 +662,26 @@ def output_world_file(input_file_path: Path, driver: str = 'PNG'): return input_file_path.with_suffix(ext) -def validate_file_crs(data_array: DataArray) -> None: +def validate_file_crs(src_ds: DatasetReader) -> None: """Explicit check for a CRS on the input geotiff. Raises HyBIGError if crs is missing. """ - if data_array.rio.crs is None: + if src_ds.crs is None: raise HyBIGError('Input geotiff must have defined CRS.') -def validate_file_type(dsr: DatasetReader) -> None: +def validate_file_type(src_ds: DatasetReader) -> None: """Ensure we can work with the input data file. Raise an exception if this file is unusable by the service. """ - if dsr.driver != 'GTiff': - raise HyBIGError(f'Input file type not supported: {dsr.driver}') + if src_ds.driver != 'GTiff': + raise HyBIGError(f'Input file type not supported: {src_ds.driver}') -def get_destination(grid_parameters: GridParams, n_bands: int) -> ndarray: +def get_destination(grid_parameters: GridParams, n_bands: int) -> NDArray: """Initialize an array for writing an output raster.""" return np.zeros( (n_bands, grid_parameters['height'], grid_parameters['width']), dtype='uint8' @@ -546,14 +689,16 @@ def get_destination(grid_parameters: GridParams, n_bands: int) -> ndarray: def write_georaster_as_browse( - data_array: DataArray, - raster: ndarray, + raster: NDArray, + src_crs: rasterio.CRS, + src_transform: rasterio.Affine, color_map: dict | None, + dst_nodata: int | np.uint8, grid_parameters: GridParams, - driver='PNG', - out_file_name='outfile.png', - out_world_name='outfile.pgw', - logger=Logger, + logger: Logger, + driver: str = 'PNG', + out_file_name: str | Path = 'outfile.png', + out_world_name: str | Path = 'outfile.pgw', ) -> None: """Write raster data to output file. @@ -564,14 +709,6 @@ def write_georaster_as_browse( """ n_bands = raster.shape[0] - if color_map is not None: - # DST_NODATA is a global that was set earlier in scale_grey_1band or - # scale_paletted_1band - dst_nodata = DST_NODATA - else: - # for banded data set each band's destination nodata to zero (TRANSPARENT). - dst_nodata = int(TRANSPARENT) - creation_options = { **grid_parameters, 'driver': driver, @@ -579,7 +716,7 @@ def write_georaster_as_browse( 'count': n_bands, } - dest_array = get_destination(grid_parameters, n_bands) + dst_array = get_destination(grid_parameters, n_bands) logger.info(f'Create output image with options: {creation_options}') @@ -587,16 +724,16 @@ def write_georaster_as_browse( for dim in range(0, n_bands): reproject( source=raster[dim, :, :], - destination=dest_array[dim, :, :], - src_transform=data_array.rio.transform(), - src_crs=data_array.rio.crs, + destination=dst_array[dim, :, :], + src_transform=src_transform, + src_crs=src_crs, dst_transform=grid_parameters['transform'], dst_crs=grid_parameters['crs'], - dst_nodata=dst_nodata, + dst_nodata=int(dst_nodata), resampling=Resampling.nearest, ) - dst_raster.write(dest_array) + dst_raster.write(dst_array) if color_map is not None: dst_raster.write_colormap(1, color_map) diff --git a/hybig/color_utility.py b/hybig/color_utility.py index 47981fd..33b5bd8 100644 --- a/hybig/color_utility.py +++ b/hybig/color_utility.py @@ -18,15 +18,16 @@ HyBIGNoColorInformation, ) -ColorMap = dict[uint8, tuple[uint8, uint8, uint8, uint8]] +# Can be tuple[uint8 * 4] for rgba or tuple[uint8 * 3] for rgb +ColorMap = dict[uint8, tuple] # Constants for output PNG images # Applied to transparent pixels where alpha < 255 TRANSPARENT = uint8(0) OPAQUE = uint8(255) # Applied to off grid areas during reprojection -NODATA_RGBA = (0, 0, 0, 0) -NODATA_IDX = 255 +NODATA_RGBA = (uint8(0), uint8(0), uint8(0), TRANSPARENT) +NODATA_IDX = OPAQUE def remove_alpha(raster: np.ndarray) -> tuple[np.ndarray, np.ndarray | None]: @@ -68,8 +69,8 @@ def get_color_palette_from_item(item: Item) -> ColorPalette | None: def get_color_palette( - dataset: DatasetReader, - source: HarmonySource = None, + src_ds: DatasetReader, + source: HarmonySource, item_color_palette: ColorPalette | None = None, ) -> ColorPalette | None: """Get a color palette for the single band image @@ -90,9 +91,9 @@ def get_color_palette( return get_remote_palette_from_source(source) except HyBIGNoColorInformation: try: - ds_cmap = dataset.colormap(1) + ds_cmap = src_ds.colormap(1) # very defensive since this function is not documented in rasterio - ndv_tuple: tuple[float, ...] = dataset.get_nodatavals() + ndv_tuple: tuple[float, ...] = src_ds.get_nodatavals() if ndv_tuple is not None and len(ndv_tuple) > 0: # this service only supports one ndv, so just use the first one # (usually the only one) @@ -104,17 +105,18 @@ def get_color_palette( return None -def get_remote_palette_from_source(source: HarmonySource) -> dict: +def get_remote_palette_from_source(source: HarmonySource) -> ColorPalette: """Get a colormap from a remote url Checks the HarmonySource object for a URL to download a color map for the input raster. """ + remote_colortable_url = '' try: - if len(source.variables) != 1: + if len(source.variables) != 1: # type: ignore raise TypeError('Palette must come from a single variable') - variable = source.variables[0] + variable = source.variables[0] # type: ignore remote_colortable_url = next( r_url.url for r_url in variable.relatedUrls @@ -134,22 +136,24 @@ def get_remote_palette_from_source(source: HarmonySource) -> dict: def all_black_color_map() -> ColorMap: """Return a full length rgba color map with all black values.""" - return {idx: (0, 0, 0, 255) for idx in range(256)} + return {uint8(idx): (uint8(0), uint8(0), uint8(0), OPAQUE) for idx in range(256)} def colormap_from_colors( - colors: list[tuple[uint8, uint8, uint8, uint8]], + colors: list[tuple[int, int, int, int] | tuple[int, int, int]], ) -> ColorMap: + """Return a ColorMap object from a list of colors read from a color map.""" color_map = {} for idx, rgba in enumerate(colors): - color_map[idx] = rgba + color_map[uint8(idx)] = rgba return color_map def greyscale_colormap() -> ColorMap: + """Return a simple greyscale ColorMap.""" color_map = {} for idx in range(255): - color_map[idx] = (idx, idx, idx, 255) + color_map[uint8(idx)] = (uint8(idx), uint8(idx), uint8(idx), OPAQUE) color_map[NODATA_IDX] = NODATA_RGBA return color_map diff --git a/hybig/crs.py b/hybig/crs.py index 8343056..b9fa986 100644 --- a/hybig/crs.py +++ b/hybig/crs.py @@ -10,12 +10,10 @@ """ +import rasterio from harmony_service_lib.message import SRS -from pyproj.crs import CRS as pyCRS - -# pylint: disable-next=no-name-in-module -from rasterio.crs import CRS -from xarray import DataArray +from pyproj import CRS as pyCRS +from rasterio.io import DatasetReader from hybig.exceptions import HyBIGValueError @@ -29,7 +27,7 @@ } -def choose_target_crs(srs: SRS, data_array: DataArray) -> CRS: +def choose_target_crs(srs: SRS | None, src_ds: DatasetReader) -> rasterio.CRS: """Return the target CRS for the output image. If a harmony message defines a SRS, we use that as the target ouptut CRS. @@ -39,7 +37,7 @@ def choose_target_crs(srs: SRS, data_array: DataArray) -> CRS: """ if srs is not None: return choose_crs_from_srs(srs) - return choose_crs_from_metadata(data_array) + return choose_crs_from_metadata(src_ds) def choose_crs_from_srs(srs: SRS): @@ -54,30 +52,32 @@ def choose_crs_from_srs(srs: SRS): """ try: - if srs.epsg is not None and srs.epsg != '': - return CRS.from_string(srs.epsg) - if srs.wkt is not None and srs.wkt != '': - return CRS.from_string(srs.wkt) - return CRS.from_string(srs.proj4) + # harmony defines properties for classes in a way that type checkers + # can't pick up on, so we use type: ignore to suppress it + if srs.epsg is not None and srs.epsg != '': # type: ignore + return rasterio.CRS.from_string(srs.epsg) # type: ignore + if srs.wkt is not None and srs.wkt != '': # type: ignore + return rasterio.CRS.from_string(srs.wkt) # type: ignore + return rasterio.CRS.from_string(srs.proj4) # type: ignore except Exception as exception: raise HyBIGValueError(f'Bad input SRS: {str(exception)}') from exception -def is_preferred_crs(crs: CRS) -> bool: - """Returns true if the input CRS is preferred by GIBS.""" +def is_preferred_crs(crs: rasterio.CRS) -> bool: + """Returns true if the input rasterio.CRS is preferred by GIBS.""" if crs.to_string() in PREFERRED_CRS.values(): return True return False -def choose_crs_from_metadata(data_array: DataArray) -> CRS | None: +def choose_crs_from_metadata(src_ds: DatasetReader) -> rasterio.CRS | None: """Determine the best CRS based on input metadata.""" - if is_preferred_crs(data_array.rio.crs): - return data_array.rio.crs - return choose_best_crs_from_metadata(data_array.rio.crs) + if is_preferred_crs(src_ds.crs): + return src_ds.crs + return choose_best_crs_from_metadata(src_ds.crs) -def choose_best_crs_from_metadata(crs: CRS) -> CRS: +def choose_best_crs_from_metadata(crs: rasterio.CRS) -> rasterio.CRS: """Determine the best preferred CRS based on the input CRS. We are targeting GIBS which has three preferred CRSs a Northern Polar @@ -103,12 +103,12 @@ def choose_best_crs_from_metadata(crs: CRS) -> CRS: projection_params = pyCRS(crs).to_dict() if projection_params.get('proj', None) == 'longlat': - return CRS.from_string(PREFERRED_CRS['global']) + return rasterio.CRS.from_string(PREFERRED_CRS['global']) if projection_params.get('lat_0', 0.0) >= 80: - return CRS.from_string(PREFERRED_CRS['north']) + return rasterio.CRS.from_string(PREFERRED_CRS['north']) if projection_params.get('lat_0', 0.0) <= -80: - return CRS.from_string(PREFERRED_CRS['south']) + return rasterio.CRS.from_string(PREFERRED_CRS['south']) - return CRS.from_string(PREFERRED_CRS['global']) + return rasterio.CRS.from_string(PREFERRED_CRS['global']) diff --git a/hybig/sizes.py b/hybig/sizes.py index 603fa24..1c68321 100644 --- a/hybig/sizes.py +++ b/hybig/sizes.py @@ -12,6 +12,7 @@ from typing import TypedDict import numpy as np +import rasterio from affine import Affine from harmony_service_lib.message import Message from harmony_service_lib.message_utility import ( @@ -19,11 +20,9 @@ has_scale_extents, has_scale_sizes, ) - -# pylint: disable-next=no-name-in-module -from rasterio.crs import CRS +from rasterio import DatasetReader from rasterio.transform import AffineTransformer, from_bounds, from_origin -from xarray import DataArray +from rasterio.warp import transform_bounds from hybig.crs import ( choose_target_crs, @@ -35,7 +34,7 @@ class GridParams(TypedDict): height: int width: int - crs: CRS + crs: rasterio.CRS transform: Affine @@ -109,7 +108,7 @@ class Dimensions(TypedDict): epsg_3031_resolutions = epsg_3413_resolutions -def get_target_grid_parameters(message: Message, data_array: DataArray) -> GridParams: +def get_target_grid_parameters(message: Message, src_ds: DatasetReader) -> GridParams: """Get the output image parameters. This computes the target grid of the ouptut image. The grid is defined by @@ -122,16 +121,16 @@ def get_target_grid_parameters(message: Message, data_array: DataArray) -> GridP - Computed parameters attempt to generate GIBS suitable images. """ - target_crs = choose_target_crs(message.format.srs, data_array) - target_scale_extent = choose_scale_extent(message, target_crs, data_array) + target_crs = choose_target_crs(message.format.srs, src_ds) + target_scale_extent = choose_scale_extent(message, target_crs, src_ds) target_dimensions = choose_target_dimensions( - message, data_array, target_scale_extent, target_crs + message, src_ds, target_scale_extent, target_crs ) return get_rasterio_parameters(target_crs, target_scale_extent, target_dimensions) def choose_scale_extent( - message: Message, target_crs: CRS, data_array: DataArray + message: Message, dst_crs: rasterio.CRS, src_ds: DatasetReader ) -> ScaleExtent: """Return the scaleExtent for the target image. @@ -145,26 +144,34 @@ def choose_scale_extent( # These values must be in the target_crs projection. scale_extent = ScaleExtent( { - 'xmin': message.format.scaleExtent.x.min, - 'ymin': message.format.scaleExtent.y.min, - 'xmax': message.format.scaleExtent.x.max, - 'ymax': message.format.scaleExtent.y.max, + 'xmin': message.format.scaleExtent.x.min, # type: ignore + 'ymin': message.format.scaleExtent.y.min, # type: ignore + 'xmax': message.format.scaleExtent.x.max, # type: ignore + 'ymax': message.format.scaleExtent.y.max, # type: ignore } ) else: - left, bottom, right, top = data_array.rio.transform_bounds(target_crs) + left, bottom, right, top = transform_bounds(src_ds.crs, dst_crs, *src_ds.bounds) # Correct for antimeridian crossing. if left > right: right = right + 360 scale_extent = ScaleExtent( - {'xmin': left, 'ymin': bottom, 'xmax': right, 'ymax': top} + { + 'xmin': left, + 'ymin': min(bottom, top), + 'xmax': right, + 'ymax': max(bottom, top), + } ) return scale_extent def choose_target_dimensions( - message: Message, data_array: DataArray, scale_extent: ScaleExtent, target_crs: CRS + message: Message, + src_ds: DatasetReader, + scale_extent: ScaleExtent, + target_crs: rasterio.CRS, ) -> Dimensions: """This selects or computes the target Dimensions. @@ -186,20 +193,22 @@ def choose_target_dimensions( """ if has_dimensions(message): dimensions = Dimensions( - {'height': message.format.height, 'width': message.format.width} + {'height': message.format.height, 'width': message.format.width} # type: ignore ) elif has_scale_sizes(message): dimensions = compute_target_dimensions( - scale_extent, message.format.scaleSize.x, message.format.scaleSize.y + scale_extent, + message.format.scaleSize.x, # type: ignore + message.format.scaleSize.y, # type: ignore ) else: - dimensions = best_guess_target_dimensions(data_array, scale_extent, target_crs) + dimensions = best_guess_target_dimensions(src_ds, scale_extent, target_crs) return dimensions def get_rasterio_parameters( - crs: CRS, scale_extent: ScaleExtent, dimensions: Dimensions + crs: rasterio.CRS, scale_extent: ScaleExtent, dimensions: Dimensions ) -> GridParams: """Convert the grid into rasterio consumable format. @@ -323,12 +332,12 @@ def needs_tiling(grid_parameters: GridParams) -> bool: From discussion, this limit is set to 8192*8192 cells. """ - MAX_UNTILED_GRIDCELLS = 8192 * 8192 - return grid_parameters['height'] * grid_parameters['width'] > MAX_UNTILED_GRIDCELLS + max_untiled_gridcells = 8192 * 8192 + return grid_parameters['height'] * grid_parameters['width'] > max_untiled_gridcells def best_guess_target_dimensions( - data_array: DataArray, scale_extent: ScaleExtent, target_crs: CRS + src_ds: DatasetReader, scale_extent: ScaleExtent, target_crs: rasterio.CRS ) -> Dimensions: """Return best guess for output image dimensions. @@ -344,12 +353,12 @@ def best_guess_target_dimensions( else: resolution_list = epsg_3413_resolutions - x_res, y_res = resolution_in_target_crs_units(data_array, target_crs) + x_res, y_res = resolution_in_target_crs_units(src_ds, target_crs) return guess_dimensions(x_res, y_res, scale_extent, resolution_list) def resolution_in_target_crs_units( - data_array: DataArray, target_crs: CRS + src_ds: DatasetReader, target_crs: rasterio.CRS ) -> tuple[float, float]: """Return the x and y target resolutions @@ -361,17 +370,17 @@ def resolution_in_target_crs_units( the user has not supplied any input parameters and we are trying to determine the dimensions for the output image. """ - if data_array.rio.crs.is_projected == target_crs.is_projected: - x_res = data_array.rio.transform().a - y_res = abs(data_array.rio.transform().e) + if src_ds.crs.is_projected == target_crs.is_projected: + x_res = src_ds.transform.a + y_res = abs(src_ds.transform.e) elif target_crs.is_projected: # transform from latlon to meters - x_res = data_array.rio.transform().a * METERS_PER_DEGREE - y_res = abs(data_array.rio.transform().e) * METERS_PER_DEGREE + x_res = src_ds.transform.a * METERS_PER_DEGREE + y_res = abs(src_ds.transform.e) * METERS_PER_DEGREE else: # transform from meters to lat/lon - x_res = data_array.rio.transform().a / METERS_PER_DEGREE - y_res = abs(data_array.rio.transform().e) / METERS_PER_DEGREE + x_res = src_ds.transform.a / METERS_PER_DEGREE + y_res = abs(src_ds.transform.e) / METERS_PER_DEGREE return x_res, y_res @@ -418,14 +427,14 @@ def compute_target_dimensions( def find_closest_resolution( resolutions: list[float], resolution_info: list[ResolutionInfo] -) -> ResolutionInfo | None: +) -> ResolutionInfo: """Return closest match to GIBS preferred Resolution Info. Cycle through all input resolutions and return the resolution_info that has the smallest absolute difference to any of the input resolutions. """ - best_info = None + best_info = resolution_info[0] smallest_diff = np.inf for res in resolutions: for info in resolution_info: diff --git a/mypy.ini b/mypy.ini new file mode 100644 index 0000000..d787271 --- /dev/null +++ b/mypy.ini @@ -0,0 +1,2 @@ +[mypy] +disable_error_code = import-untyped diff --git a/pip_requirements.txt b/pip_requirements.txt index d68169b..b9093cc 100644 --- a/pip_requirements.txt +++ b/pip_requirements.txt @@ -5,4 +5,3 @@ pillow==10.4.0 pyproj==3.6.1 pystac~=1.0.1 rasterio==1.3.10 -rioxarray==0.17.0 diff --git a/pip_requirements_skip_snyk.txt b/pip_requirements_skip_snyk.txt index b4a7e1d..5d967c0 100644 --- a/pip_requirements_skip_snyk.txt +++ b/pip_requirements_skip_snyk.txt @@ -1,3 +1,3 @@ # Because snyk can't install gdal properly during automated scans, we have this # file that includes the GDAL package alone which is not scanned by snyk -GDAL==3.6.2 +GDAL==3.12.1 diff --git a/tests/test_service/test_adapter.py b/tests/test_service/test_adapter.py index 92a9b5e..43621de 100644 --- a/tests/test_service/test_adapter.py +++ b/tests/test_service/test_adapter.py @@ -7,18 +7,18 @@ from unittest.mock import call, patch import numpy as np +import rasterio from harmony_service_lib.exceptions import ForbiddenException from harmony_service_lib.message import Message from harmony_service_lib.util import config from pystac import Catalog from rasterio.transform import array_bounds, from_bounds from rasterio.warp import Resampling -from rioxarray import open_rasterio from harmony_service.adapter import BrowseImageGeneratorAdapter from harmony_service.exceptions import HyBIGServiceError from hybig.browse import ( - convert_mulitband_to_raster, + convert_multiband_to_raster, ) from tests.utilities import Granule, create_stac @@ -194,16 +194,16 @@ def move_tif(*args, **kwargs): # input CRS is preferred 4326, # Scale Extent from icd # dimensions from input data - rio_data_array = open_rasterio(self.red_tif_fixture) + rio_data_array = rasterio.open(self.red_tif_fixture) left, bottom, right, top = array_bounds( - rio_data_array.rio.width, - rio_data_array.rio.height, - rio_data_array.rio.transform(), + rio_data_array.shape[0], + rio_data_array.shape[1], + rio_data_array.transform, ) image_scale_extent = {'xmin': left, 'ymin': bottom, 'xmax': right, 'ymax': top} - expected_width = round((right - left) / rio_data_array.rio.transform().a) - expected_height = round((top - bottom) / -rio_data_array.rio.transform().e) + expected_width = round((right - left) / rio_data_array.transform.a) + expected_height = round((top - bottom) / -rio_data_array.transform.e) expected_transform = from_bounds( image_scale_extent['xmin'], image_scale_extent['ymin'], @@ -216,14 +216,14 @@ def move_tif(*args, **kwargs): expected_params = { 'width': expected_width, 'height': expected_height, - 'crs': rio_data_array.rio.crs, + 'crs': rio_data_array.crs, 'transform': expected_transform, 'driver': 'JPEG', 'dtype': 'uint8', 'dst_nodata': 255, 'count': 3, } - raster = convert_mulitband_to_raster(rio_data_array) + raster = convert_multiband_to_raster(rio_data_array.read()) dest = np.full( (expected_params['height'], expected_params['width']), @@ -235,8 +235,8 @@ def move_tif(*args, **kwargs): call( source=raster[0, :, :], destination=dest, - src_transform=rio_data_array.rio.transform(), - src_crs=rio_data_array.rio.crs, + src_transform=rio_data_array.transform, + src_crs=rio_data_array.crs, dst_transform=expected_params['transform'], dst_crs=expected_params['crs'], dst_nodata=expected_params['dst_nodata'], @@ -245,8 +245,8 @@ def move_tif(*args, **kwargs): call( source=raster[1, :, :], destination=dest, - src_transform=rio_data_array.rio.transform(), - src_crs=rio_data_array.rio.crs, + src_transform=rio_data_array.transform, + src_crs=rio_data_array.crs, dst_transform=expected_params['transform'], dst_crs=expected_params['crs'], dst_nodata=expected_params['dst_nodata'], @@ -255,8 +255,8 @@ def move_tif(*args, **kwargs): call( source=raster[2, :, :], destination=dest, - src_transform=rio_data_array.rio.transform(), - src_crs=rio_data_array.rio.crs, + src_transform=rio_data_array.transform, + src_crs=rio_data_array.crs, dst_transform=expected_params['transform'], dst_crs=expected_params['crs'], dst_nodata=expected_params['dst_nodata'], @@ -429,16 +429,16 @@ def move_tif(*args, **kwargs): # input CRS is preferred 4326, # Scale Extent from icd # dimensions from input data - rio_data_array = open_rasterio(self.red_tif_fixture) + rio_data_array = rasterio.open(self.red_tif_fixture) left, bottom, right, top = array_bounds( - rio_data_array.rio.width, - rio_data_array.rio.height, - rio_data_array.rio.transform(), + rio_data_array.shape[0], + rio_data_array.shape[1], + rio_data_array.transform, ) image_scale_extent = {'xmin': left, 'ymin': bottom, 'xmax': right, 'ymax': top} - expected_width = round((right - left) / rio_data_array.rio.transform().a) - expected_height = round((top - bottom) / -rio_data_array.rio.transform().e) + expected_width = round((right - left) / rio_data_array.transform.a) + expected_height = round((top - bottom) / -rio_data_array.transform.e) expected_transform = from_bounds( image_scale_extent['xmin'], image_scale_extent['ymin'], @@ -451,14 +451,14 @@ def move_tif(*args, **kwargs): expected_params = { 'width': expected_width, 'height': expected_height, - 'crs': rio_data_array.rio.crs, + 'crs': rio_data_array.crs, 'transform': expected_transform, 'driver': 'PNG', 'dtype': 'uint8', 'dst_nodata': 0, 'count': 3, } - raster = convert_mulitband_to_raster(rio_data_array) + raster = convert_multiband_to_raster(rio_data_array.read()) dest = np.full( (expected_params['height'], expected_params['width']), @@ -470,8 +470,8 @@ def move_tif(*args, **kwargs): call( source=raster[band, :, :], destination=dest, - src_transform=rio_data_array.rio.transform(), - src_crs=rio_data_array.rio.crs, + src_transform=rio_data_array.transform, + src_crs=rio_data_array.crs, dst_transform=expected_params['transform'], dst_crs=expected_params['crs'], dst_nodata=expected_params['dst_nodata'], diff --git a/tests/unit/test_browse.py b/tests/unit/test_browse.py index d22470a..04d6fd8 100644 --- a/tests/unit/test_browse.py +++ b/tests/unit/test_browse.py @@ -12,20 +12,16 @@ from harmony_service_lib.message import Source as HarmonySource from numpy.testing import assert_array_equal, assert_equal from osgeo_utils.auxiliary.color_palette import ColorPalette -from PIL import Image from rasterio import Affine from rasterio.crs import CRS from rasterio.io import DatasetReader, DatasetWriter -from rasterio.transform import array_bounds from rasterio.warp import Resampling -from xarray import DataArray from hybig.browse import ( - convert_mulitband_to_raster, + convert_multiband_to_raster, convert_singleband_to_raster, create_browse, create_browse_imagery, - get_color_map_from_image, get_tiled_filename, output_image_file, output_world_file, @@ -149,27 +145,24 @@ def test_create_browse_imagery_with_single_band_raster(self): @patch('hybig.browse.reproject') @patch('rasterio.open') - @patch('hybig.browse.open_rasterio') - def test_create_browse_imagery_with_mocks( - self, rioxarray_open_mock, rasterio_open_mock, reproject_mock - ): + def test_create_browse_imagery_with_mocks(self, rasterio_open_mock, reproject_mock): file_transform = Affine(90.0, 0.0, -180.0, 0.0, -45.0, 90.0) - da_mock = MagicMock(DataArray) - in_dataset_mock = Mock(DatasetReader) - da_mock.rio._manager.acquire.return_value = in_dataset_mock - - dest_write_mock = Mock(DatasetWriter) + ds = Mock(spec=DatasetReader) - da_mock.__getitem__.return_value = self.data - in_dataset_mock.driver = 'GTiff' - da_mock.rio.height = 4 - da_mock.rio.width = 4 - da_mock.rio.transform.return_value = file_transform - da_mock.rio.crs = CRS.from_string('EPSG:4326') - da_mock.rio.count = 1 - in_dataset_mock.colormap = Mock(side_effect=ValueError) + dest_write_mock = Mock(spec=DatasetWriter) - da_mock.rio.transform_bounds.return_value = array_bounds(4, 4, file_transform) + ds.read.return_value = self.data[np.newaxis, :, :] + ds.driver = 'GTiff' + ds.shape = (4, 4) + ds.transform = file_transform + ds.crs = CRS.from_string('EPSG:4326') + ds.count = 1 + ds.colormap = Mock(side_effect=ValueError) + ds.bounds = (-180.0, -90.0, 180.0, 90.0) + ds.window_transform = Mock(return_value=file_transform) + ds.nodatavals = (255,) + ds.scales = (1,) + ds.offsets = (0,) expected_raster = np.array( [ @@ -183,23 +176,32 @@ def test_create_browse_imagery_with_mocks( dtype='uint8', ) - rioxarray_open_mock.return_value.__enter__.side_effect = [ - da_mock, - ] rasterio_open_mock.return_value.__enter__.side_effect = [ + ds, dest_write_mock, ] message = HarmonyMessage({'format': {'mime': 'JPEG'}}) - # Act to run the test - out_file_list = create_browse_imagery( - message, - self.tmp_dir / 'input_file_path', - HarmonySource({}), - None, - self.logger, - ) + # More detailed traceback for this test since it's end-to-end + try: + out_file_list = create_browse_imagery( + message, + str(self.tmp_dir / 'input_file_path'), + HarmonySource({}), + None, + self.logger, + ) + except HyBIGError as e: + import traceback + + print('\n=== Full Traceback ===') + traceback.print_exc() + print('\n=== Exception Chain ===') + print(f'HyBIGError: {e}') + if e.__cause__: + print(f'Caused by: {type(e.__cause__).__name__}: {e.__cause__}') + raise # Ensure tiling logic was not called: self.assertEqual(len(out_file_list), 1) @@ -207,7 +209,7 @@ def test_create_browse_imagery_with_mocks( actual_image, actual_world, actual_aux = out_file_list[0] target_transform = Affine(90.0, 0.0, -180.0, 0.0, -45.0, 90.0) - dest = np.zeros((da_mock.rio.height, da_mock.rio.width), dtype='uint8') + dest = np.zeros((ds.shape[0], ds.shape[1]), dtype='uint8') # For JPEG output with 1-band input, we convert to RGB, so we reproject 3 bands self.assertEqual(reproject_mock.call_count, 3) @@ -218,7 +220,7 @@ def test_create_browse_imagery_with_mocks( source=expected_raster[0, :, :], destination=dest, src_transform=file_transform, - src_crs=da_mock.rio.crs, + src_crs=ds.crs, dst_transform=target_transform, dst_crs=CRS.from_string('EPSG:4326'), dst_nodata=0, # TRANSPARENT for RGB data @@ -228,7 +230,7 @@ def test_create_browse_imagery_with_mocks( source=expected_raster[0, :, :], destination=dest, src_transform=file_transform, - src_crs=da_mock.rio.crs, + src_crs=ds.crs, dst_transform=target_transform, dst_crs=CRS.from_string('EPSG:4326'), dst_nodata=0, # TRANSPARENT for RGB data @@ -238,7 +240,7 @@ def test_create_browse_imagery_with_mocks( source=expected_raster[0, :, :], destination=dest, src_transform=file_transform, - src_crs=da_mock.rio.crs, + src_crs=ds.crs, dst_transform=target_transform, dst_crs=CRS.from_string('EPSG:4326'), dst_nodata=0, # TRANSPARENT for RGB data @@ -295,7 +297,8 @@ def test_convert_singleband_to_raster_without_colortable(self): """Tests scale_grey_1band.""" return_data = np.copy(self.data).astype('float64') return_data[0][1] = np.nan - ds = DataArray(return_data).expand_dims('band') + return_data = return_data[np.newaxis, :, :] + # ds = DataArray(return_data).expand_dims('band') expected_raster = np.array( [ @@ -308,12 +311,10 @@ def test_convert_singleband_to_raster_without_colortable(self): ], dtype='uint8', ) - actual_raster, _ = convert_singleband_to_raster(ds, None) + actual_raster, _, _ = convert_singleband_to_raster(return_data, None) assert_array_equal(expected_raster, actual_raster, strict=True) def test_convert_singleband_to_raster_with_colormap(self): - ds = DataArray(self.data).expand_dims('band') - expected_raster = np.array( [ [ # singleband paletted @@ -334,14 +335,17 @@ def test_convert_singleband_to_raster_with_colormap(self): } # Read down: red, yellow, green, blue image_palette = convert_colormap_to_palette(self.colormap) - actual_raster, actual_palette = convert_singleband_to_raster(ds, image_palette) + # functional equivalent of DataArray().expand_dims("bands") + actual_raster, actual_palette, _ = convert_singleband_to_raster( + self.data[np.newaxis, :, :], image_palette + ) assert_array_equal(expected_raster, actual_raster, strict=True) assert_equal(expected_palette, actual_palette) def test_convert_singleband_to_raster_with_colormap_and_bad_data(self): data_array = np.array(self.data, dtype='float') data_array[0, 0] = np.nan - ds = DataArray(data_array).expand_dims('band') + data_array = data_array[np.newaxis, :, :] nv_color = (10, 20, 30, 40) # Read the image down: red, yellow, green, blue @@ -367,7 +371,9 @@ def test_convert_singleband_to_raster_with_colormap_and_bad_data(self): colormap = {**self.colormap, 'nv': nv_color} image_palette = convert_colormap_to_palette(colormap) - actual_raster, actual_palette = convert_singleband_to_raster(ds, image_palette) + actual_raster, actual_palette, _ = convert_singleband_to_raster( + data_array, image_palette + ) assert_array_equal(expected_raster, actual_raster, strict=True) assert_equal(expected_palette, actual_palette) @@ -376,11 +382,7 @@ def test_convert_uint16_3_multiband_to_raster(self): bad_data = np.copy(self.data).astype('float64') bad_data[1][1] = np.nan bad_data[1][2] = np.nan - ds = DataArray( - np.stack([self.data, bad_data, self.data]), - dims=('band', 'y', 'x'), - ) - ds.encoding = {'dtype': 'uint16'} + data_array = np.stack([self.data, bad_data, self.data]).astype('float64') expected_raster = np.array( [ @@ -412,8 +414,8 @@ def test_convert_uint16_3_multiband_to_raster(self): dtype='uint8', ) - actual_raster = convert_mulitband_to_raster(ds) - assert_array_equal(expected_raster, actual_raster.data, strict=True) + actual_raster = convert_multiband_to_raster(data_array) + assert_array_equal(expected_raster, actual_raster, strict=True) def test_convert_uint8_3_multiband_to_raster(self): """Ensure valid data is unchanged when input is uint8.""" @@ -426,11 +428,7 @@ def test_convert_uint8_3_multiband_to_raster(self): ] ).astype('float32') - ds = DataArray( - np.stack([scale_data, scale_data, scale_data]), - dims=('band', 'y', 'x'), - ) - ds.encoding = {'dtype': 'uint8'} + data_array = np.stack([scale_data, scale_data, scale_data]).astype('float64') expected_data = scale_data.copy() expected_data[1][1] = 0 @@ -451,14 +449,11 @@ def test_convert_uint8_3_multiband_to_raster(self): dtype='uint8', ) - actual_raster = convert_mulitband_to_raster(ds) - assert_array_equal(expected_raster, actual_raster.data, strict=True) + actual_raster = convert_multiband_to_raster(data_array) + assert_array_equal(expected_raster, actual_raster, strict=True) def test_convert_4_multiband_uint8_to_raster(self): """4-band 'uint8' images are returned unchanged.""" - ds = Mock(DataArray) - ds.rio.count = 4 - r_data = np.array( [ [10, 200, 30, 40], @@ -474,20 +469,15 @@ def test_convert_4_multiband_uint8_to_raster(self): a_data = np.ones_like(r_data) * 255 a_data[0, 0] = 0 - to_numpy_result = np.stack([r_data, g_data, b_data, a_data]) - - ds.to_numpy.return_value = to_numpy_result + data_array = np.stack([r_data, g_data, b_data, a_data]) - expected_raster = to_numpy_result + expected_raster = data_array - actual_raster = convert_mulitband_to_raster(ds) - assert_array_equal(expected_raster, actual_raster.data, strict=True) + actual_raster = convert_multiband_to_raster(data_array) + assert_array_equal(expected_raster, actual_raster, strict=True) def test_convert_4_multiband_uint16_to_raster(self): """4-band 'uint16' images are scaled if their range exceeds 255.""" - ds = Mock(DataArray) - ds.rio.count = 4 - r_data = np.array( [ [10, 200, 300, 400], @@ -502,23 +492,19 @@ def test_convert_4_multiband_uint16_to_raster(self): a_data = np.ones_like(self.data) * OPAQUE a_data[0, 0] = TRANSPARENT - to_numpy_result = np.stack([r_data, g_data, b_data, a_data]) - - ds.to_numpy.return_value = to_numpy_result + data_array = np.stack([r_data, g_data, b_data, a_data]) # expect the input data to have the data values from 0 to 400 to be # scaled into the range 0 to 255. expected_raster = np.around( - np.interp(to_numpy_result, (0, 400), (0.0, 1.0)) * 255.0 + np.interp(data_array, (0, 400), (0.0, 1.0)) * 255.0 ).astype('uint8') - actual_raster = convert_mulitband_to_raster(ds) - assert_array_equal(expected_raster, actual_raster.data, strict=True) + actual_raster = convert_multiband_to_raster(data_array) + assert_array_equal(expected_raster, actual_raster, strict=True) def test_convert_4_multiband_masked_to_raster(self): """4-band images are returned with nan -> 0""" - ds = Mock(DataArray) - ds.rio.count = 4 nan = np.nan input_array = np.array( [ @@ -549,7 +535,6 @@ def test_convert_4_multiband_masked_to_raster(self): ], dtype=np.float32, ) - ds.to_numpy.return_value = input_array expected_raster = np.array( [ @@ -581,59 +566,20 @@ def test_convert_4_multiband_masked_to_raster(self): dtype=np.uint8, ) - actual_raster = convert_mulitband_to_raster(ds) - assert_array_equal(expected_raster.data, actual_raster.data, strict=True) + actual_raster = convert_multiband_to_raster(input_array) + assert_array_equal(expected_raster, actual_raster, strict=True) def test_convert_5_multiband_to_raster(self): - ds = Mock(DataArray) - ds.rio.count = 5 - ds.to_numpy.return_value = np.stack( - [self.data, self.data, self.data, self.data, self.data] - ) + data_array = np.stack([self.data, self.data, self.data, self.data, self.data]) with self.assertRaises(HyBIGError) as excepted: - convert_mulitband_to_raster(ds) + convert_multiband_to_raster(data_array) self.assertEqual( excepted.exception.message, 'Cannot create image from 5 band image. Expecting 3 or 4 bands.', ) - def test_get_color_map_from_image(self): - """PIL Image yields a color_map - - A palette from an PIL Image is correctly turned into a colormap - writable by rasterio. - - """ - # random image with values of 0 to 4. - image_data = self.random.integers(5, size=(5, 6), dtype='uint8') - # fmt: off - palette_sequence = [ - 255, 0, 0, 255, - 0, 255, 0, 255, - 0, 0, 255, 255, - 225, 100, 25, 25, - 0, 0, 0, 0 - ] - # fmt: on - test_image = Image.fromarray(image_data) - test_image.putpalette(palette_sequence, rawmode='RGBA') - - expected_color_map = { - **{ - 0: (255, 0, 0, 255), - 1: (0, 255, 0, 255), - 2: (0, 0, 255, 255), - 3: (225, 100, 25, 25), - 4: (0, 0, 0, 0), - }, - **{idx: (0, 0, 0, 255) for idx in range(5, 256)}, - } - - actual_color_map = get_color_map_from_image(test_image) - self.assertDictEqual(expected_color_map, actual_color_map) - def test_get_color_palette_map_exists_source_does_not(self): ds = Mock(DatasetReader) ds.colormap.return_value = self.colormap @@ -706,19 +652,19 @@ def test_get_tiled_filename(self): def test_validate_file_crs_valid(self): """Valid file should return None.""" - da = Mock(DataArray) - da.rio.crs = CRS.from_epsg(4326) + ds = Mock(DatasetReader) + ds.crs = CRS.from_epsg(4326) try: - validate_file_crs(da) + validate_file_crs(ds) except Exception: self.fail('Valid file threw unexpected exception.') def test_validate_file_crs_missing(self): """Invalid file should raise exception.""" - da = Mock(DataArray) - da.rio.crs = None + ds = Mock(DatasetReader) + ds.crs = None with self.assertRaisesRegex(HyBIGError, 'Input geotiff must have defined CRS.'): - validate_file_crs(da) + validate_file_crs(ds) def test_validate_file_type_valid(self): """Validation should not raise exception.""" @@ -790,7 +736,8 @@ def test_scale_paletted_1band_clips_underflow_values(self): [-100, 0, 150, 250], ] ).astype('float64') - ds = DataArray(data_with_underflow).expand_dims('band') + # functional equivalent of DataArray().expand_dims("bands") + data_with_underflow = data_with_underflow[np.newaxis, :, :] # Expected: underflow values (-50, 50, 0, -100) should map to index 0 # which is the lowest color (red at value 100) @@ -807,7 +754,7 @@ def test_scale_paletted_1band_clips_underflow_values(self): ) image_palette = convert_colormap_to_palette(self.colormap) - actual_raster, _ = scale_paletted_1band(ds, image_palette) + actual_raster, _, _ = scale_paletted_1band(data_with_underflow, image_palette) assert_array_equal(expected_raster, actual_raster, strict=True) def test_scale_paletted_1band_clips_overflow_values(self): @@ -824,7 +771,8 @@ def test_scale_paletted_1band_clips_overflow_values(self): [300, 350, 400, 800], ] ).astype('float64') - ds = DataArray(data_with_overflow).expand_dims('band') + # functional equivalent of DataArray().expand_dims("bands") + data_with_overflow = data_with_overflow[np.newaxis, :, :] # Expected: overflow values (500, 600, 1000, 800) should map to index 3 # which is the highest color (blue at value 400) @@ -841,7 +789,7 @@ def test_scale_paletted_1band_clips_overflow_values(self): ) image_palette = convert_colormap_to_palette(self.colormap) - actual_raster, _ = scale_paletted_1band(ds, image_palette) + actual_raster, _, _ = scale_paletted_1band(data_with_overflow, image_palette) assert_array_equal(expected_raster, actual_raster, strict=True) def test_scale_paletted_1band_with_nan_and_clipping(self): @@ -857,7 +805,8 @@ def test_scale_paletted_1band_with_nan_and_clipping(self): [-100, 250, np.nan, 800], ] ).astype('float64') - ds = DataArray(data_mixed).expand_dims('band') + # functional equivalent of DataArray().expand_dims("bands") + data_mixed = data_mixed[np.newaxis, :, :] # Expected: NaN -> 4 (nodata), underflow -> 0, overflow -> 3 expected_raster = np.array( @@ -873,7 +822,9 @@ def test_scale_paletted_1band_with_nan_and_clipping(self): ) image_palette = convert_colormap_to_palette(self.colormap) - actual_raster, actual_palette = scale_paletted_1band(ds, image_palette) + actual_raster, actual_palette, _ = scale_paletted_1band( + data_mixed, image_palette + ) assert_array_equal(expected_raster, actual_raster, strict=True) # Verify nodata color is transparent @@ -891,10 +842,11 @@ def test_scale_paletted_1band_to_rgb_clips_underflow_values(self): [100, 200, 300, 400], ] ).astype('float64') - ds = DataArray(data_with_underflow).expand_dims('band') + # functional equivalent of DataArray().expand_dims("bands") + data_with_underflow = data_with_underflow[np.newaxis, :, :] image_palette = convert_colormap_to_palette(self.colormap) - actual_rgb, _ = scale_paletted_1band_to_rgb(ds, image_palette) + actual_rgb, _ = scale_paletted_1band_to_rgb(data_with_underflow, image_palette) # Values -50 and 50 should get red color (255, 0, 0) # which is the lowest color in the palette @@ -917,10 +869,11 @@ def test_scale_paletted_1band_to_rgb_clips_overflow_values(self): [300, 400, 800, 1500], ] ).astype('float64') - ds = DataArray(data_with_overflow).expand_dims('band') + # functional equivalent of DataArray().expand_dims("bands") + data_with_overflow = data_with_overflow[np.newaxis, :, :] image_palette = convert_colormap_to_palette(self.colormap) - actual_rgb, _ = scale_paletted_1band_to_rgb(ds, image_palette) + actual_rgb, _ = scale_paletted_1band_to_rgb(data_with_overflow, image_palette) # Values 500, 600, 1000, 800, 1500 should get blue color (0, 0, 255) # which is the highest color in the palette @@ -945,10 +898,11 @@ def test_scale_paletted_1band_to_rgb_with_nan_and_clipping(self): [50, 200, np.nan, 1000], ] ).astype('float64') - ds = DataArray(data_mixed).expand_dims('band') + # functional equivalent of DataArray().expand_dims("bands") + data_mixed = data_mixed[np.newaxis, :, :] image_palette = convert_colormap_to_palette(self.colormap) - actual_rgb, _ = scale_paletted_1band_to_rgb(ds, image_palette) + actual_rgb, _ = scale_paletted_1band_to_rgb(data_mixed, image_palette) # NaN should map to nodata color (0, 0, 0) self.assertEqual(actual_rgb[0, 0, 0], 0) # Red for NaN at (0,0) diff --git a/tests/unit/test_color_utility.py b/tests/unit/test_color_utility.py index 732c780..5b404dc 100644 --- a/tests/unit/test_color_utility.py +++ b/tests/unit/test_color_utility.py @@ -213,7 +213,7 @@ def test_get_color_palette_with_item_palette( item_palette = convert_colormap_to_palette(self.colormap) expected_palette = item_palette - actual_palette = get_color_palette({}, {}, item_palette) + actual_palette = get_color_palette(ds, HarmonySource({}), item_palette) self.assertEqual(expected_palette, actual_palette) get_remote_palette_from_source_mock.assert_not_called() ds.colormap.assert_not_called() @@ -327,15 +327,21 @@ def test_get_color_palette_with_ndv(self, get_remote_palette_from_source_mock): # that matches one of our colormap keys ds.get_nodatavals.return_value = (100,) - actual_palette = get_color_palette(ds, None, None) + actual_palette = get_color_palette(ds, HarmonySource({}), None) get_remote_palette_from_source_mock.assert_called_once() ds.colormap.assert_called_once_with(1) ds.get_nodatavals.assert_called_once() + self.assertIsNotNone(actual_palette) # Compare the actual and expected palettes - self.assertEqual(actual_palette.get_color(200), encode_color(0, 255, 0, 255)) - self.assertEqual(actual_palette.get_color('nv'), encode_color(255, 0, 0, 255)) + if actual_palette is not None: + self.assertEqual( + actual_palette.get_color(200), encode_color(0, 255, 0, 255) + ) + self.assertEqual( + actual_palette.get_color('nv'), encode_color(255, 0, 0, 255) + ) def test_convert_colormap_to_palette_3bands(self): input_colormap = { diff --git a/tests/unit/test_crs.py b/tests/unit/test_crs.py index eb23d7a..15a0f70 100644 --- a/tests/unit/test_crs.py +++ b/tests/unit/test_crs.py @@ -5,12 +5,13 @@ from tempfile import TemporaryDirectory from textwrap import dedent from unittest import TestCase -from unittest.mock import patch +from unittest.mock import Mock, patch +import rasterio from affine import Affine from harmony_service_lib.message import SRS from rasterio.crs import CRS -from rioxarray import open_rasterio +from rasterio.io import DatasetReader from hybig.crs import ( PREFERRED_CRS, @@ -84,11 +85,9 @@ class TestCrs(TestCase): """A class that tests the crs module.""" - @classmethod def setUp(self): self.temp_dir = Path(TemporaryDirectory().name) - @classmethod def tearDown(self): if self.temp_dir.exists(): rmtree(self.temp_dir) @@ -97,14 +96,16 @@ def test_choose_target_crs_with_epsg_from_harmony_message(self): """Test SRS has an epsg code.""" expected_CRS = CRS.from_epsg(6932) test_srs = SRS({'epsg': 'EPSG:6932'}) - actual_CRS = choose_target_crs(test_srs, None) + ds = Mock(DatasetReader) + actual_CRS = choose_target_crs(test_srs, ds) self.assertEqual(expected_CRS, actual_CRS) def test_choose_target_crs_with_wkt_from_harmony_message(self): """Test SRS has wkt string.""" expected_CRS = CRS.from_epsg(6050) test_srs = SRS({'wkt': WKT_EPSG_6050}) - actual_CRS = choose_target_crs(test_srs, None) + ds = Mock(DatasetReader) + actual_CRS = choose_target_crs(test_srs, ds) self.assertEqual(expected_CRS, actual_CRS) def test_choose_target_crs_with_proj4_from_harmony_message_and_empty_epsg(self): @@ -119,14 +120,16 @@ def test_choose_target_crs_with_proj4_from_harmony_message_and_empty_epsg(self): 'epsg': '', } ) - actual_CRS = choose_target_crs(test_srs, None) + ds = Mock(DatasetReader) + actual_CRS = choose_target_crs(test_srs, ds) self.assertEqual(expected_CRS, actual_CRS) def test_choose_target_crs_with_invalid_SRS_from_harmony_message(self): """Test SRS does not have epsg, wkt or proj4 string.""" - test_srs_is_json = {'how': 'did this happen?'} + test_srs_is_json = SRS({'how': 'did this happen?'}) + ds = Mock(DatasetReader) with self.assertRaisesRegex(HyBIGValueError, 'Bad input SRS'): - choose_target_crs(test_srs_is_json, None) + choose_target_crs(test_srs_is_json, ds) @patch('hybig.crs.choose_crs_from_metadata') def test_choose_target_harmony_message_has_crs_but_no_srs(self, mock_choose_fxn): @@ -137,10 +140,10 @@ def test_choose_target_harmony_message_has_crs_but_no_srs(self, mock_choose_fxn) """ test_srs = None - in_dataset = {'test': 'object'} + ds = Mock(DatasetReader) - choose_target_crs(test_srs, in_dataset) - mock_choose_fxn.assert_called_once_with(in_dataset) + choose_target_crs(test_srs, ds) + mock_choose_fxn.assert_called_once_with(ds) def test_choose_target_crs_with_preferred_metadata_north(self): """Check that preferred metadata for northern projection is found.""" @@ -152,7 +155,7 @@ def test_choose_target_crs_with_preferred_metadata_north(self): transform=Affine(25000.0, 0.0, -3850000.0, 0.0, -25000.0, 5850000.0), dtype='uint16', ) as tmp_file: - with open_rasterio(tmp_file) as rio_data_array: + with rasterio.open(tmp_file) as rio_data_array: actual_CRS = choose_target_crs(None, rio_data_array) self.assertEqual(expected_CRS, actual_CRS) @@ -166,7 +169,7 @@ def test_choose_target_crs_with_preferred_metadata_south(self): transform=Affine.scale(500, 300), dtype='uint16', ) as tmp_file: - with open_rasterio(tmp_file) as rio_data_array: + with rasterio.open(tmp_file) as rio_data_array: actual_CRS = choose_target_crs(None, rio_data_array) self.assertEqual(expected_CRS, actual_CRS) @@ -178,7 +181,7 @@ def test_choose_target_crs_with_preferred_metadata_global(self): count=3, crs=CRS.from_proj4('+proj=longlat +datum=WGS84 +no_defs +type=crs'), ) as tmp_file: - with open_rasterio(tmp_file) as rio_data_array: + with rasterio.open(tmp_file) as rio_data_array: actual_CRS = choose_target_crs(None, rio_data_array) self.assertEqual(expected_CRS, actual_CRS) @@ -190,7 +193,7 @@ def test_choose_target_crs_with_non_preferred_metadata_north(self): input_CRS = CRS.from_wkt(WKT_EPSG_3411) with rasterio_test_file(crs=input_CRS) as tmp_file: - with open_rasterio(tmp_file) as rio_data_array: + with rasterio.open(tmp_file) as rio_data_array: actual_CRS = choose_target_crs(None, rio_data_array) self.assertEqual(expected_CRS, actual_CRS) @@ -201,8 +204,8 @@ def test_choose_target_crs_from_metadata_south(self): input_CRS = CRS.from_string(ease_grid_2_south) with rasterio_test_file(crs=input_CRS) as tmp_file: - with open_rasterio(tmp_file) as rio_data_array: - actual_CRS = choose_best_crs_from_metadata(rio_data_array.rio.crs) + with rasterio.open(tmp_file) as rio_data_array: + actual_CRS = choose_best_crs_from_metadata(rio_data_array.crs) self.assertEqual(expected_CRS, actual_CRS) def test_choose_target_crs_from_metadata_global(self): @@ -212,8 +215,8 @@ def test_choose_target_crs_from_metadata_global(self): input_CRS = CRS.from_string(ease_grid_2_global) with rasterio_test_file(crs=input_CRS) as tmp_file: - with open_rasterio(tmp_file) as rio_data_array: - actual_CRS = choose_best_crs_from_metadata(rio_data_array.rio.crs) + with rasterio.open(tmp_file) as rio_data_array: + actual_CRS = choose_best_crs_from_metadata(rio_data_array.crs) self.assertEqual(expected_CRS, actual_CRS) def test_multiple_crs_from_metadata(self): diff --git a/tests/unit/test_sizes.py b/tests/unit/test_sizes.py index 8215802..8707fab 100644 --- a/tests/unit/test_sizes.py +++ b/tests/unit/test_sizes.py @@ -2,18 +2,21 @@ from pathlib import Path from unittest import TestCase -from unittest.mock import MagicMock, patch +from unittest.mock import MagicMock, Mock, patch import pytest import rasterio from harmony_service_lib.message import Message from rasterio import Affine from rasterio.crs import CRS -from rioxarray import open_rasterio +from rasterio.io import DatasetReader +from rasterio.transform import array_bounds, from_bounds from hybig.crs import PREFERRED_CRS from hybig.sizes import ( METERS_PER_DEGREE, + Dimensions, + GridParams, ScaleExtent, best_guess_target_dimensions, choose_scale_extent, @@ -116,7 +119,8 @@ def test_grid_parameters_from_harmony_message_has_complete_information(self): 'transform': expected_transform, } - actual_parameters = get_target_grid_parameters(message, None) + ds = Mock(DatasetReader) + actual_parameters = get_target_grid_parameters(message, ds) self.assertDictEqual(expected_parameters, actual_parameters) def test_grid_parameters_from_harmony_no_message_information(self): @@ -128,9 +132,7 @@ def test_grid_parameters_from_harmony_no_message_information(self): sp_seaice_grid['left'], sp_seaice_grid['top'] ) * Affine.scale(sp_seaice_grid['xres'], -1 * sp_seaice_grid['yres']) - left, bottom, right, top = rasterio.transform.array_bounds( - height, width, img_transform - ) + left, bottom, right, top = array_bounds(height, width, img_transform) image_scale_extent = { 'xmin': left, 'ymin': bottom, @@ -172,7 +174,7 @@ def test_grid_parameters_from_harmony_no_message_information(self): transform=img_transform, ) as tmp_file: message = Message({'format': {}}) - with open_rasterio(tmp_file) as rio_data_array: + with rasterio.open(tmp_file) as rio_data_array: actual_parameters = get_target_grid_parameters(message, rio_data_array) self.assertDictEqual(expected_parameters, actual_parameters) @@ -191,8 +193,10 @@ def test_parameters(self): yres = nsidc_np_seaice_grid['yres'] crs = CRS.from_epsg(nsidc_np_seaice_grid['epsg']) - dimensions = {'width': width, 'height': height} - scale_extent = {'xmin': west, 'ymax': north, 'xmax': east, 'ymin': south} + dimensions = Dimensions({'width': width, 'height': height}) + scale_extent = ScaleExtent( + {'xmin': west, 'ymax': north, 'xmax': east, 'ymin': south} + ) expected_parameters = { 'width': width, @@ -219,39 +223,47 @@ def test_needs_tiling(self): """ with self.subTest('Projected, needs tiling'): - grid_parameters = { - 'height': 8192, - 'width': 8193, - 'crs': CRS.from_epsg(nsidc_np_seaice_grid['epsg']), - 'transform': Affine(400, 0.0, -3850000.0, 0.0, 400, 5850000.0), - } + grid_parameters = GridParams( + { + 'height': 8192, + 'width': 8193, + 'crs': CRS.from_epsg(nsidc_np_seaice_grid['epsg']), + 'transform': Affine(400, 0.0, -3850000.0, 0.0, 400, 5850000.0), + } + ) self.assertTrue(needs_tiling(grid_parameters)) with self.subTest('Projected, does not need tiling'): - grid_parameters = { - 'height': 8192, - 'width': 8192, - 'crs': CRS.from_epsg(nsidc_np_seaice_grid['epsg']), - 'transform': Affine(600, 0.0, -3850000.0, 0.0, 600, 5850000.0), - } + grid_parameters = GridParams( + { + 'height': 8192, + 'width': 8192, + 'crs': CRS.from_epsg(nsidc_np_seaice_grid['epsg']), + 'transform': Affine(600, 0.0, -3850000.0, 0.0, 600, 5850000.0), + } + ) self.assertFalse(needs_tiling(grid_parameters)) with self.subTest('Geographic, needs tiling'): - grid_parameters = { - 'height': 180000, - 'width': 360000, - 'crs': CRS.from_epsg(4326), - 'transform': Affine(0.001, 0.0, -180, 0.0, -0.001, 180), - } + grid_parameters = GridParams( + { + 'height': 180000, + 'width': 360000, + 'crs': CRS.from_epsg(4326), + 'transform': Affine(0.001, 0.0, -180, 0.0, -0.001, 180), + } + ) self.assertTrue(needs_tiling(grid_parameters)) with self.subTest('Geographic, does not need tiling'): - grid_parameters = { - 'height': 1800, - 'width': 3600, - 'crs': CRS.from_epsg(4326), - 'transform': Affine(0.1, 0.0, -180, 0.0, -0.1, 180), - } + grid_parameters = GridParams( + { + 'height': 1800, + 'width': 3600, + 'crs': CRS.from_epsg(4326), + 'transform': Affine(0.1, 0.0, -180, 0.0, -0.1, 180), + } + ) self.assertFalse(needs_tiling(grid_parameters)) def test_get_cells_per_tile(self): @@ -283,8 +295,8 @@ def test_compute_tile_boundaries_with_leftovers(self): def test_compute_tile_dimensions_uniform(self): """Test tile dimensions.""" - tile_origins = [0.0, 10.0, 20.0, 30.0, 40.0, 43.0] - expected_dimensions = [10.0, 10.0, 10.0, 10.0, 3.0, 0.0] + tile_origins = [0, 10, 20, 30, 40, 43] + expected_dimensions = [10, 10, 10, 10, 3, 0] actual_dimensions = compute_tile_dimensions(tile_origins) @@ -292,8 +304,8 @@ def test_compute_tile_dimensions_uniform(self): def test_compute_tile_dimensions_nonuniform(self): """Test tile dimensions.""" - tile_origins = [0.0, 20.0, 35.0, 40.0, 43.0] - expected_dimensions = [20.0, 15.0, 5.0, 3.0, 0.0] + tile_origins = [0, 20, 35, 40, 43] + expected_dimensions = [20, 15, 5, 3, 0] actual_dimensions = compute_tile_dimensions(tile_origins) @@ -334,12 +346,14 @@ def test_create_tile_output_parameters( needs_tiling_mock.return_value = True cells_per_tile_mock.return_value = 2800 - grid_parameters = { - 'width': 7200, - 'height': 3600, - 'crs': CRS.from_string(PREFERRED_CRS['global']), - 'transform': Affine(0.05, 0.0, -180.0, 0.0, -0.05, 90.0), - } + grid_parameters = GridParams( + { + 'width': 7200, + 'height': 3600, + 'crs': CRS.from_string(PREFERRED_CRS['global']), + 'transform': Affine(0.05, 0.0, -180.0, 0.0, -0.05, 90.0), + } + ) expected_grid_list = [ { @@ -420,29 +434,26 @@ def test_scale_extent_in_harmony_message(self): 'ymax': 500.0, } crs = None - actual_scale_extent = choose_scale_extent(message, crs, None) + ds = Mock(DatasetReader) + actual_scale_extent = choose_scale_extent(message, crs, ds) self.assertDictEqual(expected_scale_extent, actual_scale_extent) def test_scale_extent_from_input_image_and_no_crs_transformation(self): """Ensure no change of output extent when src_crs == target_crs""" - with open_rasterio( - self.fixtures / 'RGB.byte.small.tif', mode='r', mask_and_scale=True - ) as in_array: - source_crs = in_array.rio.crs - left, bottom, right, top = in_array.rio.bounds() + with rasterio.open(self.fixtures / 'RGB.byte.small.tif') as in_array: + source_crs = in_array.crs + left, bottom, right, top = in_array.bounds expected_scale_extent = ScaleExtent( {'xmin': left, 'ymin': bottom, 'xmax': right, 'ymax': top} ) - actual_scale_extent = choose_scale_extent({}, source_crs, in_array) + actual_scale_extent = choose_scale_extent(Message({}), source_crs, in_array) self.assertEqual(actual_scale_extent, expected_scale_extent) def test_scale_extent_from_input_image_with_crs_transformation(self): """Ensure no change of output extent when src_crs == target_crs""" target_crs = CRS.from_string(PREFERRED_CRS['global']) - with open_rasterio( - self.fixtures / 'RGB.byte.small.tif', mode='r', mask_and_scale=True - ) as in_array: + with rasterio.open(self.fixtures / 'RGB.byte.small.tif') as in_array: left, bottom, right, top = ( -78.95864996539397, 23.568866283727235, @@ -453,7 +464,7 @@ def test_scale_extent_from_input_image_with_crs_transformation(self): {'xmin': left, 'ymin': bottom, 'xmax': right, 'ymax': top} ) - actual_scale_extent = choose_scale_extent({}, target_crs, in_array) + actual_scale_extent = choose_scale_extent(Message({}), target_crs, in_array) assert expected_scale_extent == pytest.approx( actual_scale_extent, rel=1e-12 ) @@ -464,9 +475,7 @@ def test_scale_extent_from_input_image_that_crosses_antimeridian(self): Notice that the xmax value is > 180. """ target_crs = CRS.from_string(PREFERRED_CRS['global']) - with open_rasterio( - self.fixtures / 'split-dateline-sample.tif', mode='r', mask_and_scale=True - ) as in_array: + with rasterio.open(self.fixtures / 'split-dateline-sample.tif') as in_array: expected_scale_extent = ScaleExtent( { 'xmin': 179.25694918947116, @@ -476,7 +485,7 @@ def test_scale_extent_from_input_image_that_crosses_antimeridian(self): } ) - actual_scale_extent = choose_scale_extent({}, target_crs, in_array) + actual_scale_extent = choose_scale_extent(Message({}), target_crs, in_array) assert expected_scale_extent == pytest.approx( actual_scale_extent, rel=1e-12 ) @@ -486,16 +495,21 @@ class TestChooseTargetDimensions(TestCase): def test_message_has_dimensions(self): message = Message({'format': {'height': 30, 'width': 40}}) expected_dimensions = {'height': 30, 'width': 40} - actual_dimensions = choose_target_dimensions(message, None, None, None) + ds = Mock(DatasetReader) + se = Mock(ScaleExtent) + actual_dimensions = choose_target_dimensions(message, ds, se, None) self.assertDictEqual(expected_dimensions, actual_dimensions) def test_message_has_scale_sizes(self): message = Message({'format': {'scaleSize': {'x': 10, 'y': 10}}}) # scaleExtents are already extracted. - scale_extent = {'xmin': 0.0, 'xmax': 2000.0, 'ymin': 0.0, 'ymax': 1000.0} + scale_extent = ScaleExtent( + {'xmin': 0.0, 'xmax': 2000.0, 'ymin': 0.0, 'ymax': 1000.0} + ) expected_dimensions = {'height': 100, 'width': 200} - actual_dimensions = choose_target_dimensions(message, None, scale_extent, None) + ds = Mock(DatasetReader) + actual_dimensions = choose_target_dimensions(message, ds, scale_extent, None) self.assertDictEqual(expected_dimensions, actual_dimensions) @patch('hybig.sizes.best_guess_target_dimensions') @@ -546,13 +560,15 @@ def test_projected_crs(self): transform=Affine(25000.0, 0.0, -9000000.0, 0.0, -25000.0, 9000000.0), dtype='uint8', ) as tmp_file: - with open_rasterio(tmp_file) as rio_data_array: - scale_extent = { - 'xmin': -9000000.0, - 'ymin': -9000000.0, - 'xmax': 9000000.0, - 'ymax': 9000000.0, - } + with rasterio.open(tmp_file) as rio_data_array: + scale_extent = ScaleExtent( + { + 'xmin': -9000000.0, + 'ymin': -9000000.0, + 'xmax': 9000000.0, + 'ymax': 9000000.0, + } + ) # in_dataset's height and width expected_target_dimensions = {'height': 720, 'width': 720} @@ -582,13 +598,15 @@ def test_projected_crs_with_high_resolution(self): ), dtype='uint8', ) as tmp_file: - with open_rasterio(tmp_file) as rio_data_array: - scale_extent = { - 'xmin': -9000000.0, - 'ymin': -9000000.0, - 'xmax': 9000000.0, - 'ymax': 9000000.0, - } + with rasterio.open(tmp_file) as rio_data_array: + scale_extent = ScaleExtent( + { + 'xmin': -9000000.0, + 'ymin': -9000000.0, + 'xmax': 9000000.0, + 'ymax': 9000000.0, + } + ) # expected resolution is "500m" and the pixel_size is 512m # (9000000 - -9000000 ) / 512 = 35156 @@ -632,13 +650,15 @@ def test_projected_crs_with_high_resolution_to_preferred_area(self): transform=Affine(700.0423, 0.0, -4194304.0, 0.0, 700.0423, 4194304.0), dtype='uint8', ) as tmp_file: - with open_rasterio(tmp_file) as rio_data_array: - scale_extent = { - 'xmin': -4194304.0, - 'ymin': -4194304.0, - 'xmax': 4194304.0, - 'ymax': 4194304.0, - } + with rasterio.open(tmp_file) as rio_data_array: + scale_extent = ScaleExtent( + { + 'xmin': -4194304.0, + 'ymin': -4194304.0, + 'xmax': 4194304.0, + 'ymax': 4194304.0, + } + ) expected_target_dimensions = { 'height': epsg_3413_resolutions[2].width, @@ -675,7 +695,7 @@ def test_projected_crs_with_high_resolution_to_preferred_area(self): def test_longlat_crs(self): # 36km Mid-Latitude EASE Grid 2 - ml_test_transform = rasterio.transform.from_bounds( + ml_test_transform = from_bounds( nsidc_ease2_36km_grid['left'], nsidc_ease2_36km_grid['bottom'], nsidc_ease2_36km_grid['right'], @@ -691,13 +711,15 @@ def test_longlat_crs(self): transform=ml_test_transform, dtype='uint8', ) as tmp_file: - with open_rasterio(tmp_file) as rio_data_array: - scale_extent = { - 'xmin': -180.0, - 'ymin': -86.0, - 'xmax': 180.0, - 'ymax': 86.0, - } + with rasterio.open(tmp_file) as rio_data_array: + scale_extent = ScaleExtent( + { + 'xmin': -180.0, + 'ymin': -86.0, + 'xmax': 180.0, + 'ymax': 86.0, + } + ) infile_res = 0.31668943359375 expected_height = round((86 - -86) / infile_res) @@ -718,7 +740,7 @@ def test_longlat_crs(self): def test_longlat_crs_with_high_resolution(self): # 360m Mid-Latitude EASE Grid 2 # 360m -> 250m preferred - test_transform = rasterio.transform.from_bounds( + test_transform = from_bounds( nsidc_ease2_36km_grid['left'], nsidc_ease2_36km_grid['bottom'], nsidc_ease2_36km_grid['right'], @@ -733,13 +755,15 @@ def test_longlat_crs_with_high_resolution(self): crs=CRS.from_epsg(6933), dtype='uint8', ) as tmp_file: - with open_rasterio(tmp_file) as rio_data_array: - scale_extent = { - 'xmin': -180.0, - 'ymin': -86.0, - 'xmax': 180.0, - 'ymax': 86.0, - } + with rasterio.open(tmp_file) as rio_data_array: + scale_extent = ScaleExtent( + { + 'xmin': -180.0, + 'ymin': -86.0, + 'xmax': 180.0, + 'ymax': 86.0, + } + ) # resolution is 360 meters, which resolves to 250m preferred. target_resolution = epsg_4326_resolutions[3].pixel_size @@ -784,7 +808,7 @@ class TestResolutionInTargetCRS(TestCase): """Ensure resolution ends up in target_crs units""" def test_dataset_matches_target_crs_meters(self): - ml_test_transform = rasterio.transform.from_bounds( + ml_test_transform = from_bounds( nsidc_ease2_36km_grid['left'], nsidc_ease2_36km_grid['bottom'], nsidc_ease2_36km_grid['right'], @@ -796,7 +820,7 @@ def test_dataset_matches_target_crs_meters(self): crs=CRS.from_epsg(nsidc_ease2_36km_grid['epsg']), transform=ml_test_transform, ) as test_file: - with open_rasterio(test_file) as test_dataarray: + with rasterio.open(test_file) as test_dataarray: target_crs = CRS.from_epsg(3413) expected_x_res = ml_test_transform.a expected_y_res = -ml_test_transform.e @@ -810,14 +834,12 @@ def test_dataset_matches_target_crs_meters(self): def test_dataset_matches_target_crs_degrees(self): """Input dataset and target unprojected.""" - global_one_degree_transform = rasterio.transform.from_bounds( - -180, -90, 180, 90, 360, 180 - ) + global_one_degree_transform = from_bounds(-180, -90, 180, 90, 360, 180) with rasterio_test_file( crs=CRS.from_string(PREFERRED_CRS['global']), transform=global_one_degree_transform, ) as test_file: - with open_rasterio(test_file) as test_dataarray: + with rasterio.open(test_file) as test_dataarray: target_crs = CRS.from_string(PREFERRED_CRS['global']) expected_x_res = global_one_degree_transform.a expected_y_res = -global_one_degree_transform.e @@ -830,7 +852,7 @@ def test_dataset_matches_target_crs_degrees(self): self.assertEqual(expected_y_res, actual_y_res) def test_dataset_meters_target_crs_degrees(self): - ml_test_transform = rasterio.transform.from_bounds( + ml_test_transform = from_bounds( nsidc_ease2_36km_grid['left'], nsidc_ease2_36km_grid['bottom'], nsidc_ease2_36km_grid['right'], @@ -842,7 +864,7 @@ def test_dataset_meters_target_crs_degrees(self): crs=CRS.from_epsg(nsidc_ease2_36km_grid['epsg']), transform=ml_test_transform, ) as test_file: - with open_rasterio(test_file) as test_dataarray: + with rasterio.open(test_file) as test_dataarray: target_crs = CRS.from_epsg(4326) expected_x_res = ml_test_transform.a / METERS_PER_DEGREE expected_y_res = -ml_test_transform.e / METERS_PER_DEGREE @@ -855,14 +877,12 @@ def test_dataset_meters_target_crs_degrees(self): self.assertEqual(expected_y_res, actual_y_res) def test_dataset_degrees_target_crs_meters(self): - global_one_degree_transform = rasterio.transform.from_bounds( - -180, -90, 180, 90, 360, 180 - ) + global_one_degree_transform = from_bounds(-180, -90, 180, 90, 360, 180) with rasterio_test_file( crs=CRS.from_string(PREFERRED_CRS['global']), transform=global_one_degree_transform, ) as test_file: - with open_rasterio(test_file) as test_dataarray: + with rasterio.open(test_file) as test_dataarray: target_crs = CRS.from_string(PREFERRED_CRS['north']) expected_x_res = global_one_degree_transform.a * METERS_PER_DEGREE expected_y_res = -global_one_degree_transform.e * METERS_PER_DEGREE @@ -933,7 +953,7 @@ def test_coarser_than_2km_meters(self): def test_matches_preferred_meters(self): # 500m resolution. - resolution = [128] + resolution = [128.0] expected_resolution = self.south_info[4].pixel_size actual_resolution = find_closest_resolution(resolution, self.south_info) self.assertEqual(expected_resolution, actual_resolution.pixel_size) @@ -941,13 +961,13 @@ def test_matches_preferred_meters(self): def test_resolution_halfway_between_preferred_meters(self): """Exactly half way should choose larger resolution.""" # 48m resolution - resolution = [48] + resolution = [48.0] expected_resolution = self.north_info[5].pixel_size actual_resolution = find_closest_resolution(resolution, self.north_info) self.assertEqual(expected_resolution, actual_resolution.pixel_size) def test_chooses_closest_resolution_meters(self): - resolutions = [65, 540] + resolutions = [65.0, 540.0] expected_resolution = self.south_info[5].pixel_size actual_resolution = find_closest_resolution(resolutions, self.south_info) self.assertEqual(expected_resolution, actual_resolution.pixel_size)