Skip to content

Commit 2e35624

Browse files
committed
fix steps blending tests
1 parent 8184755 commit 2e35624

File tree

8 files changed

+307
-223
lines changed

8 files changed

+307
-223
lines changed

pysteps/blending/steps.py

Lines changed: 86 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -42,13 +42,15 @@
4242
calculate_weights_spn
4343
blend_means_sigmas
4444
"""
45+
from datetime import datetime
4546
import math
4647
import time
4748
from copy import copy, deepcopy
4849
from functools import partial
4950
from multiprocessing.pool import ThreadPool
5051

5152
import numpy as np
53+
import xarray as xr
5254
from scipy.linalg import inv
5355
from scipy.ndimage import binary_dilation, generate_binary_structure, iterate_structure
5456

@@ -57,6 +59,7 @@
5759
from pysteps.postprocessing import probmatching
5860
from pysteps.timeseries import autoregression, correlation
5961
from pysteps.utils.check_norain import check_norain
62+
from pysteps.xarray_helpers import convert_output_to_xarray_dataset
6063

6164
try:
6265
import dask
@@ -412,20 +415,76 @@ class StepsBlendingState:
412415
class StepsBlendingNowcaster:
413416
def __init__(
414417
self,
415-
precip,
416-
precip_models,
417-
velocity,
418-
velocity_models,
418+
radar_dataset: xr.Dataset,
419+
model_dataset: xr.Dataset,
419420
time_steps,
420-
issue_time,
421+
issue_time: datetime,
421422
steps_blending_config: StepsBlendingConfig,
422423
):
423424
"""Initializes the StepsBlendingNowcaster with inputs and configurations."""
424425
# Store inputs
425-
self.__precip = precip
426-
self.__precip_models = precip_models
427-
self.__velocity = velocity
428-
self.__velocity_models = velocity_models
426+
radar_precip_var = radar_dataset.attrs["precip_var"]
427+
model_precip_var = model_dataset.attrs["precip_var"]
428+
if issue_time != radar_dataset.time.isel(time=-1).values.astype(
429+
"datetime64[us]"
430+
).astype(datetime):
431+
raise ValueError(
432+
"Issue time should be equal to last timestep in radar dataset"
433+
)
434+
time_stepsize_seconds = radar_dataset.time.attrs["stepsize"]
435+
if isinstance(time_steps, list):
436+
# XR: validates this works or just remove the subtimestep stuff
437+
timesteps_seconds = (
438+
np.array(list(range(time_steps[-1] + 1))) * time_stepsize_seconds
439+
)
440+
else:
441+
timesteps_seconds = (
442+
np.array(list(range(time_steps + 1))) * time_stepsize_seconds
443+
)
444+
model_times = radar_dataset.time.isel(
445+
time=-1
446+
).values + timesteps_seconds.astype("timedelta64[s]")
447+
model_dataset = model_dataset.sel(time=model_times)
448+
449+
self.__precip = radar_dataset[radar_precip_var].values
450+
# XR: don't extract to dict but pass dataset
451+
if model_dataset[model_precip_var].ndim == 5:
452+
self.__precip_models = np.array(
453+
[
454+
[
455+
{
456+
"cascade_levels": model_dataset[model_precip_var]
457+
.sel(time=time, ens_number=ens_number)
458+
.values,
459+
"means": model_dataset["means"]
460+
.sel(time=time, ens_number=ens_number)
461+
.values,
462+
"stds": model_dataset["stds"]
463+
.sel(time=time, ens_number=ens_number)
464+
.values,
465+
"domain": model_dataset[model_precip_var].attrs["domain"],
466+
"normalized": model_dataset[model_precip_var].attrs[
467+
"normalized"
468+
],
469+
"compact_output": model_dataset[model_precip_var].attrs[
470+
"compact_output"
471+
],
472+
}
473+
for time in model_dataset.time
474+
]
475+
for ens_number in model_dataset.ens_number
476+
]
477+
)
478+
else:
479+
self.__precip_models = model_dataset[model_precip_var].values
480+
self.__velocity = np.array(
481+
[radar_dataset["velocity_x"].values, radar_dataset["velocity_y"].values]
482+
)
483+
self.__velocity_models = np.array(
484+
[model_dataset["velocity_x"].values, model_dataset["velocity_y"].values]
485+
).transpose(1, 2, 0, 3, 4)
486+
self.__original_timesteps = time_steps
487+
self.__input_radar_dataset = radar_dataset
429488
self.__timesteps = time_steps
430489
self.__issuetime = issue_time
431490

@@ -447,6 +506,7 @@ def compute_forecast(self):
447506
448507
Parameters
449508
----------
509+
# XR: fix docstring
450510
precip: array-like
451511
Array of shape (ar_order+1,m,n) containing the input precipitation fields
452512
ordered by timestamp from oldest to newest. The time steps between the
@@ -545,7 +605,7 @@ def compute_forecast(self):
545605

546606
# Determine if rain is present in both radar and NWP fields
547607
if self.__params.zero_precip_radar and self.__params.zero_precip_model_fields:
548-
return self.__zero_precipitation_forecast()
608+
result = self.__zero_precipitation_forecast()
549609
else:
550610
# Prepare the data for the zero precipitation radar case and initialize the noise correctly
551611
if self.__params.zero_precip_radar:
@@ -572,16 +632,20 @@ def compute_forecast(self):
572632
for j in range(self.__config.n_ens_members)
573633
]
574634
)
575-
if self.__config.measure_time:
576-
return (
577-
self.__state.final_blended_forecast,
578-
self.__init_time,
579-
self.__mainloop_time,
580-
)
581-
else:
582-
return self.__state.final_blended_forecast
635+
result = self.__state.final_blended_forecast
583636
else:
584637
return None
638+
result_dataset = convert_output_to_xarray_dataset(
639+
self.__input_radar_dataset, self.__original_timesteps, result
640+
)
641+
if self.__config.measure_time:
642+
return (
643+
result_dataset,
644+
self.__init_time,
645+
self.__mainloop_time,
646+
)
647+
else:
648+
return result_dataset
585649

586650
def __blended_nowcast_main_loop(self):
587651
"""
@@ -2817,10 +2881,8 @@ def __measure_time(self, label, start_time):
28172881

28182882

28192883
def forecast(
2820-
precip,
2821-
precip_models,
2822-
velocity,
2823-
velocity_models,
2884+
radar_dataset,
2885+
model_dataset,
28242886
timesteps,
28252887
timestep,
28262888
issuetime,
@@ -2864,6 +2926,7 @@ def forecast(
28642926
28652927
Parameters
28662928
----------
2929+
# XR: fix docstring
28672930
precip: array-like
28682931
Array of shape (ar_order+1,m,n) containing the input precipitation fields
28692932
ordered by timestamp from oldest to newest. The time steps between the
@@ -3182,13 +3245,7 @@ def forecast(
31823245
"""
31833246
# Create an instance of the new class with all the provided arguments
31843247
blended_nowcaster = StepsBlendingNowcaster(
3185-
precip,
3186-
precip_models,
3187-
velocity,
3188-
velocity_models,
3189-
timesteps,
3190-
issuetime,
3191-
blending_config,
3248+
radar_dataset, model_dataset, timesteps, issuetime, blending_config
31923249
)
31933250

31943251
forecast_steps_nowcast = blended_nowcaster.compute_forecast()

0 commit comments

Comments
 (0)