diff --git a/src/earthkit/data/core/fieldlist.py b/src/earthkit/data/core/fieldlist.py index 20c858ae5..7a1be1c0d 100644 --- a/src/earthkit/data/core/fieldlist.py +++ b/src/earthkit/data/core/fieldlist.py @@ -24,6 +24,7 @@ from earthkit.data.core.index import MultiIndex from earthkit.data.decorators import cached_method from earthkit.data.decorators import detect_out_filename +from earthkit.data.utils.compute import wrap_maths from earthkit.data.utils.metadata.args import metadata_argument @@ -88,9 +89,8 @@ def index(self, key): return self.user_indices[key] +@wrap_maths class Field(Base): - r"""Represent a Field.""" - @property def array_backend(self): r""":obj:`ArrayBackend`: Return the array backend of the field.""" @@ -931,7 +931,27 @@ def _array_matches(self, array, flatten=False, dtype=None): shape = self._required_shape(flatten) return shape == array.shape and (dtype is None or dtype == array.dtype) + def _unary_op(self, oper): + v = oper(self.values) + r = self.clone(values=v) + return r + + def _binary_op(self, oper, y): + from earthkit.data.wrappers import get_wrapper + + y = get_wrapper(y) + if isinstance(y, FieldList): + x = FieldList.from_fields([self]) + return x._binary_op(oper, y) + + vx = self.values + vy = y.values + v = oper(vx, vy) + r = self.clone(values=v) + return r + +@wrap_maths class FieldList(Index): r"""Represent a list of :obj:`Field` \s.""" @@ -1750,6 +1770,18 @@ def _cache_diag(self): return metadata_cache_diag(self) + def _unary_op(self, oper): + from earthkit.data.utils.compute import get_method + + method = "loop" + return get_method(method).unary_op(oper, self) + + def _binary_op(self, oper, y): + from earthkit.data.utils.compute import get_method + + method = "loop" + return get_method(method).binary_op(oper, self, y) + class MaskFieldList(FieldList, MaskIndex): def __init__(self, *args, **kwargs): diff --git a/src/earthkit/data/utils/compute.py b/src/earthkit/data/utils/compute.py new file mode 100644 index 000000000..0ae8012e9 --- /dev/null +++ b/src/earthkit/data/utils/compute.py @@ -0,0 +1,253 @@ +# (C) Copyright 2020 ECMWF. +# +# This software is licensed under the terms of the Apache Licence Version 2.0 +# which can be obtained at http://www.apache.org/licenses/LICENSE-2.0. +# In applying this licence, ECMWF does not waive the privileges and immunities +# granted to it by virtue of its status as an intergovernmental organisation +# nor does it submit to any jurisdiction. +# + +import math +from abc import ABCMeta +from abc import abstractmethod + +from earthkit.utils.array import array_namespace + +from earthkit.data.wrappers import get_wrapper + +COMP_UNARY = { + "__neg__": lambda x: -x, + "__pos__": lambda x: +x, + "asin": lambda x: array_namespace(x).asin(x), + "acos": lambda x: array_namespace(x).acos(x), + "atan": lambda x: array_namespace(x).atan(x), + "arcsin": lambda x: array_namespace(x).asin(x), + "arccos": lambda x: array_namespace(x).acos(x), + "arctan": lambda x: array_namespace(x).atan(x), + "cos": lambda x: array_namespace(x).cos(x), + "cosh": lambda x: array_namespace(x).cosh(x), + "exp": lambda x: array_namespace(x).exp(x), + "floor": lambda x: array_namespace(x).floor(x), + "log": lambda x: array_namespace(x).log(x), + "log10": lambda x: array_namespace(x).log10(x), + "round": lambda x: array_namespace(x).round(x), + "sign": lambda x: array_namespace(x).sign(x), + "sin": lambda x: array_namespace(x).sin(x), + "sinh": lambda x: array_namespace(x).sinh(x), + "tan": lambda x: array_namespace(x).tan(x), + "tanh": lambda x: array_namespace(x).tanh(x), + "sqrt": lambda x: array_namespace(x).sqrt(x), + "trunc": lambda x: array_namespace(x).trunc(x), +} + +COMP_BINARY = { + "__add__": lambda x, y: x + y, + "__radd__": lambda x, y: y + x, + "__sub__": lambda x, y: x - y, + "__rsub__": lambda x, y: y - x, + "__mul__": lambda x, y: x * y, + "__rmul__": lambda x, y: y * x, + "__truediv__": lambda x, y: x / y, + "__rtruediv__": lambda x, y: y / x, + "__floordiv__": lambda x, y: x // y, + "__rfloordiv__": lambda x, y: y // x, + "__mod__": lambda x, y: x % y, + "__rmod__": lambda x, y: y % x, + "__pow__": lambda x, y: x**y, + "__rpow__": lambda x, y: y**x, + "__gt__": lambda x, y: x > y, + "__lt__": lambda x, y: x < y, + "__ge__": lambda x, y: x >= y, + "__le__": lambda x, y: x <= y, + # "__eq__": lambda x, y: x == y, + "__ne__": lambda x, y: x != y, +} + + +def wrap_maths(cls): + def wrap_unary_method(op): + def wrapper(self, *args, **kwargs): + return self._unary_op(op, *args, **kwargs) + + return wrapper + + def wrap_binary_method(op): + def wrapper(self, *args, **kwargs): + return self._binary_op(op, *args, **kwargs) + + return wrapper + + for name in COMP_BINARY: + op = COMP_BINARY[name] + setattr(cls, name, wrap_binary_method(op)) + for name in COMP_UNARY: + op = COMP_UNARY[name] + setattr(cls, name, wrap_unary_method(op)) + return cls + + +def apply_ufunc(func, *args): + from earthkit.data.core.fieldlist import Field + from earthkit.data.core.fieldlist import FieldList + + x = [get_wrapper(a) for a in args] + + d = None + + if len(x) == 1: + d = x[0] + return d._unary_op(func) + else: + num = 0 + for a in x: + if isinstance(a, FieldList): + n = len(a) + if n > num: + num = n + d = a + if d is not None: + return get_method("loop").apply_ufunc(func, d, *x) + + for a in x: + if isinstance(a, Field): + d = a + d = FieldList.from_fields([d]) + r = get_method("loop").apply_ufunc(func, d, *x) + assert len(r) == 1 + return r + + if all(hasattr(a, "values") for a in x): + return func([f.values for f in x]) + + raise ValueError("Cannot find a suitable object to apply ufunc") + + +class Compute(metaclass=ABCMeta): + @abstractmethod + def unary_op(self, oper): + pass + + @abstractmethod + def binary_op(self, oper, y): + pass + + +class LoopCompute(Compute): + @staticmethod + def create_fieldlist(ref, x): + from earthkit.data.core.fieldlist import Field + from earthkit.data.core.fieldlist import FieldList + + x = get_wrapper(x) + + if isinstance(x, FieldList): + return x + + if isinstance(x, Field): + return FieldList.from_fields([x]) + elif hasattr(x, "values"): + from earthkit.data.sources.array_list import from_array + from earthkit.data.utils.metadata.dict import UserMetadata + + x_val = x.values + from earthkit.utils.array import array_namespace + + xp = array_namespace(x_val) + x_val = xp.asarray(x_val) + + # single value + if x_val.size == 1: + return from_array([x_val], [UserMetadata()]) + # multiple values + else: + ref_field_shape = ref[0].shape + x_shape = x_val.shape + if len(x_shape) > 1: + x_field_shape = x_shape[1:] + if math.prod(ref_field_shape) == math.prod(x_field_shape): + return from_array(x_val, [UserMetadata()] * x_shape[0]) + elif math.prod(ref_field_shape) == math.prod(x_shape): + return from_array([x_val], [UserMetadata()]) + elif x_shape[0] == len(ref): + return from_array(x_val, [UserMetadata()] * x_shape[0]) + + assumed_ref_shape = tuple(len(ref), **ref_field_shape) + raise ValueError(f"y shape={x.shape} cannot be used with x shape={assumed_ref_shape}") + + raise ValueError(f"y type={type(x)} cannot be used with x type={type(ref)}") + + @staticmethod + def unary_op(oper, x): + r = [] + for f in x: + f = f._unary_op(oper) + # f.to_disk() + r.append(f) + return x.from_fields(r) + + @staticmethod + def binary_op(oper, x, y): + from earthkit.data.core.fieldlist import FieldList + + assert isinstance(x, FieldList) + + y = LoopCompute.create_fieldlist(x, y) + assert isinstance(y, FieldList) + + if len(y) == 0: + raise ValueError("FieldList y must not be empty") + if len(x) != len(y): + from itertools import repeat + + if len(x) == 1: + x = repeat(x[0]) + elif len(y) == 1: + y = repeat(y[0]) + else: + raise ValueError("FieldLists must have the same length or one of them must be 1") + + r = [] + for f1, f2 in zip(x, y): + f = f1._binary_op(oper, f2) + # f.to_disk() + r.append(f) + return FieldList.from_fields(r) + + @staticmethod + def apply_ufunc(func, ref, *args, template=None): + from earthkit.data.core.fieldlist import FieldList + + x = [get_wrapper(a) for a in args] + ds = [] + for i, a in enumerate(x): + if a is not ref: + a = LoopCompute.create_fieldlist(ref, a) + if len(a) == 0: + raise ValueError(f"FieldList {a} at index={i} must not be empty") + if len(ref) != len(a): + from itertools import repeat + + if len(a) == 1: + a = repeat(a[0]) + else: + raise ValueError("FieldLists must have the same length or one of them must be 1") + ds.append(a) + + r = [] + for f_ref, *f_ds in zip(ref, *ds): + x = [f.values for f in f_ds] + vx = func(*x) + f = f_ref.clone(values=vx) + # f.to_disk() + r.append(f) + return FieldList.from_fields(r) + + +methods = {"loop": LoopCompute} + + +def get_method(method): + m = methods.get(method) + if m is None: + raise ValueError(f"Unknown method: {method}") + return m diff --git a/src/earthkit/data/utils/metadata/dict.py b/src/earthkit/data/utils/metadata/dict.py index 220f32270..2e15720fc 100644 --- a/src/earthkit/data/utils/metadata/dict.py +++ b/src/earthkit/data/utils/metadata/dict.py @@ -358,8 +358,8 @@ class UserMetadata(Metadata): LS_KEYS = ["param", "level", "base_datetime", "valid_datetime", "step", "number"] - def __init__(self, d, shape=None, **kwargs): - self._data = d + def __init__(self, d=None, shape=None, **kwargs): + self._data = d or {} self._shape = shape def __len__(self): diff --git a/src/earthkit/data/wrappers/integer.py b/src/earthkit/data/wrappers/integer.py index 6ce6f1414..a3dd0560e 100644 --- a/src/earthkit/data/wrappers/integer.py +++ b/src/earthkit/data/wrappers/integer.py @@ -25,6 +25,10 @@ def to_datetime(self): def to_datetime_list(self): return [self.to_datetime()] + @property + def values(self): + return self.data + def wrapper(data, *args, fieldlist=False, **kwargs): if isinstance(data, int): diff --git a/src/earthkit/data/wrappers/ndarray.py b/src/earthkit/data/wrappers/ndarray.py index e9e22797a..aab33e851 100644 --- a/src/earthkit/data/wrappers/ndarray.py +++ b/src/earthkit/data/wrappers/ndarray.py @@ -18,6 +18,10 @@ class NumpyNDArrayWrapper(Wrapper): def __init__(self, data): self.data = data + @property + def values(self): + return self.data + def to_numpy(self): """Return a numpy `ndarray` representation of the data. diff --git a/tests/grib/test_grib_compute.py b/tests/grib/test_grib_compute.py new file mode 100644 index 000000000..9df810a07 --- /dev/null +++ b/tests/grib/test_grib_compute.py @@ -0,0 +1,313 @@ +#!/usr/bin/env python3 + +# (C) Copyright 2020 ECMWF. +# +# This software is licensed under the terms of the Apache Licence Version 2.0 +# which can be obtained at http://www.apache.org/licenses/LICENSE-2.0. +# In applying this licence, ECMWF does not waive the privileges and immunities +# granted to it by virtue of its status as an intergovernmental organisation +# nor does it submit to any jurisdiction. +# + +import math +import os +import sys + +import numpy as np +import pytest + +from earthkit.data.utils.compute import apply_ufunc + +here = os.path.dirname(__file__) +sys.path.insert(0, here) +from grib_fixtures import FL_NUMPY # noqa: E402 +from grib_fixtures import FL_TYPES # noqa: E402 +from grib_fixtures import load_grib_data # noqa: E402 + + +class ComputeOperand: + def __init__(self, ds, array_backend=None): + self.ds = ds + self.array_backend = array_backend + + +class SingleValueOperand(ComputeOperand): + def val(self): + return 10, 10 + + +class ArraySingleValueOperand(ComputeOperand): + def val(self): + xp = self.array_backend.namespace + v = xp.asarray([(i + 1) * 10 for i in range(len(self.ds))]) + + size = self.ds[0].values.size + if callable(size): + size = math.prod(self.ds[0].values.shape) + + z = xp.zeros(size) + z = [z + xp.asarray((i + 1) * 10) for i in range(len(self.ds))] + v_ref = xp.stack(z) + return v, v_ref + + +class ArrayFieldOperand(ComputeOperand): + def val(self): + xp = self.array_backend.namespace + return self.ds[0].values, xp.asarray([self.ds[0].values for _ in range(len(self.ds))]) + + +class ArrayFieldListOperand(ComputeOperand): + def val(self): + return self.ds.values, self.ds.values + + +class FieldOperand(ComputeOperand): + def val(self): + return self.ds[0], self.ds[0].values + + +class SingleFieldListOperand(ComputeOperand): + def val(self): + f = self.ds.from_fields([self.ds[0]]) + return f, f.values + + +class FieldListOperand(ComputeOperand): + def val(self): + return self.ds, self.ds.values + + +class FieldUnaryOperand(ComputeOperand): + def val(self): + return self.ds[0], self.ds[0].values + + +class FieldListUnaryOperand(ComputeOperand): + def val(self): + return self.ds, self.ds.values + + +RIGHT_OPERANDS = [ + SingleValueOperand, + ArraySingleValueOperand, + ArrayFieldOperand, + ArrayFieldListOperand, + FieldOperand, + SingleFieldListOperand, + FieldListOperand, +] + +# arrays cannot be left operands +LEFT_OPERANDS = [SingleValueOperand, FieldOperand, SingleFieldListOperand, FieldListOperand] + +UNARY_OPERANDS = [ + FieldUnaryOperand, + FieldListUnaryOperand, +] + + +@pytest.mark.parametrize("fl_type", FL_TYPES) +@pytest.mark.parametrize("operand", RIGHT_OPERANDS) +def test_grib_compute_add(fl_type, operand): + ds, array_backend = load_grib_data("test.grib", fl_type) + rval, rval_ref = operand(ds, array_backend).val() + + res = ds + rval + ref = ds.values + rval_ref + assert array_backend.allclose(res.values, ref, equal_nan=True) + + +@pytest.mark.parametrize("fl_type", FL_TYPES) +@pytest.mark.parametrize("operand", RIGHT_OPERANDS) +def test_grib_compute_sub(fl_type, operand): + ds, array_backend = load_grib_data("test.grib", fl_type) + rval, rval_ref = operand(ds, array_backend).val() + + res = ds - rval + ref = ds.values - rval_ref + assert array_backend.allclose(res.values, ref, equal_nan=True) + + +@pytest.mark.parametrize("fl_type", FL_TYPES) +@pytest.mark.parametrize("operand", RIGHT_OPERANDS) +def test_grib_compute_mul(fl_type, operand): + ds, array_backend = load_grib_data("test.grib", fl_type) + rval, rval_ref = operand(ds, array_backend).val() + + res = ds * rval + ref = ds.values * rval_ref + assert array_backend.allclose(res.values, ref, equal_nan=True) + + +@pytest.mark.parametrize("fl_type", FL_TYPES) +@pytest.mark.parametrize("operand", RIGHT_OPERANDS) +def test_grib_compute_div(fl_type, operand): + ds, array_backend = load_grib_data("test.grib", fl_type) + rval, rval_ref = operand(ds, array_backend).val() + + res = ds / rval + ref = ds.values / rval_ref + assert array_backend.allclose(res.values, ref, equal_nan=True) + + +@pytest.mark.parametrize("fl_type", FL_TYPES) +@pytest.mark.parametrize("operand", RIGHT_OPERANDS) +def test_grib_compute_floordiv(fl_type, operand): + ds, array_backend = load_grib_data("test.grib", fl_type) + rval, rval_ref = operand(ds, array_backend).val() + + res = ds // rval + ref = ds.values // rval_ref + assert array_backend.allclose(res.values, ref, equal_nan=True) + + +@pytest.mark.parametrize("fl_type", FL_TYPES) +@pytest.mark.parametrize("operand", RIGHT_OPERANDS) +def test_grib_compute_mod(fl_type, operand): + ds, array_backend = load_grib_data("test.grib", fl_type) + rval, rval_ref = operand(ds, array_backend).val() + + res = ds % rval + ref = ds.values % rval_ref + assert array_backend.allclose(res.values, ref, equal_nan=True) + + +@pytest.mark.parametrize("fl_type", FL_TYPES) +@pytest.mark.parametrize("operand", RIGHT_OPERANDS) +def test_grib_compute_pow(fl_type, operand): + ds, array_backend = load_grib_data("test.grib", fl_type) + rval, rval_ref = operand(ds, array_backend).val() + + res = ds**rval + ref = ds.values**rval_ref + assert array_backend.allclose(res.values, ref, equal_nan=True) + + +@pytest.mark.parametrize("fl_type", FL_TYPES) +@pytest.mark.parametrize("operand", LEFT_OPERANDS) +def test_grib_compute_radd(fl_type, operand): + ds, array_backend = load_grib_data("test.grib", fl_type) + lval, lval_ref = operand(ds, array_backend).val() + + res = lval + ds + ref = lval_ref + ds.values + assert array_backend.allclose(res.values, ref, equal_nan=True) + + +@pytest.mark.parametrize("fl_type", FL_TYPES) +@pytest.mark.parametrize("operand", LEFT_OPERANDS) +def test_grib_compute_rsub(fl_type, operand): + ds, array_backend = load_grib_data("test.grib", fl_type) + lval, lval_ref = operand(ds, array_backend).val() + + res = lval - ds + ref = lval_ref - ds.values + assert array_backend.allclose(res.values, ref, equal_nan=True) + + +@pytest.mark.parametrize("fl_type", FL_TYPES) +@pytest.mark.parametrize("operand", LEFT_OPERANDS) +def test_grib_compute_rmul(fl_type, operand): + ds, array_backend = load_grib_data("test.grib", fl_type) + lval, lval_ref = operand(ds, array_backend).val() + + res = lval * ds + ref = lval_ref * ds.values + assert array_backend.allclose(res.values, ref, equal_nan=True) + + +@pytest.mark.parametrize("fl_type", FL_TYPES) +@pytest.mark.parametrize("operand", LEFT_OPERANDS) +def test_grib_compute_rdiv(fl_type, operand): + ds, array_backend = load_grib_data("test.grib", fl_type) + lval, lval_ref = operand(ds, array_backend).val() + + res = lval / ds + ref = lval_ref / ds.values + assert array_backend.allclose(res.values, ref, equal_nan=True) + + +@pytest.mark.parametrize("fl_type", FL_TYPES) +@pytest.mark.parametrize("operand", LEFT_OPERANDS) +def test_grib_compute_rfloordiv(fl_type, operand): + ds, array_backend = load_grib_data("test.grib", fl_type) + lval, lval_ref = operand(ds, array_backend).val() + + res = lval // ds + ref = lval_ref // ds.values + assert array_backend.allclose(res.values, ref, equal_nan=True) + + +@pytest.mark.parametrize("fl_type", FL_TYPES) +@pytest.mark.parametrize("operand", LEFT_OPERANDS) +def test_grib_compute_rmod(fl_type, operand): + ds, array_backend = load_grib_data("test.grib", fl_type) + lval, lval_ref = operand(ds, array_backend).val() + + res = lval % ds + ref = lval_ref % ds.values + assert array_backend.allclose(res.values, ref, equal_nan=True) + + +@pytest.mark.parametrize("fl_type", FL_TYPES) +@pytest.mark.parametrize("operand", LEFT_OPERANDS) +def test_grib_compute_rpow(fl_type, operand): + ds, array_backend = load_grib_data("test.grib", fl_type) + lval, lval_ref = operand(ds, array_backend).val() + + res = lval**ds + ref = lval_ref**ds.values + assert array_backend.allclose(res.values, ref, equal_nan=True) + + +@pytest.mark.parametrize("fl_type", FL_TYPES) +@pytest.mark.parametrize("operand", UNARY_OPERANDS) +def test_grib_compute_pos(fl_type, operand): + ds, array_backend = load_grib_data("test.grib", fl_type) + val, val_ref = operand(ds).val() + + res = +val + ref = val_ref + assert array_backend.allclose(res.values, ref, equal_nan=True) + + +@pytest.mark.parametrize("fl_type", FL_TYPES) +@pytest.mark.parametrize("operand", UNARY_OPERANDS) +def test_grib_compute_neg(fl_type, operand): + ds, array_backend = load_grib_data("test.grib", fl_type) + val, val_ref = operand(ds).val() + + res = -val + ref = -val_ref + assert array_backend.allclose(res.values, ref, equal_nan=True) + + +@pytest.mark.parametrize("fl_type", FL_NUMPY) +@pytest.mark.parametrize("operand", UNARY_OPERANDS) +def test_grib_compute_ufunc(fl_type, operand): + ds, array_backend = load_grib_data("test.grib", fl_type) + val, val_ref = operand(ds).val() + + def func(x, y): + return np.sin(x) + y * 2 + + ds1 = val + ds2 = val + 1 + + res = apply_ufunc(func, ds1, ds2) + ref = func(val_ref, val_ref + 1) + assert array_backend.allclose(res.values, ref, equal_nan=True) + + +@pytest.mark.parametrize("fl_type", FL_TYPES) +@pytest.mark.parametrize("operand", UNARY_OPERANDS) +def test_grib_compute_sin(fl_type, operand): + ds, array_backend = load_grib_data("test.grib", fl_type) + val, val_ref = operand(ds).val() + + xp = array_backend.namespace + + res = val.sin() + ref = xp.sin(val_ref) + assert array_backend.allclose(res.values, ref, equal_nan=True)