Skip to content

Commit 201246f

Browse files
seismanweiji14
andauthored
GMTDataArrayAccessor: Support applying grid operations on the current xarray.DataArray object (#3854)
Co-authored-by: Wei Ji <[email protected]>
1 parent 12b789b commit 201246f

File tree

2 files changed

+108
-1
lines changed

2 files changed

+108
-1
lines changed

pygmt/tests/test_xarray_accessor.py

Lines changed: 50 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,10 +14,19 @@
1414
from pygmt.datasets import load_earth_relief
1515
from pygmt.enums import GridRegistration, GridType
1616
from pygmt.exceptions import GMTValueError
17+
from pygmt.helpers.testing import load_static_earth_relief
1718

1819
_HAS_NETCDF4 = bool(importlib.util.find_spec("netCDF4"))
1920

2021

22+
@pytest.fixture(scope="module", name="grid")
23+
def fixture_grid():
24+
"""
25+
Load the grid data from the sample earth_relief file.
26+
"""
27+
return load_static_earth_relief()
28+
29+
2130
def test_xarray_accessor_gridline_cartesian():
2231
"""
2332
Check that the accessor returns the correct registration and gtype values for a
@@ -169,3 +178,44 @@ def test_xarray_accessor_tiled_grid_slice_and_add():
169178
added_grid.gmt.gtype = GridType.GEOGRAPHIC
170179
assert added_grid.gmt.registration is GridRegistration.PIXEL
171180
assert added_grid.gmt.gtype is GridType.GEOGRAPHIC
181+
182+
183+
def test_xarray_accessor_clip(grid):
184+
"""
185+
Check that the accessor has the clip method and that it works correctly.
186+
187+
This test is adapted from the `test_grdclip_no_outgrid` test.
188+
"""
189+
clipped_grid = grid.gmt.clip(
190+
below=[550, -1000], above=[700, 1000], region=[-53, -49, -19, -16]
191+
)
192+
193+
expected_clipped_grid = xr.DataArray(
194+
data=[
195+
[1000.0, 570.5, -1000.0, -1000.0],
196+
[1000.0, 1000.0, 571.5, 638.5],
197+
[555.5, 556.0, 580.0, 1000.0],
198+
],
199+
coords={"lon": [-52.5, -51.5, -50.5, -49.5], "lat": [-18.5, -17.5, -16.5]},
200+
dims=["lat", "lon"],
201+
)
202+
xr.testing.assert_allclose(a=clipped_grid, b=expected_clipped_grid)
203+
204+
205+
def test_xarray_accessor_histeq(grid):
206+
"""
207+
Check that the accessor has the histeq method and that it works correctly.
208+
209+
This test is adapted from the `test_equalize_grid_no_outgrid` test.
210+
"""
211+
equalized_grid = grid.gmt.histeq(divisions=2, region=[-52, -48, -22, -18])
212+
213+
expected_equalized_grid = xr.DataArray(
214+
data=[[0, 0, 0, 1], [0, 0, 0, 1], [0, 0, 1, 1], [1, 1, 1, 1]],
215+
coords={
216+
"lon": [-51.5, -50.5, -49.5, -48.5],
217+
"lat": [-21.5, -20.5, -19.5, -18.5],
218+
},
219+
dims=["lat", "lon"],
220+
)
221+
xr.testing.assert_allclose(a=equalized_grid, b=expected_equalized_grid)

pygmt/xarray/accessor.py

Lines changed: 58 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,12 +3,25 @@
33
"""
44

55
import contextlib
6+
import functools
67
from pathlib import Path
78

89
import xarray as xr
910
from pygmt.enums import GridRegistration, GridType
1011
from pygmt.exceptions import GMTValueError
11-
from pygmt.src.grdinfo import grdinfo
12+
from pygmt.src import (
13+
dimfilter,
14+
grdclip,
15+
grdcut,
16+
grdfill,
17+
grdfilter,
18+
grdgradient,
19+
grdhisteq,
20+
grdinfo,
21+
grdproject,
22+
grdsample,
23+
grdtrack,
24+
)
1225

1326

1427
@xr.register_dataarray_accessor("gmt")
@@ -23,6 +36,11 @@ class GMTDataArrayAccessor:
2336
- ``registration``: Grid registration type :class:`pygmt.enums.GridRegistration`.
2437
- ``gtype``: Grid coordinate system type :class:`pygmt.enums.GridType`.
2538
39+
The *gmt* accessor also provides a set of grid-operation methods that enables
40+
applying GMT's grid processing functionalities directly to the current
41+
:class:`xarray.DataArray` object. See the summary table below for the list of
42+
available methods.
43+
2644
Notes
2745
-----
2846
When accessed the first time, the *gmt* accessor will first be initialized to the
@@ -150,6 +168,19 @@ class GMTDataArrayAccessor:
150168
>>> zval.gmt.gtype = GridType.GEOGRAPHIC
151169
>>> zval.gmt.registration, zval.gmt.gtype
152170
(<GridRegistration.GRIDLINE: 0>, <GridType.GEOGRAPHIC: 1>)
171+
172+
Instead of calling a grid-processing function and passing the
173+
:class:`xarray.DataArray` object as an input, you can call the corresponding method
174+
directly on the object. For example, the following two are equivalent:
175+
176+
>>> from pygmt.datasets import load_earth_relief
177+
>>> grid = load_earth_relief(resolution="30m", region=[10, 30, 15, 25])
178+
>>> # Create a new grid from an input grid. Set all values below 1,000 to 0 and all
179+
>>> # values above 1,500 to 10,000.
180+
>>> # Option 1:
181+
>>> new_grid = pygmt.grdclip(grid=grid, below=[1000, 0], above=[1500, 10000])
182+
>>> # Option 2:
183+
>>> new_grid = grid.gmt.clip(below=[1000, 0], above=[1500, 10000])
153184
"""
154185

155186
def __init__(self, xarray_obj: xr.DataArray):
@@ -200,3 +231,29 @@ def gtype(self, value: GridType | int):
200231
value, description="grid coordinate system type", choices=GridType
201232
)
202233
self._gtype = GridType(value)
234+
235+
@staticmethod
236+
def _make_method(func):
237+
"""
238+
Create a wrapper method for PyGMT grid-processing methods.
239+
240+
The :class:`xarray.DataArray` object is passed as the first argument.
241+
"""
242+
243+
@functools.wraps(func)
244+
def wrapper(self, *args, **kwargs):
245+
return func(self._obj, *args, **kwargs)
246+
247+
return wrapper
248+
249+
# Accessor methods for grid operations.
250+
clip = _make_method(grdclip)
251+
cut = _make_method(grdcut)
252+
dimfilter = _make_method(dimfilter)
253+
histeq = _make_method(grdhisteq.equalize_grid)
254+
fill = _make_method(grdfill)
255+
filter = _make_method(grdfilter)
256+
gradient = _make_method(grdgradient)
257+
project = _make_method(grdproject)
258+
sample = _make_method(grdsample)
259+
track = _make_method(grdtrack)

0 commit comments

Comments
 (0)