Skip to content

Commit 44f027c

Browse files
committed
Move Hazard xarray reader into separate module
1 parent f312a58 commit 44f027c

File tree

3 files changed

+424
-327
lines changed

3 files changed

+424
-327
lines changed

climada/hazard/io.py

Lines changed: 16 additions & 326 deletions
Original file line numberDiff line numberDiff line change
@@ -19,29 +19,28 @@
1919
Define Hazard IO Methods.
2020
"""
2121

22-
import copy
2322
import datetime as dt
2423
import itertools
2524
import logging
2625
import pathlib
2726
import warnings
2827
from collections.abc import Collection
29-
from typing import Any, Callable, Dict, Optional, Union
28+
from typing import Any, Dict, Optional, Union
3029

3130
import h5py
3231
import numpy as np
3332
import pandas as pd
3433
import rasterio
35-
import sparse as sp
3634
import xarray as xr
3735
from scipy import sparse
3836

3937
import climada.util.constants as u_const
4038
import climada.util.coordinates as u_coord
41-
import climada.util.dates_times as u_dt
4239
import climada.util.hdf5_handler as u_hdf5
4340
from climada.hazard.centroids.centr import Centroids
4441

42+
from .xarray import HazardXarrayReader
43+
4544
LOGGER = logging.getLogger(__name__)
4645

4746
DEF_VAR_EXCEL = {
@@ -86,12 +85,6 @@
8685
}
8786
"""MATLAB variable names"""
8887

89-
DEF_COORDS = dict(event="time", longitude="longitude", latitude="latitude")
90-
"""Default coordinates when reading Hazard data from an xarray Dataset"""
91-
92-
DEF_DATA_VARS = ["fraction", "frequency", "event_id", "event_name", "date"]
93-
"""Default keys for optional Hazard attributes when reading from an xarray Dataset"""
94-
9588

9689
# pylint: disable=no-member
9790

@@ -320,8 +313,9 @@ def from_xarray_raster_file(
320313
>>> with xarray.open_dataset("path/to/file.nc", **open_kwargs) as dset:
321314
... hazard = Hazard.from_xarray_raster(dset, "", "")
322315
"""
323-
with xr.open_dataset(filepath, chunks="auto") as dset:
324-
return cls.from_xarray_raster(dset, *args, **kwargs)
316+
reader = HazardXarrayReader.from_file(filepath, *args, **kwargs)
317+
kwargs = reader.get_hazard_kwargs()
318+
return cls(**cls._check_and_cast_attrs(kwargs))
325319

326320
@classmethod
327321
def from_xarray_raster(
@@ -574,7 +568,6 @@ def from_xarray_raster(
574568
>>> dset = dset.expand_dims(time=[numpy.datetime64("2000-01-01")])
575569
>>> hazard = Hazard.from_xarray_raster(dset, "", "")
576570
"""
577-
# Check data type for better error message
578571
if not isinstance(data, xr.Dataset):
579572
if isinstance(data, (pathlib.Path, str)):
580573
raise TypeError(
@@ -584,321 +577,18 @@ def from_xarray_raster(
584577

585578
raise TypeError("This method only supports xarray.Dataset as input data")
586579

587-
# Initialize Hazard object
588-
hazard_kwargs = dict(haz_type=hazard_type, units=intensity_unit)
589-
590-
# Update coordinate identifiers
591-
coords = copy.deepcopy(DEF_COORDS)
592-
coordinate_vars = coordinate_vars if coordinate_vars is not None else {}
593-
unknown_coords = [co for co in coordinate_vars if co not in coords]
594-
if unknown_coords:
595-
raise ValueError(
596-
f"Unknown coordinates passed: '{unknown_coords}'. Supported "
597-
f"coordinates are {list(coords.keys())}."
598-
)
599-
coords.update(coordinate_vars)
600-
601-
# Retrieve dimensions of coordinates
602-
try:
603-
dims = dict(
604-
event=data[coords["event"]].dims,
605-
longitude=data[coords["longitude"]].dims,
606-
latitude=data[coords["latitude"]].dims,
607-
)
608-
# Handle KeyError for better error message
609-
except KeyError as err:
610-
key = err.args[0]
611-
raise RuntimeError(
612-
f"Dataset is missing dimension/coordinate: {key}. Dataset dimensions: "
613-
f"{list(data.dims.keys())}"
614-
) from err
615-
616-
# Try promoting single-value coordinates to dimensions
617-
for key, val in dims.items():
618-
if not val:
619-
coord = coords[key]
620-
LOGGER.debug("Promoting Dataset coordinate '%s' to dimension", coord)
621-
data = data.expand_dims(coord)
622-
dims[key] = data[coord].dims
623-
624-
# Try to rechunk the data to optimize the stack operation afterwards.
625-
if rechunk:
626-
# We want one event to be contained in one chunk
627-
chunks = {dim: -1 for dim in dims["longitude"]}
628-
chunks.update({dim: -1 for dim in dims["latitude"]})
629-
630-
# Chunks can be auto-sized along the event dimensions
631-
chunks.update({dim: "auto" for dim in dims["event"]})
632-
data = data.chunk(chunks=chunks)
633-
634-
# Stack (vectorize) the entire dataset into 2D (time, lat/lon)
635-
# NOTE: We want the set union of the dimensions, but Python 'set' does not
636-
# preserve order. However, we want longitude to run faster than latitude.
637-
# So we use 'dict' without values, as 'dict' preserves insertion order
638-
# (dict keys behave like a set).
639-
data = data.stack(
640-
event=dims["event"],
641-
lat_lon=dict.fromkeys(dims["latitude"] + dims["longitude"]),
642-
)
643-
644-
# Transform coordinates into centroids
645-
centroids = Centroids(
646-
lat=data[coords["latitude"]].values,
647-
lon=data[coords["longitude"]].values,
580+
reader = HazardXarrayReader(
581+
data=data,
582+
hazard_type=hazard_type,
583+
intensity_unit=intensity_unit,
584+
intensity=intensity,
585+
coordinate_vars=coordinate_vars,
586+
data_vars=data_vars,
648587
crs=crs,
588+
rechunk=rechunk,
649589
)
650-
651-
def to_csr_matrix(array: xr.DataArray) -> sparse.csr_matrix:
652-
"""Store a numpy array as sparse matrix, optimizing storage space
653-
654-
The CSR matrix stores NaNs explicitly, so we set them to zero.
655-
"""
656-
array = array.where(array.notnull(), 0)
657-
array = xr.apply_ufunc(
658-
sp.COO.from_numpy,
659-
array,
660-
dask="parallelized",
661-
output_dtypes=[array.dtype],
662-
)
663-
sparse_coo = array.compute().data # Load into memory
664-
return sparse_coo.tocsr() # Convert sparse.COO to scipy.sparse.csr_matrix
665-
666-
# Read the intensity data
667-
LOGGER.debug("Loading Hazard intensity from DataArray '%s'", intensity)
668-
intensity_matrix = to_csr_matrix(data[intensity])
669-
670-
# Define accessors for xarray DataArrays
671-
def default_accessor(array: xr.DataArray) -> np.ndarray:
672-
"""Take a DataArray and return its numpy representation"""
673-
return array.values
674-
675-
def strict_positive_int_accessor(array: xr.DataArray) -> np.ndarray:
676-
"""Take a positive int DataArray and return its numpy representation
677-
678-
Raises
679-
------
680-
TypeError
681-
If the underlying data type is not integer
682-
ValueError
683-
If any value is zero or less
684-
"""
685-
if not np.issubdtype(array.dtype, np.integer):
686-
raise TypeError(f"'{array.name}' data array must be integers")
687-
if not (array > 0).all():
688-
raise ValueError(f"'{array.name}' data must be larger than zero")
689-
return array.values
690-
691-
def date_to_ordinal_accessor(
692-
array: xr.DataArray, strict: bool = True
693-
) -> np.ndarray:
694-
"""Take a DataArray and transform it into ordinals"""
695-
try:
696-
if np.issubdtype(array.dtype, np.integer):
697-
# Assume that data is ordinals
698-
return strict_positive_int_accessor(array)
699-
700-
# Try transforming to ordinals
701-
return np.array(u_dt.datetime64_to_ordinal(array.values))
702-
703-
# Handle access errors
704-
except (ValueError, TypeError, AttributeError) as err:
705-
if strict:
706-
raise err
707-
708-
LOGGER.warning(
709-
"Failed to read values of '%s' as dates or ordinals. Hazard.date "
710-
"will be ones only",
711-
array.name,
712-
)
713-
return np.ones(array.shape)
714-
715-
def year_month_day_accessor(
716-
array: xr.DataArray, strict: bool = True
717-
) -> np.ndarray:
718-
"""Take an array and return am array of YYYY-MM-DD strings"""
719-
try:
720-
return array.dt.strftime("%Y-%m-%d").values
721-
722-
# Handle access errors
723-
except (ValueError, TypeError, AttributeError) as err:
724-
if strict:
725-
raise err
726-
727-
LOGGER.warning(
728-
"Failed to read values of '%s' as dates. Hazard.event_name will be "
729-
"empty strings",
730-
array.name,
731-
)
732-
return np.full(array.shape, "")
733-
734-
def maybe_repeat(values: np.ndarray, times: int) -> np.ndarray:
735-
"""Return the array or repeat a single-valued array
736-
737-
If ``values`` has size 1, return an array that repeats this value ``times``
738-
times. If the size is different, just return the array.
739-
"""
740-
if values.size == 1:
741-
return np.array(list(itertools.repeat(values.flat[0], times)))
742-
743-
return values
744-
745-
# Create a DataFrame storing access information for each of data_vars
746-
# NOTE: Each row will be passed as arguments to
747-
# `load_from_xarray_or_return_default`, see its docstring for further
748-
# explanation of the DataFrame columns / keywords.
749-
num_events = data.sizes["event"]
750-
data_ident = pd.DataFrame(
751-
data=dict(
752-
# The attribute of the Hazard class where the data will be stored
753-
hazard_attr=DEF_DATA_VARS,
754-
# The identifier and default key used in this method
755-
default_key=DEF_DATA_VARS,
756-
# The key assigned by the user
757-
user_key=None,
758-
# The default value for each attribute
759-
default_value=[
760-
None,
761-
np.ones(num_events),
762-
np.array(range(num_events), dtype=int) + 1,
763-
list(
764-
year_month_day_accessor(
765-
data[coords["event"]], strict=False
766-
).flat
767-
),
768-
date_to_ordinal_accessor(data[coords["event"]], strict=False),
769-
],
770-
# The accessor for the data in the Dataset
771-
accessor=[
772-
to_csr_matrix,
773-
lambda x: maybe_repeat(default_accessor(x), num_events),
774-
strict_positive_int_accessor,
775-
lambda x: list(maybe_repeat(default_accessor(x), num_events).flat),
776-
lambda x: maybe_repeat(date_to_ordinal_accessor(x), num_events),
777-
],
778-
)
779-
)
780-
781-
# Check for unexpected keys
782-
data_vars = data_vars if data_vars is not None else {}
783-
default_keys = data_ident["default_key"]
784-
unknown_keys = [
785-
key for key in data_vars.keys() if not default_keys.str.contains(key).any()
786-
]
787-
if unknown_keys:
788-
raise ValueError(
789-
f"Unknown data variables passed: '{unknown_keys}'. Supported "
790-
f"data variables are {list(default_keys)}."
791-
)
792-
793-
# Update with keys provided by the user
794-
# NOTE: Keys in 'default_keys' missing from 'data_vars' will be set to 'None'
795-
# (which is exactly what we want) and the result is written into
796-
# 'user_key'. 'default_keys' is not modified.
797-
data_ident["user_key"] = default_keys.map(data_vars)
798-
799-
def load_from_xarray_or_return_default(
800-
user_key: Optional[str],
801-
default_key: str,
802-
hazard_attr: str,
803-
accessor: Callable[[xr.DataArray], Any],
804-
default_value: Any,
805-
) -> Any:
806-
"""Load data for a single Hazard attribute or return the default value
807-
808-
Does the following based on the ``user_key``:
809-
* If the key is an empty string, return the default value
810-
* If the key is a non-empty string, load the data for that key and return it.
811-
* If the key is ``None``, look for the ``default_key`` in the data. If it
812-
exists, return that data. If not, return the default value.
813-
814-
Parameters
815-
----------
816-
user_key : str or None
817-
The key set by the user to identify the DataArray to read data from.
818-
default_key : str
819-
The default key identifying the DataArray to read data from.
820-
hazard_attr : str
821-
The name of the attribute of ``Hazard`` where the data will be stored in.
822-
accessor : Callable
823-
A callable that takes the DataArray as argument and returns the data
824-
structure that is required by the ``Hazard`` attribute.
825-
default_value
826-
The default value/array to return in case the data could not be found.
827-
828-
Returns
829-
-------
830-
The object that will be stored in the ``Hazard`` attribute ``hazard_attr``.
831-
832-
Raises
833-
------
834-
KeyError
835-
If ``user_key`` was a non-empty string but no such key was found in the
836-
data
837-
RuntimeError
838-
If the data structure loaded has a different shape than the default data
839-
structure
840-
"""
841-
# User does not want to read data
842-
if user_key == "":
843-
LOGGER.debug(
844-
"Using default values for Hazard.%s per user request", hazard_attr
845-
)
846-
return default_value
847-
848-
if not pd.isna(user_key):
849-
# Read key exclusively
850-
LOGGER.debug(
851-
"Reading data for Hazard.%s from DataArray '%s'",
852-
hazard_attr,
853-
user_key,
854-
)
855-
val = accessor(data[user_key])
856-
else:
857-
# Try default key
858-
try:
859-
val = accessor(data[default_key])
860-
LOGGER.debug(
861-
"Reading data for Hazard.%s from DataArray '%s'",
862-
hazard_attr,
863-
default_key,
864-
)
865-
except KeyError:
866-
LOGGER.debug(
867-
"Using default values for Hazard.%s. No data found", hazard_attr
868-
)
869-
return default_value
870-
871-
def vshape(array):
872-
"""Return a shape tuple for any array-like type we use"""
873-
if isinstance(array, list):
874-
return len(array)
875-
if isinstance(array, sparse.csr_matrix):
876-
return array.get_shape()
877-
return array.shape
878-
879-
# Check size for read data
880-
if default_value is not None and not np.array_equal(
881-
vshape(val), vshape(default_value)
882-
):
883-
raise RuntimeError(
884-
f"'{user_key if user_key else default_key}' must have shape "
885-
f"{vshape(default_value)}, but shape is {vshape(val)}"
886-
)
887-
888-
# Return the data
889-
return val
890-
891-
# Set the Hazard attributes
892-
for _, ident in data_ident.iterrows():
893-
hazard_kwargs[ident["hazard_attr"]] = load_from_xarray_or_return_default(
894-
**ident
895-
)
896-
897-
hazard_kwargs = cls._check_and_cast_attrs(hazard_kwargs)
898-
899-
# Done!
900-
LOGGER.debug("Hazard successfully loaded. Number of events: %i", num_events)
901-
return cls(centroids=centroids, intensity=intensity_matrix, **hazard_kwargs)
590+
kwargs = reader.get_hazard_kwargs()
591+
return cls(**cls._check_and_cast_attrs(kwargs))
902592

903593
@staticmethod
904594
def _check_and_cast_attrs(attrs: Dict[str, Any]) -> Dict[str, Any]:

0 commit comments

Comments
 (0)