Skip to content

Commit 1fc29fc

Browse files
committed
Merge branch 'main' into rajeeja/vector_calc_div
2 parents 1df4cdb + f7b3ad4 commit 1fc29fc

File tree

6 files changed

+3290
-108
lines changed

6 files changed

+3290
-108
lines changed

docs/user-guide/zonal-average.ipynb

Lines changed: 2783 additions & 54 deletions
Large diffs are not rendered by default.

test/grid/integrate/test_zonal.py

Lines changed: 111 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -92,3 +92,114 @@ def test_mismatched_dims():
9292

9393
assert za.shape == (10, 19, 5)
9494
assert za.dims[1] == "latitudes"
95+
96+
97+
class TestConservativeZonalMean:
98+
"""Test conservative zonal mean functionality."""
99+
100+
def test_conservative_zonal_mean_basic(self, gridpath, datasetpath):
101+
"""Test basic conservative zonal mean with bands."""
102+
grid_path = gridpath("ugrid", "outCSne30", "outCSne30.ug")
103+
data_path = datasetpath("ugrid", "outCSne30", "outCSne30_vortex.nc")
104+
uxds = ux.open_dataset(grid_path, data_path)
105+
106+
# Test with explicit bands
107+
bands = np.array([-90, -30, 0, 30, 90])
108+
result = uxds["psi"].zonal_mean(lat=bands, conservative=True)
109+
110+
# Should have one less value than bands (4 bands from 5 edges)
111+
assert result.shape == (len(bands) - 1,)
112+
assert np.all(np.isfinite(result.values))
113+
114+
def test_conservative_float_step_size(self, gridpath, datasetpath):
115+
"""Test conservative zonal mean with float step sizes."""
116+
grid_path = gridpath("ugrid", "outCSne30", "outCSne30.ug")
117+
data_path = datasetpath("ugrid", "outCSne30", "outCSne30_vortex.nc")
118+
uxds = ux.open_dataset(grid_path, data_path)
119+
120+
# Test with float step size (5.5 degrees)
121+
result = uxds["psi"].zonal_mean(lat=(-90, 90, 0.05), conservative=True)
122+
123+
# Should get valid results
124+
assert len(result) > 0
125+
assert np.all(np.isfinite(result.values))
126+
127+
# Test with reasonable float step size (no warning)
128+
result = uxds["psi"].zonal_mean(lat=(-90, 90, 5.5), conservative=True)
129+
expected_n_bands = int(np.ceil(180 / 5.5))
130+
assert result.shape[0] == expected_n_bands
131+
assert np.all(np.isfinite(result.values))
132+
133+
def test_conservative_near_pole(self, gridpath, datasetpath):
134+
"""Test conservative zonal mean with bands near the poles."""
135+
grid_path = gridpath("ugrid", "outCSne30", "outCSne30.ug")
136+
data_path = datasetpath("ugrid", "outCSne30", "outCSne30_vortex.nc")
137+
uxds = ux.open_dataset(grid_path, data_path)
138+
139+
# Test near north pole with float step
140+
bands_north = np.array([85.0, 87.5, 90.0])
141+
result_north = uxds["psi"].zonal_mean(lat=bands_north, conservative=True)
142+
assert result_north.shape == (2,)
143+
assert np.all(np.isfinite(result_north.values))
144+
145+
# Test near south pole with float step
146+
bands_south = np.array([-90.0, -87.5, -85.0])
147+
result_south = uxds["psi"].zonal_mean(lat=bands_south, conservative=True)
148+
assert result_south.shape == (2,)
149+
assert np.all(np.isfinite(result_south.values))
150+
151+
# Test spanning pole with non-integer step
152+
bands_span = np.array([88.5, 89.25, 90.0])
153+
result_span = uxds["psi"].zonal_mean(lat=bands_span, conservative=True)
154+
assert result_span.shape == (2,)
155+
assert np.all(np.isfinite(result_span.values))
156+
157+
def test_conservative_step_size_validation(self, gridpath, datasetpath):
158+
"""Test that step size validation works correctly."""
159+
grid_path = gridpath("ugrid", "outCSne30", "outCSne30.ug")
160+
data_path = datasetpath("ugrid", "outCSne30", "outCSne30_vortex.nc")
161+
uxds = ux.open_dataset(grid_path, data_path)
162+
163+
# Test negative step size
164+
with pytest.raises(ValueError, match="Step size must be positive"):
165+
uxds["psi"].zonal_mean(lat=(-90, 90, -10), conservative=True)
166+
167+
# Test zero step size
168+
with pytest.raises(ValueError, match="Step size must be positive"):
169+
uxds["psi"].zonal_mean(lat=(-90, 90, 0), conservative=True)
170+
171+
def test_conservative_full_sphere_conservation(self, gridpath, datasetpath):
172+
"""Test that single band covering entire sphere conserves global mean."""
173+
grid_path = gridpath("ugrid", "outCSne30", "outCSne30.ug")
174+
data_path = datasetpath("ugrid", "outCSne30", "outCSne30_vortex.nc")
175+
uxds = ux.open_dataset(grid_path, data_path)
176+
177+
# Single band covering entire sphere
178+
bands = np.array([-90, 90])
179+
result = uxds["psi"].zonal_mean(lat=bands, conservative=True)
180+
181+
# Compare with global mean
182+
global_mean = uxds["psi"].mean()
183+
184+
assert result.shape == (1,)
185+
assert result.values[0] == pytest.approx(global_mean.values, rel=0.01)
186+
187+
def test_conservative_vs_nonconservative_comparison(self, gridpath, datasetpath):
188+
"""Compare conservative and non-conservative methods."""
189+
grid_path = gridpath("ugrid", "outCSne30", "outCSne30.ug")
190+
data_path = datasetpath("ugrid", "outCSne30", "outCSne30_vortex.nc")
191+
uxds = ux.open_dataset(grid_path, data_path)
192+
193+
# Non-conservative at band centers
194+
lat_centers = np.array([-60, 0, 60])
195+
non_conservative = uxds["psi"].zonal_mean(lat=lat_centers)
196+
197+
# Conservative with bands
198+
bands = np.array([-90, -30, 30, 90])
199+
conservative = uxds["psi"].zonal_mean(lat=bands, conservative=True)
200+
201+
# Results should be similar but not identical
202+
assert non_conservative.shape == conservative.shape
203+
# Check they are in the same ballpark
204+
assert np.all(np.abs(conservative.values - non_conservative.values) <
205+
np.abs(non_conservative.values) * 0.5)

test/test_plot.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -76,17 +76,19 @@ def test_dataarray_methods(gridpath, datasetpath):
7676
# plot.scatter() is an xarray method
7777
assert hasattr(uxds.plot, 'scatter')
7878

79+
import hvplot.xarray # registers .hvplot accessor
80+
7981
def test_line(gridpath):
8082
mesh_path = gridpath("mpas", "QU", "oQU480.231010.nc")
8183
uxds = ux.open_dataset(mesh_path, mesh_path)
82-
_plot_line = uxds['bottomDepth'].zonal_average().plot.line()
84+
_plot_line = uxds['bottomDepth'].zonal_average().hvplot.line()
8385
assert isinstance(_plot_line, hv.Curve)
8486

8587
def test_scatter(gridpath):
8688
mesh_path = gridpath("mpas", "QU", "oQU480.231010.nc")
8789
uxds = ux.open_dataset(mesh_path, mesh_path)
88-
_plot_line = uxds['bottomDepth'].zonal_average().plot.scatter()
89-
assert isinstance(_plot_line, hv.Scatter)
90+
_plot_scatter = uxds['bottomDepth'].zonal_average().hvplot.scatter()
91+
assert isinstance(_plot_scatter, hv.Scatter)
9092

9193

9294

uxarray/core/dataarray.py

Lines changed: 114 additions & 44 deletions
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,10 @@
2121
)
2222
from uxarray.core.vector_calculus import _calculate_divergence
2323
from uxarray.core.utils import _map_dims_to_ugrid
24-
from uxarray.core.zonal import _compute_non_conservative_zonal_mean
24+
from uxarray.core.zonal import (
25+
_compute_conservative_zonal_mean_bands,
26+
_compute_non_conservative_zonal_mean,
27+
)
2528
from uxarray.cross_sections import UxDataArrayCrossSectionAccessor
2629
from uxarray.formatting_html import array_repr
2730
from uxarray.grid import Grid
@@ -511,16 +514,20 @@ def integrate(
511514

512515
return uxda
513516

514-
def zonal_mean(self, lat=(-90, 90, 10), **kwargs):
515-
"""Compute averages along lines of constant latitude.
517+
def zonal_mean(self, lat=(-90, 90, 10), conservative: bool = False, **kwargs):
518+
"""Compute non-conservative or conservative averages along lines of constant latitude or latitude bands.
516519
517520
Parameters
518521
----------
519522
lat : tuple, float, or array-like, default=(-90, 90, 10)
520-
Latitude values in degrees. Can be specified as:
521-
- tuple (start, end, step): Computes means at intervals of `step` in range [start, end]
522-
- float: Computes mean for a single latitude
523-
- array-like: Computes means for each specified latitude
523+
Latitude specification:
524+
- tuple (start, end, step): For non-conservative, computes means at intervals of `step`.
525+
For conservative, creates band edges via np.arange(start, end+step, step).
526+
- float: Single latitude for non-conservative averaging
527+
- array-like: For non-conservative, latitudes to sample. For conservative, band edges.
528+
conservative : bool, default=False
529+
If True, performs conservative (area-weighted) zonal averaging over latitude bands.
530+
If False, performs traditional (non-conservative) averaging at latitude lines.
524531
525532
Returns
526533
-------
@@ -530,62 +537,125 @@ def zonal_mean(self, lat=(-90, 90, 10), **kwargs):
530537
531538
Examples
532539
--------
533-
# All latitudes from -90° to 90° at 10° intervals
540+
# Non-conservative averaging from -90° to 90° at 10° intervals by default
534541
>>> uxds["var"].zonal_mean()
535542
536-
# Single latitude at 30°
543+
# Single latitude (non-conservative) over 30° latitude
537544
>>> uxds["var"].zonal_mean(lat=30.0)
538545
539-
# Range from -60° to 60° at 10° intervals
540-
>>> uxds["var"].zonal_mean(lat=(-60, 60, 10))
546+
# Conservative averaging over latitude bands
547+
>>> uxds["var"].zonal_mean(lat=(-60, 60, 10), conservative=True)
548+
549+
# Conservative with explicit band edges
550+
>>> uxds["var"].zonal_mean(lat=[-90, -30, 0, 30, 90], conservative=True)
541551
542552
Notes
543553
-----
544-
Only supported for face-centered data variables. Candidate faces are determined
545-
using spherical bounding boxes - faces whose bounds contain the target latitude
546-
are included in calculations.
554+
Only supported for face-centered data variables.
555+
556+
Conservative averaging preserves integral quantities and is recommended for
557+
physical analysis. Non-conservative averaging samples at latitude lines.
547558
"""
548559
if not self._face_centered():
549560
raise ValueError(
550561
"Zonal mean computations are currently only supported for face-centered data variables."
551562
)
552563

553-
if isinstance(lat, tuple):
554-
# zonal mean over a range of latitudes
555-
latitudes = np.arange(lat[0], lat[1] + lat[2], lat[2])
556-
latitudes = np.clip(latitudes, -90, 90)
557-
elif isinstance(lat, (float, int)):
558-
# zonal mean over a single latitude
559-
latitudes = [lat]
560-
elif isinstance(lat, (list, np.ndarray)):
561-
# zonal mean over an array of arbitrary latitudes
562-
latitudes = np.asarray(lat)
563-
else:
564-
raise ValueError(
565-
"Invalid value for 'lat' provided. Must either be a single scalar value, tuple (min_lat, max_lat, step), or array-like."
564+
face_axis = self.dims.index("n_face")
565+
566+
if not conservative:
567+
# Non-conservative (traditional) zonal averaging
568+
if isinstance(lat, tuple):
569+
start, end, step = lat
570+
if step <= 0:
571+
raise ValueError("Step size must be positive.")
572+
if step < 0.1:
573+
warnings.warn(
574+
f"Very small step size ({step}°) may lead to performance issues...",
575+
UserWarning,
576+
stacklevel=2,
577+
)
578+
num_points = int(round((end - start) / step)) + 1
579+
latitudes = np.linspace(start, end, num_points)
580+
latitudes = np.clip(latitudes, -90, 90)
581+
elif isinstance(lat, (float, int)):
582+
latitudes = [lat]
583+
elif isinstance(lat, (list, np.ndarray)):
584+
latitudes = np.asarray(lat)
585+
else:
586+
raise ValueError(
587+
"Invalid value for 'lat' provided. Must be a scalar, tuple (min_lat, max_lat, step), or array-like."
588+
)
589+
590+
res = _compute_non_conservative_zonal_mean(
591+
uxda=self, latitudes=latitudes, **kwargs
566592
)
567593

568-
res = _compute_non_conservative_zonal_mean(
569-
uxda=self, latitudes=latitudes, **kwargs
570-
)
594+
dims = list(self.dims)
595+
dims[face_axis] = "latitudes"
596+
597+
return xr.DataArray(
598+
res,
599+
dims=dims,
600+
coords={"latitudes": latitudes},
601+
name=self.name + "_zonal_mean"
602+
if self.name is not None
603+
else "zonal_mean",
604+
attrs={"zonal_mean": True, "conservative": False},
605+
)
571606

572-
face_axis = self.dims.index("n_face")
573-
dims = list(self.dims)
574-
dims[face_axis] = "latitudes"
607+
else:
608+
# Conservative zonal averaging
609+
if isinstance(lat, tuple):
610+
start, end, step = lat
611+
if step <= 0:
612+
raise ValueError(
613+
"Step size must be positive for conservative averaging."
614+
)
615+
if step < 0.1:
616+
warnings.warn(
617+
f"Very small step size ({step}°) may lead to performance issues...",
618+
UserWarning,
619+
stacklevel=2,
620+
)
621+
num_points = int(round((end - start) / step)) + 1
622+
edges = np.linspace(start, end, num_points)
623+
edges = np.clip(edges, -90, 90)
624+
elif isinstance(lat, (list, np.ndarray)):
625+
edges = np.asarray(lat, dtype=float)
626+
else:
627+
raise ValueError(
628+
"For conservative averaging, 'lat' must be a tuple (start, end, step) or array-like band edges."
629+
)
575630

576-
uxda = UxDataArray(
577-
res,
578-
uxgrid=self.uxgrid,
579-
dims=dims,
580-
coords={"latitudes": latitudes},
581-
name=self.name + "_zonal_mean" if self.name is not None else "zonal_mean",
582-
attrs={"zonal_mean": True},
583-
)
631+
if edges.ndim != 1 or edges.size < 2:
632+
raise ValueError("Band edges must be 1D with at least two values")
584633

585-
return uxda
634+
res = _compute_conservative_zonal_mean_bands(self, edges)
635+
636+
# Use band centers as coordinate values
637+
centers = 0.5 * (edges[:-1] + edges[1:])
638+
639+
dims = list(self.dims)
640+
dims[face_axis] = "latitudes"
641+
642+
return xr.DataArray(
643+
res,
644+
dims=dims,
645+
coords={"latitudes": centers},
646+
name=self.name + "_zonal_mean"
647+
if self.name is not None
648+
else "zonal_mean",
649+
attrs={
650+
"zonal_mean": True,
651+
"conservative": True,
652+
"lat_band_edges": edges,
653+
},
654+
)
586655

587-
# Alias for 'zonal_mean', since this name is also commonly used.
588-
zonal_average = zonal_mean
656+
def zonal_average(self, lat=(-90, 90, 10), conservative: bool = False, **kwargs):
657+
"""Alias of zonal_mean; prefer `zonal_mean` for primary API."""
658+
return self.zonal_mean(lat=lat, conservative=conservative, **kwargs)
589659

590660
def azimuthal_mean(
591661
self,

0 commit comments

Comments
 (0)