Skip to content

Commit f7b3ad4

Browse files
authored
Merge pull request #1345 from UXARRAY/rajeeja/zonal-mean-analytic-band
conservative zonal means
2 parents 460eb23 + 703356c commit f7b3ad4

File tree

6 files changed

+3290
-108
lines changed

6 files changed

+3290
-108
lines changed

docs/user-guide/zonal-average.ipynb

Lines changed: 2783 additions & 54 deletions
Large diffs are not rendered by default.

test/grid/integrate/test_zonal.py

Lines changed: 111 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -92,3 +92,114 @@ def test_mismatched_dims():
9292

9393
assert za.shape == (10, 19, 5)
9494
assert za.dims[1] == "latitudes"
95+
96+
97+
class TestConservativeZonalMean:
98+
"""Test conservative zonal mean functionality."""
99+
100+
def test_conservative_zonal_mean_basic(self, gridpath, datasetpath):
101+
"""Test basic conservative zonal mean with bands."""
102+
grid_path = gridpath("ugrid", "outCSne30", "outCSne30.ug")
103+
data_path = datasetpath("ugrid", "outCSne30", "outCSne30_vortex.nc")
104+
uxds = ux.open_dataset(grid_path, data_path)
105+
106+
# Test with explicit bands
107+
bands = np.array([-90, -30, 0, 30, 90])
108+
result = uxds["psi"].zonal_mean(lat=bands, conservative=True)
109+
110+
# Should have one less value than bands (4 bands from 5 edges)
111+
assert result.shape == (len(bands) - 1,)
112+
assert np.all(np.isfinite(result.values))
113+
114+
def test_conservative_float_step_size(self, gridpath, datasetpath):
115+
"""Test conservative zonal mean with float step sizes."""
116+
grid_path = gridpath("ugrid", "outCSne30", "outCSne30.ug")
117+
data_path = datasetpath("ugrid", "outCSne30", "outCSne30_vortex.nc")
118+
uxds = ux.open_dataset(grid_path, data_path)
119+
120+
# Test with float step size (5.5 degrees)
121+
result = uxds["psi"].zonal_mean(lat=(-90, 90, 0.05), conservative=True)
122+
123+
# Should get valid results
124+
assert len(result) > 0
125+
assert np.all(np.isfinite(result.values))
126+
127+
# Test with reasonable float step size (no warning)
128+
result = uxds["psi"].zonal_mean(lat=(-90, 90, 5.5), conservative=True)
129+
expected_n_bands = int(np.ceil(180 / 5.5))
130+
assert result.shape[0] == expected_n_bands
131+
assert np.all(np.isfinite(result.values))
132+
133+
def test_conservative_near_pole(self, gridpath, datasetpath):
134+
"""Test conservative zonal mean with bands near the poles."""
135+
grid_path = gridpath("ugrid", "outCSne30", "outCSne30.ug")
136+
data_path = datasetpath("ugrid", "outCSne30", "outCSne30_vortex.nc")
137+
uxds = ux.open_dataset(grid_path, data_path)
138+
139+
# Test near north pole with float step
140+
bands_north = np.array([85.0, 87.5, 90.0])
141+
result_north = uxds["psi"].zonal_mean(lat=bands_north, conservative=True)
142+
assert result_north.shape == (2,)
143+
assert np.all(np.isfinite(result_north.values))
144+
145+
# Test near south pole with float step
146+
bands_south = np.array([-90.0, -87.5, -85.0])
147+
result_south = uxds["psi"].zonal_mean(lat=bands_south, conservative=True)
148+
assert result_south.shape == (2,)
149+
assert np.all(np.isfinite(result_south.values))
150+
151+
# Test spanning pole with non-integer step
152+
bands_span = np.array([88.5, 89.25, 90.0])
153+
result_span = uxds["psi"].zonal_mean(lat=bands_span, conservative=True)
154+
assert result_span.shape == (2,)
155+
assert np.all(np.isfinite(result_span.values))
156+
157+
def test_conservative_step_size_validation(self, gridpath, datasetpath):
158+
"""Test that step size validation works correctly."""
159+
grid_path = gridpath("ugrid", "outCSne30", "outCSne30.ug")
160+
data_path = datasetpath("ugrid", "outCSne30", "outCSne30_vortex.nc")
161+
uxds = ux.open_dataset(grid_path, data_path)
162+
163+
# Test negative step size
164+
with pytest.raises(ValueError, match="Step size must be positive"):
165+
uxds["psi"].zonal_mean(lat=(-90, 90, -10), conservative=True)
166+
167+
# Test zero step size
168+
with pytest.raises(ValueError, match="Step size must be positive"):
169+
uxds["psi"].zonal_mean(lat=(-90, 90, 0), conservative=True)
170+
171+
def test_conservative_full_sphere_conservation(self, gridpath, datasetpath):
172+
"""Test that single band covering entire sphere conserves global mean."""
173+
grid_path = gridpath("ugrid", "outCSne30", "outCSne30.ug")
174+
data_path = datasetpath("ugrid", "outCSne30", "outCSne30_vortex.nc")
175+
uxds = ux.open_dataset(grid_path, data_path)
176+
177+
# Single band covering entire sphere
178+
bands = np.array([-90, 90])
179+
result = uxds["psi"].zonal_mean(lat=bands, conservative=True)
180+
181+
# Compare with global mean
182+
global_mean = uxds["psi"].mean()
183+
184+
assert result.shape == (1,)
185+
assert result.values[0] == pytest.approx(global_mean.values, rel=0.01)
186+
187+
def test_conservative_vs_nonconservative_comparison(self, gridpath, datasetpath):
188+
"""Compare conservative and non-conservative methods."""
189+
grid_path = gridpath("ugrid", "outCSne30", "outCSne30.ug")
190+
data_path = datasetpath("ugrid", "outCSne30", "outCSne30_vortex.nc")
191+
uxds = ux.open_dataset(grid_path, data_path)
192+
193+
# Non-conservative at band centers
194+
lat_centers = np.array([-60, 0, 60])
195+
non_conservative = uxds["psi"].zonal_mean(lat=lat_centers)
196+
197+
# Conservative with bands
198+
bands = np.array([-90, -30, 30, 90])
199+
conservative = uxds["psi"].zonal_mean(lat=bands, conservative=True)
200+
201+
# Results should be similar but not identical
202+
assert non_conservative.shape == conservative.shape
203+
# Check they are in the same ballpark
204+
assert np.all(np.abs(conservative.values - non_conservative.values) <
205+
np.abs(non_conservative.values) * 0.5)

test/test_plot.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -76,17 +76,19 @@ def test_dataarray_methods(gridpath, datasetpath):
7676
# plot.scatter() is an xarray method
7777
assert hasattr(uxds.plot, 'scatter')
7878

79+
import hvplot.xarray # registers .hvplot accessor
80+
7981
def test_line(gridpath):
8082
mesh_path = gridpath("mpas", "QU", "oQU480.231010.nc")
8183
uxds = ux.open_dataset(mesh_path, mesh_path)
82-
_plot_line = uxds['bottomDepth'].zonal_average().plot.line()
84+
_plot_line = uxds['bottomDepth'].zonal_average().hvplot.line()
8385
assert isinstance(_plot_line, hv.Curve)
8486

8587
def test_scatter(gridpath):
8688
mesh_path = gridpath("mpas", "QU", "oQU480.231010.nc")
8789
uxds = ux.open_dataset(mesh_path, mesh_path)
88-
_plot_line = uxds['bottomDepth'].zonal_average().plot.scatter()
89-
assert isinstance(_plot_line, hv.Scatter)
90+
_plot_scatter = uxds['bottomDepth'].zonal_average().hvplot.scatter()
91+
assert isinstance(_plot_scatter, hv.Scatter)
9092

9193

9294

uxarray/core/dataarray.py

Lines changed: 114 additions & 44 deletions
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,10 @@
2020
_compute_gradient,
2121
)
2222
from uxarray.core.utils import _map_dims_to_ugrid
23-
from uxarray.core.zonal import _compute_non_conservative_zonal_mean
23+
from uxarray.core.zonal import (
24+
_compute_conservative_zonal_mean_bands,
25+
_compute_non_conservative_zonal_mean,
26+
)
2427
from uxarray.cross_sections import UxDataArrayCrossSectionAccessor
2528
from uxarray.formatting_html import array_repr
2629
from uxarray.grid import Grid
@@ -510,16 +513,20 @@ def integrate(
510513

511514
return uxda
512515

513-
def zonal_mean(self, lat=(-90, 90, 10), **kwargs):
514-
"""Compute averages along lines of constant latitude.
516+
def zonal_mean(self, lat=(-90, 90, 10), conservative: bool = False, **kwargs):
517+
"""Compute non-conservative or conservative averages along lines of constant latitude or latitude bands.
515518
516519
Parameters
517520
----------
518521
lat : tuple, float, or array-like, default=(-90, 90, 10)
519-
Latitude values in degrees. Can be specified as:
520-
- tuple (start, end, step): Computes means at intervals of `step` in range [start, end]
521-
- float: Computes mean for a single latitude
522-
- array-like: Computes means for each specified latitude
522+
Latitude specification:
523+
- tuple (start, end, step): For non-conservative, computes means at intervals of `step`.
524+
For conservative, creates band edges via np.arange(start, end+step, step).
525+
- float: Single latitude for non-conservative averaging
526+
- array-like: For non-conservative, latitudes to sample. For conservative, band edges.
527+
conservative : bool, default=False
528+
If True, performs conservative (area-weighted) zonal averaging over latitude bands.
529+
If False, performs traditional (non-conservative) averaging at latitude lines.
523530
524531
Returns
525532
-------
@@ -529,62 +536,125 @@ def zonal_mean(self, lat=(-90, 90, 10), **kwargs):
529536
530537
Examples
531538
--------
532-
# All latitudes from -90° to 90° at 10° intervals
539+
# Non-conservative averaging from -90° to 90° at 10° intervals by default
533540
>>> uxds["var"].zonal_mean()
534541
535-
# Single latitude at 30°
542+
# Single latitude (non-conservative) over 30° latitude
536543
>>> uxds["var"].zonal_mean(lat=30.0)
537544
538-
# Range from -60° to 60° at 10° intervals
539-
>>> uxds["var"].zonal_mean(lat=(-60, 60, 10))
545+
# Conservative averaging over latitude bands
546+
>>> uxds["var"].zonal_mean(lat=(-60, 60, 10), conservative=True)
547+
548+
# Conservative with explicit band edges
549+
>>> uxds["var"].zonal_mean(lat=[-90, -30, 0, 30, 90], conservative=True)
540550
541551
Notes
542552
-----
543-
Only supported for face-centered data variables. Candidate faces are determined
544-
using spherical bounding boxes - faces whose bounds contain the target latitude
545-
are included in calculations.
553+
Only supported for face-centered data variables.
554+
555+
Conservative averaging preserves integral quantities and is recommended for
556+
physical analysis. Non-conservative averaging samples at latitude lines.
546557
"""
547558
if not self._face_centered():
548559
raise ValueError(
549560
"Zonal mean computations are currently only supported for face-centered data variables."
550561
)
551562

552-
if isinstance(lat, tuple):
553-
# zonal mean over a range of latitudes
554-
latitudes = np.arange(lat[0], lat[1] + lat[2], lat[2])
555-
latitudes = np.clip(latitudes, -90, 90)
556-
elif isinstance(lat, (float, int)):
557-
# zonal mean over a single latitude
558-
latitudes = [lat]
559-
elif isinstance(lat, (list, np.ndarray)):
560-
# zonal mean over an array of arbitrary latitudes
561-
latitudes = np.asarray(lat)
562-
else:
563-
raise ValueError(
564-
"Invalid value for 'lat' provided. Must either be a single scalar value, tuple (min_lat, max_lat, step), or array-like."
563+
face_axis = self.dims.index("n_face")
564+
565+
if not conservative:
566+
# Non-conservative (traditional) zonal averaging
567+
if isinstance(lat, tuple):
568+
start, end, step = lat
569+
if step <= 0:
570+
raise ValueError("Step size must be positive.")
571+
if step < 0.1:
572+
warnings.warn(
573+
f"Very small step size ({step}°) may lead to performance issues...",
574+
UserWarning,
575+
stacklevel=2,
576+
)
577+
num_points = int(round((end - start) / step)) + 1
578+
latitudes = np.linspace(start, end, num_points)
579+
latitudes = np.clip(latitudes, -90, 90)
580+
elif isinstance(lat, (float, int)):
581+
latitudes = [lat]
582+
elif isinstance(lat, (list, np.ndarray)):
583+
latitudes = np.asarray(lat)
584+
else:
585+
raise ValueError(
586+
"Invalid value for 'lat' provided. Must be a scalar, tuple (min_lat, max_lat, step), or array-like."
587+
)
588+
589+
res = _compute_non_conservative_zonal_mean(
590+
uxda=self, latitudes=latitudes, **kwargs
565591
)
566592

567-
res = _compute_non_conservative_zonal_mean(
568-
uxda=self, latitudes=latitudes, **kwargs
569-
)
593+
dims = list(self.dims)
594+
dims[face_axis] = "latitudes"
595+
596+
return xr.DataArray(
597+
res,
598+
dims=dims,
599+
coords={"latitudes": latitudes},
600+
name=self.name + "_zonal_mean"
601+
if self.name is not None
602+
else "zonal_mean",
603+
attrs={"zonal_mean": True, "conservative": False},
604+
)
570605

571-
face_axis = self.dims.index("n_face")
572-
dims = list(self.dims)
573-
dims[face_axis] = "latitudes"
606+
else:
607+
# Conservative zonal averaging
608+
if isinstance(lat, tuple):
609+
start, end, step = lat
610+
if step <= 0:
611+
raise ValueError(
612+
"Step size must be positive for conservative averaging."
613+
)
614+
if step < 0.1:
615+
warnings.warn(
616+
f"Very small step size ({step}°) may lead to performance issues...",
617+
UserWarning,
618+
stacklevel=2,
619+
)
620+
num_points = int(round((end - start) / step)) + 1
621+
edges = np.linspace(start, end, num_points)
622+
edges = np.clip(edges, -90, 90)
623+
elif isinstance(lat, (list, np.ndarray)):
624+
edges = np.asarray(lat, dtype=float)
625+
else:
626+
raise ValueError(
627+
"For conservative averaging, 'lat' must be a tuple (start, end, step) or array-like band edges."
628+
)
574629

575-
uxda = UxDataArray(
576-
res,
577-
uxgrid=self.uxgrid,
578-
dims=dims,
579-
coords={"latitudes": latitudes},
580-
name=self.name + "_zonal_mean" if self.name is not None else "zonal_mean",
581-
attrs={"zonal_mean": True},
582-
)
630+
if edges.ndim != 1 or edges.size < 2:
631+
raise ValueError("Band edges must be 1D with at least two values")
583632

584-
return uxda
633+
res = _compute_conservative_zonal_mean_bands(self, edges)
634+
635+
# Use band centers as coordinate values
636+
centers = 0.5 * (edges[:-1] + edges[1:])
637+
638+
dims = list(self.dims)
639+
dims[face_axis] = "latitudes"
640+
641+
return xr.DataArray(
642+
res,
643+
dims=dims,
644+
coords={"latitudes": centers},
645+
name=self.name + "_zonal_mean"
646+
if self.name is not None
647+
else "zonal_mean",
648+
attrs={
649+
"zonal_mean": True,
650+
"conservative": True,
651+
"lat_band_edges": edges,
652+
},
653+
)
585654

586-
# Alias for 'zonal_mean', since this name is also commonly used.
587-
zonal_average = zonal_mean
655+
def zonal_average(self, lat=(-90, 90, 10), conservative: bool = False, **kwargs):
656+
"""Alias of zonal_mean; prefer `zonal_mean` for primary API."""
657+
return self.zonal_mean(lat=lat, conservative=conservative, **kwargs)
588658

589659
def azimuthal_mean(
590660
self,

0 commit comments

Comments
 (0)