diff --git a/test/grid/integrate/test_zonal.py b/test/grid/integrate/test_zonal.py index 3fd1c2e10..9fc400d59 100644 --- a/test/grid/integrate/test_zonal.py +++ b/test/grid/integrate/test_zonal.py @@ -2,6 +2,7 @@ import dask.array as da import numpy as np import pytest +import warnings import numpy.testing as nt @@ -84,6 +85,44 @@ def test_lat_inputs(self, gridpath, datasetpath): assert len(uxds['psi'].zonal_mean(lat=1)) == 1 assert len(uxds['psi'].zonal_mean(lat=(-90, 90, 1))) == 181 + def test_zonal_mean_missing_latitudes_nan(self, gridpath, datasetpath): + """Zonal mean should return NaN (not zeros) when no faces intersect a latitude.""" + grid_path = gridpath("ugrid", "outCSne30", "outCSne30.ug") + data_path = datasetpath("ugrid", "outCSne30", "outCSne30_vortex.nc") + uxds = ux.open_dataset(grid_path, data_path) + + # Restrict to a narrow band so most requested latitudes have no coverage + narrow = uxds["psi"].subset.bounding_box(lon_bounds=(-20, 20), lat_bounds=(0, 10)) + + with warnings.catch_warnings(): + warnings.filterwarnings("error", category=RuntimeWarning) + res = narrow.zonal_mean(lat=(-90, 90, 10)) + + below_band = res.sel(latitudes=res.latitudes < 0) + assert np.all(np.isnan(below_band)) + assert np.isfinite(res.sel(latitudes=0).item()) + + with warnings.catch_warnings(): + warnings.filterwarnings("error", category=RuntimeWarning) + res_cons = narrow.zonal_mean(lat=(-90, 90, 10), conservative=True) + + below_band_cons = res_cons.sel(latitudes=res_cons.latitudes < 0) + assert np.all(np.isnan(below_band_cons)) + assert np.isfinite(res_cons.sel(latitudes=5).item()) + + def test_zonal_mean_int_data_promotes_dtype(self): + """Integer inputs should be promoted so NaNs can be stored.""" + grid = ux.Grid.from_healpix(zoom=0) + faces = np.where(grid.face_lat > 0)[0] # only northern hemisphere + uxda = ux.UxDataArray( + np.ones(grid.n_face, dtype=np.int32), dims=["n_face"], uxgrid=grid + ).isel(n_face=faces) + + res = uxda.zonal_mean(lat=(-90, 90, 30)) + + assert np.issubdtype(res.dtype, np.floating) + assert np.isnan(res.sel(latitudes=-90)).item() + def test_mismatched_dims(): uxgrid = ux.Grid.from_healpix(zoom=0) uxda = ux.UxDataArray(np.ones((10, uxgrid.n_face, 5)), dims=['a', 'n_face', 'b'], uxgrid=uxgrid) diff --git a/uxarray/core/zonal.py b/uxarray/core/zonal.py index 334b0052a..173bb3f30 100644 --- a/uxarray/core/zonal.py +++ b/uxarray/core/zonal.py @@ -26,12 +26,19 @@ def _compute_non_conservative_zonal_mean(uxda, latitudes, use_robust_weights=Fal shape = list(uxda.shape) shape[face_axis] = len(latitudes) + if np.issubdtype(uxda.dtype, np.integer) or np.issubdtype(uxda.dtype, np.bool_): + # Promote integers/bools so we can represent NaNs + result_dtype = np.float64 + else: + # Preserve existing float/complex dtype + result_dtype = uxda.dtype + if isinstance(uxda.data, da.Array): - # Create a Dask array for storing results - result = da.zeros(shape, dtype=uxda.dtype) + # Pre-fill with NaNs so empty slices stay missing without extra work + result = da.full(shape, np.nan, dtype=result_dtype) else: # Create a NumPy array for storing results - result = np.zeros(shape, dtype=uxda.dtype) + result = np.full(shape, np.nan, dtype=result_dtype) faces_edge_nodes_xyz = _get_cartesian_face_edge_nodes_array( uxgrid.face_node_connectivity.values, @@ -46,6 +53,14 @@ def _compute_non_conservative_zonal_mean(uxda, latitudes, use_robust_weights=Fal for i, lat in enumerate(latitudes): face_indices = uxda.uxgrid.get_faces_at_constant_latitude(lat) + + idx = [slice(None)] * result.ndim + idx[face_axis] = i + + if face_indices.size == 0: + # No intersecting faces for this latitude + continue + z = np.sin(np.deg2rad(lat)) fe = faces_edge_nodes_xyz[face_indices] @@ -59,14 +74,16 @@ def _compute_non_conservative_zonal_mean(uxda, latitudes, use_robust_weights=Fal total = w.sum() + if total == 0.0 or not np.isfinite(total): + # If weights collapse to zero, keep the pre-filled NaNs + continue + data_slice = uxda.isel(n_face=face_indices, ignore_grid=True).data w_shape = [1] * data_slice.ndim w_shape[face_axis] = w.size w_reshaped = w.reshape(w_shape) weighted = (data_slice * w_reshaped).sum(axis=face_axis) / total - idx = [slice(None)] * result.ndim - idx[face_axis] = i result[tuple(idx)] = weighted return result