Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
39 changes: 39 additions & 0 deletions test/grid/integrate/test_zonal.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
import dask.array as da
import numpy as np
import pytest
import warnings

import numpy.testing as nt

Expand Down Expand Up @@ -84,6 +85,44 @@ def test_lat_inputs(self, gridpath, datasetpath):
assert len(uxds['psi'].zonal_mean(lat=1)) == 1
assert len(uxds['psi'].zonal_mean(lat=(-90, 90, 1))) == 181

def test_zonal_mean_missing_latitudes_nan(self, gridpath, datasetpath):
"""Zonal mean should return NaN (not zeros) when no faces intersect a latitude."""
grid_path = gridpath("ugrid", "outCSne30", "outCSne30.ug")
data_path = datasetpath("ugrid", "outCSne30", "outCSne30_vortex.nc")
uxds = ux.open_dataset(grid_path, data_path)

# Restrict to a narrow band so most requested latitudes have no coverage
narrow = uxds["psi"].subset.bounding_box(lon_bounds=(-20, 20), lat_bounds=(0, 10))

with warnings.catch_warnings():
warnings.filterwarnings("error", category=RuntimeWarning)
res = narrow.zonal_mean(lat=(-90, 90, 10))

below_band = res.sel(latitudes=res.latitudes < 0)
assert np.all(np.isnan(below_band))
assert np.isfinite(res.sel(latitudes=0).item())

with warnings.catch_warnings():
warnings.filterwarnings("error", category=RuntimeWarning)
res_cons = narrow.zonal_mean(lat=(-90, 90, 10), conservative=True)

below_band_cons = res_cons.sel(latitudes=res_cons.latitudes < 0)
assert np.all(np.isnan(below_band_cons))
assert np.isfinite(res_cons.sel(latitudes=5).item())

def test_zonal_mean_int_data_promotes_dtype(self):
"""Integer inputs should be promoted so NaNs can be stored."""
grid = ux.Grid.from_healpix(zoom=0)
faces = np.where(grid.face_lat > 0)[0] # only northern hemisphere
uxda = ux.UxDataArray(
np.ones(grid.n_face, dtype=np.int32), dims=["n_face"], uxgrid=grid
).isel(n_face=faces)

res = uxda.zonal_mean(lat=(-90, 90, 30))

assert np.issubdtype(res.dtype, np.floating)
assert np.isnan(res.sel(latitudes=-90)).item()

def test_mismatched_dims():
uxgrid = ux.Grid.from_healpix(zoom=0)
uxda = ux.UxDataArray(np.ones((10, uxgrid.n_face, 5)), dims=['a', 'n_face', 'b'], uxgrid=uxgrid)
Expand Down
27 changes: 22 additions & 5 deletions uxarray/core/zonal.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,12 +26,19 @@ def _compute_non_conservative_zonal_mean(uxda, latitudes, use_robust_weights=Fal
shape = list(uxda.shape)
shape[face_axis] = len(latitudes)

if np.issubdtype(uxda.dtype, np.integer) or np.issubdtype(uxda.dtype, np.bool_):
# Promote integers/bools so we can represent NaNs
result_dtype = np.float64
else:
# Preserve existing float/complex dtype
result_dtype = uxda.dtype

if isinstance(uxda.data, da.Array):
# Create a Dask array for storing results
result = da.zeros(shape, dtype=uxda.dtype)
# Pre-fill with NaNs so empty slices stay missing without extra work
result = da.full(shape, np.nan, dtype=result_dtype)
else:
# Create a NumPy array for storing results
result = np.zeros(shape, dtype=uxda.dtype)
result = np.full(shape, np.nan, dtype=result_dtype)

faces_edge_nodes_xyz = _get_cartesian_face_edge_nodes_array(
uxgrid.face_node_connectivity.values,
Expand All @@ -46,6 +53,14 @@ def _compute_non_conservative_zonal_mean(uxda, latitudes, use_robust_weights=Fal

for i, lat in enumerate(latitudes):
face_indices = uxda.uxgrid.get_faces_at_constant_latitude(lat)

idx = [slice(None)] * result.ndim
idx[face_axis] = i

if face_indices.size == 0:
# No intersecting faces for this latitude
continue

z = np.sin(np.deg2rad(lat))

fe = faces_edge_nodes_xyz[face_indices]
Expand All @@ -59,14 +74,16 @@ def _compute_non_conservative_zonal_mean(uxda, latitudes, use_robust_weights=Fal

total = w.sum()

if total == 0.0 or not np.isfinite(total):
# If weights collapse to zero, keep the pre-filled NaNs
continue

data_slice = uxda.isel(n_face=face_indices, ignore_grid=True).data
w_shape = [1] * data_slice.ndim
w_shape[face_axis] = w.size
w_reshaped = w.reshape(w_shape)
weighted = (data_slice * w_reshaped).sum(axis=face_axis) / total

idx = [slice(None)] * result.ndim
idx[face_axis] = i
result[tuple(idx)] = weighted

return result
Expand Down