|
2 | 2 | import dask.array as da |
3 | 3 | import numpy as np |
4 | 4 | import pytest |
| 5 | +import warnings |
5 | 6 |
|
6 | 7 | import numpy.testing as nt |
7 | 8 |
|
@@ -84,6 +85,44 @@ def test_lat_inputs(self, gridpath, datasetpath): |
84 | 85 | assert len(uxds['psi'].zonal_mean(lat=1)) == 1 |
85 | 86 | assert len(uxds['psi'].zonal_mean(lat=(-90, 90, 1))) == 181 |
86 | 87 |
|
| 88 | + def test_zonal_mean_missing_latitudes_nan(self, gridpath, datasetpath): |
| 89 | + """Zonal mean should return NaN (not zeros) when no faces intersect a latitude.""" |
| 90 | + grid_path = gridpath("ugrid", "outCSne30", "outCSne30.ug") |
| 91 | + data_path = datasetpath("ugrid", "outCSne30", "outCSne30_vortex.nc") |
| 92 | + uxds = ux.open_dataset(grid_path, data_path) |
| 93 | + |
| 94 | + # Restrict to a narrow band so most requested latitudes have no coverage |
| 95 | + narrow = uxds["psi"].subset.bounding_box(lon_bounds=(-20, 20), lat_bounds=(0, 10)) |
| 96 | + |
| 97 | + with warnings.catch_warnings(): |
| 98 | + warnings.filterwarnings("error", category=RuntimeWarning) |
| 99 | + res = narrow.zonal_mean(lat=(-90, 90, 10)) |
| 100 | + |
| 101 | + below_band = res.sel(latitudes=res.latitudes < 0) |
| 102 | + assert np.all(np.isnan(below_band)) |
| 103 | + assert np.isfinite(res.sel(latitudes=0).item()) |
| 104 | + |
| 105 | + with warnings.catch_warnings(): |
| 106 | + warnings.filterwarnings("error", category=RuntimeWarning) |
| 107 | + res_cons = narrow.zonal_mean(lat=(-90, 90, 10), conservative=True) |
| 108 | + |
| 109 | + below_band_cons = res_cons.sel(latitudes=res_cons.latitudes < 0) |
| 110 | + assert np.all(np.isnan(below_band_cons)) |
| 111 | + assert np.isfinite(res_cons.sel(latitudes=5).item()) |
| 112 | + |
| 113 | + def test_zonal_mean_int_data_promotes_dtype(self): |
| 114 | + """Integer inputs should be promoted so NaNs can be stored.""" |
| 115 | + grid = ux.Grid.from_healpix(zoom=0) |
| 116 | + faces = np.where(grid.face_lat > 0)[0] # only northern hemisphere |
| 117 | + uxda = ux.UxDataArray( |
| 118 | + np.ones(grid.n_face, dtype=np.int32), dims=["n_face"], uxgrid=grid |
| 119 | + ).isel(n_face=faces) |
| 120 | + |
| 121 | + res = uxda.zonal_mean(lat=(-90, 90, 30)) |
| 122 | + |
| 123 | + assert np.issubdtype(res.dtype, np.floating) |
| 124 | + assert np.isnan(res.sel(latitudes=-90)).item() |
| 125 | + |
87 | 126 | def test_mismatched_dims(): |
88 | 127 | uxgrid = ux.Grid.from_healpix(zoom=0) |
89 | 128 | uxda = ux.UxDataArray(np.ones((10, uxgrid.n_face, 5)), dims=['a', 'n_face', 'b'], uxgrid=uxgrid) |
|
0 commit comments