Skip to content

Commit 4caf685

Browse files
committed
Add xarray testing, fix xarray bug
1 parent 6de9f79 commit 4caf685

File tree

3 files changed

+99
-2
lines changed

3 files changed

+99
-2
lines changed

gsw/geostrophy.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -73,7 +73,10 @@ def geo_strf_dyn_height(SA, CT, p, p_ref=0, axis=0, max_dp=1.0,
7373
dh = np.empty(SA.shape, dtype=float)
7474
dh.fill(np.nan)
7575

76-
order = 'F' if SA.flags.fortran else 'C'
76+
try:
77+
order = 'F' if SA.flags.fortran else 'C'
78+
except AttributeError:
79+
order = 'C' # e.g., xarray DataArray doesn't have flags
7780
for ind in indexer(SA.shape, axis, order=order):
7881
igood = goodmask[ind]
7982
# If p_ref is below the deepest value, skip the profile.

gsw/tests/check_functions.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -55,7 +55,7 @@ def find(x):
5555
"""
5656
Numpy equivalent to Matlab find.
5757
"""
58-
return np.nonzero(x.flatten())[0]
58+
return np.nonzero(np.asarray(x).flatten())[0]
5959

6060

6161
def group_or(line):

gsw/tests/test_xarray.py

Lines changed: 94 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,94 @@
1+
"""
2+
Tests functions with xarray inputs.
3+
4+
This version is a copy of the original test_check_functions but with
5+
an import of xarray, and conversion of the 3 main check cast arrays
6+
into DataArray objects.
7+
8+
An additional xarray-dask test is added.
9+
"""
10+
11+
import os
12+
import pytest
13+
14+
import numpy as np
15+
from numpy.testing import assert_allclose
16+
17+
import gsw
18+
from gsw._utilities import Bunch
19+
from check_functions import parse_check_functions
20+
21+
xr = pytest.importorskip('xarray')
22+
23+
# Most of the tests have some nan values, so we need to suppress the warning.
24+
# Any more careful fix would likely require considerable effort.
25+
np.seterr(invalid='ignore')
26+
27+
root_path = os.path.abspath(os.path.dirname(__file__))
28+
29+
# Function checks that we can't handle automatically yet.
30+
blacklist = ['deltaSA_atlas', # the test is complicated; doesn't fit the pattern.
31+
'geostrophic_velocity', # test elsewhere; we changed the API
32+
#'CT_from_entropy', # needs prior entropy_from_CT; don't have it in C
33+
#'CT_first_derivatives', # passes, but has trouble in "details";
34+
# see check_functions.py
35+
#'entropy_second_derivatives', # OK now; handling extra parens.
36+
#'melting_ice_into_seawater', # OK now; fixed nargs mismatch.
37+
]
38+
39+
# We get an overflow from ct_from_enthalpy_exact, but the test passes.
40+
cv = Bunch(np.load(os.path.join(root_path, 'gsw_cv_v3_0.npz')))
41+
42+
# Substitute new check values for the pchip interpolation version.
43+
cv.geo_strf_dyn_height = np.load(os.path.join(root_path,'geo_strf_dyn_height.npy'))
44+
cv.geo_strf_velocity = np.load(os.path.join(root_path,'geo_strf_velocity.npy'))
45+
46+
for name in ['SA_chck_cast', 't_chck_cast', 'p_chck_cast']:
47+
cv[name] = xr.DataArray(cv[name])
48+
49+
cf = Bunch()
50+
51+
d = dir(gsw)
52+
funcnames = [name for name in d if '__' not in name]
53+
54+
mfuncs = parse_check_functions(os.path.join(root_path, 'gsw_check_functions_save.m'))
55+
mfuncs = [mf for mf in mfuncs if mf.name in d and mf.name not in blacklist]
56+
mfuncnames = [mf.name for mf in mfuncs]
57+
58+
59+
@pytest.fixture(scope='session', params=mfuncs)
60+
def cfcf(request):
61+
return cv, cf, request.param
62+
63+
64+
def test_check_function(cfcf):
65+
cv, cf, mfunc = cfcf
66+
mfunc.run(locals())
67+
if mfunc.exception is not None or not mfunc.passed:
68+
print('\n', mfunc.name)
69+
print(' ', mfunc.runline)
70+
print(' ', mfunc.testline)
71+
if mfunc.exception is None:
72+
mfunc.exception = ValueError('Calculated values are different from the expected matlab results.')
73+
raise mfunc.exception
74+
else:
75+
print(mfunc.name)
76+
assert mfunc.passed
77+
78+
79+
def test_dask_chunking():
80+
dsa = pytest.importorskip('dask.array')
81+
82+
# define some input data
83+
shape = (100, 1000)
84+
chunks = (100, 200)
85+
sp = xr.DataArray(dsa.full(shape, 35., chunks=chunks), dims=['time', 'depth'])
86+
p = xr.DataArray(np.arange(shape[1]), dims=['depth'])
87+
lon = 0
88+
lat = 45
89+
90+
sa = gsw.SA_from_SP(sp, p, lon, lat)
91+
sa_dask = sa.compute()
92+
93+
sa_numpy = gsw.SA_from_SP(np.full(shape, 35.0), p.values, lon, lat)
94+
assert_allclose(sa_dask, sa_numpy)

0 commit comments

Comments
 (0)