Skip to content

Add direct pint/pint-xarray support #41

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 15 commits into from
May 18, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions .github/workflows/ci.yml
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ jobs:
- uses: psf/black@stable
with:
options: "--check --verbose"
version: "22.1.0"
version: "22.3.0"

build:
name: build (${{ matrix.os }}, ${{ matrix.python-version }})
Expand Down Expand Up @@ -63,7 +63,7 @@ jobs:
poetry-version: ${{ matrix.poetry-version }}
- name: Install dependencies
run: |
poetry install
poetry install -E pint
poetry run pip install cf_units==3.0.1
- name: Test with pytest
run: |
Expand Down
11 changes: 10 additions & 1 deletion CHANGELOG.rst
Original file line number Diff line number Diff line change
@@ -1,5 +1,14 @@
Changelog
=========

v0.3.0 - Unreleased
-------------------
This release will focus on supporting Pint quantities.

Highlights
``````````
* If (at least 1) arguments are xarray.DataArray wrapped by pint-xarray, the result is wrapped into a pint-xarray quantity.

v0.2.1 - 2022-03-22
-------------------
Despite all the checking, we missed a bad bug.
Expand Down Expand Up @@ -37,4 +46,4 @@ Breaking Changes
v0.1.0 - 2021-12-15
-------------------
* Original release, was basically a proof of concept.
Only a few functions were wrapped.
Only a few functions were wrapped.
63 changes: 63 additions & 0 deletions README.rst
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,69 @@ Outputs

<class 'numpy.ndarray'> [-5.08964499 2.1101098 9.28348219]


We support (but don't yet validate) the usage of pint.Quantities and the usage of xarray wrapped Quantities.
Support for pint requires the installation of two optional dependencies: ``pint`` and ``pint-xarray``.
If any of the inputs to a gsw function are Quantities, the returned object will also be a Quantity belonging to the same UnitRegistry.

.. warning::

Quantities must all belong to the same pint.UnitRegistry, a ValueError will be thrown if there are mixed registries.

.. code:: python

import pint_xarray
import gsw_xarray as gsw

# Create a xarray.Dataset
import numpy as np
import xarray as xr
ds = xr.Dataset()
id = np.arange(3)
ds['id'] = xr.DataArray(id, coords={'id':id})
ds['CT'] = ds['id'] * 10
# make sure there are unit attrs this time
ds['CT'].attrs = {'standard_name':'sea_water_conservative_temperature', 'units': 'degC'}
ds['SA'] = ds['id'] * 0.1 + 34
ds['SA'].attrs = {'standard_name':'sea_water_absolute_salinity', 'units': 'g/kg'}

# use the pint accessor to quantify things
ds = ds.pint.quantify()

# Apply gsw functions
sigma0 = gsw.sigma0(SA=ds['SA'], CT=ds['CT'])
# outputs are now quantities!
print(sigma0)

Outputs

::

<xarray.DataArray 'sigma0' (id: 3)>
<Quantity([27.17191038 26.12820162 24.03930887], 'kilogram / meter ** 3')>
Coordinates:
* id (id) int64 0 1 2
Attributes:
standard_name: sea_water_sigma_t

The usage of xarray wrapped Quantities is not required, you can use pint directly (though the ``pint-xarray`` dep still needs to be installed).

.. code:: python

import gsw_xarray as gsw
import pint
ureg = pint.UnitRegistry()
SA = ureg.Quantity(35, ureg("g/kg"))
CT = ureg.Quantity(10, ureg.degC)
sigma0 = gsw.sigma0(SA=SA, CT=CT)
print(sigma0)

Outputs

::

26.824644457868317 kilogram / meter ** 3

Installation
------------
Pip
Expand Down
75 changes: 75 additions & 0 deletions gsw_xarray/_core.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from functools import wraps
from itertools import chain

import gsw
import xarray as xr
Expand All @@ -7,24 +8,98 @@
from ._names import _names
from ._check_funcs import _check_funcs

try:
import pint_xarray
import pint

except ImportError:
pint_xarray = None


def add_attrs(rv, attrs, name):
if isinstance(rv, xr.DataArray):
rv.name = name
rv.attrs = attrs


def quantify(rv, attrs, unit_registry=None):
if unit_registry is None:
return rv

if isinstance(rv, xr.DataArray):
rv = rv.pint.quantify(unit_registry=unit_registry)
else:
if attrs is not None:
# Necessary to use the Q_ and not simply multiplication with ureg unit because of temperature
# see https://pint.readthedocs.io/en/latest/nonmult.html
rv = unit_registry.Quantity(rv, attrs["units"])
return rv


def pint_compat(args, kwargs):
if pint_xarray is None:
return args, kwargs, None

using_pint = False
new_args = []
new_kwargs = {}
registries = []
for arg in args:
if isinstance(arg, xr.DataArray):
if arg.pint.units is not None:
new_args.append(arg.pint.dequantify())
registries.append(arg.pint.registry)
else:
new_args.append(arg)
elif isinstance(arg, pint.Quantity):
new_args.append(arg.magnitude)
registries.append(arg._REGISTRY)
else:
new_args.append(arg)

for kw, arg in kwargs.items():
if isinstance(arg, xr.DataArray):
if arg.pint.units is not None:
new_kwargs[kw] = arg.pint.dequantify()
registries.append(arg.pint.registry)
else:
new_kwargs[kw] = arg
elif isinstance(arg, pint.Quantity):
new_kwargs[kw] = arg.magnitude
registries.append(arg._REGISTRY)
else:
new_kwargs[kw] = arg

registries = set(registries)
if len(registries) > 1:
raise ValueError("Quantity arguments must all belong to the same unit registry")
elif len(registries) == 0:
registries = None
else:
(registries,) = registries
return new_args, new_kwargs, registries


def cf_attrs(attrs, name, check_func):
def cf_attrs_decorator(func):
@wraps(func)
def cf_attrs_wrapper(*args, **kwargs):
args, kwargs, unit_registry = pint_compat(args, kwargs)
rv = func(*args, **kwargs)
attrs_checked = check_func(attrs, args, kwargs)
if isinstance(rv, tuple):
rv_updated = []
for (i, da) in enumerate(rv):
add_attrs(da, attrs_checked[i], name[i])
rv_updated.append(
quantify(da, attrs_checked[i], unit_registry=unit_registry)
)

rv = tuple(rv_updated)

else:
add_attrs(rv, attrs_checked, name)
rv = quantify(rv, attrs_checked, unit_registry=unit_registry)
return rv

return cf_attrs_wrapper
Expand Down
14 changes: 12 additions & 2 deletions gsw_xarray/tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,13 +10,23 @@ def ds():
id = np.arange(3)
ds["id"] = xr.DataArray(id, coords={"id": id})
ds["CT"] = ds["id"] * 10
ds["CT"].attrs = {"standard_name": "sea_water_conservative_temperature"}
ds["CT"].attrs = {
"standard_name": "sea_water_conservative_temperature",
"units": "degC",
}
ds["SA"] = ds["id"] * 0.1 + 34
ds["SA"].attrs = {"standard_name": "sea_water_absolute_salinity"}
ds["SA"].attrs = {"standard_name": "sea_water_absolute_salinity", "units": "g/kg"}
return ds


@pytest.fixture(scope="session")
def ureg():
pint = pytest.importorskip("pint")
return pint.UnitRegistry()


@pytest.fixture
def ds_pint(ds, ureg):
pytest.importorskip("pint_xarray")

return ds.pint.quantify()
83 changes: 83 additions & 0 deletions gsw_xarray/tests/test_units.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,8 @@
"""
import pytest

import gsw_xarray

from .test_imports import gsw_base
from gsw_xarray._attributes import _func_attrs

Expand Down Expand Up @@ -41,3 +43,84 @@ def test_unit_cf_units(func_name):
for a in attrs:
print(a["units"])
cf_units.Unit(a["units"])


def test_xarray_quantity(ds_pint):
pint_xarray = pytest.importorskip("pint_xarray")
sigma0 = gsw_xarray.sigma0(SA=ds_pint.SA, CT=ds_pint.CT)
assert sigma0.pint.units == pint_xarray.unit_registry("kg / m^3")


@pytest.mark.parametrize("SA_type", ["unit", "ds"])
@pytest.mark.parametrize("CT_type", ["unit", "ds"])
def test_xarray_quantity_or_ds(ds, ds_pint, SA_type, CT_type):
"""If at least 1 of the inputs is quantity, the result should be quantity"""
pint_xarray = pytest.importorskip("pint_xarray")
if SA_type == "unit":
SA = ds_pint.SA
elif SA_type == "ds":
SA = ds.SA

if CT_type == "unit":
CT = ds_pint.CT
elif CT_type == "ds":
CT = ds.CT

sigma0 = gsw_xarray.sigma0(SA=SA, CT=CT)
if SA_type == "unit" or CT_type == "unit":
assert sigma0.pint.units == pint_xarray.unit_registry("kg / m^3")
else:
assert sigma0.pint.units is None
assert sigma0.pint.quantify().pint.units == pint_xarray.unit_registry(
"kg / m^3"
)


def test_func_return_tuple_quantity(ds_pint):
pint_xarray = pytest.importorskip("pint_xarray")
(CT_SA, CT_pt) = gsw_xarray.CT_first_derivatives(ds_pint.SA, 1)
assert CT_SA.pint.units == pint_xarray.unit_registry("K/(g/kg)")


def test_pint_quantity_xarray(ds):
"""If input is mixed between xr.DataArray and pint quantity it should return pint-xarray wrapped quantity"""
pint_xarray = pytest.importorskip("pint_xarray")

ureg = pint_xarray.unit_registry
Q_ = ureg.Quantity
sigma0 = gsw_xarray.sigma0(SA=ds.SA, CT=Q_(25.4, ureg.degC))
assert sigma0.pint.units == pint_xarray.unit_registry("kg / m^3")


def test_pint_quantity():
"""If input is pint quantity should return a quantity"""
pint_xarray = pytest.importorskip("pint_xarray")
import pint

ureg = pint_xarray.unit_registry
CT = gsw_xarray.CT_from_pt(SA=35 * ureg("g / kg"), pt=10)
assert isinstance(CT, pint.Quantity)


def test_pint_quantity_tuple():
"""If input is pint quantity should return a quantity"""
pint_xarray = pytest.importorskip("pint_xarray")
import pint

ureg = pint_xarray.unit_registry
(a, b) = gsw_xarray.CT_first_derivatives(35 * ureg("g / kg"), pt=1)
assert isinstance(a, pint.Quantity)
assert isinstance(b, pint.Quantity)


def test_mixed_unit_regestiries():
"""If input quantities are from different registries, it should fail"""
pint_xarray = pytest.importorskip("pint_xarray")
import pint

ureg_a = pint.UnitRegistry()
ureg_b = pint.UnitRegistry()
with pytest.raises(ValueError):
gsw_xarray.CT_first_derivatives(
35 * ureg_a("g / kg"), pt=ureg_b.Quantity(1, ureg_b.degC)
)
Loading