Skip to content

Commit 5cdf333

Browse files
authored
o Fix non conservative zonal mean bug - got zero instead of nan due to initialization, added two test cases (#1423)
1 parent 5578e12 commit 5cdf333

File tree

2 files changed

+61
-5
lines changed

2 files changed

+61
-5
lines changed

test/grid/integrate/test_zonal.py

Lines changed: 39 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
import dask.array as da
33
import numpy as np
44
import pytest
5+
import warnings
56

67
import numpy.testing as nt
78

@@ -84,6 +85,44 @@ def test_lat_inputs(self, gridpath, datasetpath):
8485
assert len(uxds['psi'].zonal_mean(lat=1)) == 1
8586
assert len(uxds['psi'].zonal_mean(lat=(-90, 90, 1))) == 181
8687

88+
def test_zonal_mean_missing_latitudes_nan(self, gridpath, datasetpath):
89+
"""Zonal mean should return NaN (not zeros) when no faces intersect a latitude."""
90+
grid_path = gridpath("ugrid", "outCSne30", "outCSne30.ug")
91+
data_path = datasetpath("ugrid", "outCSne30", "outCSne30_vortex.nc")
92+
uxds = ux.open_dataset(grid_path, data_path)
93+
94+
# Restrict to a narrow band so most requested latitudes have no coverage
95+
narrow = uxds["psi"].subset.bounding_box(lon_bounds=(-20, 20), lat_bounds=(0, 10))
96+
97+
with warnings.catch_warnings():
98+
warnings.filterwarnings("error", category=RuntimeWarning)
99+
res = narrow.zonal_mean(lat=(-90, 90, 10))
100+
101+
below_band = res.sel(latitudes=res.latitudes < 0)
102+
assert np.all(np.isnan(below_band))
103+
assert np.isfinite(res.sel(latitudes=0).item())
104+
105+
with warnings.catch_warnings():
106+
warnings.filterwarnings("error", category=RuntimeWarning)
107+
res_cons = narrow.zonal_mean(lat=(-90, 90, 10), conservative=True)
108+
109+
below_band_cons = res_cons.sel(latitudes=res_cons.latitudes < 0)
110+
assert np.all(np.isnan(below_band_cons))
111+
assert np.isfinite(res_cons.sel(latitudes=5).item())
112+
113+
def test_zonal_mean_int_data_promotes_dtype(self):
114+
"""Integer inputs should be promoted so NaNs can be stored."""
115+
grid = ux.Grid.from_healpix(zoom=0)
116+
faces = np.where(grid.face_lat > 0)[0] # only northern hemisphere
117+
uxda = ux.UxDataArray(
118+
np.ones(grid.n_face, dtype=np.int32), dims=["n_face"], uxgrid=grid
119+
).isel(n_face=faces)
120+
121+
res = uxda.zonal_mean(lat=(-90, 90, 30))
122+
123+
assert np.issubdtype(res.dtype, np.floating)
124+
assert np.isnan(res.sel(latitudes=-90)).item()
125+
87126
def test_mismatched_dims():
88127
uxgrid = ux.Grid.from_healpix(zoom=0)
89128
uxda = ux.UxDataArray(np.ones((10, uxgrid.n_face, 5)), dims=['a', 'n_face', 'b'], uxgrid=uxgrid)

uxarray/core/zonal.py

Lines changed: 22 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -26,12 +26,19 @@ def _compute_non_conservative_zonal_mean(uxda, latitudes, use_robust_weights=Fal
2626
shape = list(uxda.shape)
2727
shape[face_axis] = len(latitudes)
2828

29+
if np.issubdtype(uxda.dtype, np.integer) or np.issubdtype(uxda.dtype, np.bool_):
30+
# Promote integers/bools so we can represent NaNs
31+
result_dtype = np.float64
32+
else:
33+
# Preserve existing float/complex dtype
34+
result_dtype = uxda.dtype
35+
2936
if isinstance(uxda.data, da.Array):
30-
# Create a Dask array for storing results
31-
result = da.zeros(shape, dtype=uxda.dtype)
37+
# Pre-fill with NaNs so empty slices stay missing without extra work
38+
result = da.full(shape, np.nan, dtype=result_dtype)
3239
else:
3340
# Create a NumPy array for storing results
34-
result = np.zeros(shape, dtype=uxda.dtype)
41+
result = np.full(shape, np.nan, dtype=result_dtype)
3542

3643
faces_edge_nodes_xyz = _get_cartesian_face_edge_nodes_array(
3744
uxgrid.face_node_connectivity.values,
@@ -46,6 +53,14 @@ def _compute_non_conservative_zonal_mean(uxda, latitudes, use_robust_weights=Fal
4653

4754
for i, lat in enumerate(latitudes):
4855
face_indices = uxda.uxgrid.get_faces_at_constant_latitude(lat)
56+
57+
idx = [slice(None)] * result.ndim
58+
idx[face_axis] = i
59+
60+
if face_indices.size == 0:
61+
# No intersecting faces for this latitude
62+
continue
63+
4964
z = np.sin(np.deg2rad(lat))
5065

5166
fe = faces_edge_nodes_xyz[face_indices]
@@ -59,14 +74,16 @@ def _compute_non_conservative_zonal_mean(uxda, latitudes, use_robust_weights=Fal
5974

6075
total = w.sum()
6176

77+
if total == 0.0 or not np.isfinite(total):
78+
# If weights collapse to zero, keep the pre-filled NaNs
79+
continue
80+
6281
data_slice = uxda.isel(n_face=face_indices, ignore_grid=True).data
6382
w_shape = [1] * data_slice.ndim
6483
w_shape[face_axis] = w.size
6584
w_reshaped = w.reshape(w_shape)
6685
weighted = (data_slice * w_reshaped).sum(axis=face_axis) / total
6786

68-
idx = [slice(None)] * result.ndim
69-
idx[face_axis] = i
7087
result[tuple(idx)] = weighted
7188

7289
return result

0 commit comments

Comments
 (0)