Skip to content

Commit 021647f

Browse files
authored
Merge branch 'main' into rajeeja/cumulative_integrate
2 parents 4791058 + d7e05cf commit 021647f

File tree

3 files changed

+68
-15
lines changed

3 files changed

+68
-15
lines changed

docs/user-guide/subset.ipynb

Lines changed: 7 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,8 @@
1515
"1. Nearest Neighbor\n",
1616
"2. Bounding Box\n",
1717
"3. Bounding Circle\n",
18-
"4. Constant Latitude/Longtude \n"
18+
"4. Constant Latitude/Longitude\n",
19+
"5. Interval Latitude/Longitude\n"
1920
]
2021
},
2122
{
@@ -501,7 +502,7 @@
501502
"clat_subset.plot(\n",
502503
" rasterize=True,\n",
503504
" clim=clim,\n",
504-
" title=\"Constant Longitude Subset\",\n",
505+
" title=\"Constant Latitude Subset\",\n",
505506
" global_extent=True,\n",
506507
" **plot_opts,\n",
507508
") * features"
@@ -512,9 +513,7 @@
512513
{
513514
"cell_type": "markdown",
514515
"metadata": {},
515-
"source": [
516-
"### Constant Longitude Interval"
517-
]
516+
"source": "### Longitude Interval"
518517
},
519518
{
520519
"cell_type": "code",
@@ -527,7 +526,7 @@
527526
"clon_int_subset.plot(\n",
528527
" rasterize=True,\n",
529528
" clim=clim,\n",
530-
" title=\"Constant Latitude Interval Subset\",\n",
529+
" title=\"Longitude Interval Subset\",\n",
531530
" global_extent=True,\n",
532531
" **plot_opts,\n",
533532
") * features"
@@ -538,9 +537,7 @@
538537
{
539538
"cell_type": "markdown",
540539
"metadata": {},
541-
"source": [
542-
"### Constant Latitude Interval"
543-
]
540+
"source": "### Latitude Interval"
544541
},
545542
{
546543
"cell_type": "code",
@@ -553,7 +550,7 @@
553550
"clat_int_subset.plot(\n",
554551
" rasterize=True,\n",
555552
" clim=clim,\n",
556-
" title=\"Constant Latitude Interval Subset\",\n",
553+
" title=\"Latitude Interval Subset\",\n",
557554
" global_extent=True,\n",
558555
" **plot_opts,\n",
559556
") * features"

test/grid/integrate/test_zonal.py

Lines changed: 39 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
import dask.array as da
33
import numpy as np
44
import pytest
5+
import warnings
56

67
import numpy.testing as nt
78

@@ -84,6 +85,44 @@ def test_lat_inputs(self, gridpath, datasetpath):
8485
assert len(uxds['psi'].zonal_mean(lat=1)) == 1
8586
assert len(uxds['psi'].zonal_mean(lat=(-90, 90, 1))) == 181
8687

88+
def test_zonal_mean_missing_latitudes_nan(self, gridpath, datasetpath):
89+
"""Zonal mean should return NaN (not zeros) when no faces intersect a latitude."""
90+
grid_path = gridpath("ugrid", "outCSne30", "outCSne30.ug")
91+
data_path = datasetpath("ugrid", "outCSne30", "outCSne30_vortex.nc")
92+
uxds = ux.open_dataset(grid_path, data_path)
93+
94+
# Restrict to a narrow band so most requested latitudes have no coverage
95+
narrow = uxds["psi"].subset.bounding_box(lon_bounds=(-20, 20), lat_bounds=(0, 10))
96+
97+
with warnings.catch_warnings():
98+
warnings.filterwarnings("error", category=RuntimeWarning)
99+
res = narrow.zonal_mean(lat=(-90, 90, 10))
100+
101+
below_band = res.sel(latitudes=res.latitudes < 0)
102+
assert np.all(np.isnan(below_band))
103+
assert np.isfinite(res.sel(latitudes=0).item())
104+
105+
with warnings.catch_warnings():
106+
warnings.filterwarnings("error", category=RuntimeWarning)
107+
res_cons = narrow.zonal_mean(lat=(-90, 90, 10), conservative=True)
108+
109+
below_band_cons = res_cons.sel(latitudes=res_cons.latitudes < 0)
110+
assert np.all(np.isnan(below_band_cons))
111+
assert np.isfinite(res_cons.sel(latitudes=5).item())
112+
113+
def test_zonal_mean_int_data_promotes_dtype(self):
114+
"""Integer inputs should be promoted so NaNs can be stored."""
115+
grid = ux.Grid.from_healpix(zoom=0)
116+
faces = np.where(grid.face_lat > 0)[0] # only northern hemisphere
117+
uxda = ux.UxDataArray(
118+
np.ones(grid.n_face, dtype=np.int32), dims=["n_face"], uxgrid=grid
119+
).isel(n_face=faces)
120+
121+
res = uxda.zonal_mean(lat=(-90, 90, 30))
122+
123+
assert np.issubdtype(res.dtype, np.floating)
124+
assert np.isnan(res.sel(latitudes=-90)).item()
125+
87126
def test_mismatched_dims():
88127
uxgrid = ux.Grid.from_healpix(zoom=0)
89128
uxda = ux.UxDataArray(np.ones((10, uxgrid.n_face, 5)), dims=['a', 'n_face', 'b'], uxgrid=uxgrid)

uxarray/core/zonal.py

Lines changed: 22 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -26,12 +26,19 @@ def _compute_non_conservative_zonal_mean(uxda, latitudes, use_robust_weights=Fal
2626
shape = list(uxda.shape)
2727
shape[face_axis] = len(latitudes)
2828

29+
if np.issubdtype(uxda.dtype, np.integer) or np.issubdtype(uxda.dtype, np.bool_):
30+
# Promote integers/bools so we can represent NaNs
31+
result_dtype = np.float64
32+
else:
33+
# Preserve existing float/complex dtype
34+
result_dtype = uxda.dtype
35+
2936
if isinstance(uxda.data, da.Array):
30-
# Create a Dask array for storing results
31-
result = da.zeros(shape, dtype=uxda.dtype)
37+
# Pre-fill with NaNs so empty slices stay missing without extra work
38+
result = da.full(shape, np.nan, dtype=result_dtype)
3239
else:
3340
# Create a NumPy array for storing results
34-
result = np.zeros(shape, dtype=uxda.dtype)
41+
result = np.full(shape, np.nan, dtype=result_dtype)
3542

3643
faces_edge_nodes_xyz = _get_cartesian_face_edge_nodes_array(
3744
uxgrid.face_node_connectivity.values,
@@ -46,6 +53,14 @@ def _compute_non_conservative_zonal_mean(uxda, latitudes, use_robust_weights=Fal
4653

4754
for i, lat in enumerate(latitudes):
4855
face_indices = uxda.uxgrid.get_faces_at_constant_latitude(lat)
56+
57+
idx = [slice(None)] * result.ndim
58+
idx[face_axis] = i
59+
60+
if face_indices.size == 0:
61+
# No intersecting faces for this latitude
62+
continue
63+
4964
z = np.sin(np.deg2rad(lat))
5065

5166
fe = faces_edge_nodes_xyz[face_indices]
@@ -59,14 +74,16 @@ def _compute_non_conservative_zonal_mean(uxda, latitudes, use_robust_weights=Fal
5974

6075
total = w.sum()
6176

77+
if total == 0.0 or not np.isfinite(total):
78+
# If weights collapse to zero, keep the pre-filled NaNs
79+
continue
80+
6281
data_slice = uxda.isel(n_face=face_indices, ignore_grid=True).data
6382
w_shape = [1] * data_slice.ndim
6483
w_shape[face_axis] = w.size
6584
w_reshaped = w.reshape(w_shape)
6685
weighted = (data_slice * w_reshaped).sum(axis=face_axis) / total
6786

68-
idx = [slice(None)] * result.ndim
69-
idx[face_axis] = i
7087
result[tuple(idx)] = weighted
7188

7289
return result

0 commit comments

Comments
 (0)