Skip to content

Commit e62d2a2

Browse files
committed
o Add support for Floting-Point Bands and
new tests for the same. Use np.linspace instead of np.arange for latitude bands as it prevents the accumulation of precision errors.
1 parent 222c6d0 commit e62d2a2

File tree

2 files changed

+69
-2
lines changed

2 files changed

+69
-2
lines changed

test/grid/integrate/test_zonal.py

Lines changed: 57 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -111,6 +111,63 @@ def test_conservative_zonal_mean_basic(self, gridpath, datasetpath):
111111
assert result.shape == (len(bands) - 1,)
112112
assert np.all(np.isfinite(result.values))
113113

114+
def test_conservative_float_step_size(self, gridpath, datasetpath):
115+
"""Test conservative zonal mean with float step sizes."""
116+
grid_path = gridpath("ugrid", "outCSne30", "outCSne30.ug")
117+
data_path = datasetpath("ugrid", "outCSne30", "outCSne30_vortex.nc")
118+
uxds = ux.open_dataset(grid_path, data_path)
119+
120+
# Test with float step size (5.5 degrees)
121+
result = uxds["psi"].zonal_mean(lat=(-90, 90, 0.05), conservative=True)
122+
123+
# Should get valid results
124+
assert len(result) > 0
125+
assert np.all(np.isfinite(result.values))
126+
127+
# Test with reasonable float step size (no warning)
128+
result = uxds["psi"].zonal_mean(lat=(-90, 90, 5.5), conservative=True)
129+
expected_n_bands = int(np.ceil(180 / 5.5))
130+
assert result.shape[0] == expected_n_bands
131+
assert np.all(np.isfinite(result.values))
132+
133+
def test_conservative_near_pole(self, gridpath, datasetpath):
134+
"""Test conservative zonal mean with bands near the poles."""
135+
grid_path = gridpath("ugrid", "outCSne30", "outCSne30.ug")
136+
data_path = datasetpath("ugrid", "outCSne30", "outCSne30_vortex.nc")
137+
uxds = ux.open_dataset(grid_path, data_path)
138+
139+
# Test near north pole with float step
140+
bands_north = np.array([85.0, 87.5, 90.0])
141+
result_north = uxds["psi"].zonal_mean(lat=bands_north, conservative=True)
142+
assert result_north.shape == (2,)
143+
assert np.all(np.isfinite(result_north.values))
144+
145+
# Test near south pole with float step
146+
bands_south = np.array([-90.0, -87.5, -85.0])
147+
result_south = uxds["psi"].zonal_mean(lat=bands_south, conservative=True)
148+
assert result_south.shape == (2,)
149+
assert np.all(np.isfinite(result_south.values))
150+
151+
# Test spanning pole with non-integer step
152+
bands_span = np.array([88.5, 89.25, 90.0])
153+
result_span = uxds["psi"].zonal_mean(lat=bands_span, conservative=True)
154+
assert result_span.shape == (2,)
155+
assert np.all(np.isfinite(result_span.values))
156+
157+
def test_conservative_step_size_validation(self, gridpath, datasetpath):
158+
"""Test that step size validation works correctly."""
159+
grid_path = gridpath("ugrid", "outCSne30", "outCSne30.ug")
160+
data_path = datasetpath("ugrid", "outCSne30", "outCSne30_vortex.nc")
161+
uxds = ux.open_dataset(grid_path, data_path)
162+
163+
# Test negative step size
164+
with pytest.raises(ValueError, match="Step size must be positive"):
165+
uxds["psi"].zonal_mean(lat=(-90, 90, -10), conservative=True)
166+
167+
# Test zero step size
168+
with pytest.raises(ValueError, match="Step size must be positive"):
169+
uxds["psi"].zonal_mean(lat=(-90, 90, 0), conservative=True)
170+
114171
def test_conservative_full_sphere_conservation(self, gridpath, datasetpath):
115172
"""Test that single band covering entire sphere conserves global mean."""
116173
grid_path = gridpath("ugrid", "outCSne30", "outCSne30.ug")

uxarray/core/dataarray.py

Lines changed: 12 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -565,7 +565,11 @@ def zonal_mean(self, lat=(-90, 90, 10), conservative: bool = False, **kwargs):
565565
if not conservative:
566566
# Non-conservative (traditional) zonal averaging
567567
if isinstance(lat, tuple):
568-
latitudes = np.arange(lat[0], lat[1] + lat[2], lat[2])
568+
start, end, step = lat
569+
if step <= 0:
570+
raise ValueError("Step size must be positive.")
571+
num_points = int(round((end - start) / step)) + 1
572+
latitudes = np.linspace(start, end, num_points)
569573
latitudes = np.clip(latitudes, -90, 90)
570574
elif isinstance(lat, (float, int)):
571575
latitudes = [lat]
@@ -601,7 +605,13 @@ def zonal_mean(self, lat=(-90, 90, 10), conservative: bool = False, **kwargs):
601605
else:
602606
# Conservative zonal averaging
603607
if isinstance(lat, tuple):
604-
edges = np.arange(lat[0], lat[1] + lat[2], lat[2])
608+
start, end, step = lat
609+
if step <= 0:
610+
raise ValueError(
611+
"Step size must be positive for conservative averaging."
612+
)
613+
num_points = int(round((end - start) / step)) + 1
614+
edges = np.linspace(start, end, num_points)
605615
edges = np.clip(edges, -90, 90)
606616
elif isinstance(lat, (list, np.ndarray)):
607617
edges = np.asarray(lat, dtype=float)

0 commit comments

Comments
 (0)