Skip to content

Commit 4791058

Browse files
committed
o Add cumulative_integrate added in xarray 05.2023 Fixes #98
1 parent 5578e12 commit 4791058

File tree

2 files changed

+59
-1
lines changed

2 files changed

+59
-1
lines changed

test/test_cross_sections.py

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -226,3 +226,27 @@ def test_cross_section(gridpath, datasetpath):
226226
_ = uxds['RELHUM'].cross_section(start=(45, 45))
227227
_ = uxds['RELHUM'].cross_section(lon=45, end=(45, 45))
228228
_ = uxds['RELHUM'].cross_section()
229+
230+
231+
def test_cross_section_cumulative_integrate(gridpath, datasetpath):
232+
uxds = ux.open_dataset(gridpath("scrip", "ne30pg2", "grid.nc"), datasetpath("scrip", "ne30pg2", "data.nc"))
233+
cs = uxds['RELHUM'].cross_section(start=(-45, -45), end=(45, 45), steps=6)
234+
cs = cs.assign_coords(distance=("steps", np.linspace(0.0, 1.0, cs.sizes["steps"])))
235+
236+
cs_ux = ux.UxDataArray(cs, uxgrid=uxds.uxgrid)
237+
238+
result = cs_ux.cumulative_integrate(coord="distance")
239+
expected = cs.cumulative_integrate(coord="distance")
240+
241+
assert isinstance(result, ux.UxDataArray)
242+
assert result.uxgrid == cs_ux.uxgrid
243+
xr.testing.assert_allclose(result.to_xarray(), expected)
244+
245+
246+
def test_cumulative_integrate_requires_coord(gridpath, datasetpath):
247+
uxds = ux.open_dataset(gridpath("scrip", "ne30pg2", "grid.nc"), datasetpath("scrip", "ne30pg2", "data.nc"))
248+
cs = uxds['RELHUM'].cross_section(start=(-45, -45), end=(45, 45), steps=3)
249+
cs_ux = ux.UxDataArray(cs, uxgrid=uxds.uxgrid)
250+
251+
with pytest.raises(ValueError, match="Coordinate .* must be specified"):
252+
cs_ux.cumulative_integrate()

uxarray/core/dataarray.py

Lines changed: 35 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22

33
import warnings
44
from html import escape
5-
from typing import TYPE_CHECKING, Any, Hashable, Literal, Mapping, Optional
5+
from typing import TYPE_CHECKING, Any, Hashable, Literal, Mapping, Optional, Sequence
66
from warnings import warn
77

88
import numpy as np
@@ -513,6 +513,40 @@ def integrate(
513513

514514
return uxda
515515

516+
def cumulative_integrate(
517+
self,
518+
coord: Hashable | Sequence[Hashable] | None = None,
519+
datetime_unit: Optional[str] = None,
520+
) -> "UxDataArray":
521+
"""
522+
Integrate cumulatively along the given coordinate using the trapezoidal rule.
523+
524+
Mirrors :py:meth:`xarray.DataArray.cumulative_integrate` while preserving
525+
``uxgrid`` on the result.
526+
527+
Parameters
528+
----------
529+
coord : Hashable or sequence of Hashable
530+
Coordinate(s) used for the integration. This must be provided.
531+
datetime_unit : str, optional
532+
Unit to use when integrating over datetime coordinates.
533+
534+
Returns
535+
-------
536+
UxDataArray
537+
The cumulative integral along the specified coordinate.
538+
"""
539+
if coord is None:
540+
raise ValueError(
541+
"Coordinate ('coord') must be specified for cumulative_integrate."
542+
)
543+
544+
integrated = super().cumulative_integrate(
545+
coord=coord, datetime_unit=datetime_unit
546+
)
547+
548+
return UxDataArray(integrated, uxgrid=self.uxgrid)
549+
516550
def zonal_mean(self, lat=(-90, 90, 10), conservative: bool = False, **kwargs):
517551
"""Compute non-conservative or conservative averages of a face-centered variable along lines of constant latitude or latitude bands.
518552

0 commit comments

Comments
 (0)