Skip to content

Commit b87024f

Browse files
committed
Change regridding so it does not crash on parametric vertical coordinates
1 parent 8b8093d commit b87024f

5 files changed

Lines changed: 85 additions & 145 deletions

File tree

src/ref_sample_data/data_request/base.py

Lines changed: 51 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,9 @@
44
import pandas as pd
55
import xarray as xr
66
from intake_esgf import ESGFCatalog
7+
from loguru import logger
8+
9+
from ref_sample_data.resample import decimate_curvilinear, decimate_rectilinear
710

811

912
class DataRequest(Protocol):
@@ -69,9 +72,9 @@ def _deduplicate_group(group: pd.DataFrame) -> pd.DataFrame:
6972
return datasets.groupby("key").apply(_deduplicate_group, include_groups=False).reset_index()
7073

7174

72-
class IntakeESGFDataRequest(DataRequest):
75+
class IntakeESGFMixin:
7376
"""
74-
A data request that fetches datasets from ESGF using intake-esgf.
77+
A mixin that fetches datasets from ESGF using intake-esgf.
7578
"""
7679

7780
facets: dict[str, str | tuple[str, ...]]
@@ -91,3 +94,49 @@ def fetch_datasets(self) -> pd.DataFrame:
9194
merged_df["time_start"] = self.time_span[0]
9295
merged_df["time_end"] = self.time_span[1]
9396
return _deduplicate_datasets(merged_df)
97+
98+
99+
class DecimateMixin:
100+
"""
101+
Mixin for decimating datasets based on their grid type.
102+
"""
103+
104+
def decimate_dataset(self, dataset: xr.Dataset) -> xr.Dataset | None:
105+
"""
106+
Downscale the dataset to a smaller size.
107+
108+
Parameters
109+
----------
110+
dataset
111+
The dataset to downscale
112+
113+
Returns
114+
-------
115+
xr.Dataset
116+
The downscaled dataset
117+
"""
118+
if "time" in dataset.dims and self.time_span is not None:
119+
result = dataset.sel(time=slice(*self.time_span))
120+
if result.time.size == 0:
121+
# The dataset does not contain data in the requested time range.
122+
return None
123+
else:
124+
result = dataset.copy()
125+
126+
has_latlon = "lat" in result.dims and "lon" in result.dims
127+
has_ij = "i" in result.dims and "j" in result.dims
128+
129+
if has_latlon:
130+
assert len(result.lat.dims) == 1 and len(result.lon.dims) == 1
131+
132+
result = decimate_rectilinear(result)
133+
elif has_ij:
134+
# 2d curvilinear grid (generally ocean variables)
135+
result = decimate_curvilinear(result)
136+
else:
137+
logger.debug(
138+
"No algorithm implemented for this grid type, not spatially decimating dataset:\n{dataset}",
139+
dataset=dataset,
140+
)
141+
142+
return result

src/ref_sample_data/data_request/cmip6.py

Lines changed: 2 additions & 42 deletions
Original file line numberDiff line numberDiff line change
@@ -6,8 +6,7 @@
66
import pandas as pd
77
import xarray as xr
88

9-
from ref_sample_data.data_request.base import IntakeESGFDataRequest
10-
from ref_sample_data.resample import decimate_curvilinear, decimate_rectilinear
9+
from ref_sample_data.data_request.base import DecimateMixin, IntakeESGFMixin
1110

1211

1312
def prefix_to_filename(ds, filename_prefix: str) -> str:
@@ -37,7 +36,7 @@ def prefix_to_filename(ds, filename_prefix: str) -> str:
3736
return filename
3837

3938

40-
class CMIP6Request(IntakeESGFDataRequest):
39+
class CMIP6Request(IntakeESGFMixin, DecimateMixin):
4140
"""
4241
Represents a CMIP6 dataset request
4342
@@ -92,45 +91,6 @@ def __init__(
9291
assert all(key in self.avail_facets for key in self.cmip6_path_items), "Error message"
9392
assert all(key in self.avail_facets for key in self.cmip6_filename_paths), "Error message"
9493

95-
def decimate_dataset(self, dataset: xr.Dataset) -> xr.Dataset | None:
96-
"""
97-
Downscale the dataset to a smaller size.
98-
99-
Parameters
100-
----------
101-
dataset
102-
The dataset to downscale
103-
104-
Returns
105-
-------
106-
xr.Dataset
107-
The downscaled dataset
108-
"""
109-
has_latlon = "lat" in dataset.dims and "lon" in dataset.dims
110-
has_ij = "i" in dataset.dims and "j" in dataset.dims
111-
112-
# The AMOC variable `msftmz` has these strange dims and we do not want to decimate
113-
skip_decimate = {"time", "basin", "lev", "lat"}.issubset(dataset.dims)
114-
115-
if has_latlon:
116-
assert len(dataset.lat.dims) == 1 and len(dataset.lon.dims) == 1
117-
118-
result = decimate_rectilinear(dataset)
119-
elif has_ij:
120-
# 2d curvilinear grid (generally ocean variables)
121-
result = decimate_curvilinear(dataset)
122-
elif skip_decimate:
123-
result = dataset
124-
else:
125-
raise ValueError("Cannot decimate this grid: too many dimensions")
126-
127-
if "time" in dataset.dims and self.time_span is not None:
128-
result = result.sel(time=slice(*self.time_span))
129-
if result.time.size == 0:
130-
result = None
131-
132-
return result
133-
13494
def generate_filename(self, metadata: pd.Series, ds: xr.Dataset, ds_filename: pathlib.Path) -> Path:
13595
"""
13696
Create the output filename for the dataset.

src/ref_sample_data/data_request/obs4mips.py

Lines changed: 2 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -6,12 +6,11 @@
66
import pandas as pd
77
import xarray as xr
88

9-
from ref_sample_data.data_request.base import IntakeESGFDataRequest
9+
from ref_sample_data.data_request.base import DecimateMixin, IntakeESGFMixin
1010
from ref_sample_data.data_request.cmip6 import prefix_to_filename
11-
from ref_sample_data.resample import decimate_curvilinear, decimate_rectilinear
1211

1312

14-
class Obs4MIPsRequest(IntakeESGFDataRequest):
13+
class Obs4MIPsRequest(IntakeESGFMixin, DecimateMixin):
1514
"""
1615
Represents a Obs4MIPs dataset request
1716
"""
@@ -70,40 +69,6 @@ def __init__(
7069
assert all(key in self.avail_facets for key in self.obs4mips_path_items), "Error message"
7170
assert all(key in self.avail_facets for key in self.obs4mips_filename_paths), "Error message"
7271

73-
def decimate_dataset(self, dataset: xr.Dataset) -> xr.Dataset | None:
74-
"""
75-
Downscale the dataset to a smaller size.
76-
77-
Parameters
78-
----------
79-
dataset
80-
The dataset to downscale
81-
82-
Returns
83-
-------
84-
xr.Dataset
85-
The downscaled dataset
86-
"""
87-
has_latlon = "lat" in dataset.dims and "lon" in dataset.dims
88-
has_ij = "i" in dataset.dims and "j" in dataset.dims
89-
90-
if has_latlon:
91-
assert len(dataset.lat.dims) == 1 and len(dataset.lon.dims) == 1
92-
93-
result = decimate_rectilinear(dataset)
94-
elif has_ij:
95-
# 2d curvilinear grid (generally ocean variables)
96-
result = decimate_curvilinear(dataset)
97-
else:
98-
raise ValueError("Cannot decimate this grid: too many dimensions")
99-
100-
if "time" in dataset.dims and self.time_span is not None:
101-
result = result.sel(time=slice(*self.time_span))
102-
if result.time.size == 0:
103-
result = None
104-
105-
return result
106-
10772
def generate_filename(self, metadata: pd.Series, ds: xr.Dataset, ds_filename: pathlib.Path) -> Path:
10873
"""
10974
Create the output filename for the dataset.

src/ref_sample_data/data_request/obs4ref.py

Lines changed: 2 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -6,11 +6,10 @@
66
import xarray as xr
77
from climate_ref_core.dataset_registry import dataset_registry_manager
88

9-
from ref_sample_data.data_request.base import DataRequest
10-
from ref_sample_data.resample import decimate_curvilinear, decimate_rectilinear
9+
from ref_sample_data.data_request.base import DecimateMixin
1110

1211

13-
class Obs4REFRequest(DataRequest):
12+
class Obs4REFRequest(DecimateMixin):
1413
"""
1514
Fetch the unpublished Obs4MIPs datasets from the PMP registry
1615
@@ -42,40 +41,6 @@ def fetch_datasets(self) -> pd.DataFrame:
4241
)
4342
return pd.DataFrame(datasets)
4443

45-
def decimate_dataset(self, dataset: xr.Dataset) -> xr.Dataset | None:
46-
"""
47-
Downscale the dataset to a smaller size.
48-
49-
Parameters
50-
----------
51-
dataset
52-
The dataset to downscale
53-
54-
Returns
55-
-------
56-
xr.Dataset
57-
The downscaled dataset
58-
"""
59-
has_latlon = "lat" in dataset.dims and "lon" in dataset.dims
60-
has_ij = "i" in dataset.dims and "j" in dataset.dims
61-
62-
# If less than 10 MB skip decimating
63-
small_file_threshold = 10 * 1024**2
64-
if dataset.nbytes < small_file_threshold:
65-
return dataset
66-
67-
if has_latlon:
68-
assert len(dataset.lat.dims) == 1 and len(dataset.lon.dims) == 1
69-
70-
result = decimate_rectilinear(dataset)
71-
elif has_ij:
72-
# 2d curvilinear grid (generally ocean variables)
73-
result = decimate_curvilinear(dataset)
74-
else:
75-
raise ValueError("Cannot decimate this grid: too many dimensions")
76-
77-
return result
78-
7944
def generate_filename(self, metadata: pd.Series, ds: xr.Dataset, ds_filename: pathlib.Path) -> Path:
8045
"""
8146
Create the output filename for the dataset.

src/ref_sample_data/resample.py

Lines changed: 28 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -1,22 +1,23 @@
11
import numpy as np
22
import xarray as xr
33
import xcdat
4+
import xesmf
45

56

67
def _calculate_2d_cell_bounds(
7-
dimension: xr.DataArray,
8+
points: np.ndarray,
89
i: int,
910
j: int,
10-
) -> [float, float, float, float]:
11-
cell_center = dimension[j, i].data
11+
) -> list[float]:
12+
cell_center = points[j, i]
1213
if i == 0:
13-
di = dimension[j, i + 1].data - cell_center
14+
di = points[j, i + 1] - cell_center
1415
else:
15-
di = cell_center - dimension[j, i - 1].data
16+
di = cell_center - points[j, i - 1]
1617
if j == 0:
17-
dj = dimension[j + 1, i].data - cell_center
18+
dj = points[j + 1, i] - cell_center
1819
else:
19-
dj = cell_center - dimension[j - 1, i].data
20+
dj = cell_center - points[j - 1, i]
2021

2122
return np.asarray(
2223
[
@@ -43,22 +44,20 @@ def decimate_rectilinear(dataset: xr.Dataset) -> xr.Dataset:
4344
"""
4445
# Decimate the dataset, but update the bounds
4546
# 10x10 degree grid
46-
regridded_vars = []
47-
48-
for data_var in dataset.data_vars:
49-
# Some datasets don't correctly use data_vars
50-
if "_bnds" in data_var:
51-
continue
52-
output_grid = xcdat.create_uniform_grid(-90, 90, 10, 0, 359, 10)
53-
regridded_vars.append(
54-
dataset.regridder.horizontal(
55-
data_var,
56-
output_grid=output_grid,
57-
tool="xesmf",
58-
method="bilinear",
59-
)
60-
)
61-
return xr.merge(regridded_vars)
47+
output_grid = xcdat.create_uniform_grid(-90, 90, 10, 0, 359, 10)
48+
regrid = xesmf.Regridder(dataset, output_grid, "bilinear", periodic=True)
49+
result = regrid(dataset.copy())
50+
result = result.bounds.add_bounds("Y").bounds.add_bounds("X")
51+
# Restore attributes and add dataarrays that have not been regridded.
52+
for k, v in dataset.data_vars.items():
53+
if k in result:
54+
result[k].attrs = v.attrs
55+
else:
56+
result[k] = v
57+
for k, v in dataset.coords.items():
58+
result[k].attrs = v.attrs
59+
result.attrs = dataset.attrs
60+
return result
6261

6362

6463
def decimate_curvilinear(dataset: xr.Dataset, factor: int = 10) -> xr.Dataset:
@@ -82,13 +81,15 @@ def decimate_curvilinear(dataset: xr.Dataset, factor: int = 10) -> xr.Dataset:
8281
"""
8382
assert factor >= 1
8483
result = dataset.interp(i=dataset.i[::factor]).interp(j=dataset.j[::factor])
85-
result.coords["i"].values[:] = range(len(result.i))
86-
result.coords["j"].values[:] = range(len(result.j))
84+
result.coords["i"].values[:] = np.arange(len(result.i))
85+
result.coords["j"].values[:] = np.arange(len(result.j))
8786

8887
# Update the bounds of the cells
88+
latitude_points = result.latitude.values
89+
longitude_points = result.longitude.values
8990
for j in result.j:
9091
for i in result.i:
91-
result.vertices_latitude[j, i] = _calculate_2d_cell_bounds(result.latitude, i, j)
92-
result.vertices_longitude[j, i] = _calculate_2d_cell_bounds(result.longitude, i, j)
92+
result.vertices_latitude[j, i] = _calculate_2d_cell_bounds(latitude_points, i, j)
93+
result.vertices_longitude[j, i] = _calculate_2d_cell_bounds(longitude_points, i, j)
9394

9495
return result

0 commit comments

Comments
 (0)