Skip to content

Commit 0a82070

Browse files
committed
Add test cases from @gmacgilchrist
1 parent 30d0ecc commit 0a82070

File tree

1 file changed

+49
-0
lines changed

1 file changed

+49
-0
lines changed

gsw/tests/test_xarray.py

Lines changed: 49 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -92,3 +92,52 @@ def test_dask_chunking():
9292

9393
sa_numpy = gsw.SA_from_SP(np.full(shape, 35.0), p.values, lon, lat)
9494
assert_allclose(sa_dask, sa_numpy)
95+
96+
97+
# Additional tests from Graeme MacGilchrist
98+
# https://nbviewer.jupyter.org/github/gmacgilchrist/wmt_bgc/blob/master/notebooks/test_gsw-xarray.ipynb
99+
100+
# Define dimensions and coordinates
101+
dims = ['y','z','t']
102+
# 2x2x2
103+
y = np.arange(0,2)
104+
z = np.arange(0,2)
105+
t = np.arange(0,2)
106+
# Define numpy arrays of salinity, temperature and pressure
107+
SA_vals = np.array([[[34.7,34.8],[34.9,35]],[[35.1,35.2],[35.3,35.4]]])
108+
CT_vals = np.array([[[7,8],[9,10]],[[11,12],[13,14]]])
109+
p_vals = np.array([10,20])
110+
lat_vals = np.array([0,10])
111+
# Plug in to xarray objects
112+
SA = xr.DataArray(SA_vals,dims=dims,coords={'y':y,'z':z,'t':t})
113+
CT = xr.DataArray(CT_vals,dims=dims,coords={'y':y,'z':z,'t':t})
114+
p = xr.DataArray(p_vals,dims=['z'],coords={'z':z})
115+
lat = xr.DataArray(lat_vals,dims=['y'],coords={'y':y})
116+
117+
118+
def test_xarray_with_coords():
119+
pytest.importorskip('dask')
120+
SA_chunk = SA.chunk(chunks={'y':1,'t':1})
121+
CT_chunk = CT.chunk(chunks={'y':1,'t':1})
122+
lat_chunk = lat.chunk(chunks={'y':1})
123+
124+
# Dimensions and cordinates match:
125+
expected = gsw.sigma0(SA_vals, CT_vals)
126+
xarray = gsw.sigma0(SA, CT)
127+
chunked = gsw.sigma0(SA_chunk, CT_chunk)
128+
assert_allclose(xarray, expected)
129+
assert_allclose(chunked, expected)
130+
131+
# Broadcasting along dimension required (dimensions known)
132+
expected = gsw.alpha(SA_vals, CT_vals, p_vals[np.newaxis, :, np.newaxis])
133+
xarray = gsw.alpha(SA, CT, p)
134+
chunked = gsw.alpha(SA_chunk, CT_chunk, p)
135+
assert_allclose(xarray, expected)
136+
assert_allclose(chunked, expected)
137+
138+
# Broadcasting along dimension required (dimensions unknown/exclusive)
139+
expected = gsw.z_from_p(p_vals[:, np.newaxis], lat_vals[np.newaxis, :])
140+
xarray = gsw.z_from_p(p, lat)
141+
chunked = gsw.z_from_p(p,lat_chunk)
142+
assert_allclose(xarray, expected)
143+
assert_allclose(chunked, expected)

0 commit comments

Comments
 (0)