Skip to content

Commit d7f84fb

Browse files
committed
fix blend utils and xarray dataformat
1 parent 2330b01 commit d7f84fb

24 files changed

+310
-521
lines changed

pysteps/blending/utils.py

Lines changed: 93 additions & 183 deletions
Original file line numberDiff line numberDiff line change
@@ -19,10 +19,12 @@
1919
"""
2020

2121
import datetime
22+
from typing import Any
2223
import warnings
2324
from pathlib import Path
2425

2526
import numpy as np
27+
import xarray as xr
2628

2729
from pysteps.cascade import get_method as cascade_get_method
2830
from pysteps.cascade.bandpass_filters import filter_gaussian
@@ -241,12 +243,7 @@ def blend_optical_flows(flows, weights):
241243

242244

243245
def decompose_NWP(
244-
R_NWP,
245-
NWP_model,
246-
analysis_time,
247-
timestep,
248-
valid_times,
249-
output_path,
246+
precip_nwp_dataset: xr.Dataset,
250247
num_cascade_levels=8,
251248
num_workers=1,
252249
decomp_method="fft",
@@ -255,7 +252,7 @@ def decompose_NWP(
255252
normalize=True,
256253
compute_stats=True,
257254
compact_output=True,
258-
):
255+
) -> xr.Dataset:
259256
"""Decomposes the NWP forecast data into cascades and saves it in
260257
a netCDF file
261258
@@ -269,11 +266,6 @@ def decompose_NWP(
269266
analysis_time: numpy.datetime64
270267
The analysis time of the NWP forecast. The analysis time is assumed to be a
271268
numpy.datetime64 type as imported by the pysteps importer
272-
timestep: int
273-
Timestep in minutes between subsequent NWP forecast fields
274-
valid_times: array_like
275-
Array containing the valid times of the NWP forecast fields. The times are
276-
assumed to be numpy.datetime64 types as imported by the pysteps importer.
277269
output_path: str
278270
The location where to save the file with the NWP cascade. Defaults to the
279271
path_workdir specified in the rcparams file.
@@ -315,62 +307,41 @@ def decompose_NWP(
315307
316308
Returns
317309
-------
318-
None
310+
xarray.Dataset
311+
The same dataset as was passed in but with the precip data replaced
312+
with decomposed precip data and means and stds added
319313
"""
320314

321-
if not NETCDF4_IMPORTED:
322-
raise MissingOptionalDependency(
323-
"netCDF4 package is required to save the decomposed NWP data, "
324-
"but it is not installed"
325-
)
326-
327-
# Make a NetCDF file
328-
output_date = f"{analysis_time.astype('datetime64[us]').astype(datetime.datetime):%Y%m%d%H%M%S}"
329-
outfn = Path(output_path) / f"cascade_{NWP_model}_{output_date}.nc"
330-
ncf = netCDF4.Dataset(outfn, "w", format="NETCDF4")
331-
332-
# Express times relative to the zero time
333-
zero_time = np.datetime64("1970-01-01T00:00:00", "ns")
334-
valid_times = np.array(valid_times) - zero_time
335-
analysis_time = analysis_time - zero_time
336-
337-
# Set attributes of decomposition method
338-
ncf.domain = domain
339-
ncf.normalized = int(normalize)
340-
ncf.compact_output = int(compact_output)
341-
ncf.analysis_time = int(analysis_time)
342-
ncf.timestep = int(timestep)
343-
344-
# Create dimensions
345-
ncf.createDimension("time", R_NWP.shape[0])
346-
ncf.createDimension("cascade_levels", num_cascade_levels)
347-
ncf.createDimension("x", R_NWP.shape[2])
348-
ncf.createDimension("y", R_NWP.shape[1])
349-
350-
# Create variables (decomposed cascade, means and standard deviations)
351-
R_d = ncf.createVariable(
352-
"pr_decomposed",
353-
np.float32,
354-
("time", "cascade_levels", "y", "x"),
355-
zlib=True,
356-
complevel=4,
357-
)
358-
means = ncf.createVariable("means", np.float64, ("time", "cascade_levels"))
359-
stds = ncf.createVariable("stds", np.float64, ("time", "cascade_levels"))
360-
v_times = ncf.createVariable("valid_times", np.float64, ("time",))
361-
v_times.units = "nanoseconds since 1970-01-01 00:00:00"
362-
363-
# The valid times are saved as an array of floats, because netCDF files can't handle datetime types
364-
v_times[:] = np.array([np.float64(valid_times[i]) for i in range(len(valid_times))])
365-
315+
nwp_precip_var = precip_nwp_dataset.attrs["precip_var"]
316+
precip_nwp = precip_nwp_dataset[nwp_precip_var].values
366317
# Decompose the NWP data
367-
filter_g = filter_gaussian(R_NWP.shape[1:], num_cascade_levels)
368-
fft = utils_get_method(fft_method, shape=R_NWP.shape[1:], n_threads=num_workers)
318+
filter_g = filter_gaussian(precip_nwp.shape[1:], num_cascade_levels)
319+
fft = utils_get_method(
320+
fft_method, shape=precip_nwp.shape[1:], n_threads=num_workers
321+
)
369322
decomp_method, _ = cascade_get_method(decomp_method)
370323

371-
for i in range(R_NWP.shape[0]):
372-
R_ = decomp_method(
373-
field=R_NWP[i, :, :],
324+
pr_decomposed = np.zeros(
325+
(
326+
precip_nwp.shape[0],
327+
num_cascade_levels,
328+
precip_nwp.shape[1],
329+
precip_nwp.shape[2],
330+
),
331+
dtype=np.float32,
332+
)
333+
means = np.zeros(
334+
(precip_nwp.shape[0], num_cascade_levels),
335+
dtype=np.float64,
336+
)
337+
stds = np.zeros(
338+
(precip_nwp.shape[0], num_cascade_levels),
339+
dtype=np.float64,
340+
)
341+
342+
for i in range(precip_nwp.shape[0]):
343+
decomposed_precip_nwp = decomp_method(
344+
field=precip_nwp[i, :, :],
374345
bp_filter=filter_g,
375346
fft_method=fft,
376347
input_domain=domain,
@@ -380,157 +351,96 @@ def decompose_NWP(
380351
compact_output=compact_output,
381352
)
382353

383-
# Save data to netCDF file
384-
# print(R_["cascade_levels"])
385-
R_d[i, :, :, :] = R_["cascade_levels"]
386-
means[i, :] = R_["means"]
387-
stds[i, :] = R_["stds"]
354+
pr_decomposed[i, :, :, :] = decomposed_precip_nwp["cascade_levels"]
355+
means[i, :] = decomposed_precip_nwp["means"]
356+
stds[i, :] = decomposed_precip_nwp["stds"]
388357

389-
# Close the file
390-
ncf.close()
391-
392-
393-
def compute_store_nwp_motion(
394-
precip_nwp,
395-
oflow_method,
396-
analysis_time,
397-
nwp_model,
398-
output_path,
358+
precip_nwp_dataset = precip_nwp_dataset.assign_coords(
359+
cascade_level=(
360+
"cascade_level",
361+
np.arange(num_cascade_levels),
362+
{"long_name": "cascade level", "units": ""},
363+
)
364+
)
365+
precip_nwp_dataset = precip_nwp_dataset.drop_vars(nwp_precip_var)
366+
precip_nwp_dataset[nwp_precip_var] = (
367+
["time", "cascade_level", "y", "x"],
368+
pr_decomposed,
369+
)
370+
precip_nwp_dataset["means"] = (["time", "cascade_level"], means)
371+
precip_nwp_dataset["stds"] = (["time", "cascade_level"], stds)
372+
return precip_nwp_dataset
373+
374+
375+
def preprocess_and_store_nwp_data(
376+
precip_nwp_dataset: xr.Dataset,
377+
oflow_method: str,
378+
nwp_model: str,
379+
output_path: str | None,
380+
decompose_nwp: bool,
381+
decompose_kwargs: dict[str, Any] = {},
399382
):
400383
"""Computes, per forecast lead time, the velocity field of an NWP model field.
401384
402385
Parameters
403386
----------
404-
precip_nwp: array-like
405-
Array of dimension (n_timesteps, x, y) containing the precipitation forecast
387+
precip_nwp_dataset: xarray.Dataset
388+
xarray Dataset containing the precipitation forecast
406389
from some NWP model.
407390
oflow_method: {'constant', 'darts', 'lucaskanade', 'proesmans', 'vet'}, optional
408391
An optical flow method from pysteps.motion.get_method.
409-
analysis_time: numpy.datetime64
410-
The analysis time of the NWP forecast. The analysis time is assumed to be a
411-
numpy.datetime64 type as imported by the pysteps importer.
412392
nwp_model: str
413393
The name of the NWP model.
414394
output_path: str, optional
415-
The location where to save the file with the NWP velocity fields. Defaults
395+
The location where to save the netcdf file with the NWP velocity fields. Defaults
416396
to the path_workdir specified in the rcparams file.
397+
decompose_nwp: bool
398+
Defines wether or not the NWP needs to be decomposed before storing. This can
399+
be beneficial for performance, because then the decomposition does not need
400+
to happen during the blending anymore. It can however also be detrimental because
401+
this increases the amount of storage and RAM required for the blending.
402+
decompose_kwargs: dict
403+
Keyword arguments passed to the decompose_NWP method.
417404
418405
Returns
419406
-------
420407
Nothing
421408
"""
422409

410+
if not NETCDF4_IMPORTED:
411+
raise MissingOptionalDependency(
412+
"netCDF4 package is required to save the NWP data, "
413+
"but it is not installed"
414+
)
415+
423416
# Set the output file
417+
analysis_time = precip_nwp_dataset.time.values[0]
424418
output_date = f"{analysis_time.astype('datetime64[us]').astype(datetime.datetime):%Y%m%d%H%M%S}"
425-
outfn = Path(output_path) / f"motion_{nwp_model}_{output_date}.npy"
419+
outfn = Path(output_path) / f"preprocessed_{nwp_model}_{output_date}.nc"
420+
nwp_precip_var = precip_nwp_dataset.attrs["precip_var"]
421+
precip_nwp = precip_nwp_dataset[nwp_precip_var].values
426422

427423
# Get the velocity field per time step
428-
v_nwp = np.zeros((precip_nwp.shape[0], 2, precip_nwp.shape[1], precip_nwp.shape[2]))
424+
v_nwp_x = np.zeros((precip_nwp.shape[0], precip_nwp.shape[1], precip_nwp.shape[2]))
425+
v_nwp_y = np.zeros((precip_nwp.shape[0], precip_nwp.shape[1], precip_nwp.shape[2]))
429426
# Loop through the timesteps. We need two images to construct a motion
430427
# field, so we can start from timestep 1.
431428
for t in range(1, precip_nwp.shape[0]):
432-
v_nwp[t] = oflow_method(precip_nwp[t - 1 : t + 1, :, :])
429+
v_nwp_dataset = oflow_method(precip_nwp_dataset.isel(time=slice(t - 1, t + 1)))
430+
v_nwp_x[t] = v_nwp_dataset.velocity_x
431+
v_nwp_y[t] = v_nwp_dataset.velocity_y
433432

434433
# Make timestep 0 the same as timestep 1.
435-
v_nwp[0] = v_nwp[1]
434+
v_nwp_x[0] = v_nwp_x[1]
435+
v_nwp_y[0] = v_nwp_y[1]
436+
precip_nwp_dataset["velocity_x"] = (["time", "y", "x"], v_nwp_x)
437+
precip_nwp_dataset["velocity_y"] = (["time", "y", "x"], v_nwp_y)
436438

437-
assert v_nwp.ndim == 4, "v_nwp must be a four-dimensional array"
439+
if decompose_nwp:
440+
precip_nwp_dataset = decompose_NWP(precip_nwp_dataset, **decompose_kwargs)
438441

439442
# Save it as a numpy array
440-
np.save(outfn, v_nwp)
441-
442-
443-
def load_NWP(input_nc_path_decomp, input_path_velocities, start_time, n_timesteps):
444-
"""Loads the decomposed NWP and velocity data from the netCDF files
445-
446-
Parameters
447-
----------
448-
input_nc_path_decomp: str
449-
Path to the saved netCDF file containing the decomposed NWP data.
450-
input_path_velocities: str
451-
Path to the saved numpy binary file containing the estimated velocity
452-
fields from the NWP data.
453-
start_time: numpy.datetime64
454-
The start time of the nowcasting. Assumed to be a numpy.datetime64 type
455-
n_timesteps: int
456-
Number of time steps to forecast
457-
458-
Returns
459-
-------
460-
R_d: list
461-
A list of dictionaries with each element in the list corresponding to
462-
a different time step. Each dictionary has the same structure as the
463-
output of the decomposition function
464-
uv: array-like
465-
Array of shape (timestep,2,m,n) containing the x- and y-components
466-
of the advection field for the (NWP) model field per forecast lead time.
467-
"""
468-
469-
if not NETCDF4_IMPORTED:
470-
raise MissingOptionalDependency(
471-
"netCDF4 package is required to load the decomposed NWP data, "
472-
"but it is not installed"
473-
)
474-
475-
# Open the file
476-
ncf_decomp = netCDF4.Dataset(input_nc_path_decomp, "r", format="NETCDF4")
477-
velocities = np.load(input_path_velocities)
478-
479-
decomp_dict = {
480-
"domain": ncf_decomp.domain,
481-
"normalized": bool(ncf_decomp.normalized),
482-
"compact_output": bool(ncf_decomp.compact_output),
483-
}
484-
485-
# Convert the start time and the timestep to datetime64 and timedelta64 type
486-
zero_time = np.datetime64("1970-01-01T00:00:00", "ns")
487-
analysis_time = np.timedelta64(int(ncf_decomp.analysis_time), "ns") + zero_time
488-
489-
timestep = ncf_decomp.timestep
490-
timestep = np.timedelta64(timestep, "m")
491-
492-
valid_times = ncf_decomp.variables["valid_times"][:]
493-
valid_times = np.array(
494-
[np.timedelta64(int(valid_times[i]), "ns") for i in range(len(valid_times))]
495-
)
496-
valid_times = valid_times + zero_time
497-
498-
# Find the indices corresponding with the required start and end time
499-
start_i = (start_time - analysis_time) // timestep
500-
assert analysis_time + start_i * timestep == start_time
501-
end_i = start_i + n_timesteps + 1
502-
503-
# Check if the requested end time (the forecast horizon) is in the stored data.
504-
# If not, raise an error
505-
if end_i > ncf_decomp.variables["pr_decomposed"].shape[0]:
506-
raise IndexError(
507-
"The requested forecast horizon is outside the stored NWP forecast horizon. Either request a shorter forecast horizon or store a longer NWP forecast horizon"
508-
)
509-
510-
# Add the valid times to the output
511-
decomp_dict["valid_times"] = valid_times[start_i:end_i]
512-
513-
# Slice the velocity fields with the start and end indices
514-
uv = velocities[start_i:end_i, :, :, :]
515-
516-
# Initialise the list of dictionaries which will serve as the output (cf: the STEPS function)
517-
R_d = list()
518-
519-
pr_decomposed = ncf_decomp.variables["pr_decomposed"][start_i:end_i, :, :, :]
520-
means = ncf_decomp.variables["means"][start_i:end_i, :]
521-
stds = ncf_decomp.variables["stds"][start_i:end_i, :]
522-
523-
for i in range(n_timesteps + 1):
524-
decomp_dict["cascade_levels"] = np.ma.filled(
525-
pr_decomposed[i], fill_value=np.nan
526-
)
527-
decomp_dict["means"] = np.ma.filled(means[i], fill_value=np.nan)
528-
decomp_dict["stds"] = np.ma.filled(stds[i], fill_value=np.nan)
529-
530-
R_d.append(decomp_dict.copy())
531-
532-
ncf_decomp.close()
533-
return R_d, uv
443+
precip_nwp_dataset.to_netcdf(outfn)
534444

535445

536446
def check_norain(precip_arr, precip_thr=None, norain_thr=0.0):

pysteps/downscaling/rainfarm.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -318,7 +318,12 @@ def downscale(
318318
noise_dataset = xr.Dataset(
319319
data_vars={precip_var: (["time", "y", "x"], [noise_field])},
320320
coords={
321-
"time": (["time"], precip_dataset.time.values, precip_dataset.time.attrs),
321+
"time": (
322+
["time"],
323+
precip_dataset.time.values,
324+
precip_dataset.time.attrs,
325+
precip_dataset.time.encoding,
326+
),
322327
"y": (
323328
["y"],
324329
y_new,

0 commit comments

Comments
 (0)