Skip to content

Commit 1b564cd

Browse files
authored
Merge pull request #71 from ecmwf/feat/gribjump
Feat/gribjump
2 parents 5e8c034 + 4d171bb commit 1b564cd

File tree

9 files changed

+481
-84
lines changed

9 files changed

+481
-84
lines changed

README.md

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -50,6 +50,18 @@ pip install -e .[dev]
5050
pre-commit install
5151
```
5252

53+
HAT provides **experimental** support for earthkit-data's [gribjump source](https://earthkit-data.readthedocs.io/en/latest/guide/sources.html#gribjump).
54+
To install the gribjump extras for testing and experimentation, run:
55+
```bash
56+
pip install hydro-analysis-toolkit[gribjump]
57+
```
58+
59+
> [!NOTE]
60+
> The gribjump feature is experimental. It is not recommended for production use and may change or break in future releases.
61+
> Information on how to build gribjump can be found in [GribJump's source code](https://github.com/ecmwf/gribjump/). Experimental
62+
> wheels of `gribjumplib` can also be found [on PyPI](https://pypi.org/project/gribjumplib/).
63+
64+
5365
## Licence
5466

5567
```

hat/compute_hydrostats/stat_calc.py

Lines changed: 3 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -1,17 +1,9 @@
1-
import earthkit.data as ekd
2-
from earthkit.hydro._readers import find_main_var
1+
from hat.core import load_da
32
import numpy as np
43
import xarray as xr
54
from hat.compute_hydrostats import stats
65

76

8-
def load_da(ds_config):
9-
ds = ekd.from_source(*ds_config["source"]).to_xarray()
10-
var_name = find_main_var(ds, 2)
11-
da = ds[var_name]
12-
return da
13-
14-
157
def find_valid_subset(sim_da, obs_da, sim_coords, obs_coords, new_coords):
168
sim_station_colname = sim_coords.get("s", "station")
179
obs_station_colname = obs_coords.get("s", "station")
@@ -35,9 +27,9 @@ def find_valid_subset(sim_da, obs_da, sim_coords, obs_coords, new_coords):
3527

3628
def stat_calc(config):
3729
sim_config = config["sim"]
38-
sim_da = load_da(config["sim"])
30+
sim_da, _ = load_da(sim_config, 2)
3931
obs_config = config["obs"]
40-
obs_da = load_da(obs_config)
32+
obs_da, _ = load_da(obs_config, 2)
4133
new_coords = config["output"]["coords"]
4234
sim_da, obs_da = find_valid_subset(sim_da, obs_da, sim_config["coords"], obs_config["coords"], new_coords)
4335
stat_dict = {}

hat/core.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,11 @@
1+
import earthkit.data as ekd
2+
from earthkit.hydro._readers import find_main_var
3+
4+
5+
def load_da(ds_config, n_dims):
6+
src_name = list(ds_config["source"].keys())[0]
7+
source = ekd.from_source(src_name, **ds_config["source"][src_name])
8+
ds = source.to_xarray(**ds_config.get("to_xarray_options", {}))
9+
var_name = find_main_var(ds, n_dims)
10+
da = ds[var_name]
11+
return da, var_name

hat/extract_timeseries/extractor.py

Lines changed: 147 additions & 65 deletions
Original file line numberDiff line numberDiff line change
@@ -2,127 +2,209 @@
22
import pandas as pd
33
import xarray as xr
44
import numpy as np
5-
import earthkit.data as ekd
6-
from earthkit.hydro._readers import find_main_var
5+
from typing import Any
6+
from hat.core import load_da
77

88
from hat import _LOGGER as logger
99

1010

1111
def process_grid_inputs(grid_config):
12-
src_name = list(grid_config["source"].keys())[0]
13-
logger.info(f"Processing grid inputs from source: {src_name}")
14-
logger.debug(f"Grid config: {grid_config['source'][src_name]}")
15-
ds = ekd.from_source(src_name, **grid_config["source"][src_name]).to_xarray(
16-
**grid_config.get("to_xarray_options", {})
17-
)
18-
var_name = find_main_var(ds, 3)
19-
da = ds[var_name]
12+
da, var_name = load_da(grid_config, 3)
2013
logger.info(f"Xarray created from source:\n{da}\n")
21-
gridx_colname = grid_config.get("coord_x", "lat")
22-
gridy_colname = grid_config.get("coord_y", "lon")
23-
da = da.sortby([gridx_colname, gridy_colname])
24-
shape = da[gridx_colname].shape[0], da[gridy_colname].shape[0]
25-
return da, var_name, gridx_colname, gridy_colname, shape
14+
coord_config = grid_config.get("coords", {})
15+
x_dim = coord_config.get("x", "lat")
16+
y_dim = coord_config.get("y", "lon")
17+
da = da.sortby([x_dim, y_dim])
18+
shape = da[x_dim].shape[0], da[y_dim].shape[0]
19+
return da, var_name, x_dim, y_dim, shape
2620

2721

28-
def construct_mask(indx, indy, shape):
22+
def construct_mask(x_indices, y_indices, shape):
2923
mask = np.zeros(shape, dtype=bool)
30-
mask[indx, indy] = True
24+
mask[x_indices, y_indices] = True
3125

32-
flat_indices = np.ravel_multi_index((indx, indy), shape)
33-
_, inverse = np.unique(flat_indices, return_inverse=True)
34-
return mask, inverse
26+
flat_indices = np.ravel_multi_index((x_indices, y_indices), shape)
27+
_, duplication_indexes = np.unique(flat_indices, return_inverse=True)
28+
return mask, duplication_indexes
3529

3630

37-
def create_mask_from_index(index_config, df, shape):
38-
logger.info(f"Creating mask {shape} from index: {index_config}")
31+
def create_mask_from_index(df, shape):
32+
logger.info(f"Creating mask {shape} from index")
3933
logger.debug(f"DataFrame columns: {df.columns.tolist()}")
40-
indx_colname = index_config.get("x", "opt_x_index")
41-
indy_colname = index_config.get("y", "opt_y_index")
42-
indx, indy = df[indx_colname].values, df[indy_colname].values
43-
mask, duplication_indexes = construct_mask(indx, indy, shape)
34+
x_indices = df["x_index"].values
35+
y_indices = df["y_index"].values
36+
if np.any(x_indices < 0) or np.any(x_indices >= shape[0]) or np.any(y_indices < 0) or np.any(y_indices >= shape[1]):
37+
raise ValueError(
38+
f"Station indices out of grid bounds. Grid shape={shape}, "
39+
f"x_index range=({int(x_indices.min())},{int(x_indices.max())}), "
40+
f"y_index range=({int(y_indices.min())},{int(y_indices.max())})"
41+
)
42+
mask, duplication_indexes = construct_mask(x_indices, y_indices, shape)
4443
return mask, duplication_indexes
4544

4645

47-
def create_mask_from_coords(coords_config, df, gridx, gridy, shape):
48-
logger.info(f"Creating mask {shape} from coordinates: {coords_config}")
46+
def create_mask_from_coords(df, gridx, gridy, shape):
47+
logger.info(f"Creating mask {shape} from coordinates")
4948
logger.debug(f"DataFrame columns: {df.columns.tolist()}")
50-
x_colname = coords_config.get("x", "opt_x_coord")
51-
y_colname = coords_config.get("y", "opt_y_coord")
52-
xs = df[x_colname].values
53-
ys = df[y_colname].values
49+
station_x = df["x_coord"].values
50+
station_y = df["y_coord"].values
5451

55-
diffx = np.abs(xs[:, np.newaxis] - gridx)
56-
indx = np.argmin(diffx, axis=1)
57-
diffy = np.abs(ys[:, np.newaxis] - gridy)
58-
indy = np.argmin(diffy, axis=1)
52+
x_distances = np.abs(station_x[:, np.newaxis] - gridx)
53+
x_indices = np.argmin(x_distances, axis=1)
54+
y_distances = np.abs(station_y[:, np.newaxis] - gridy)
55+
y_indices = np.argmin(y_distances, axis=1)
5956

60-
mask, duplication_indexes = construct_mask(indx, indy, shape)
57+
mask, duplication_indexes = construct_mask(x_indices, y_indices, shape)
6158
return mask, duplication_indexes
6259

6360

64-
def process_inputs(station_config, grid_config):
61+
def parse_stations(station_config: dict[str, Any]) -> pd.DataFrame:
62+
"""Read, filter, and normalize station DataFrame to canonical column names."""
6563
logger.debug(f"Reading station file, {station_config}")
64+
if "name" not in station_config:
65+
raise ValueError("Station config must include a 'name' key mapping to the station column")
6666
df = pd.read_csv(station_config["file"])
6767
filters = station_config.get("filter")
6868
if filters is not None:
6969
logger.debug(f"Applying filters: {filters} to station DataFrame")
7070
df = df.query(filters)
71-
station_names = df[station_config["name"]].values
7271

73-
index_config = station_config.get("index", None)
74-
coords_config = station_config.get("coords", None)
72+
if len(df) == 0:
73+
raise ValueError("No stations found. Check station file or filter.")
74+
75+
has_index = "index" in station_config
76+
has_coords = "coords" in station_config
77+
has_index_1d = "index_1d" in station_config
78+
79+
if not has_index_1d:
80+
if has_index and has_coords:
81+
raise ValueError("Station config must use either 'index' or 'coords', not both.")
82+
if not has_index and not has_coords:
83+
raise ValueError("Station config must provide either 'index' or 'coords' for station mapping.")
7584

76-
if index_config is not None and coords_config is not None:
77-
raise ValueError("Use either index or coords, not both.")
85+
renames = {}
86+
renames[station_config["name"]] = "station_name"
7887

79-
da, da_varname, gridx_colname, gridy_colname, shape = process_grid_inputs(grid_config)
88+
if has_index:
89+
index_config = station_config["index"]
90+
x_col = index_config.get("x", "opt_x_index")
91+
y_col = index_config.get("y", "opt_y_index")
92+
renames[x_col] = "x_index"
93+
renames[y_col] = "y_index"
8094

81-
if index_config is not None:
82-
mask, duplication_indexes = create_mask_from_index(index_config, df, shape)
83-
elif coords_config is not None:
84-
mask, duplication_indexes = create_mask_from_coords(
85-
coords_config, df, da[gridx_colname].values, da[gridy_colname].values, shape
95+
if has_coords:
96+
coords_config = station_config["coords"]
97+
x_col = coords_config.get("x", "opt_x_coord")
98+
y_col = coords_config.get("y", "opt_y_coord")
99+
renames[x_col] = "x_coord"
100+
renames[y_col] = "y_coord"
101+
102+
if has_index_1d:
103+
renames[station_config["index_1d"]] = "index_1d"
104+
105+
df_renamed = df.rename(columns=renames)
106+
107+
if has_index and ("x_index" not in df_renamed.columns or "y_index" not in df_renamed.columns):
108+
raise ValueError(
109+
"Station file missing required index columns. Expected columns to map to 'x_index' and 'y_index'."
110+
)
111+
if has_coords and ("x_coord" not in df_renamed.columns or "y_coord" not in df_renamed.columns):
112+
raise ValueError(
113+
"Station file missing required coordinate columns. Expected columns to map to 'x_coord' and 'y_coord'."
86114
)
115+
if has_index_1d and "index_1d" not in df_renamed.columns:
116+
raise ValueError("Station file missing required 'index_1d' column.")
117+
118+
return df_renamed
119+
120+
121+
def _process_gribjump(grid_config: dict[str, Any], df: pd.DataFrame) -> xr.Dataset:
122+
if "index_1d" not in df.columns:
123+
raise ValueError("Gribjump source requires 'index_1d' in station config.")
124+
125+
station_names = df["station_name"].values
126+
unique_indices, duplication_indexes = np.unique(df["index_1d"].values, return_inverse=True) # type: ignore[call-overload]
127+
128+
# Converting indices to ranges is currently faster than using indices
129+
# directly. This is a problem in the earthkit-data gribjump source and will
130+
# be fixed there.
131+
ranges = [(i, i + 1) for i in unique_indices]
132+
133+
gribjump_config = {
134+
"source": {
135+
"gribjump": {
136+
**grid_config["source"]["gribjump"],
137+
"ranges": ranges,
138+
# fetch_coords_from_fdb is currently very slow. Needs fix in
139+
# earthkit-data gribjump source.
140+
# "fetch_coords_from_fdb": True,
141+
}
142+
},
143+
"to_xarray_options": grid_config.get("to_xarray_options", {}),
144+
}
145+
146+
masked_da, var_name = load_da(gribjump_config, 2)
147+
148+
ds = xr.Dataset({var_name: masked_da})
149+
ds = ds.isel(index=duplication_indexes)
150+
ds = ds.rename({"index": "station"})
151+
ds["station"] = station_names
152+
return ds
153+
154+
155+
def _process_regular(grid_config: dict[str, Any], df: pd.DataFrame) -> xr.Dataset:
156+
station_names = df["station_name"].values
157+
da, var_name, x_dim, y_dim, shape = process_grid_inputs(grid_config)
158+
159+
use_index = "x_index" in df.columns and "y_index" in df.columns
160+
161+
if use_index:
162+
mask, duplication_indexes = create_mask_from_index(df, shape)
87163
else:
88-
# default to index approach
89-
mask, duplication_indexes = create_mask_from_index(index_config, df, shape)
164+
mask, duplication_indexes = create_mask_from_coords(df, da[x_dim].values, da[y_dim].values, shape)
165+
166+
logger.info("Extracting timeseries at selected stations")
167+
masked_da = apply_mask(da, mask, x_dim, y_dim)
168+
169+
ds = xr.Dataset({var_name: masked_da})
170+
ds = ds.isel(index=duplication_indexes)
171+
ds = ds.rename({"index": "station"})
172+
ds["station"] = station_names
173+
return ds
90174

91-
return da, da_varname, gridx_colname, gridy_colname, mask, station_names, duplication_indexes
92175

176+
def process_inputs(station_config: dict[str, Any], grid_config: dict[str, Any]) -> xr.Dataset:
177+
df = parse_stations(station_config)
178+
if "gribjump" in grid_config.get("source", {}):
179+
return _process_gribjump(grid_config, df)
180+
return _process_regular(grid_config, df)
93181

94-
def mask_array_np(arr, mask):
182+
183+
def mask_array_np(arr: np.ndarray, mask: np.ndarray) -> np.ndarray:
95184
return arr[..., mask]
96185

97186

98-
def apply_mask(da, mask, coordx, coordy):
187+
def apply_mask(da: xr.DataArray, mask: np.ndarray, coordx: str, coordy: str) -> xr.DataArray:
99188
task = xr.apply_ufunc(
100189
mask_array_np,
101190
da,
102191
mask,
103192
input_core_dims=[(coordx, coordy), (coordx, coordy)],
104-
output_core_dims=[["station"]],
193+
output_core_dims=[["index"]],
105194
output_dtypes=[da.dtype],
106195
exclude_dims={coordx, coordy},
107196
dask="parallelized",
108197
dask_gufunc_kwargs={
109-
"output_sizes": {"station": int(mask.sum())},
198+
"output_sizes": {"index": int(mask.sum())},
110199
"allow_rechunk": True,
111200
},
112201
)
113202
with ProgressBar(dt=15):
114203
return task.compute()
115204

116205

117-
def extractor(config):
118-
da, da_varname, gridx_colname, gridy_colname, mask, station_names, duplication_indexes = process_inputs(
119-
config["station"], config["grid"]
120-
)
121-
logger.info("Extracting timeseries at selected stations")
122-
masked_da = apply_mask(da, mask, gridx_colname, gridy_colname)
123-
ds = xr.Dataset({da_varname: masked_da})
124-
ds = ds.isel(station=duplication_indexes)
125-
ds["station"] = station_names
206+
def extractor(config: dict[str, Any]) -> xr.Dataset:
207+
ds = process_inputs(config["station"], config["grid"])
126208
if config.get("output", None) is not None:
127209
logger.info(f"Saving output to {config['output']['file']}")
128210
ds.to_netcdf(config["output"]["file"])

hat/station_mapping/mapper.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -47,14 +47,15 @@ def apply_blacklist(blacklist_config, metric_grid, grid_area_coords1, grid_area_
4747
return metric_grid, grid_area_coords1, grid_area_coords2
4848

4949

50-
def outputs_to_df(df, indx, indy, cindx, cindy, errors, grid_area_coords1, grid_area_coords2, filename):
50+
def outputs_to_df(df, indx, indy, cindx, cindy, errors, grid_area_coords1, grid_area_coords2, shape, filename):
5151
df["opt_x_index"] = indx
5252
df["opt_y_index"] = indy
5353
df["near_x_index"] = cindx
5454
df["near_y_index"] = cindy
5555
df["opt_error"] = errors
5656
df["opt_x_coord"] = grid_area_coords1[indx, 0]
5757
df["opt_y_coord"] = grid_area_coords2[0, indy]
58+
df["opt_1d_index"] = indy + shape[1] * indx
5859
if filename is not None:
5960
df.to_csv(filename, index=False)
6061
return df
@@ -109,6 +110,7 @@ def mapper(config):
109110
*mapping_outputs,
110111
grid_area_coords1,
111112
grid_area_coords2,
113+
shape=grid_area_coords1.shape,
112114
filename=config["output"]["file"] if config.get("output", None) is not None else None,
113115
)
114116
generate_summary_plots(df, config.get("plot", None))

notebooks/workflow/hydrostats_computation.ipynb

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -19,14 +19,14 @@
1919
"source": [
2020
"config = {\n",
2121
" \"sim\": {\n",
22-
" \"source\": [\"file\", \"extracted_timeseries.nc\"],\n",
22+
" \"source\": {\"file\": \"extracted_timeseries.nc\"},\n",
2323
" \"coords\": {\n",
2424
" \"s\": \"station\",\n",
2525
" \"t\": \"time\"\n",
2626
" }\n",
2727
" },\n",
2828
" \"obs\": {\n",
29-
" \"source\": [\"file\", \"observations.nc\"],\n",
29+
" \"source\": {\"file\": \"observations.nc\"},\n",
3030
" \"coords\": {\n",
3131
" \"s\": \"station\",\n",
3232
" \"t\": \"time\"\n",
@@ -49,7 +49,7 @@
4949
],
5050
"metadata": {
5151
"kernelspec": {
52-
"display_name": "Python 3",
52+
"display_name": "hat",
5353
"language": "python",
5454
"name": "python3"
5555
},

notebooks/workflow/timeseries_extraction.ipynb

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,7 @@
3535
" \"name\": \"station_id\"\n",
3636
" },\n",
3737
" \"grid\": {\n",
38-
" \"source\": [\"file\", \"./sim.nc\"],\n",
38+
" \"source\": {\"file\": \"./sim.nc\"},\n",
3939
" \"coords\": {\n",
4040
" \"x\": \"lat\",\n",
4141
" \"y\": \"lon\",\n",

0 commit comments

Comments
 (0)