|
| 1 | +""" |
| 2 | +Tests functions with xarray inputs. |
| 3 | +
|
| 4 | +This version is a copy of the original test_check_functions but with |
| 5 | +an import of xarray, and conversion of the 3 main check cast arrays |
| 6 | +into DataArray objects. |
| 7 | +
|
| 8 | +An additional xarray-dask test is added. |
| 9 | +""" |
| 10 | + |
| 11 | +import os |
| 12 | +import pytest |
| 13 | + |
| 14 | +import numpy as np |
| 15 | +from numpy.testing import assert_allclose |
| 16 | + |
| 17 | +import gsw |
| 18 | +from gsw._utilities import Bunch |
| 19 | +from check_functions import parse_check_functions |
| 20 | + |
| 21 | +xr = pytest.importorskip('xarray') |
| 22 | + |
| 23 | +# Most of the tests have some nan values, so we need to suppress the warning. |
| 24 | +# Any more careful fix would likely require considerable effort. |
| 25 | +np.seterr(invalid='ignore') |
| 26 | + |
| 27 | +root_path = os.path.abspath(os.path.dirname(__file__)) |
| 28 | + |
| 29 | +# Function checks that we can't handle automatically yet. |
| 30 | +blacklist = ['deltaSA_atlas', # the test is complicated; doesn't fit the pattern. |
| 31 | + 'geostrophic_velocity', # test elsewhere; we changed the API |
| 32 | + #'CT_from_entropy', # needs prior entropy_from_CT; don't have it in C |
| 33 | + #'CT_first_derivatives', # passes, but has trouble in "details"; |
| 34 | + # see check_functions.py |
| 35 | + #'entropy_second_derivatives', # OK now; handling extra parens. |
| 36 | + #'melting_ice_into_seawater', # OK now; fixed nargs mismatch. |
| 37 | + ] |
| 38 | + |
| 39 | +# We get an overflow from ct_from_enthalpy_exact, but the test passes. |
| 40 | +cv = Bunch(np.load(os.path.join(root_path, 'gsw_cv_v3_0.npz'))) |
| 41 | + |
| 42 | +# Substitute new check values for the pchip interpolation version. |
| 43 | +cv.geo_strf_dyn_height = np.load(os.path.join(root_path,'geo_strf_dyn_height.npy')) |
| 44 | +cv.geo_strf_velocity = np.load(os.path.join(root_path,'geo_strf_velocity.npy')) |
| 45 | + |
| 46 | +for name in ['SA_chck_cast', 't_chck_cast', 'p_chck_cast']: |
| 47 | + cv[name] = xr.DataArray(cv[name]) |
| 48 | + |
| 49 | +cf = Bunch() |
| 50 | + |
| 51 | +d = dir(gsw) |
| 52 | +funcnames = [name for name in d if '__' not in name] |
| 53 | + |
| 54 | +mfuncs = parse_check_functions(os.path.join(root_path, 'gsw_check_functions_save.m')) |
| 55 | +mfuncs = [mf for mf in mfuncs if mf.name in d and mf.name not in blacklist] |
| 56 | +mfuncnames = [mf.name for mf in mfuncs] |
| 57 | + |
| 58 | + |
| 59 | +@pytest.fixture(scope='session', params=mfuncs) |
| 60 | +def cfcf(request): |
| 61 | + return cv, cf, request.param |
| 62 | + |
| 63 | + |
| 64 | +def test_check_function(cfcf): |
| 65 | + cv, cf, mfunc = cfcf |
| 66 | + mfunc.run(locals()) |
| 67 | + if mfunc.exception is not None or not mfunc.passed: |
| 68 | + print('\n', mfunc.name) |
| 69 | + print(' ', mfunc.runline) |
| 70 | + print(' ', mfunc.testline) |
| 71 | + if mfunc.exception is None: |
| 72 | + mfunc.exception = ValueError('Calculated values are different from the expected matlab results.') |
| 73 | + raise mfunc.exception |
| 74 | + else: |
| 75 | + print(mfunc.name) |
| 76 | + assert mfunc.passed |
| 77 | + |
| 78 | + |
| 79 | +def test_dask_chunking(): |
| 80 | + dsa = pytest.importorskip('dask.array') |
| 81 | + |
| 82 | + # define some input data |
| 83 | + shape = (100, 1000) |
| 84 | + chunks = (100, 200) |
| 85 | + sp = xr.DataArray(dsa.full(shape, 35., chunks=chunks), dims=['time', 'depth']) |
| 86 | + p = xr.DataArray(np.arange(shape[1]), dims=['depth']) |
| 87 | + lon = 0 |
| 88 | + lat = 45 |
| 89 | + |
| 90 | + sa = gsw.SA_from_SP(sp, p, lon, lat) |
| 91 | + sa_dask = sa.compute() |
| 92 | + |
| 93 | + sa_numpy = gsw.SA_from_SP(np.full(shape, 35.0), p.values, lon, lat) |
| 94 | + assert_allclose(sa_dask, sa_numpy) |
| 95 | + |
| 96 | + |
| 97 | +# Additional tests from Graeme MacGilchrist |
| 98 | +# https://nbviewer.jupyter.org/github/gmacgilchrist/wmt_bgc/blob/master/notebooks/test_gsw-xarray.ipynb |
| 99 | + |
| 100 | +# Define dimensions and coordinates |
| 101 | +dims = ['y','z','t'] |
| 102 | +# 2x2x2 |
| 103 | +y = np.arange(0,2) |
| 104 | +z = np.arange(0,2) |
| 105 | +t = np.arange(0,2) |
| 106 | +# Define numpy arrays of salinity, temperature and pressure |
| 107 | +SA_vals = np.array([[[34.7,34.8],[34.9,35]],[[35.1,35.2],[35.3,35.4]]]) |
| 108 | +CT_vals = np.array([[[7,8],[9,10]],[[11,12],[13,14]]]) |
| 109 | +p_vals = np.array([10,20]) |
| 110 | +lat_vals = np.array([0,10]) |
| 111 | +# Plug in to xarray objects |
| 112 | +SA = xr.DataArray(SA_vals,dims=dims,coords={'y':y,'z':z,'t':t}) |
| 113 | +CT = xr.DataArray(CT_vals,dims=dims,coords={'y':y,'z':z,'t':t}) |
| 114 | +p = xr.DataArray(p_vals,dims=['z'],coords={'z':z}) |
| 115 | +lat = xr.DataArray(lat_vals,dims=['y'],coords={'y':y}) |
| 116 | + |
| 117 | + |
| 118 | +def test_xarray_with_coords(): |
| 119 | + pytest.importorskip('dask') |
| 120 | + SA_chunk = SA.chunk(chunks={'y':1,'t':1}) |
| 121 | + CT_chunk = CT.chunk(chunks={'y':1,'t':1}) |
| 122 | + lat_chunk = lat.chunk(chunks={'y':1}) |
| 123 | + |
| 124 | + # Dimensions and cordinates match: |
| 125 | + expected = gsw.sigma0(SA_vals, CT_vals) |
| 126 | + xarray = gsw.sigma0(SA, CT) |
| 127 | + chunked = gsw.sigma0(SA_chunk, CT_chunk) |
| 128 | + assert_allclose(xarray, expected) |
| 129 | + assert_allclose(chunked, expected) |
| 130 | + |
| 131 | + # Broadcasting along dimension required (dimensions known) |
| 132 | + expected = gsw.alpha(SA_vals, CT_vals, p_vals[np.newaxis, :, np.newaxis]) |
| 133 | + xarray = gsw.alpha(SA, CT, p) |
| 134 | + chunked = gsw.alpha(SA_chunk, CT_chunk, p) |
| 135 | + assert_allclose(xarray, expected) |
| 136 | + assert_allclose(chunked, expected) |
| 137 | + |
| 138 | + # Broadcasting along dimension required (dimensions unknown/exclusive) |
| 139 | + expected = gsw.z_from_p(p_vals[:, np.newaxis], lat_vals[np.newaxis, :]) |
| 140 | + xarray = gsw.z_from_p(p, lat) |
| 141 | + chunked = gsw.z_from_p(p,lat_chunk) |
| 142 | + assert_allclose(xarray, expected) |
| 143 | + assert_allclose(chunked, expected) |
0 commit comments