1212import os
1313import warnings
1414import weakref
15+ from pathlib import Path
1516from tempfile import mkstemp
1617
1718import cdsapi
2021import xarray as xr
2122from dask import compute , delayed
2223from dask .array import arctan2 , sqrt
24+ from dask .utils import SerializableLock
2325from numpy import atleast_1d
2426
2527from atlite .gis import maybe_swap_spatial_dims
@@ -86,9 +88,7 @@ def _rename_and_clean_coords(ds, add_lon_lat=True):
8688 Optionally (add_lon_lat, default:True) preserves latitude and
8789 longitude columns as 'lat' and 'lon'.
8890 """
89- ds = ds .rename ({"longitude" : "x" , "latitude" : "y" })
90- if "valid_time" in ds .sizes :
91- ds = ds .rename ({"valid_time" : "time" }).unify_chunks ()
91+ ds = ds .rename ({"longitude" : "x" , "latitude" : "y" , "valid_time" : "time" })
9292 # round coords since cds coords are float32 which would lead to mismatches
9393 ds = ds .assign_coords (
9494 x = np .round (ds .x .astype (float ), 5 ), y = np .round (ds .y .astype (float ), 5 )
@@ -331,20 +331,161 @@ def noisy_unlink(path):
331331 logger .error (f"Unable to delete file { path } , as it is still in use." )
332332
333333
334- def retrieve_data (product , chunks = None , tmpdir = None , lock = None , ** updates ):
334+ def add_finalizer (ds : xr .Dataset , target : str | Path ):
335+ logger .debug (f"Adding finalizer for { target } " )
336+ weakref .finalize (ds ._close .__self__ .ds , noisy_unlink , target )
337+
338+
339+ def sanitize_chunks (chunks , ** dim_mapping ):
340+ dim_mapping = dict (time = "valid_time" , x = "longitude" , y = "latitude" ) | dim_mapping
341+ if not isinstance (chunks , dict ):
342+ # preserve "auto" or None
343+ return chunks
344+
345+ return {
346+ extname : chunks [intname ]
347+ for intname , extname in dim_mapping .items ()
348+ if intname in chunks
349+ }
350+
351+
352+ def open_with_grib_conventions (
353+ grib_file : str | Path , chunks = None , tmpdir : str | Path | None = None
354+ ) -> xr .Dataset :
355+ """
356+ Convert grib file of ERA5 data from the CDS to netcdf file.
357+
358+ The function does the same thing as the CDS backend does, but locally.
359+ This is needed, as the grib file is the recommended download file type for CDS, with conversion to netcdf locally.
360+ The routine is a reduced version based on the documentation here:
361+ https://confluence.ecmwf.int/display/CKB/GRIB+to+netCDF+conversion+on+new+CDS+and+ADS+systems#GRIBtonetCDFconversiononnewCDSandADSsystems-jupiternotebook
362+
363+ Parameters
364+ ----------
365+ grib_file : str | Path
366+ Path to the grib file to be converted.
367+ chunks
368+ Chunks
369+ tmpdir : Path, optional
370+ If None adds a finalizer to the dataset object
371+
372+ Returns
373+ -------
374+ xr.Dataset
375+ """
376+ #
377+ # Open grib file as dataset
378+ # Options to open different datasets into a datasets of consistent hypercubes which are compatible netCDF
379+ # There are options that might be relevant for e.g. for wave model data, that have been removed here
380+ # to keep the code cleaner and shorter
381+ ds = xr .open_dataset (
382+ grib_file ,
383+ engine = "cfgrib" ,
384+ time_dims = ["valid_time" ],
385+ ignore_keys = ["edition" ],
386+ # extra_coords={"expver": "valid_time"},
387+ coords_as_attributes = [
388+ "surface" ,
389+ "depthBelowLandLayer" ,
390+ "entireAtmosphere" ,
391+ "heightAboveGround" ,
392+ "meanSea" ,
393+ ],
394+ chunks = sanitize_chunks (chunks ),
395+ )
396+ if tmpdir is None :
397+ add_finalizer (ds , grib_file )
398+
399+ def safely_expand_dims (dataset : xr .Dataset , expand_dims : list [str ]) -> xr .Dataset :
400+ """
401+ Expand dimensions in an xarray dataset, ensuring that the new dimensions are not already in the dataset
402+ and that the order of dimensions is preserved.
403+ """
404+ dims_required = [
405+ c for c in dataset .coords if c in expand_dims + list (dataset .dims )
406+ ]
407+ dims_missing = [
408+ (c , i ) for i , c in enumerate (dims_required ) if c not in dataset .dims
409+ ]
410+ dataset = dataset .expand_dims (
411+ dim = [x [0 ] for x in dims_missing ], axis = [x [1 ] for x in dims_missing ]
412+ )
413+ return dataset
414+
415+ logger .debug ("Converting grib file to netcdf format" )
416+ # Variables and dimensions to rename if they exist in the dataset
417+ rename_vars = {
418+ "time" : "forecast_reference_time" ,
419+ "step" : "forecast_period" ,
420+ "isobaricInhPa" : "pressure_level" ,
421+ "hybrid" : "model_level" ,
422+ }
423+ rename_vars = {k : v for k , v in rename_vars .items () if k in ds }
424+ ds = ds .rename (rename_vars )
425+
426+ # safely expand dimensions in an xarray dataset to ensure that data for the new dimensions are in the dataset
427+ ds = safely_expand_dims (ds , ["valid_time" , "pressure_level" , "model_level" ])
428+
429+ return ds
430+
431+
432+ def retrieve_data (
433+ product : str ,
434+ chunks : dict [str , int ] | None = None ,
435+ tmpdir : str | Path | None = None ,
436+ lock : SerializableLock | None = None ,
437+ ** updates ,
438+ ) -> xr .Dataset :
335439 """
336440 Download data like ERA5 from the Climate Data Store (CDS).
337441
338442 If you want to track the state of your request go to
339443 https://cds-beta.climate.copernicus.eu/requests?tab=all
444+
445+ Parameters
446+ ----------
447+ product : str
448+ Product name, e.g. 'reanalysis-era5-single-levels'.
449+ chunks : dict, optional
450+ Chunking for xarray dataset, e.g. {'time': 1, 'x': 100, 'y': 100}.
451+ Default is None.
452+ tmpdir : str, optional
453+ Directory where the downloaded data is temporarily stored.
454+ Default is None, which uses the system's temporary directory.
455+ lock : dask.utils.SerializableLock, optional
456+ Lock for thread-safe file writing. Default is None.
457+ updates : dict
458+ Additional parameters for the request.
459+ Must include 'year', 'month', and 'variable'.
460+ Can include e.g. 'data_format'.
461+
462+ Returns
463+ -------
464+ xarray.Dataset
465+ Dataset with the retrieved variables.
466+
467+ Examples
468+ --------
469+ >>> ds = retrieve_data(
470+ ... product='reanalysis-era5-single-levels',
471+ ... chunks={'time': 1, 'x': 100, 'y': 100},
472+ ... tmpdir='/tmp',
473+ ... lock=None,
474+ ... year='2020',
475+ ... month='01',
476+ ... variable=['10m_u_component_of_wind', '10m_v_component_of_wind'],
477+ ... data_format='netcdf'
478+ ... )
340479 """
341- request = {"product_type" : "reanalysis" , "format " : "netcdf " }
480+ request = {"product_type" : [ "reanalysis" ] , "download_format " : "unarchived " }
342481 request .update (updates )
343482
344483 assert {"year" , "month" , "variable" }.issubset (request ), (
345484 "Need to specify at least 'variable', 'year' and 'month'"
346485 )
347486
487+ logger .debug (f"Requesting { product } with API request: { request } " )
488+
348489 client = cdsapi .Client (
349490 info_callback = logger .debug , debug = logging .DEBUG >= logging .root .level
350491 )
@@ -353,8 +494,9 @@ def retrieve_data(product, chunks=None, tmpdir=None, lock=None, **updates):
353494 if lock is None :
354495 lock = nullcontext ()
355496
497+ suffix = f".{ request ['data_format' ]} " # .netcdf or .grib
356498 with lock :
357- fd , target = mkstemp (suffix = ".nc" , dir = tmpdir )
499+ fd , target = mkstemp (suffix = suffix , dir = tmpdir )
358500 os .close (fd )
359501
360502 # Inform user about data being downloaded as "* variable (year-month)"
@@ -364,10 +506,13 @@ def retrieve_data(product, chunks=None, tmpdir=None, lock=None, **updates):
364506 logger .info (f"CDS: Downloading variables\n \t { varstr } \n " )
365507 result .download (target )
366508
367- ds = xr .open_dataset (target , chunks = chunks or {})
368- if tmpdir is None :
369- logger .debug (f"Adding finalizer for { target } " )
370- weakref .finalize (ds ._file_obj ._manager , noisy_unlink , target )
509+ # Convert from grib to netcdf locally, same conversion as in CDS backend
510+ if request ["data_format" ] == "grib" :
511+ ds = open_with_grib_conventions (target , chunks = chunks , tmpdir = tmpdir )
512+ else :
513+ ds = xr .open_dataset (target , chunks = sanitize_chunks (chunks ))
514+ if tmpdir is None :
515+ add_finalizer (target )
371516
372517 return ds
373518
@@ -377,6 +522,7 @@ def get_data(
377522 feature ,
378523 tmpdir ,
379524 lock = None ,
525+ data_format = "grib" ,
380526 monthly_requests = False ,
381527 concurrent_requests = False ,
382528 ** creation_parameters ,
@@ -399,6 +545,9 @@ def get_data(
399545 If True, the data is requested on a monthly basis in ERA5. This is useful for
400546 large cutouts, where the data is requested in smaller chunks. The
401547 default is False
548+ data_format : str, optional
549+ The format of the data to be downloaded. Can be either 'grib' or 'netcdf',
550+ 'grib' highly recommended because CDSAPI limits request size for netcdf.
402551 concurrent_requests : bool, optional
403552 If True, the monthly data requests are posted concurrently.
404553 Only has an effect if `monthly_requests` is True.
@@ -420,9 +569,10 @@ def get_data(
420569 "product" : "reanalysis-era5-single-levels" ,
421570 "area" : _area (coords ),
422571 "chunks" : cutout .chunks ,
423- "grid" : [ cutout .dx , cutout .dy ] ,
572+ "grid" : f" { cutout .dx } / { cutout .dy } " ,
424573 "tmpdir" : tmpdir ,
425574 "lock" : lock ,
575+ "data_format" : data_format ,
426576 }
427577
428578 func = globals ().get (f"get_data_{ feature } " )
0 commit comments