|
| 1 | +""" |
| 2 | +Tests functions with xarray inputs. |
| 3 | +
|
| 4 | +This version is a copy of the original test_check_functions but with |
| 5 | +an import of xarray, and conversion of the 3 main check cast arrays |
| 6 | +into DataArray objects. |
| 7 | +
|
| 8 | +An additional xarray-dask test is added. |
| 9 | +""" |
| 10 | + |
| 11 | +import os |
| 12 | +import pytest |
| 13 | + |
| 14 | +import numpy as np |
| 15 | +from numpy.testing import assert_allclose |
| 16 | + |
| 17 | +import gsw |
| 18 | +from gsw._utilities import Bunch |
| 19 | +from check_functions import parse_check_functions |
| 20 | + |
| 21 | +xr = pytest.importorskip('xarray') |
| 22 | + |
| 23 | +# Most of the tests have some nan values, so we need to suppress the warning. |
| 24 | +# Any more careful fix would likely require considerable effort. |
| 25 | +np.seterr(invalid='ignore') |
| 26 | + |
| 27 | +root_path = os.path.abspath(os.path.dirname(__file__)) |
| 28 | + |
| 29 | +# Function checks that we can't handle automatically yet. |
| 30 | +blacklist = ['deltaSA_atlas', # the test is complicated; doesn't fit the pattern. |
| 31 | + 'geostrophic_velocity', # test elsewhere; we changed the API |
| 32 | + #'CT_from_entropy', # needs prior entropy_from_CT; don't have it in C |
| 33 | + #'CT_first_derivatives', # passes, but has trouble in "details"; |
| 34 | + # see check_functions.py |
| 35 | + #'entropy_second_derivatives', # OK now; handling extra parens. |
| 36 | + #'melting_ice_into_seawater', # OK now; fixed nargs mismatch. |
| 37 | + ] |
| 38 | + |
| 39 | +# We get an overflow from ct_from_enthalpy_exact, but the test passes. |
| 40 | +cv = Bunch(np.load(os.path.join(root_path, 'gsw_cv_v3_0.npz'))) |
| 41 | + |
| 42 | +# Substitute new check values for the pchip interpolation version. |
| 43 | +cv.geo_strf_dyn_height = np.load(os.path.join(root_path,'geo_strf_dyn_height.npy')) |
| 44 | +cv.geo_strf_velocity = np.load(os.path.join(root_path,'geo_strf_velocity.npy')) |
| 45 | + |
| 46 | +for name in ['SA_chck_cast', 't_chck_cast', 'p_chck_cast']: |
| 47 | + cv[name] = xr.DataArray(cv[name]) |
| 48 | + |
| 49 | +cf = Bunch() |
| 50 | + |
| 51 | +d = dir(gsw) |
| 52 | +funcnames = [name for name in d if '__' not in name] |
| 53 | + |
| 54 | +mfuncs = parse_check_functions(os.path.join(root_path, 'gsw_check_functions_save.m')) |
| 55 | +mfuncs = [mf for mf in mfuncs if mf.name in d and mf.name not in blacklist] |
| 56 | +mfuncnames = [mf.name for mf in mfuncs] |
| 57 | + |
| 58 | + |
| 59 | +@pytest.fixture(scope='session', params=mfuncs) |
| 60 | +def cfcf(request): |
| 61 | + return cv, cf, request.param |
| 62 | + |
| 63 | + |
| 64 | +def test_check_function(cfcf): |
| 65 | + cv, cf, mfunc = cfcf |
| 66 | + mfunc.run(locals()) |
| 67 | + if mfunc.exception is not None or not mfunc.passed: |
| 68 | + print('\n', mfunc.name) |
| 69 | + print(' ', mfunc.runline) |
| 70 | + print(' ', mfunc.testline) |
| 71 | + if mfunc.exception is None: |
| 72 | + mfunc.exception = ValueError('Calculated values are different from the expected matlab results.') |
| 73 | + raise mfunc.exception |
| 74 | + else: |
| 75 | + print(mfunc.name) |
| 76 | + assert mfunc.passed |
| 77 | + |
| 78 | + |
| 79 | +def test_dask_chunking(): |
| 80 | + dsa = pytest.importorskip('dask.array') |
| 81 | + |
| 82 | + # define some input data |
| 83 | + shape = (100, 1000) |
| 84 | + chunks = (100, 200) |
| 85 | + sp = xr.DataArray(dsa.full(shape, 35., chunks=chunks), dims=['time', 'depth']) |
| 86 | + p = xr.DataArray(np.arange(shape[1]), dims=['depth']) |
| 87 | + lon = 0 |
| 88 | + lat = 45 |
| 89 | + |
| 90 | + sa = gsw.SA_from_SP(sp, p, lon, lat) |
| 91 | + sa_dask = sa.compute() |
| 92 | + |
| 93 | + sa_numpy = gsw.SA_from_SP(np.full(shape, 35.0), p.values, lon, lat) |
| 94 | + assert_allclose(sa_dask, sa_numpy) |
0 commit comments