Skip to content

Commit 3013f29

Browse files
committed
o Address comments, make one test file for vector calculus
1 parent 7d287d6 commit 3013f29

File tree

3 files changed

+153
-162
lines changed

3 files changed

+153
-162
lines changed

test/core/test_gradient.py

Lines changed: 0 additions & 145 deletions
This file was deleted.
Lines changed: 140 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,144 @@
55
import numpy.testing as nt
66

77

8-
class TestQuadHex:
8+
# TODO: pytest fixtures
9+
10+
11+
class TestGradientQuadHex:
12+
13+
def test_gradient_output_format(self, gridpath, datasetpath):
14+
"""Tests the output format of gradient functionality"""
15+
uxds = ux.open_dataset(gridpath("ugrid", "quad-hexagon", "grid.nc"), datasetpath("ugrid", "quad-hexagon", "data.nc"))
16+
17+
grad_ds = uxds['t2m'].gradient()
18+
19+
assert isinstance(grad_ds, ux.UxDataset)
20+
assert "zonal_gradient" in grad_ds
21+
assert "meridional_gradient" in grad_ds
22+
assert "gradient" in grad_ds.attrs
23+
assert uxds['t2m'].sizes == grad_ds.sizes
24+
25+
def test_gradient_all_boundary_faces(self, gridpath, datasetpath):
26+
"""Quad hexagon grid has 4 faces, each of which are on the boundary, so the expected gradients are zero for both components"""
27+
uxds = ux.open_dataset(gridpath("ugrid", "quad-hexagon", "grid.nc"), datasetpath("ugrid", "quad-hexagon", "data.nc"))
28+
29+
grad = uxds['t2m'].gradient()
30+
31+
assert np.isnan(grad['meridional_gradient']).all()
32+
assert np.isnan(grad['zonal_gradient']).all()
33+
34+
35+
class TestGradientMPASOcean:
36+
37+
def test_gradient(self, gridpath, datasetpath):
38+
uxds = ux.open_dataset(gridpath("mpas", "QU", "480", "grid.nc"), datasetpath("mpas", "QU", "480", "data.nc"))
39+
40+
grad = uxds['bottomDepth'].gradient()
41+
42+
# There should be some boundary faces
43+
assert np.isnan(grad['meridional_gradient']).any()
44+
assert np.isnan(grad['zonal_gradient']).any()
45+
46+
# Not every face is on the boundary, ensure there are valid values
47+
assert not np.isnan(grad['meridional_gradient']).all()
48+
assert not np.isnan(grad['zonal_gradient']).all()
49+
50+
51+
class TestGradientDyamondSubset:
52+
53+
center_fidx = 153
54+
left_fidx = 100
55+
right_fidx = 164
56+
top_fidx = 154
57+
bottom_fidx = 66
58+
59+
def test_lat_field(self, gridpath, datasetpath):
60+
"""Gradient of a latitude field. All vectors should be pointing east."""
61+
uxds = ux.open_dataset(
62+
gridpath("mpas", "dyamond-30km", "gradient_grid_subset.nc"),
63+
datasetpath("mpas", "dyamond-30km", "gradient_data_subset.nc")
64+
)
65+
grad = uxds['face_lat'].gradient()
66+
zg, mg = grad.zonal_gradient, grad.meridional_gradient
67+
assert mg.max() > zg.max()
68+
69+
assert mg.min() > zg.max()
70+
71+
72+
def test_lon_field(self, gridpath, datasetpath):
73+
"""Gradient of a longitude field. All vectors should be pointing north."""
74+
uxds = ux.open_dataset(
75+
gridpath("mpas", "dyamond-30km", "gradient_grid_subset.nc"),
76+
datasetpath("mpas", "dyamond-30km", "gradient_data_subset.nc")
77+
)
78+
grad = uxds['face_lon'].gradient()
79+
zg, mg = grad.zonal_gradient, grad.meridional_gradient
80+
assert zg.max() > mg.max()
81+
82+
assert zg.min() > mg.max()
83+
84+
def test_gaussian_field(self, gridpath, datasetpath):
85+
"""Gradient of a gaussian field. All vectors should be pointing toward the center"""
86+
uxds = ux.open_dataset(
87+
gridpath("mpas", "dyamond-30km", "gradient_grid_subset.nc"),
88+
datasetpath("mpas", "dyamond-30km", "gradient_data_subset.nc")
89+
)
90+
grad = uxds['gaussian'].gradient()
91+
zg, mg = grad.zonal_gradient, grad.meridional_gradient
92+
mag = np.hypot(zg, mg)
93+
angle = np.arctan2(mg, zg)
94+
95+
# Ensure a valid range for min/max
96+
assert zg.min() < 0
97+
assert zg.max() > 0
98+
assert mg.min() < 0
99+
assert mg.max() > 0
100+
101+
# The Magnitude at the center is less than the corners
102+
assert mag[self.center_fidx] < mag[self.left_fidx]
103+
assert mag[self.center_fidx] < mag[self.right_fidx]
104+
assert mag[self.center_fidx] < mag[self.top_fidx]
105+
assert mag[self.center_fidx] < mag[self.bottom_fidx]
106+
107+
# Pointing Towards Center
108+
assert angle[self.left_fidx] < 0
109+
assert angle[self.right_fidx] > 0
110+
assert angle[self.top_fidx] < 0
111+
assert angle[self.bottom_fidx] > 0
112+
113+
114+
115+
def test_inverse_gaussian_field(self, gridpath, datasetpath):
116+
"""Gradient of an inverse gaussian field. All vectors should be pointing outward from the center."""
117+
uxds = ux.open_dataset(
118+
gridpath("mpas", "dyamond-30km", "gradient_grid_subset.nc"),
119+
datasetpath("mpas", "dyamond-30km", "gradient_data_subset.nc")
120+
)
121+
grad = uxds['inverse_gaussian'].gradient()
122+
zg, mg = grad.zonal_gradient, grad.meridional_gradient
123+
mag = np.hypot(zg, mg)
124+
angle = np.arctan2(mg, zg)
125+
126+
# Ensure a valid range for min/max
127+
assert zg.min() < 0
128+
assert zg.max() > 0
129+
assert mg.min() < 0
130+
assert mg.max() > 0
131+
132+
# The Magnitude at the center is less than the corners
133+
assert mag[self.center_fidx] < mag[self.left_fidx]
134+
assert mag[self.center_fidx] < mag[self.right_fidx]
135+
assert mag[self.center_fidx] < mag[self.top_fidx]
136+
assert mag[self.center_fidx] < mag[self.bottom_fidx]
137+
138+
# Pointing Away from Center
139+
assert angle[self.left_fidx] > 0
140+
assert angle[self.right_fidx] < 0
141+
assert angle[self.top_fidx] > 0
142+
assert angle[self.bottom_fidx] < 0
143+
144+
145+
class TestDivergenceQuadHex:
9146

10147
def test_divergence_output_format(self, gridpath, datasetpath):
11148
"""Tests the output format of divergence functionality"""
@@ -39,7 +176,7 @@ def test_divergence_input_validation(self, gridpath, datasetpath):
39176
# This would require creating data with different dims, so we'll skip for now
40177

41178

42-
class TestMPASOcean:
179+
class TestDivergenceMPASOcean:
43180

44181
def test_divergence_basic(self, gridpath, datasetpath):
45182
"""Basic test of divergence computation"""
@@ -56,7 +193,7 @@ def test_divergence_basic(self, gridpath, datasetpath):
56193
assert np.isfinite(div_field.values).any()
57194

58195

59-
class TestDyamondSubset:
196+
class TestDivergenceDyamondSubset:
60197

61198
def test_divergence_constant_field(self, gridpath, datasetpath):
62199
"""Test divergence of constant vector field (should be zero)"""

uxarray/core/dataarray.py

Lines changed: 13 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -1397,24 +1397,23 @@ def divergence(self, other: "UxDataArray", **kwargs) -> "UxDataArray":
13971397
# For divergence: div(V) = ∂u/∂x + ∂v/∂y
13981398
# In spherical coordinates: div(V) = (1/cos(lat)) * ∂u/∂lon + ∂v/∂lat
13991399
# We use the zonal gradient (∂/∂lon) of u and meridional gradient (∂/∂lat) of v
1400-
divergence_values = (
1401-
u_gradient["zonal_gradient"].values
1402-
+ v_gradient["meridional_gradient"].values
1403-
)
1404-
1405-
# Create the divergence UxDataArray
1406-
divergence_da = UxDataArray(
1407-
data=divergence_values,
1408-
name="divergence",
1409-
dims=self.dims,
1410-
uxgrid=self.uxgrid,
1411-
attrs={
1400+
u = u_gradient["zonal_gradient"]
1401+
v = v_gradient["meridional_gradient"]
1402+
1403+
# Align DataArrays to ensure coords/dims match, then perform xarray-aware addition
1404+
u, v = xr.align(u, v)
1405+
divergence = u + v
1406+
divergence.name = "divergence"
1407+
divergence.attrs.update(
1408+
{
14121409
"divergence": True,
14131410
"units": "1/s" if "units" not in kwargs else kwargs["units"],
1414-
},
1415-
coords=self.coords,
1411+
}
14161412
)
14171413

1414+
# Wrap result as a UxDataArray while preserving uxgrid and coords
1415+
divergence_da = UxDataArray(divergence, uxgrid=self.uxgrid)
1416+
14181417
return divergence_da
14191418

14201419
def difference(self, destination: Optional[str] = "edge"):

0 commit comments

Comments
 (0)