Skip to content

Commit 8943050

Browse files
committed
feat: first basic implementation using gribjump for timeseries extraction
1 parent 8385bae commit 8943050

File tree

1 file changed

+43
-20
lines changed

1 file changed

+43
-20
lines changed

hat/extract_timeseries/extractor.py

Lines changed: 43 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -8,13 +8,18 @@
88
from hat import _LOGGER as logger
99

1010

11-
def process_grid_inputs(grid_config):
11+
def load_ekd_source(grid_config):
1212
src_name = list(grid_config["source"].keys())[0]
1313
logger.info(f"Processing grid inputs from source: {src_name}")
1414
logger.debug(f"Grid config: {grid_config['source'][src_name]}")
1515
ds = ekd.from_source(src_name, **grid_config["source"][src_name]).to_xarray(
1616
**grid_config.get("to_xarray_options", {})
1717
)
18+
return ds
19+
20+
21+
def process_grid_inputs(grid_config):
22+
ds = load_ekd_source(grid_config)
1823
var_name = find_main_var(ds, 3)
1924
da = ds[var_name]
2025
logger.info(f"Xarray created from source:\n{da}\n")
@@ -61,7 +66,7 @@ def create_mask_from_coords(coords_config, df, gridx, gridy, shape):
6166
return mask, duplication_indexes
6267

6368

64-
def process_inputs(station_config, grid_config):
69+
def parse_stations(station_config):
6570
logger.debug(f"Reading station file, {station_config}")
6671
df = pd.read_csv(station_config["file"])
6772
filters = station_config.get("filter")
@@ -72,23 +77,44 @@ def process_inputs(station_config, grid_config):
7277

7378
index_config = station_config.get("index", None)
7479
coords_config = station_config.get("coords", None)
80+
index_1d_config = station_config.get("index_1d", None)
81+
return index_config, coords_config, index_1d_config, station_names, df
7582

83+
84+
def process_inputs(station_config, grid_config):
85+
index_config, coords_config, index_1d_config, station_names, df = parse_stations(station_config)
86+
87+
# TODO: better malformed config handling
7688
if index_config is not None and coords_config is not None:
7789
raise ValueError("Use either index or coords, not both.")
7890

79-
da, da_varname, gridx_colname, gridy_colname, shape = process_grid_inputs(grid_config)
91+
if list(grid_config["source"].keys())[0] == "gribjump":
92+
assert index_1d_config is not None
93+
unique_indices, duplication_indexes = np.unique(df[index_1d_config].values, return_inverse=True)
94+
grid_config["source"]["gribjump"]["indices"] = unique_indices
95+
masked_da = load_ekd_source(grid_config)
96+
# TODO: implement
97+
da_varname = "placeholder_variable_name"
8098

81-
if index_config is not None:
82-
mask, duplication_indexes = create_mask_from_index(index_config, df, shape)
83-
elif coords_config is not None:
84-
mask, duplication_indexes = create_mask_from_coords(
85-
coords_config, df, da[gridx_colname].values, da[gridy_colname].values, shape
86-
)
99+
var_name = find_main_var(masked_da, 2)
100+
masked_da = masked_da[var_name]
87101
else:
88-
# default to index approach
89-
mask, duplication_indexes = create_mask_from_index(index_config, df, shape)
102+
da, da_varname, gridx_colname, gridy_colname, shape = process_grid_inputs(grid_config)
90103

91-
return da, da_varname, gridx_colname, gridy_colname, mask, station_names, duplication_indexes
104+
if index_config is not None:
105+
mask, duplication_indexes = create_mask_from_index(index_config, df, shape)
106+
elif coords_config is not None:
107+
mask, duplication_indexes = create_mask_from_coords(
108+
coords_config, df, da[gridx_colname].values, da[gridy_colname].values, shape
109+
)
110+
else:
111+
# default to index approach
112+
mask, duplication_indexes = create_mask_from_index(index_config, df, shape)
113+
114+
logger.info("Extracting timeseries at selected stations")
115+
masked_da = apply_mask(da, mask, gridx_colname, gridy_colname)
116+
117+
return da_varname, station_names, duplication_indexes, masked_da
92118

93119

94120
def mask_array_np(arr, mask):
@@ -101,12 +127,12 @@ def apply_mask(da, mask, coordx, coordy):
101127
da,
102128
mask,
103129
input_core_dims=[(coordx, coordy), (coordx, coordy)],
104-
output_core_dims=[["station"]],
130+
output_core_dims=[["index"]],
105131
output_dtypes=[da.dtype],
106132
exclude_dims={coordx, coordy},
107133
dask="parallelized",
108134
dask_gufunc_kwargs={
109-
"output_sizes": {"station": int(mask.sum())},
135+
"output_sizes": {"index": int(mask.sum())},
110136
"allow_rechunk": True,
111137
},
112138
)
@@ -115,13 +141,10 @@ def apply_mask(da, mask, coordx, coordy):
115141

116142

117143
def extractor(config):
118-
da, da_varname, gridx_colname, gridy_colname, mask, station_names, duplication_indexes = process_inputs(
119-
config["station"], config["grid"]
120-
)
121-
logger.info("Extracting timeseries at selected stations")
122-
masked_da = apply_mask(da, mask, gridx_colname, gridy_colname)
144+
da_varname, station_names, duplication_indexes, masked_da = process_inputs(config["station"], config["grid"])
145+
print(masked_da)
123146
ds = xr.Dataset({da_varname: masked_da})
124-
ds = ds.isel(station=duplication_indexes)
147+
ds = ds.isel(index=duplication_indexes)
125148
ds["station"] = station_names
126149
if config.get("output", None) is not None:
127150
logger.info(f"Saving output to {config['output']['file']}")

0 commit comments

Comments
 (0)