diff --git a/pygmt/tests/test_xarray_backend.py b/pygmt/tests/test_xarray_backend.py index a9244118c97..d9f7d72c80a 100644 --- a/pygmt/tests/test_xarray_backend.py +++ b/pygmt/tests/test_xarray_backend.py @@ -9,6 +9,7 @@ import numpy.testing as npt import pytest import xarray as xr +from pygmt.clib import Session from pygmt.enums import GridRegistration, GridType from pygmt.exceptions import GMTValueError from pygmt.helpers import GMTTempFile @@ -38,6 +39,32 @@ def test_xarray_backend_load_dataarray(): dataarray.to_netcdf(tmpfile.name) +def test_xarray_backend_load_dataarray_temp_nc_grid(): + """ + Check that xarray.load_dataarray works to read a temporary netCDF grid, and ensure + that GMTDataArrayAccessor information is retained after original file is deleted. + + This is a regression test for + https://github.com/GenericMappingTools/pygmt/issues/4005 + """ + + with Session() as lib: + with GMTTempFile(suffix=".nc") as tmpfile: + args = [ + "@earth_relief_01d_g", + "-T", # change from gridline to pixel registration + f"-G{tmpfile.name}", + ] + lib.call_module(module="grdedit", args=args) + dataarray = xr.load_dataarray( + tmpfile.name, engine="gmt", raster_kind="grid" + ) + + # Ensure GMTDataArrayAccessor info is preserved after tempfile is deleted + assert dataarray.gmt.registration is GridRegistration.PIXEL + assert dataarray.gmt.gtype is GridType.CARTESIAN + + def test_xarray_backend_gmt_open_nc_grid(): """ Ensure that passing engine='gmt' to xarray.open_dataarray works to open a netCDF diff --git a/pygmt/xarray/accessor.py b/pygmt/xarray/accessor.py index 05e162ed78f..cf1d6d42be3 100644 --- a/pygmt/xarray/accessor.py +++ b/pygmt/xarray/accessor.py @@ -25,6 +25,7 @@ @xr.register_dataarray_accessor("gmt") +@xr.register_dataset_accessor("gmt") class GMTDataArrayAccessor: """ GMT accessor for :class:`xarray.DataArray`. diff --git a/pygmt/xarray/backend.py b/pygmt/xarray/backend.py index 37500932b57..ceeee93ff10 100644 --- a/pygmt/xarray/backend.py +++ b/pygmt/xarray/backend.py @@ -150,4 +150,6 @@ def open_dataset( # type: ignore[override] raster.encoding["source"] = ( sorted(source)[0] if isinstance(source, list) else source ) - return raster.to_dataset() + raster_dataset: xr.Dataset = raster.to_dataset() + _ = raster_dataset.gmt # Load GMTDataArray accessor information + return raster_dataset