88from hat import _LOGGER as logger
99
1010
11- def process_grid_inputs (grid_config ):
11+ def load_ekd_source (grid_config ):
1212 src_name = list (grid_config ["source" ].keys ())[0 ]
1313 logger .info (f"Processing grid inputs from source: { src_name } " )
1414 logger .debug (f"Grid config: { grid_config ['source' ][src_name ]} " )
1515 ds = ekd .from_source (src_name , ** grid_config ["source" ][src_name ]).to_xarray (
1616 ** grid_config .get ("to_xarray_options" , {})
1717 )
18+ return ds
19+
20+
21+ def process_grid_inputs (grid_config ):
22+ ds = load_ekd_source (grid_config )
1823 var_name = find_main_var (ds , 3 )
1924 da = ds [var_name ]
2025 logger .info (f"Xarray created from source:\n { da } \n " )
@@ -61,7 +66,7 @@ def create_mask_from_coords(coords_config, df, gridx, gridy, shape):
6166 return mask , duplication_indexes
6267
6368
64- def process_inputs (station_config , grid_config ):
69+ def parse_stations (station_config ):
6570 logger .debug (f"Reading station file, { station_config } " )
6671 df = pd .read_csv (station_config ["file" ])
6772 filters = station_config .get ("filter" )
@@ -72,23 +77,44 @@ def process_inputs(station_config, grid_config):
7277
7378 index_config = station_config .get ("index" , None )
7479 coords_config = station_config .get ("coords" , None )
80+ index_1d_config = station_config .get ("index_1d" , None )
81+ return index_config , coords_config , index_1d_config , station_names , df
7582
83+
84+ def process_inputs (station_config , grid_config ):
85+ index_config , coords_config , index_1d_config , station_names , df = parse_stations (station_config )
86+
87+ # TODO: better malformed config handling
7688 if index_config is not None and coords_config is not None :
7789 raise ValueError ("Use either index or coords, not both." )
7890
79- da , da_varname , gridx_colname , gridy_colname , shape = process_grid_inputs (grid_config )
91+ if list (grid_config ["source" ].keys ())[0 ] == "gribjump" :
92+ assert index_1d_config is not None
93+ unique_indices , duplication_indexes = np .unique (df [index_1d_config ].values , return_inverse = True )
94+ grid_config ["source" ]["gribjump" ]["indices" ] = unique_indices
95+ masked_da = load_ekd_source (grid_config )
96+ # TODO: implement
97+ da_varname = "placeholder_variable_name"
8098
81- if index_config is not None :
82- mask , duplication_indexes = create_mask_from_index (index_config , df , shape )
83- elif coords_config is not None :
84- mask , duplication_indexes = create_mask_from_coords (
85- coords_config , df , da [gridx_colname ].values , da [gridy_colname ].values , shape
86- )
99+ var_name = find_main_var (masked_da , 2 )
100+ masked_da = masked_da [var_name ]
87101 else :
88- # default to index approach
89- mask , duplication_indexes = create_mask_from_index (index_config , df , shape )
102+ da , da_varname , gridx_colname , gridy_colname , shape = process_grid_inputs (grid_config )
90103
91- return da , da_varname , gridx_colname , gridy_colname , mask , station_names , duplication_indexes
104+ if index_config is not None :
105+ mask , duplication_indexes = create_mask_from_index (index_config , df , shape )
106+ elif coords_config is not None :
107+ mask , duplication_indexes = create_mask_from_coords (
108+ coords_config , df , da [gridx_colname ].values , da [gridy_colname ].values , shape
109+ )
110+ else :
111+ # default to index approach
112+ mask , duplication_indexes = create_mask_from_index (index_config , df , shape )
113+
114+ logger .info ("Extracting timeseries at selected stations" )
115+ masked_da = apply_mask (da , mask , gridx_colname , gridy_colname )
116+
117+ return da_varname , station_names , duplication_indexes , masked_da
92118
93119
94120def mask_array_np (arr , mask ):
@@ -101,12 +127,12 @@ def apply_mask(da, mask, coordx, coordy):
101127 da ,
102128 mask ,
103129 input_core_dims = [(coordx , coordy ), (coordx , coordy )],
104- output_core_dims = [["station " ]],
130+ output_core_dims = [["index " ]],
105131 output_dtypes = [da .dtype ],
106132 exclude_dims = {coordx , coordy },
107133 dask = "parallelized" ,
108134 dask_gufunc_kwargs = {
109- "output_sizes" : {"station " : int (mask .sum ())},
135+ "output_sizes" : {"index " : int (mask .sum ())},
110136 "allow_rechunk" : True ,
111137 },
112138 )
@@ -115,13 +141,10 @@ def apply_mask(da, mask, coordx, coordy):
115141
116142
117143def extractor (config ):
118- da , da_varname , gridx_colname , gridy_colname , mask , station_names , duplication_indexes = process_inputs (
119- config ["station" ], config ["grid" ]
120- )
121- logger .info ("Extracting timeseries at selected stations" )
122- masked_da = apply_mask (da , mask , gridx_colname , gridy_colname )
144+ da_varname , station_names , duplication_indexes , masked_da = process_inputs (config ["station" ], config ["grid" ])
145+ print (masked_da )
123146 ds = xr .Dataset ({da_varname : masked_da })
124- ds = ds .isel (station = duplication_indexes )
147+ ds = ds .isel (index = duplication_indexes )
125148 ds ["station" ] = station_names
126149 if config .get ("output" , None ) is not None :
127150 logger .info (f"Saving output to { config ['output' ]['file' ]} " )
0 commit comments