|
2 | 2 | import pandas as pd |
3 | 3 | import xarray as xr |
4 | 4 | import numpy as np |
5 | | -import earthkit.data as ekd |
6 | | -from earthkit.hydro._readers import find_main_var |
| 5 | +from typing import Any |
| 6 | +from hat.core import load_da |
7 | 7 |
|
8 | 8 | from hat import _LOGGER as logger |
9 | 9 |
|
10 | 10 |
|
11 | 11 | def process_grid_inputs(grid_config): |
12 | | - src_name = list(grid_config["source"].keys())[0] |
13 | | - logger.info(f"Processing grid inputs from source: {src_name}") |
14 | | - logger.debug(f"Grid config: {grid_config['source'][src_name]}") |
15 | | - ds = ekd.from_source(src_name, **grid_config["source"][src_name]).to_xarray( |
16 | | - **grid_config.get("to_xarray_options", {}) |
17 | | - ) |
18 | | - var_name = find_main_var(ds, 3) |
19 | | - da = ds[var_name] |
| 12 | + da, var_name = load_da(grid_config, 3) |
20 | 13 | logger.info(f"Xarray created from source:\n{da}\n") |
21 | | - gridx_colname = grid_config.get("coord_x", "lat") |
22 | | - gridy_colname = grid_config.get("coord_y", "lon") |
23 | | - da = da.sortby([gridx_colname, gridy_colname]) |
24 | | - shape = da[gridx_colname].shape[0], da[gridy_colname].shape[0] |
25 | | - return da, var_name, gridx_colname, gridy_colname, shape |
| 14 | + coord_config = grid_config.get("coords", {}) |
| 15 | + x_dim = coord_config.get("x", "lat") |
| 16 | + y_dim = coord_config.get("y", "lon") |
| 17 | + da = da.sortby([x_dim, y_dim]) |
| 18 | + shape = da[x_dim].shape[0], da[y_dim].shape[0] |
| 19 | + return da, var_name, x_dim, y_dim, shape |
26 | 20 |
|
27 | 21 |
|
28 | | -def construct_mask(indx, indy, shape): |
| 22 | +def construct_mask(x_indices, y_indices, shape): |
29 | 23 | mask = np.zeros(shape, dtype=bool) |
30 | | - mask[indx, indy] = True |
| 24 | + mask[x_indices, y_indices] = True |
31 | 25 |
|
32 | | - flat_indices = np.ravel_multi_index((indx, indy), shape) |
33 | | - _, inverse = np.unique(flat_indices, return_inverse=True) |
34 | | - return mask, inverse |
| 26 | + flat_indices = np.ravel_multi_index((x_indices, y_indices), shape) |
| 27 | + _, duplication_indexes = np.unique(flat_indices, return_inverse=True) |
| 28 | + return mask, duplication_indexes |
35 | 29 |
|
36 | 30 |
|
37 | | -def create_mask_from_index(index_config, df, shape): |
38 | | - logger.info(f"Creating mask {shape} from index: {index_config}") |
| 31 | +def create_mask_from_index(df, shape): |
| 32 | + logger.info(f"Creating mask {shape} from index") |
39 | 33 | logger.debug(f"DataFrame columns: {df.columns.tolist()}") |
40 | | - indx_colname = index_config.get("x", "opt_x_index") |
41 | | - indy_colname = index_config.get("y", "opt_y_index") |
42 | | - indx, indy = df[indx_colname].values, df[indy_colname].values |
43 | | - mask, duplication_indexes = construct_mask(indx, indy, shape) |
| 34 | + x_indices = df["x_index"].values |
| 35 | + y_indices = df["y_index"].values |
| 36 | + if np.any(x_indices < 0) or np.any(x_indices >= shape[0]) or np.any(y_indices < 0) or np.any(y_indices >= shape[1]): |
| 37 | + raise ValueError( |
| 38 | + f"Station indices out of grid bounds. Grid shape={shape}, " |
| 39 | + f"x_index range=({int(x_indices.min())},{int(x_indices.max())}), " |
| 40 | + f"y_index range=({int(y_indices.min())},{int(y_indices.max())})" |
| 41 | + ) |
| 42 | + mask, duplication_indexes = construct_mask(x_indices, y_indices, shape) |
44 | 43 | return mask, duplication_indexes |
45 | 44 |
|
46 | 45 |
|
47 | | -def create_mask_from_coords(coords_config, df, gridx, gridy, shape): |
48 | | - logger.info(f"Creating mask {shape} from coordinates: {coords_config}") |
| 46 | +def create_mask_from_coords(df, gridx, gridy, shape): |
| 47 | + logger.info(f"Creating mask {shape} from coordinates") |
49 | 48 | logger.debug(f"DataFrame columns: {df.columns.tolist()}") |
50 | | - x_colname = coords_config.get("x", "opt_x_coord") |
51 | | - y_colname = coords_config.get("y", "opt_y_coord") |
52 | | - xs = df[x_colname].values |
53 | | - ys = df[y_colname].values |
| 49 | + station_x = df["x_coord"].values |
| 50 | + station_y = df["y_coord"].values |
54 | 51 |
|
55 | | - diffx = np.abs(xs[:, np.newaxis] - gridx) |
56 | | - indx = np.argmin(diffx, axis=1) |
57 | | - diffy = np.abs(ys[:, np.newaxis] - gridy) |
58 | | - indy = np.argmin(diffy, axis=1) |
| 52 | + x_distances = np.abs(station_x[:, np.newaxis] - gridx) |
| 53 | + x_indices = np.argmin(x_distances, axis=1) |
| 54 | + y_distances = np.abs(station_y[:, np.newaxis] - gridy) |
| 55 | + y_indices = np.argmin(y_distances, axis=1) |
59 | 56 |
|
60 | | - mask, duplication_indexes = construct_mask(indx, indy, shape) |
| 57 | + mask, duplication_indexes = construct_mask(x_indices, y_indices, shape) |
61 | 58 | return mask, duplication_indexes |
62 | 59 |
|
63 | 60 |
|
64 | | -def process_inputs(station_config, grid_config): |
| 61 | +def parse_stations(station_config: dict[str, Any]) -> pd.DataFrame: |
| 62 | + """Read, filter, and normalize station DataFrame to canonical column names.""" |
65 | 63 | logger.debug(f"Reading station file, {station_config}") |
| 64 | + if "name" not in station_config: |
| 65 | + raise ValueError("Station config must include a 'name' key mapping to the station column") |
66 | 66 | df = pd.read_csv(station_config["file"]) |
67 | 67 | filters = station_config.get("filter") |
68 | 68 | if filters is not None: |
69 | 69 | logger.debug(f"Applying filters: {filters} to station DataFrame") |
70 | 70 | df = df.query(filters) |
71 | | - station_names = df[station_config["name"]].values |
72 | 71 |
|
73 | | - index_config = station_config.get("index", None) |
74 | | - coords_config = station_config.get("coords", None) |
| 72 | + if len(df) == 0: |
| 73 | + raise ValueError("No stations found. Check station file or filter.") |
| 74 | + |
| 75 | + has_index = "index" in station_config |
| 76 | + has_coords = "coords" in station_config |
| 77 | + has_index_1d = "index_1d" in station_config |
| 78 | + |
| 79 | + if not has_index_1d: |
| 80 | + if has_index and has_coords: |
| 81 | + raise ValueError("Station config must use either 'index' or 'coords', not both.") |
| 82 | + if not has_index and not has_coords: |
| 83 | + raise ValueError("Station config must provide either 'index' or 'coords' for station mapping.") |
75 | 84 |
|
76 | | - if index_config is not None and coords_config is not None: |
77 | | - raise ValueError("Use either index or coords, not both.") |
| 85 | + renames = {} |
| 86 | + renames[station_config["name"]] = "station_name" |
78 | 87 |
|
79 | | - da, da_varname, gridx_colname, gridy_colname, shape = process_grid_inputs(grid_config) |
| 88 | + if has_index: |
| 89 | + index_config = station_config["index"] |
| 90 | + x_col = index_config.get("x", "opt_x_index") |
| 91 | + y_col = index_config.get("y", "opt_y_index") |
| 92 | + renames[x_col] = "x_index" |
| 93 | + renames[y_col] = "y_index" |
80 | 94 |
|
81 | | - if index_config is not None: |
82 | | - mask, duplication_indexes = create_mask_from_index(index_config, df, shape) |
83 | | - elif coords_config is not None: |
84 | | - mask, duplication_indexes = create_mask_from_coords( |
85 | | - coords_config, df, da[gridx_colname].values, da[gridy_colname].values, shape |
| 95 | + if has_coords: |
| 96 | + coords_config = station_config["coords"] |
| 97 | + x_col = coords_config.get("x", "opt_x_coord") |
| 98 | + y_col = coords_config.get("y", "opt_y_coord") |
| 99 | + renames[x_col] = "x_coord" |
| 100 | + renames[y_col] = "y_coord" |
| 101 | + |
| 102 | + if has_index_1d: |
| 103 | + renames[station_config["index_1d"]] = "index_1d" |
| 104 | + |
| 105 | + df_renamed = df.rename(columns=renames) |
| 106 | + |
| 107 | + if has_index and ("x_index" not in df_renamed.columns or "y_index" not in df_renamed.columns): |
| 108 | + raise ValueError( |
| 109 | + "Station file missing required index columns. Expected columns to map to 'x_index' and 'y_index'." |
| 110 | + ) |
| 111 | + if has_coords and ("x_coord" not in df_renamed.columns or "y_coord" not in df_renamed.columns): |
| 112 | + raise ValueError( |
| 113 | + "Station file missing required coordinate columns. Expected columns to map to 'x_coord' and 'y_coord'." |
86 | 114 | ) |
| 115 | + if has_index_1d and "index_1d" not in df_renamed.columns: |
| 116 | + raise ValueError("Station file missing required 'index_1d' column.") |
| 117 | + |
| 118 | + return df_renamed |
| 119 | + |
| 120 | + |
| 121 | +def _process_gribjump(grid_config: dict[str, Any], df: pd.DataFrame) -> xr.Dataset: |
| 122 | + if "index_1d" not in df.columns: |
| 123 | + raise ValueError("Gribjump source requires 'index_1d' in station config.") |
| 124 | + |
| 125 | + station_names = df["station_name"].values |
| 126 | + unique_indices, duplication_indexes = np.unique(df["index_1d"].values, return_inverse=True) # type: ignore[call-overload] |
| 127 | + |
| 128 | + # Converting indices to ranges is currently faster than using indices |
| 129 | + # directly. This is a problem in the earthkit-data gribjump source and will |
| 130 | + # be fixed there. |
| 131 | + ranges = [(i, i + 1) for i in unique_indices] |
| 132 | + |
| 133 | + gribjump_config = { |
| 134 | + "source": { |
| 135 | + "gribjump": { |
| 136 | + **grid_config["source"]["gribjump"], |
| 137 | + "ranges": ranges, |
| 138 | + # fetch_coords_from_fdb is currently very slow. Needs fix in |
| 139 | + # earthkit-data gribjump source. |
| 140 | + # "fetch_coords_from_fdb": True, |
| 141 | + } |
| 142 | + }, |
| 143 | + "to_xarray_options": grid_config.get("to_xarray_options", {}), |
| 144 | + } |
| 145 | + |
| 146 | + masked_da, var_name = load_da(gribjump_config, 2) |
| 147 | + |
| 148 | + ds = xr.Dataset({var_name: masked_da}) |
| 149 | + ds = ds.isel(index=duplication_indexes) |
| 150 | + ds = ds.rename({"index": "station"}) |
| 151 | + ds["station"] = station_names |
| 152 | + return ds |
| 153 | + |
| 154 | + |
| 155 | +def _process_regular(grid_config: dict[str, Any], df: pd.DataFrame) -> xr.Dataset: |
| 156 | + station_names = df["station_name"].values |
| 157 | + da, var_name, x_dim, y_dim, shape = process_grid_inputs(grid_config) |
| 158 | + |
| 159 | + use_index = "x_index" in df.columns and "y_index" in df.columns |
| 160 | + |
| 161 | + if use_index: |
| 162 | + mask, duplication_indexes = create_mask_from_index(df, shape) |
87 | 163 | else: |
88 | | - # default to index approach |
89 | | - mask, duplication_indexes = create_mask_from_index(index_config, df, shape) |
| 164 | + mask, duplication_indexes = create_mask_from_coords(df, da[x_dim].values, da[y_dim].values, shape) |
| 165 | + |
| 166 | + logger.info("Extracting timeseries at selected stations") |
| 167 | + masked_da = apply_mask(da, mask, x_dim, y_dim) |
| 168 | + |
| 169 | + ds = xr.Dataset({var_name: masked_da}) |
| 170 | + ds = ds.isel(index=duplication_indexes) |
| 171 | + ds = ds.rename({"index": "station"}) |
| 172 | + ds["station"] = station_names |
| 173 | + return ds |
90 | 174 |
|
91 | | - return da, da_varname, gridx_colname, gridy_colname, mask, station_names, duplication_indexes |
92 | 175 |
|
| 176 | +def process_inputs(station_config: dict[str, Any], grid_config: dict[str, Any]) -> xr.Dataset: |
| 177 | + df = parse_stations(station_config) |
| 178 | + if "gribjump" in grid_config.get("source", {}): |
| 179 | + return _process_gribjump(grid_config, df) |
| 180 | + return _process_regular(grid_config, df) |
93 | 181 |
|
94 | | -def mask_array_np(arr, mask): |
| 182 | + |
| 183 | +def mask_array_np(arr: np.ndarray, mask: np.ndarray) -> np.ndarray: |
95 | 184 | return arr[..., mask] |
96 | 185 |
|
97 | 186 |
|
98 | | -def apply_mask(da, mask, coordx, coordy): |
| 187 | +def apply_mask(da: xr.DataArray, mask: np.ndarray, coordx: str, coordy: str) -> xr.DataArray: |
99 | 188 | task = xr.apply_ufunc( |
100 | 189 | mask_array_np, |
101 | 190 | da, |
102 | 191 | mask, |
103 | 192 | input_core_dims=[(coordx, coordy), (coordx, coordy)], |
104 | | - output_core_dims=[["station"]], |
| 193 | + output_core_dims=[["index"]], |
105 | 194 | output_dtypes=[da.dtype], |
106 | 195 | exclude_dims={coordx, coordy}, |
107 | 196 | dask="parallelized", |
108 | 197 | dask_gufunc_kwargs={ |
109 | | - "output_sizes": {"station": int(mask.sum())}, |
| 198 | + "output_sizes": {"index": int(mask.sum())}, |
110 | 199 | "allow_rechunk": True, |
111 | 200 | }, |
112 | 201 | ) |
113 | 202 | with ProgressBar(dt=15): |
114 | 203 | return task.compute() |
115 | 204 |
|
116 | 205 |
|
117 | | -def extractor(config): |
118 | | - da, da_varname, gridx_colname, gridy_colname, mask, station_names, duplication_indexes = process_inputs( |
119 | | - config["station"], config["grid"] |
120 | | - ) |
121 | | - logger.info("Extracting timeseries at selected stations") |
122 | | - masked_da = apply_mask(da, mask, gridx_colname, gridy_colname) |
123 | | - ds = xr.Dataset({da_varname: masked_da}) |
124 | | - ds = ds.isel(station=duplication_indexes) |
125 | | - ds["station"] = station_names |
| 206 | +def extractor(config: dict[str, Any]) -> xr.Dataset: |
| 207 | + ds = process_inputs(config["station"], config["grid"]) |
126 | 208 | if config.get("output", None) is not None: |
127 | 209 | logger.info(f"Saving output to {config['output']['file']}") |
128 | 210 | ds.to_netcdf(config["output"]["file"]) |
|
0 commit comments