Skip to content

Commit 7222ff9

Browse files
authored
Merge pull request #414 from RTIInternational/410-grid-weights-file-being-generated-with-rowcolumns-as-float-instead-of-integer
Add schema validation to weights file for NWM grid fetching
2 parents 34bff66 + abd7138 commit 7222ff9

File tree

10 files changed

+112
-71
lines changed

10 files changed

+112
-71
lines changed

src/teehr/evaluation/fetch.py

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -478,11 +478,12 @@ def nwm_retrospective_grids(
478478
>>> ev = teehr.Evaluation()
479479
480480
>>> ev.fetch.nwm_retrospective_grids(
481-
>>> nwm_configuration="forcing_short_range",
481+
>>> nwm_version="nwm30",
482482
>>> variable_name="RAINRATE",
483483
>>> zonal_weights_filepath = Path(Path.home(), "nextgen_03S_weights.parquet"),
484484
>>> start_date=datetime(2000, 1, 1),
485-
>>> end_date=datetime(2001, 1, 1)
485+
>>> end_date=datetime(2001, 1, 1),
486+
>>> location_id_prefix="huc10"
486487
>>> )
487488
488489
.. note::
@@ -496,12 +497,12 @@ def nwm_retrospective_grids(
496497
497498
>>> nwm_retro_grids_to_parquet(
498499
>>> nwm_version="nwm30",
499-
>>> nwm_configuration="forcing_short_range",
500500
>>> variable_name="RAINRATE",
501501
>>> zonal_weights_filepath=Path(Path.home(), "nextgen_03S_weights.parquet"),
502502
>>> start_date=2020-12-18,
503503
>>> end_date=2022-12-18,
504-
>>> output_parquet_dir=Path(Path.home(), "temp/parquet")
504+
>>> output_parquet_dir=Path(Path.home(), "temp/parquet"),
505+
>>> location_id_prefix="huc10",
505506
>>> )
506507
507508
See Also

src/teehr/fetching/nwm/grid_utils.py

Lines changed: 13 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
import pandas as pd
99
import xarray as xr
1010

11+
import teehr.models.pandera_dataframe_schemas as schemas
1112
from teehr.fetching.utils import (
1213
get_dataset,
1314
write_timeseries_parquet_file,
@@ -94,6 +95,17 @@ def compute_weighted_average(
9495
return df[[LOCATION_ID, VALUE]].copy()
9596

9697

98+
def read_and_validate_weights_file(
99+
weights_filepath: str
100+
) -> pd.DataFrame:
101+
"""Read weights file from parquet, validating data types."""
102+
schema = schemas.weights_file_schema()
103+
weights_df = pd.read_parquet(
104+
weights_filepath, columns=list(schema.columns.keys())
105+
)
106+
return schema.validate(weights_df)
107+
108+
97109
@dask.delayed
98110
def process_single_nwm_grid_file(
99111
row: Tuple,
@@ -121,9 +133,7 @@ def process_single_nwm_grid_file(
121133
value_time = ds.time.values[0]
122134
da = ds[variable_name][0]
123135

124-
weights_df = pd.read_parquet(
125-
weights_filepath, columns=["row", "col", "weight", LOCATION_ID]
126-
)
136+
weights_df = read_and_validate_weights_file(weights_filepath)
127137

128138
weights_bounds = get_weights_row_col_stats(weights_df)
129139

src/teehr/fetching/nwm/nwm_grids.py

Lines changed: 0 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -295,31 +295,3 @@ def nwm_grids_to_parquet(
295295
variable_mapper=variable_mapper,
296296
timeseries_type=timeseries_type
297297
)
298-
299-
300-
# if __name__ == "__main__":
301-
# # Local testing
302-
# weights_parquet = "/mnt/data/ciroh/onehuc10_weights.parquet"
303-
304-
# import time
305-
# t1 = time.time()
306-
307-
# nwm_grids_to_parquet(
308-
# configuration="forcing_analysis_assim",
309-
# output_type="forcing",
310-
# variable_name="RAINRATE",
311-
# start_date="2023-11-28",
312-
# ingest_days=1,
313-
# zonal_weights_filepath=weights_parquet,
314-
# json_dir="/mnt/data/ciroh/jsons",
315-
# output_parquet_dir="/mnt/data/ciroh/parquet",
316-
# nwm_version="nwm30",
317-
# data_source="GCS",
318-
# kerchunk_method="auto",
319-
# t_minus_hours=[0],
320-
# ignore_missing_file=False,
321-
# overwrite_output=True,
322-
# location_id_prefix="wbd10"
323-
# )
324-
325-
# print(f"elapsed: {time.time() - t1:.2f} s")

src/teehr/fetching/nwm/retrospective_grids.py

Lines changed: 4 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -45,7 +45,6 @@
4545
from teehr.fetching.const import (
4646
VALUE_TIME,
4747
REFERENCE_TIME,
48-
LOCATION_ID,
4948
UNIT_NAME,
5049
VARIABLE_NAME,
5150
CONFIGURATION_NAME
@@ -61,7 +60,8 @@
6160
update_location_id_prefix,
6261
compute_weighted_average,
6362
get_nwm_grid_data,
64-
get_weights_row_col_stats
63+
get_weights_row_col_stats,
64+
read_and_validate_weights_file
6565
)
6666
from teehr.fetching.utils import (
6767
write_timeseries_parquet_file,
@@ -107,9 +107,7 @@ def process_nwm30_retro_group(
107107
and the output is saved to parquet files.
108108
"""
109109
logger.debug("Processing NWM v3.0 retro grid data chunk.")
110-
weights_df = pd.read_parquet(
111-
weights_filepath, columns=["row", "col", "weight", LOCATION_ID]
112-
)
110+
weights_df = read_and_validate_weights_file(weights_filepath)
113111

114112
weights_bounds = get_weights_row_col_stats(weights_df)
115113

@@ -149,7 +147,6 @@ def process_nwm30_retro_group(
149147
if location_id_prefix:
150148
chunk_df = update_location_id_prefix(chunk_df, location_id_prefix)
151149

152-
153150
return chunk_df
154151

155152

@@ -203,9 +200,7 @@ def process_single_nwm21_retro_grid_file(
203200
value_time = row.datetime
204201
da = ds[variable_name].isel(Time=0)
205202

206-
weights_df = pd.read_parquet(
207-
weights_filepath, columns=["row", "col", "weight", LOCATION_ID]
208-
)
203+
weights_df = read_and_validate_weights_file(weights_filepath)
209204

210205
weights_bounds = get_weights_row_col_stats(weights_df)
211206

src/teehr/models/pandera_dataframe_schemas.py

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -313,6 +313,36 @@ def location_crosswalks_schema(
313313
coerce=True
314314
)
315315

316+
317+
def weights_file_schema() -> pa.DataFrameSchema:
318+
"""Return the schema for a weights file."""
319+
return pa.DataFrameSchema(
320+
columns={
321+
"row": pa.Column(
322+
pa.Int32,
323+
nullable=False,
324+
coerce=True
325+
),
326+
"col": pa.Column(
327+
pa.Int32,
328+
nullable=False,
329+
coerce=True
330+
),
331+
"weight": pa.Column(
332+
pa.Float32,
333+
nullable=False,
334+
coerce=True
335+
),
336+
"location_id": pa.Column(
337+
pa.String,
338+
nullable=False,
339+
coerce=True
340+
)
341+
},
342+
strict="filter"
343+
)
344+
345+
316346
# Timeseries
317347
pandas_value_type = pa.Float32()
318348
pyspark_value_type = T.FloatType()

src/teehr/utilities/generate_weights.py

Lines changed: 19 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414

1515
from teehr.fetching.utils import load_gdf
1616
from teehr.fetching.const import LOCATION_ID
17+
import teehr.models.pandera_dataframe_schemas as schemas
1718

1819

1920
@dask.delayed
@@ -236,11 +237,19 @@ def generate_weights_file(
236237
grid_transform = src_da.rio.transform()
237238
nodata_val = src_da.rio.nodata
238239

240+
if not all([dim in src_da.dims for dim in ["x", "y"]]):
241+
raise ValueError("Template dataset must have x and y dimensions.")
242+
239243
# Get the subset of the grid that intersects the total zone bounds
240244
bbox = tuple(zone_gdf.total_bounds)
241-
src_da = src_da.sel(x=slice(bbox[0], bbox[2]), y=slice(bbox[1], bbox[3]))[
242-
0
243-
]
245+
if len(ds.dims) == 2:
246+
src_da = src_da.sel(
247+
x=slice(bbox[0], bbox[2]), y=slice(bbox[1], bbox[3])
248+
)
249+
else:
250+
src_da = src_da.sel(
251+
x=slice(bbox[0], bbox[2]), y=slice(bbox[1], bbox[3])
252+
)[0]
244253
src_da = src_da.astype("float32")
245254
src_da["x"] = np.float32(src_da.x.values)
246255
src_da["y"] = np.float32(src_da.y.values)
@@ -275,28 +284,11 @@ def generate_weights_file(
275284
if location_id_prefix:
276285
df.loc[:, LOCATION_ID] = location_id_prefix + "-" + df[LOCATION_ID]
277286

287+
schema = schemas.weights_file_schema()
288+
validated_df = schema.validate(df)
289+
278290
if output_weights_filepath:
279-
df.to_parquet(output_weights_filepath)
280-
df = None
281-
282-
return df
283-
284-
285-
# if __name__ == "__main__":
286-
# # Local testing
287-
# zone_polygon_filepath = "/mnt/data/wbd/one_alaska_huc10.parquet"
288-
# template_dataset = "/mnt/data/ciroh/nwm_temp/nwm.20231101_forcing_analysis_assim_alaska_nwm.t00z.analysis_assim.forcing.tm01.alaska.nc" # noqa
289-
# variable_name = "RAINRATE"
290-
# unique_zone_id = "huc10"
291-
# output_weights_filepath = (
292-
# "/mnt/sf_shared/data/ciroh/one_huc10_alaska_weights.parquet"
293-
# )
294-
295-
# generate_weights_file(
296-
# zone_polygon_filepath=zone_polygon_filepath,
297-
# template_dataset=template_dataset,
298-
# variable_name=variable_name,
299-
# output_weights_filepath=output_weights_filepath,
300-
# crs_wkt=AL_NWM_WKT,
301-
# unique_zone_id=unique_zone_id
302-
# )
291+
validated_df.to_parquet(output_weights_filepath)
292+
validated_df = None
293+
294+
return validated_df
Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,3 @@
1+
version https://git-lfs.github.com/spec/v1
2+
oid sha256:160df6bd1d64618fe23aca51903693b8aa5ce6e7fca53b56501282bfbda66436
3+
size 70864318
Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,3 @@
1+
version https://git-lfs.github.com/spec/v1
2+
oid sha256:0cc0cebb8ffde0f474417d0d063c9dbb5125cdc853ca91c2e9b6d6fe4dfd4e87
3+
size 6014
Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,3 @@
1+
version https://git-lfs.github.com/spec/v1
2+
oid sha256:04cb0fb0df5f9c628394749cda31fc0bb78c5a56fe94dffde4a77bf3b6a08759
3+
size 3047

tests/test_generate_weights.py

Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,32 @@
1+
"""Test the generation of weights."""
2+
import pandas as pd
3+
import numpy as np
4+
from pathlib import Path
5+
from teehr.utilities.generate_weights import generate_weights_file
6+
from teehr.fetching.const import CONUS_NWM_WKT
7+
8+
9+
TEST_DIR = Path("tests", "data", "nwm30")
10+
TEMPLATE_FILEPATH = Path(TEST_DIR, "nwm_retro_v3_template_grid.nc")
11+
ZONES_FILEPATH = Path(TEST_DIR, "one_huc10_conus_1016000606.parquet")
12+
WEIGHTS_FILEPATH = Path(TEST_DIR, "one_huc10_1016000606_teehr_weights.parquet")
13+
14+
15+
def test_weights():
16+
"""Test the generation of weights."""
17+
df = generate_weights_file(
18+
zone_polygon_filepath=ZONES_FILEPATH,
19+
template_dataset=TEMPLATE_FILEPATH,
20+
variable_name="RAINRATE",
21+
crs_wkt=CONUS_NWM_WKT,
22+
output_weights_filepath=None,
23+
unique_zone_id="id",
24+
)
25+
26+
df_test = pd.read_parquet(WEIGHTS_FILEPATH).astype({"weight": np.float32})
27+
28+
assert df.equals(df_test)
29+
30+
31+
if __name__ == "__main__":
32+
test_weights()

0 commit comments

Comments
 (0)