Skip to content

Commit 47b0159

Browse files
authored
Merge pull request #23 from Climate-REF/refactor-intake-based
Introduce a base class for handling intake-esgf related queries
2 parents 2f7ee7a + 7bdfbd7 commit 47b0159

File tree

7 files changed

+52
-47
lines changed

7 files changed

+52
-47
lines changed

.github/workflows/ci.yaml

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,10 @@ on:
66
branches: [main]
77
tags: ['v*']
88

9+
concurrency:
10+
group: ${{ github.workflow }}-${{ github.ref }}
11+
cancel-in-progress: true
12+
913
jobs:
1014
pre-commit:
1115
if: ${{ !github.event.pull_request.draft }}

changelog/23.improvement.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
Refactored intake-esgf-based data requests to have a common base class (`ref_sample_data.data_request.base.IntakeESGFDataRequest`)

scripts/fetch_test_data.py

Lines changed: 2 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -6,43 +6,13 @@
66
import pooch
77
import typer
88
import xarray as xr
9-
from intake_esgf import ESGFCatalog
109

1110
from ref_sample_data import CMIP6Request, DataRequest, Obs4MIPsRequest
1211

1312
OUTPUT_PATH = Path("data")
1413
app = typer.Typer()
1514

1615

17-
def fetch_datasets(request: DataRequest, quiet: bool) -> pd.DataFrame:
18-
"""
19-
Fetch the datasets from ESGF.
20-
21-
Parameters
22-
----------
23-
request
24-
The request object
25-
quiet
26-
Whether to suppress progress messages from intake-esgf
27-
28-
Returns
29-
-------
30-
Dataframe that contains metadata and paths to the fetched datasets
31-
"""
32-
cat = ESGFCatalog()
33-
34-
cat.search(**request.facets)
35-
if request.remove_ensembles:
36-
cat.remove_ensembles()
37-
38-
path_dict = cat.to_path_dict(prefer_streaming=False, minimal_keys=False, quiet=quiet)
39-
merged_df = cat.df.merge(pd.Series(path_dict, name="files"), left_on="key", right_index=True)
40-
if request.time_span:
41-
merged_df["time_start"] = request.time_span[0]
42-
merged_df["time_end"] = request.time_span[1]
43-
return merged_df
44-
45-
4616
def deduplicate_datasets(datasets: pd.DataFrame) -> pd.DataFrame:
4717
"""
4818
Deduplicate a dataset collection.
@@ -90,15 +60,15 @@ def process_sample_data_request(
9060
quiet
9161
Whether to suppress progress messages
9262
"""
93-
datasets = fetch_datasets(request, quiet)
63+
datasets = request.fetch_datasets()
9464
datasets = deduplicate_datasets(datasets)
9565

9666
for _, dataset in datasets.iterrows():
9767
for ds_filename in dataset["files"]:
9868
ds_orig = xr.open_dataset(ds_filename)
9969

10070
if decimate:
101-
ds_decimated = request.decimate_dataset(ds_orig, request.time_span)
71+
ds_decimated = request.decimate_dataset(ds_orig)
10272
else:
10373
ds_decimated = ds_orig
10474
if ds_decimated is None:

src/ref_sample_data/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,4 +11,4 @@
1111
from .data_request.cmip6 import CMIP6Request
1212
from .data_request.obs4mips import Obs4MIPsRequest
1313

14-
__all__ = ["DataRequest", "CMIP6Request", "Obs4MIPsRequest"]
14+
__all__ = ["CMIP6Request", "DataRequest", "Obs4MIPsRequest"]

src/ref_sample_data/data_request/base.py

Lines changed: 34 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33

44
import pandas as pd
55
import xarray as xr
6+
from intake_esgf import ESGFCatalog
67

78

89
class DataRequest(Protocol):
@@ -14,11 +15,15 @@ class DataRequest(Protocol):
1415
differently to generate the sample data.
1516
"""
1617

17-
facets: dict[str, str | tuple[str, ...]]
18-
remove_ensembles: bool
19-
time_span: tuple[str, str]
18+
def fetch_datasets(self) -> pd.DataFrame:
19+
"""
20+
Fetch the datasets from the source
21+
22+
Returns a dataframe of the metadata and paths to the fetched datasets.
23+
"""
24+
...
2025

21-
def decimate_dataset(self, dataset: xr.Dataset, time_span: tuple[str, str] | None) -> xr.Dataset | None:
26+
def decimate_dataset(self, dataset: xr.Dataset) -> xr.Dataset | None:
2227
"""Downscale the dataset to a smaller size."""
2328
...
2429

@@ -27,3 +32,28 @@ def generate_filename(
2732
) -> pathlib.Path:
2833
"""Create the output filename for the dataset."""
2934
...
35+
36+
37+
class IntakeESGFDataRequest(DataRequest):
38+
"""
39+
A data request that fetches datasets from ESGF using intake-esgf.
40+
"""
41+
42+
facets: dict[str, str | tuple[str, ...]]
43+
remove_ensembles: bool
44+
time_span: tuple[str, str]
45+
46+
def fetch_datasets(self) -> pd.DataFrame:
47+
"""Fetch the datasets from the ESGF."""
48+
cat = ESGFCatalog()
49+
50+
cat.search(**self.facets)
51+
if self.remove_ensembles:
52+
cat.remove_ensembles()
53+
54+
path_dict = cat.to_path_dict(prefer_streaming=False, minimal_keys=False, quiet=True)
55+
merged_df = cat.df.merge(pd.Series(path_dict, name="files"), left_on="key", right_index=True)
56+
if self.time_span:
57+
merged_df["time_start"] = self.time_span[0]
58+
merged_df["time_end"] = self.time_span[1]
59+
return merged_df

src/ref_sample_data/data_request/cmip6.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66
import pandas as pd
77
import xarray as xr
88

9-
from ref_sample_data.data_request.base import DataRequest
9+
from ref_sample_data.data_request.base import IntakeESGFDataRequest
1010
from ref_sample_data.resample import decimate_curvilinear, decimate_rectilinear
1111

1212

@@ -37,7 +37,7 @@ def prefix_to_filename(ds, filename_prefix: str) -> str:
3737
return filename
3838

3939

40-
class CMIP6Request(DataRequest):
40+
class CMIP6Request(IntakeESGFDataRequest):
4141
"""
4242
Represents a CMIP6 dataset request
4343
@@ -86,7 +86,7 @@ def __init__(self, facets: dict[str, Any], remove_ensembles: bool, time_span: tu
8686
assert all(key in self.avail_facets for key in self.cmip6_path_items), "Error message"
8787
assert all(key in self.avail_facets for key in self.cmip6_filename_paths), "Error message"
8888

89-
def decimate_dataset(self, dataset: xr.Dataset, time_span: tuple[str, str] | None) -> xr.Dataset | None:
89+
def decimate_dataset(self, dataset: xr.Dataset) -> xr.Dataset | None:
9090
"""
9191
Downscale the dataset to a smaller size.
9292
@@ -115,8 +115,8 @@ def decimate_dataset(self, dataset: xr.Dataset, time_span: tuple[str, str] | Non
115115
else:
116116
raise ValueError("Cannot decimate this grid: too many dimensions")
117117

118-
if "time" in dataset.dims and time_span is not None:
119-
result = result.sel(time=slice(*time_span))
118+
if "time" in dataset.dims and self.time_span is not None:
119+
result = result.sel(time=slice(*self.time_span))
120120
if result.time.size == 0:
121121
result = None
122122

src/ref_sample_data/data_request/obs4mips.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -6,12 +6,12 @@
66
import pandas as pd
77
import xarray as xr
88

9-
from ref_sample_data.data_request.base import DataRequest
9+
from ref_sample_data.data_request.base import IntakeESGFDataRequest
1010
from ref_sample_data.data_request.cmip6 import prefix_to_filename
1111
from ref_sample_data.resample import decimate_curvilinear, decimate_rectilinear
1212

1313

14-
class Obs4MIPsRequest(DataRequest):
14+
class Obs4MIPsRequest(IntakeESGFDataRequest):
1515
"""
1616
Represents a Obs4MIPs dataset request
1717
"""
@@ -65,7 +65,7 @@ def __init__(self, facets: dict[str, Any], remove_ensembles: bool, time_span: tu
6565
assert all(key in self.avail_facets for key in self.obs4mips_path_items), "Error message"
6666
assert all(key in self.avail_facets for key in self.obs4mips_filename_paths), "Error message"
6767

68-
def decimate_dataset(self, dataset: xr.Dataset, time_span: tuple[str, str] | None) -> xr.Dataset | None:
68+
def decimate_dataset(self, dataset: xr.Dataset) -> xr.Dataset | None:
6969
"""
7070
Downscale the dataset to a smaller size.
7171
@@ -94,8 +94,8 @@ def decimate_dataset(self, dataset: xr.Dataset, time_span: tuple[str, str] | Non
9494
else:
9595
raise ValueError("Cannot decimate this grid: too many dimensions")
9696

97-
if "time" in dataset.dims and time_span is not None:
98-
result = result.sel(time=slice(*time_span))
97+
if "time" in dataset.dims and self.time_span is not None:
98+
result = result.sel(time=slice(*self.time_span))
9999
if result.time.size == 0:
100100
result = None
101101

0 commit comments

Comments
 (0)