Skip to content

Commit 4b09d36

Browse files
committed
Improve type annotations
- `utilFcns.write_weather_vars_to_ds`: annotate number of dimensions - types: Add generic version of the numpy array helper types so you can annotate an array of booleans as `Array1D[np.bool]` - Add type annotations to `models.hrrr.get_bounds_indices`
1 parent 6a42272 commit 4b09d36

File tree

8 files changed

+131
-114
lines changed

8 files changed

+131
-114
lines changed

test/test_util.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,7 @@
3333
transform_bbox,
3434
unproject,
3535
writeArrayToRaster,
36-
writeWeatherVarsXarray,
36+
write_weather_vars_to_ds,
3737
)
3838
from test import TEST_DIR, pushd
3939

@@ -917,7 +917,7 @@ def test_UTM_to_WGS84_empty_input():
917917

918918

919919
# Test writeWeatherVarsXarray
920-
def test_writeWeatherVarsXarray(tmp_path):
920+
def test_write_weather_vars_to_ds(tmp_path):
921921
"""Test writing weather variables to an xarray dataset and NetCDF file."""
922922
# Mock inputs
923923
lat = np.random.rand(91, 144) * 180 - 90 # Random latitudes between -90 and 90
@@ -938,7 +938,7 @@ def test_writeWeatherVarsXarray(tmp_path):
938938
out_path = tmp_path / "test_output.nc"
939939

940940
# Call the function
941-
writeWeatherVarsXarray(lat, lon, h, q, p, t, datetime_value, crs, out_path)
941+
write_weather_vars_to_ds(lat, lon, h, q, p, t, datetime_value, crs, out_path)
942942

943943
# Check that the file was created
944944
assert out_path.exists()

test/test_weather_model.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -403,7 +403,7 @@ def test_ztd(model: MockWeatherModel) -> None:
403403

404404
def test_get_bounds_indices() -> None:
405405
"""Test bounds indices."""
406-
snwe = [-10, 10, -10, 10]
406+
snwe = (-10, 10, -10, 10)
407407
ll = np.arange(-20, 20)
408408
lats, lons = np.meshgrid(ll, ll)
409409
xmin, xmax, ymin, ymax = get_bounds_indices(snwe, lats, lons)
@@ -415,7 +415,7 @@ def test_get_bounds_indices() -> None:
415415

416416
def test_get_bounds_indices_2() -> None:
417417
"""Test bounds indices."""
418-
snwe = [-10, 10, 170, -170]
418+
snwe = (-10, 10, 170, -170)
419419
l = np.arange(-20, 20)
420420
l2 = (((np.arange(160, 200) + 180) % 360) - 180)
421421
lats, lons = np.meshgrid(l, l2)
@@ -425,7 +425,7 @@ def test_get_bounds_indices_2() -> None:
425425

426426
def test_get_bounds_indices_2b() -> None:
427427
"""Test bounds indices."""
428-
snwe = [-10, 10, 170, 190]
428+
snwe = (-10, 10, 170, 190)
429429
l = np.arange(-20, 20)
430430
l2 = np.arange(160, 200)
431431
lats, lons = np.meshgrid(l, l2)
@@ -438,7 +438,7 @@ def test_get_bounds_indices_2b() -> None:
438438

439439
def test_get_bounds_indices_3() -> None:
440440
"""Test bounds indices"""
441-
snwe = [-10, 10, -10, 10]
441+
snwe = (-10, 10, -10, 10)
442442
l = np.arange(-20, 20)
443443
l2 = (((np.arange(160, 200) + 180) % 360) - 180)
444444
lats, lons = np.meshgrid(l, l2)
@@ -448,7 +448,7 @@ def test_get_bounds_indices_3() -> None:
448448

449449
def test_get_bounds_indices_4() -> None:
450450
"""Test bounds_indices."""
451-
snwe = [55, 60, 175, 185]
451+
snwe = (55, 60, 175, 185)
452452
l = np.arange(55, 60, 1)
453453
l2 = np.arange(175, 185, 1)
454454
lats, lons = np.meshgrid(l, l2)

tools/RAiDER/models/gmao.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@
1212
LEVELS_137_HEIGHTS,
1313
)
1414
from RAiDER.models.weatherModel import TIME_RES, WeatherModel
15-
from RAiDER.utilFcns import requests_retry_session, round_date, writeWeatherVarsXarray
15+
from RAiDER.utilFcns import requests_retry_session, round_date, write_weather_vars_to_ds
1616

1717

1818
class GMAO(WeatherModel):
@@ -142,7 +142,7 @@ def _fetch(self, out: Path) -> None:
142142

143143
try:
144144
# Note that lat/lon gets written twice for GMAO because they are the same as y/x
145-
writeWeatherVarsXarray(lat, lon, h, q, p, t, self._time, self._proj, out)
145+
write_weather_vars_to_ds(lat, lon, h, q, p, t, self._time, self._proj, out)
146146
except:
147147
logger.exception('Unable to save weathermodel to file:')
148148
raise

tools/RAiDER/models/hrrr.py

Lines changed: 44 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,15 @@
11
import datetime as dt
2-
import os
32
from pathlib import Path
3+
from typing import List, Tuple, Union, cast
44

5+
from RAiDER.types import BB, Array2D, FloatArray2D
56
import geopandas as gpd
7+
import herbie
8+
import herbie.accessors
69
import numpy as np
710
import xarray as xr
8-
from herbie import Herbie
911
from pyproj import CRS, Transformer
1012
from shapely.geometry import Polygon, box
11-
from typing import Optional, Union, List, Tuple
1213

1314
from RAiDER.logger import logger
1415
from RAiDER.models.customExceptions import NoWeatherModelData
@@ -29,13 +30,13 @@
2930

3031
def check_hrrr_dataset_availability(datetime: dt.datetime, model='hrrr') -> bool:
3132
"""Note a file could still be missing within the models valid range."""
32-
herbie = Herbie(
33+
h = herbie.Herbie(
3334
datetime,
3435
model=model,
3536
product='nat',
3637
fxx=0,
3738
)
38-
return herbie.grib_source is not None
39+
return h.grib_source is not None
3940

4041

4142
def download_hrrr_file(ll_bounds, DATE, out: Path, model='hrrr', product='nat', fxx=0, verbose=False) -> None:
@@ -54,7 +55,7 @@ def download_hrrr_file(ll_bounds, DATE, out: Path, model='hrrr', product='nat',
5455
Returns:
5556
None, writes data to a netcdf file
5657
"""
57-
herbie = Herbie(
58+
h = herbie.Herbie(
5859
DATE.strftime('%Y-%m-%d %H:%M'),
5960
model=model,
6061
product=product,
@@ -66,22 +67,27 @@ def download_hrrr_file(ll_bounds, DATE, out: Path, model='hrrr', product='nat',
6667

6768
# Iterate through the list of datasets
6869
try:
69-
ds_list = herbie.xarray(':(SPFH|PRES|TMP|HGT):', verbose=verbose)
70+
# cast: Herbie.xarray can return one Dataset or a list
71+
ds_list = cast(
72+
Union[xr.Dataset, List[xr.Dataset]],
73+
h.xarray(':(SPFH|PRES|TMP|HGT):', verbose=verbose),
74+
)
75+
assert isinstance(ds_list, list)
7076
except ValueError as e:
7177
logger.error(e)
7278
raise
7379

7480
# Note order coord names are request for `test_HRRR_ztd` matters
7581
# when both coord names are retreived by Herbie is ds_list possibly in
7682
# Different orders on different machines; `hybrid` is what is expected for the test.
77-
ds_list_filt_0 = [ds for ds in ds_list if 'hybrid' in ds._coord_names]
78-
ds_list_filt_1 = [ds for ds in ds_list if 'isobaricInhPa' in ds._coord_names]
79-
if ds_list_filt_0:
83+
ds_list_filt_0 = [ds for ds in ds_list if 'hybrid' in ds.coords]
84+
ds_list_filt_1 = [ds for ds in ds_list if 'isobaricInhPa' in ds.coords]
85+
if len(ds_list_filt_0) > 0:
8086
ds_out = ds_list_filt_0[0]
8187
coord = 'hybrid'
8288
# I do not think that this coord name will result in successful processing nominally as variables are
8389
# gh,gribfile_projection for test_HRRR_ztd
84-
elif ds_list_filt_1:
90+
elif len(ds_list_filt_1) > 0:
8591
ds_out = ds_list_filt_1[0]
8692
coord = 'isobaricInhPa'
8793
else:
@@ -91,22 +97,25 @@ def download_hrrr_file(ll_bounds, DATE, out: Path, model='hrrr', product='nat',
9197
try:
9298
x_min, x_max, y_min, y_max = get_bounds_indices(
9399
ll_bounds,
94-
ds_out.latitude.to_numpy(),
95-
ds_out.longitude.to_numpy(),
100+
ds_out['latitude'].to_numpy(),
101+
ds_out['longitude'].to_numpy(),
96102
)
97103
except NoWeatherModelData as e:
98104
logger.error(e)
99-
logger.error('lat/lon bounds: %s', ll_bounds)
100-
logger.error('Weather model is {}'.format(model))
105+
logger.error(f'lat/lon bounds: {ll_bounds}')
106+
logger.error(f'Weather model is {model}')
101107
raise
102108

103109
# bookkeepping
104110
ds_out = ds_out.rename({'gh': 'z', coord: 'levels'})
105111

112+
# Herbie injects brainworms into all its xarray Datasets
113+
crs = cast(herbie.accessors.HerbieAccessor, ds_out.herbie).crs
114+
106115
# projection information
107116
ds_out['proj'] = 0
108-
for k, v in CRS.from_user_input(ds_out.herbie.crs).to_cf().items():
109-
ds_out.proj.attrs[k] = v
117+
for k, v in CRS.from_user_input(crs).to_cf().items():
118+
ds_out['proj'].attrs[k] = v
110119
for var in ds_out.data_vars:
111120
ds_out[var].attrs['grid_mapping'] = 'proj'
112121

@@ -134,40 +143,37 @@ def download_hrrr_file(ll_bounds, DATE, out: Path, model='hrrr', product='nat',
134143
ds_sub.to_netcdf(out)
135144

136145

137-
def get_bounds_indices(SNWE, lats, lons):
146+
def get_bounds_indices(snwe: BB.SNWE, lats: FloatArray2D, lons: FloatArray2D) -> BB.WSEN:
138147
"""Convert SNWE lat/lon bounds to index bounds."""
139148
# Unpack the bounds and find the relevent indices
140-
S, N, W, E = SNWE
149+
s, n, w, e = snwe
141150

142151
# Need to account for crossing the international date line
143-
if W < E:
144-
m1 = (S <= lats) & (N >= lats) & (W <= lons) & (E >= lons)
145-
else:
152+
if w >= e:
146153
raise ValueError(
147-
'Longitude is either flipped or you are crossing the international date line;'
148-
+ 'if the latter please give me longitudes from 0-360'
154+
'Longitude is either flipped or you are crossing the international date line; '
155+
'if the latter please give me longitudes from 0-360'
149156
)
150157

151-
if np.sum(m1) == 0:
158+
m1: Array2D[np.bool] = (s <= lats) & (n >= lats) & (w <= lons) & (e >= lons)
159+
if np.sum(m1) == 0: # All False
152160
lons = np.mod(lons, 360)
153-
W, E = np.mod([W, E], 360)
154-
m1 = (S <= lats) & (N >= lats) & (W <= lons) & (E >= lons)
155-
if np.sum(m1) == 0:
161+
w, e = np.mod([w, e], 360)
162+
# try again
163+
m1 = (s <= lats) & (n >= lats) & (w <= lons) & (e >= lons)
164+
if np.sum(m1) == 0: # All False
156165
raise NoWeatherModelData('Area of Interest has no overlap with the HRRR model available extent')
157166

158167
# Y extent
159-
shp = lats.shape
160-
m1_y = np.argwhere(np.sum(m1, axis=1) != 0)
161-
y_min = max(m1_y[0][0], 0)
162-
y_max = min(m1_y[-1][0], shp[0])
163-
m1_y = None
168+
m1_y: Array2D[np.integer] = np.argwhere(np.sum(m1, axis=1) != 0)
169+
y_min: int = max(m1_y[0][0], 0)
170+
y_max: int = min(m1_y[-1][0], lats.shape[0])
171+
del m1_y # big
164172

165173
# X extent
166-
m1_x = np.argwhere(np.sum(m1, axis=0) != 0)
167-
x_min = max(m1_x[0][0], 0)
168-
x_max = min(m1_x[-1][0], shp[1])
169-
m1_x = None
170-
m1 = None
174+
m1_x: Array2D[np.integer] = np.argwhere(np.sum(m1, axis=0) != 0)
175+
x_min: int = max(m1_x[0][0], 0)
176+
x_max: int = min(m1_x[-1][0], lats.shape[1])
171177

172178
return x_min, x_max, y_min, y_max
173179

tools/RAiDER/models/merra2.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@
1111
LEVELS_137_HEIGHTS,
1212
)
1313
from RAiDER.models.weatherModel import WeatherModel
14-
from RAiDER.utilFcns import writeWeatherVarsXarray
14+
from RAiDER.utilFcns import write_weather_vars_to_ds
1515

1616

1717
# Path to Netrc file, can be controlled by env var
@@ -137,7 +137,7 @@ def _fetch(self, out: Path) -> None:
137137
].data.squeeze()
138138

139139
try:
140-
writeWeatherVarsXarray(lat, lon, h, q, p, t, time, self._proj, out_path=out)
140+
write_weather_vars_to_ds(lat, lon, h, q, p, t, time, self._proj, out_path=out)
141141
except Exception as e:
142142
logger.debug(e)
143143
logger.exception('MERRA-2: Unable to save weather model query to file')

tools/RAiDER/models/ncmr.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@
1818
from RAiDER.utilFcns import (
1919
read_NCMR_loginInfo,
2020
show_progress,
21-
writeWeatherVarsXarray,
21+
write_weather_vars_to_ds,
2222
)
2323

2424

@@ -193,7 +193,7 @@ def _download_ncmr_file(self, out: Path, date_time, bounding_box) -> None:
193193
########################################################################################################################
194194

195195
try:
196-
writeWeatherVarsXarray(lats, lons, hgt, q, p, t, self._time, self._proj, out_path=out)
196+
write_weather_vars_to_ds(lats, lons, hgt, q, p, t, self._time, self._proj, out_path=out)
197197
except:
198198
logger.exception('Unable to save weathermodel to file')
199199

tools/RAiDER/types/__init__.py

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -17,16 +17,19 @@
1717
# Helpers for type annotating the dimensions of a numpy array.
1818
# When you use these, your type checker will alert you when an annotated array
1919
# is used in a way inconsistent with its dimensions.
20-
Array1D = np.ndarray[tuple[int], np.dtype[_ScalarT_co]]
21-
Array2D = np.ndarray[tuple[int, int], np.dtype[_ScalarT_co]]
22-
Array3D = np.ndarray[tuple[int, int, int], np.dtype[_ScalarT_co]]
20+
Array1D = np.ndarray[tuple[int], np.dtype[np._ScalarT_co]]
21+
Array2D = np.ndarray[tuple[int, int], np.dtype[np._ScalarT_co]]
22+
Array3D = np.ndarray[tuple[int, int, int], np.dtype[np._ScalarT_co]]
2323
# ... (repeat the pattern as needed for higher dimensions)
2424

2525
# Any number of dimensions -- when ndim is not able to be known ahead of time
26-
ArrayND = np.ndarray[tuple[int, ...], np.dtype[_ScalarT_co]]
26+
ArrayND = np.ndarray[tuple[int, ...], np.dtype[np._ScalarT_co]]
2727

2828

2929
FloatArray1D = Array1D[np.floating]
3030
FloatArray2D = Array2D[np.floating]
3131
FloatArray3D = Array3D[np.floating]
32+
# ... (repeat the pattern as needed for higher dimensions)
33+
34+
# Any number of dimensions -- when ndim is not able to be known ahead of time
3235
FloatArrayND = ArrayND[np.floating]

0 commit comments

Comments
 (0)