Skip to content

Commit 9ea1802

Browse files
committed
chore: Parallise the decimation process while skipping files smaller than 10MB
1 parent 04759d1 commit 9ea1802

File tree

2 files changed

+77
-57
lines changed

2 files changed

+77
-57
lines changed

scripts/fetch_test_data.py

Lines changed: 72 additions & 57 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
import pathlib
22
from pathlib import Path
33

4+
import joblib
45
import pandas as pd
56
import pooch
67
import typer
@@ -39,6 +40,70 @@ def _get_match(dataset: pd.DataFrame, source_type: str, key: str) -> pd.Series |
3940
return matches.iloc[0]
4041

4142

43+
def _process_dataset(
44+
processed_datasets: pd.DataFrame,
45+
dataset: pd.Series,
46+
request: DataRequest,
47+
decimate: bool,
48+
output_directory: Path,
49+
) -> list[dict[str, str]]:
50+
match = _get_match(processed_datasets, request.source_type, dataset.key)
51+
52+
# Check if the dataset has already been processed and can be skipped
53+
if match is not None and request.time_span is not None:
54+
# Dataset has already been processed and a time span was specified
55+
# Check if the dataset already covers the requested time span
56+
if int(match.time_start) <= int(dataset["time_start"]) and int(match.time_end) >= int(
57+
dataset["time_end"]
58+
):
59+
# Already have a dataset that covers the requested time span
60+
logger.info(f"Skipping regenerating {dataset.key} as it already covers the requested time span")
61+
return []
62+
63+
# Update the request to match the superset of the time spans
64+
time_start = dataset["time_start"] if dataset["time_start"] < match.time_start else match.time_start
65+
time_end = dataset["time_end"] if dataset["time_end"] > match.time_end else match.time_end
66+
request.time_span = (str(time_start), str(time_end))
67+
68+
logger.info(f"Regenerating dataset with new time span: {dataset.key} {request.time_span}")
69+
for file in match.files:
70+
file_path = pathlib.Path(file)
71+
if file_path.exists():
72+
logger.info(f"Removing existing file: {file}")
73+
file_path.unlink()
74+
75+
output_filenames = []
76+
for ds_filename in dataset["files"]:
77+
try:
78+
ds_orig = xr.open_dataset(ds_filename)
79+
80+
if decimate:
81+
ds_decimated = request.decimate_dataset(ds_orig)
82+
else:
83+
ds_decimated = ds_orig
84+
if ds_decimated is None:
85+
continue
86+
87+
output_filename = output_directory / request.generate_filename(dataset, ds_decimated, ds_filename)
88+
output_filename.parent.mkdir(parents=True, exist_ok=True)
89+
ds_decimated.to_netcdf(output_filename)
90+
output_filenames.append(output_filename)
91+
except:
92+
logger.exception(f"Failed to process dataset {ds_filename}")
93+
raise
94+
95+
item = {
96+
"source_type": request.source_type,
97+
"key": dataset.key,
98+
"files": output_filenames,
99+
}
100+
if request.time_span is not None:
101+
item["time_start"] = request.time_span[0]
102+
item["time_end"] = request.time_span[1]
103+
104+
return [item]
105+
106+
42107
def process_sample_data_request(
43108
processed_datasets: pd.DataFrame,
44109
request: DataRequest,
@@ -67,64 +132,14 @@ def process_sample_data_request(
67132
"""
68133
logger.info(f"Resolving request: {request.id}")
69134
datasets = request.fetch_datasets()
70-
items = []
71-
72-
for _, dataset in datasets.iterrows():
73-
match = _get_match(processed_datasets, request.source_type, dataset.key)
74-
75-
# Check if the dataset has already been processed and can be skipped
76-
if match is not None and request.time_span is not None:
77-
# Dataset has already been processed and a time span was specified
78-
# Check if the dataset already covers the requested time span
79-
if int(match.time_start) <= int(dataset["time_start"]) and int(match.time_end) >= int(
80-
dataset["time_end"]
81-
):
82-
# Already have a dataset that covers the requested time span
83-
logger.info(
84-
f"Skipping regenerating {dataset.key} as it already covers the requested time span"
85-
)
86-
continue
87-
88-
# Update the request to match the superset of the time spans
89-
time_start = (
90-
dataset["time_start"] if dataset["time_start"] < match.time_start else match.time_start
91-
)
92-
time_end = dataset["time_end"] if dataset["time_end"] > match.time_end else match.time_end
93-
request.time_span = (str(time_start), str(time_end))
94-
95-
logger.info(f"Regenerating dataset with new time span: {dataset.key} {request.time_span}")
96-
for file in match.files:
97-
file_path = pathlib.Path(file)
98-
if file_path.exists():
99-
logger.info(f"Removing existing file: {file}")
100-
file_path.unlink()
101-
102-
output_filenames = []
103-
for ds_filename in dataset["files"]:
104-
ds_orig = xr.open_dataset(ds_filename)
105-
106-
if decimate:
107-
ds_decimated = request.decimate_dataset(ds_orig)
108-
else:
109-
ds_decimated = ds_orig
110-
if ds_decimated is None:
111-
continue
112-
113-
output_filename = output_directory / request.generate_filename(dataset, ds_decimated, ds_filename)
114-
output_filename.parent.mkdir(parents=True, exist_ok=True)
115-
ds_decimated.to_netcdf(output_filename)
116-
output_filenames.append(output_filename)
117-
118-
item = {
119-
"source_type": request.source_type,
120-
"key": dataset.key,
121-
"files": output_filenames,
122-
}
123-
if request.time_span is not None:
124-
item["time_start"] = request.time_span[0]
125-
item["time_end"] = request.time_span[1]
126135

127-
items.append(item)
136+
# Process all the datasets in parallel
137+
items = joblib.Parallel(n_jobs=-1)(
138+
joblib.delayed(_process_dataset)(processed_datasets, dataset, request, decimate, output_directory)
139+
for _, dataset in datasets.iterrows()
140+
)
141+
# Flatten the list of lists
142+
items = [item for sublist in items for item in sublist]
128143

129144
# Regenerate the registry.txt file
130145
pooch.make_registry(str(OUTPUT_PATH), "registry.txt")

src/ref_sample_data/data_request/obs4ref.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -59,6 +59,11 @@ def decimate_dataset(self, dataset: xr.Dataset) -> xr.Dataset | None:
5959
has_latlon = "lat" in dataset.dims and "lon" in dataset.dims
6060
has_ij = "i" in dataset.dims and "j" in dataset.dims
6161

62+
# If less than 10 MB skip decimating
63+
small_file_threshold = 10 * 1024**2
64+
if dataset.nbytes < small_file_threshold:
65+
return dataset
66+
6267
if has_latlon:
6368
assert len(dataset.lat.dims) == 1 and len(dataset.lon.dims) == 1
6469

0 commit comments

Comments
 (0)