Skip to content

Commit d213509

Browse files
authored
Merge pull request #289 from NREL/bnb/lr_features_ordering
- invert_uv=True as default for forward pass strategy. - fix ordering of lr_features in dual sampler.
2 parents de05597 + b0c5cc1 commit d213509

File tree

10 files changed

+1065
-408
lines changed

10 files changed

+1065
-408
lines changed

pixi.lock

Lines changed: 1017 additions & 369 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

pyproject.toml

Lines changed: 15 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -58,7 +58,6 @@ tensorflow = {version = "~=2.15.0", channel = "conda-forge"}
5858

5959
[project.optional-dependencies]
6060
dev = [
61-
"build>=0.5",
6261
"flake8",
6362
"pre-commit",
6463
"pylint",
@@ -74,9 +73,9 @@ test = [
7473
"pytest-env"
7574
]
7675
build = [
77-
"build>=1.2.2,<2",
76+
"build>=0.6",
7877
"pkginfo>=1.10.0,<2",
79-
"twine>=6.1.0,<7",
78+
"twine>=5.0",
8079
]
8180

8281
[project.scripts]
@@ -298,7 +297,6 @@ numpy = "~=1.7"
298297
pandas = ">=2.0"
299298
scipy = ">=1.0.0"
300299
xarray = ">=2023.0"
301-
ipython = ">=9.7.0,<10"
302300

303301
[tool.pixi.pypi-dependencies]
304302
NREL-sup3r = { path = ".", editable = true }
@@ -309,14 +307,21 @@ NREL-farms = { version = ">=1.0.4" }
309307

310308
[tool.pixi.environments]
311309
default = { solve-group = "default" }
312-
dev = { features = ["dev", "doc", "test"], solve-group = "default" }
310+
dev = { features = ["dev", "doc", "test", "viz", "build"], solve-group = "default" }
313311
doc = { features = ["doc"], solve-group = "default" }
314312
test = { features = ["test"], solve-group = "default" }
315313
viz = { features = ["viz"], solve-group = "default" }
314+
build = { features = ["build"], solve-group = "default" }
316315

317-
[tool.pixi.tasks]
316+
[tool.pixi.feature.test.tasks]
318317
test = "pytest --pdb --durations=10 tests"
319318

319+
[tool.pixi.feature.build.tasks]
320+
clean-readme = { cmd = "python sup3r/utilities/_clean_readme.py README.rst" }
321+
build-wheels = { cmd = "uv build --sdist --wheel --out-dir dist/ .", depends-on = ["clean-readme"] }
322+
check-wheels = { cmd = ["twine", "check", "dist/*"], depends-on = ["build-wheels"] }
323+
upload-wheels = { cmd = ["twine", "upload", "dist/*"], depends-on = ["check-wheels"] }
324+
320325
[tool.pixi.feature.doc.dependencies]
321326
sphinx = ">=8.1.3,<9"
322327
sphinx-book-theme = ">=1.1.3,<2"
@@ -327,8 +332,6 @@ pytest = ">=5.2"
327332
pytest-cov = ">=5.0.0"
328333

329334
[tool.pixi.feature.dev.dependencies]
330-
build = ">=0.6"
331-
twine = ">=5.0"
332335
ruff = ">=0.4"
333336
ipython = ">=8.0"
334337
pytest-xdist = ">=3.0"
@@ -337,6 +340,10 @@ pytest-xdist = ">=3.0"
337340
jupyter = ">=1.0"
338341
hvplot = ">=0.10"
339342

343+
[tool.pixi.feature.build.dependencies]
344+
build = ">=0.6"
345+
twine = ">=5.0"
346+
340347
[tool.pytest_env]
341348
CUDA_VISIBLE_DEVICES = "-1"
342349
TF_ENABLE_ONEDNN_OPTS = "0"

sup3r/pipeline/forward_pass.py

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -535,6 +535,7 @@ def _run_parallel(cls, strategy, node_index):
535535
model_class=strategy.model_class,
536536
allowed_const=strategy.allowed_const,
537537
output_workers=strategy.output_workers,
538+
invert_uv=strategy.invert_uv,
538539
meta=fwp.meta,
539540
)
540541
futures[fut] = {
@@ -585,7 +586,7 @@ def run_chunk(
585586
model_kwargs,
586587
model_class,
587588
allowed_const,
588-
invert_uv=None,
589+
invert_uv=False,
589590
meta=None,
590591
nn_fill=True,
591592
output_workers=None,
@@ -613,8 +614,8 @@ def run_chunk(
613614
information on this argument.
614615
invert_uv : bool
615616
Whether to convert uv to windspeed and winddirection for writing
616-
output. This defaults to True for H5 output and False for NETCDF
617-
output.
617+
output. When this method is called during a pipeline forward pass
618+
the value is taken from the strategy.invert_uv attribute.
618619
nn_fill : bool
619620
Whether to fill data outside of limits with nearest neighbour or
620621
cap to limits.
@@ -637,7 +638,7 @@ def run_chunk(
637638
model = get_model(model_class, model_kwargs)
638639

639640
mask = np.isnan(chunk.input_data).any(axis=(0, 1, 2))
640-
feats = np.array(model.lr_features[:len(mask)])[mask]
641+
feats = np.array(model.lr_features[: len(mask)])[mask]
641642
if np.any(mask):
642643
msg = f'Input data for {feats} contains NaN values!'
643644
logger.error(msg)

sup3r/pipeline/strategy.py

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -171,10 +171,9 @@ class ForwardPassStrategy:
171171
chunks and overwrite any pre-existing outputs (False).
172172
output_workers : int | None
173173
Max number of workers to use for writing forward pass output.
174-
invert_uv : bool | None
174+
invert_uv : bool
175175
Whether to convert u and v wind components to windspeed and direction
176-
for writing to output. This defaults to True for H5 output and False
177-
for NETCDF output.
176+
for writing to output. This defaults to True.
178177
nn_fill : bool
179178
Whether to fill data outside of accepted limits (e.g. relative
180179
humidity 0-100) with nearest neighbour or cap to limits.
@@ -221,7 +220,7 @@ class ForwardPassStrategy:
221220
allowed_const: Optional[Union[list, bool]] = None
222221
incremental: bool = True
223222
output_workers: int = 1
224-
invert_uv: Optional[bool] = None
223+
invert_uv: Optional[bool] = True
225224
nn_fill: bool = True
226225
pass_workers: int = 1
227226
max_nodes: int = 1
@@ -305,7 +304,7 @@ def meta(self):
305304
return meta_data
306305

307306
def get_time_slices(self):
308-
"""Get the time slice for initializaing the input handler and the
307+
"""Get the time slice for initializing the input handler and the
309308
time slice applied to the data given by the input handler to get the
310309
actual requested time period. These are different because we want the
311310
data stored by the input handler to have extra time steps at the start

sup3r/postprocessing/writers/base.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -305,7 +305,7 @@ def _transform_output(
305305
data,
306306
features,
307307
lat_lon,
308-
invert_uv=None,
308+
invert_uv=False,
309309
nn_fill=False,
310310
max_workers=None,
311311
):
@@ -322,7 +322,7 @@ def _transform_output(
322322
Array of high res lat/lon for output data.
323323
(spatial_1, spatial_2, 2)
324324
Last dimension has ordering (lat, lon)
325-
invert_uv : bool | None
325+
invert_uv : bool
326326
Whether to convert u and v wind components to windspeed and
327327
direction
328328
nn_fill : bool
@@ -554,7 +554,7 @@ def _write_output(
554554
times,
555555
out_file,
556556
meta_data,
557-
invert_uv=True,
557+
invert_uv=False,
558558
nn_fill=False,
559559
max_workers=None,
560560
gids=None,
@@ -570,7 +570,7 @@ def write_output(
570570
low_res_times,
571571
out_file,
572572
meta_data=None,
573-
invert_uv=None,
573+
invert_uv=False,
574574
nn_fill=False,
575575
max_workers=None,
576576
gids=None,
@@ -593,7 +593,7 @@ def write_output(
593593
Output file path
594594
meta_data : dict | None
595595
Dictionary of meta data from model
596-
invert_uv : bool | None
596+
invert_uv : bool
597597
Whether to convert u and v wind components to windspeed and
598598
direction
599599
nn_fill : bool

sup3r/postprocessing/writers/h5.py

Lines changed: 7 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,7 @@ def _write_output(
2525
times,
2626
out_file,
2727
meta_data=None,
28-
invert_uv=None,
28+
invert_uv=False,
2929
nn_fill=False,
3030
max_workers=None,
3131
gids=None,
@@ -49,7 +49,7 @@ def _write_output(
4949
Output file path
5050
meta_data : dict | None
5151
Dictionary of meta data from model
52-
invert_uv : bool | None
52+
invert_uv : bool
5353
Whether to convert u and v wind components to windspeed and
5454
direction
5555
nn_fill : bool
@@ -71,7 +71,6 @@ def _write_output(
7171
f'({len(times)}) conflict.'
7272
)
7373
assert data.shape[-2] == len(times), msg
74-
invert_uv = True if invert_uv is None else invert_uv
7574
data, features = cls._transform_output(
7675
data.copy(),
7776
features,
@@ -85,13 +84,11 @@ def _write_output(
8584
if gids is not None
8685
else np.arange(np.prod(lat_lon.shape[:-1]))
8786
)
88-
meta = pd.DataFrame(
89-
{
90-
'gid': gids.flatten(),
91-
'latitude': lat_lon[..., 0].flatten(),
92-
'longitude': lat_lon[..., 1].flatten(),
93-
}
94-
)
87+
meta = pd.DataFrame({
88+
'gid': gids.flatten(),
89+
'latitude': lat_lon[..., 0].flatten(),
90+
'longitude': lat_lon[..., 1].flatten(),
91+
})
9592
data_list = []
9693
for i, _ in enumerate(features):
9794
flat_data = data[..., i].reshape((-1, len(times)))

sup3r/postprocessing/writers/nc.py

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
"""Output handling"""
22

3+
import datetime
34
import logging
45
from datetime import datetime as dt
56

@@ -27,7 +28,7 @@ def _write_output(
2728
out_file,
2829
meta_data=None,
2930
max_workers=None,
30-
invert_uv=None,
31+
invert_uv=False,
3132
nn_fill=False,
3233
gids=None,
3334
):
@@ -52,7 +53,7 @@ def _write_output(
5253
Dictionary of meta data from model
5354
max_workers : int | None
5455
Max workers to use for inverse transform.
55-
invert_uv : bool | None
56+
invert_uv : bool
5657
Whether to convert u and v wind components to windspeed and
5758
direction
5859
nn_fill : bool
@@ -62,8 +63,6 @@ def _write_output(
6263
List of coordinate indices used to label each lat lon pair and to
6364
help with spatial chunk data collection
6465
"""
65-
66-
invert_uv = False if invert_uv is None else invert_uv
6766
data, features = cls._transform_output(
6867
data=data,
6968
features=features,
@@ -88,7 +87,7 @@ def _write_output(
8887
)
8988

9089
attrs = meta_data or {}
91-
now = dt.utcnow().isoformat()
90+
now = dt.now(datetime.timezone.utc).isoformat()
9291
attrs['date_modified'] = now
9392
attrs['date_created'] = attrs.get('date_created', now)
9493

sup3r/preprocessing/samplers/dual.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -83,8 +83,9 @@ def __init__(
8383
self._lr_only_features = feature_sets.get('lr_only_features', [])
8484
self._hr_exo_features = feature_sets.get('hr_exo_features', [])
8585
self.features = self.get_features(feature_sets)
86-
lr_feats = self.lr_data.features
87-
self.lr_features = [f for f in lr_feats if f in self.features]
86+
self.lr_features = [
87+
f for f in self.features if f in self.lr_data.features
88+
]
8889
self.s_enhance = s_enhance
8990
self.t_enhance = t_enhance
9091
self.check_for_consistent_shapes()

tests/forward_pass/test_forward_pass.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -73,6 +73,7 @@ def test_fwp_nc_cc():
7373
out_pattern=out_files,
7474
input_handler_name='DataHandlerNCforCC',
7575
pass_workers=None,
76+
invert_uv=False
7677
)
7778
forward_pass = ForwardPass(strat)
7879
forward_pass.run(strat, node_index=0)
@@ -172,6 +173,7 @@ def test_fwp_spatial_only(input_files):
172173
out_pattern=out_files,
173174
pass_workers=1,
174175
output_workers=1,
176+
invert_uv=False
175177
)
176178
forward_pass = ForwardPass(strat)
177179
assert strat.output_workers == 1
@@ -225,6 +227,7 @@ def test_fwp_nc(input_files):
225227
},
226228
out_pattern=out_files,
227229
pass_workers=1,
230+
invert_uv=False
228231
)
229232
forward_pass = ForwardPass(strat)
230233
assert forward_pass.strategy.pass_workers == 1
@@ -614,6 +617,7 @@ def test_fwp_multi_step_model(input_files):
614617
model_class=strat.model_class,
615618
allowed_const=strat.allowed_const,
616619
output_workers=strat.output_workers,
620+
invert_uv=strat.invert_uv,
617621
meta=fwp.meta,
618622
)
619623

tests/output/test_qa.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -79,6 +79,7 @@ def test_qa(input_files, ext):
7979
input_handler_kwargs=input_handler_kwargs.copy(),
8080
out_pattern=out_files,
8181
max_nodes=1,
82+
invert_uv=False
8283
)
8384

8485
forward_pass = ForwardPass(strategy)

0 commit comments

Comments
 (0)