@@ -92,3 +92,52 @@ def test_dask_chunking():
9292
9393 sa_numpy = gsw .SA_from_SP (np .full (shape , 35.0 ), p .values , lon , lat )
9494 assert_allclose (sa_dask , sa_numpy )
95+
96+
97+ # Additional tests from Graeme MacGilchrist
98+ # https://nbviewer.jupyter.org/github/gmacgilchrist/wmt_bgc/blob/master/notebooks/test_gsw-xarray.ipynb
99+
100+ # Define dimensions and coordinates
101+ dims = ['y' ,'z' ,'t' ]
102+ # 2x2x2
103+ y = np .arange (0 ,2 )
104+ z = np .arange (0 ,2 )
105+ t = np .arange (0 ,2 )
106+ # Define numpy arrays of salinity, temperature and pressure
107+ SA_vals = np .array ([[[34.7 ,34.8 ],[34.9 ,35 ]],[[35.1 ,35.2 ],[35.3 ,35.4 ]]])
108+ CT_vals = np .array ([[[7 ,8 ],[9 ,10 ]],[[11 ,12 ],[13 ,14 ]]])
109+ p_vals = np .array ([10 ,20 ])
110+ lat_vals = np .array ([0 ,10 ])
111+ # Plug in to xarray objects
112+ SA = xr .DataArray (SA_vals ,dims = dims ,coords = {'y' :y ,'z' :z ,'t' :t })
113+ CT = xr .DataArray (CT_vals ,dims = dims ,coords = {'y' :y ,'z' :z ,'t' :t })
114+ p = xr .DataArray (p_vals ,dims = ['z' ],coords = {'z' :z })
115+ lat = xr .DataArray (lat_vals ,dims = ['y' ],coords = {'y' :y })
116+
117+
118+ def test_xarray_with_coords ():
119+ pytest .importorskip ('dask' )
120+ SA_chunk = SA .chunk (chunks = {'y' :1 ,'t' :1 })
121+ CT_chunk = CT .chunk (chunks = {'y' :1 ,'t' :1 })
122+ lat_chunk = lat .chunk (chunks = {'y' :1 })
123+
124+ # Dimensions and cordinates match:
125+ expected = gsw .sigma0 (SA_vals , CT_vals )
126+ xarray = gsw .sigma0 (SA , CT )
127+ chunked = gsw .sigma0 (SA_chunk , CT_chunk )
128+ assert_allclose (xarray , expected )
129+ assert_allclose (chunked , expected )
130+
131+ # Broadcasting along dimension required (dimensions known)
132+ expected = gsw .alpha (SA_vals , CT_vals , p_vals [np .newaxis , :, np .newaxis ])
133+ xarray = gsw .alpha (SA , CT , p )
134+ chunked = gsw .alpha (SA_chunk , CT_chunk , p )
135+ assert_allclose (xarray , expected )
136+ assert_allclose (chunked , expected )
137+
138+ # Broadcasting along dimension required (dimensions unknown/exclusive)
139+ expected = gsw .z_from_p (p_vals [:, np .newaxis ], lat_vals [np .newaxis , :])
140+ xarray = gsw .z_from_p (p , lat )
141+ chunked = gsw .z_from_p (p ,lat_chunk )
142+ assert_allclose (xarray , expected )
143+ assert_allclose (chunked , expected )
0 commit comments