|
14 | 14 |
|
15 | 15 | from teehr.fetching.utils import load_gdf |
16 | 16 | from teehr.fetching.const import LOCATION_ID |
| 17 | +import teehr.models.pandera_dataframe_schemas as schemas |
17 | 18 |
|
18 | 19 |
|
19 | 20 | @dask.delayed |
@@ -236,11 +237,19 @@ def generate_weights_file( |
236 | 237 | grid_transform = src_da.rio.transform() |
237 | 238 | nodata_val = src_da.rio.nodata |
238 | 239 |
|
| 240 | + if not all([dim in src_da.dims for dim in ["x", "y"]]): |
| 241 | + raise ValueError("Template dataset must have x and y dimensions.") |
| 242 | + |
239 | 243 | # Get the subset of the grid that intersects the total zone bounds |
240 | 244 | bbox = tuple(zone_gdf.total_bounds) |
241 | | - src_da = src_da.sel(x=slice(bbox[0], bbox[2]), y=slice(bbox[1], bbox[3]))[ |
242 | | - 0 |
243 | | - ] |
| 245 | + if len(ds.dims) == 2: |
| 246 | + src_da = src_da.sel( |
| 247 | + x=slice(bbox[0], bbox[2]), y=slice(bbox[1], bbox[3]) |
| 248 | + ) |
| 249 | + else: |
| 250 | + src_da = src_da.sel( |
| 251 | + x=slice(bbox[0], bbox[2]), y=slice(bbox[1], bbox[3]) |
| 252 | + )[0] |
244 | 253 | src_da = src_da.astype("float32") |
245 | 254 | src_da["x"] = np.float32(src_da.x.values) |
246 | 255 | src_da["y"] = np.float32(src_da.y.values) |
@@ -275,28 +284,11 @@ def generate_weights_file( |
275 | 284 | if location_id_prefix: |
276 | 285 | df.loc[:, LOCATION_ID] = location_id_prefix + "-" + df[LOCATION_ID] |
277 | 286 |
|
| 287 | + schema = schemas.weights_file_schema() |
| 288 | + validated_df = schema.validate(df) |
| 289 | + |
278 | 290 | if output_weights_filepath: |
279 | | - df.to_parquet(output_weights_filepath) |
280 | | - df = None |
281 | | - |
282 | | - return df |
283 | | - |
284 | | - |
285 | | -# if __name__ == "__main__": |
286 | | -# # Local testing |
287 | | -# zone_polygon_filepath = "/mnt/data/wbd/one_alaska_huc10.parquet" |
288 | | -# template_dataset = "/mnt/data/ciroh/nwm_temp/nwm.20231101_forcing_analysis_assim_alaska_nwm.t00z.analysis_assim.forcing.tm01.alaska.nc" # noqa |
289 | | -# variable_name = "RAINRATE" |
290 | | -# unique_zone_id = "huc10" |
291 | | -# output_weights_filepath = ( |
292 | | -# "/mnt/sf_shared/data/ciroh/one_huc10_alaska_weights.parquet" |
293 | | -# ) |
294 | | - |
295 | | -# generate_weights_file( |
296 | | -# zone_polygon_filepath=zone_polygon_filepath, |
297 | | -# template_dataset=template_dataset, |
298 | | -# variable_name=variable_name, |
299 | | -# output_weights_filepath=output_weights_filepath, |
300 | | -# crs_wkt=AL_NWM_WKT, |
301 | | -# unique_zone_id=unique_zone_id |
302 | | -# ) |
| 291 | + validated_df.to_parquet(output_weights_filepath) |
| 292 | + validated_df = None |
| 293 | + |
| 294 | + return validated_df |
0 commit comments