diff --git a/parcels/field.py b/parcels/field.py index 159c7c299f..048de22759 100644 --- a/parcels/field.py +++ b/parcels/field.py @@ -1,8 +1,10 @@ -import collections +from __future__ import annotations + import math import warnings from collections.abc import Iterable from ctypes import POINTER, Structure, c_float, c_int, pointer +from glob import glob from pathlib import Path from typing import TYPE_CHECKING, Literal @@ -50,6 +52,9 @@ from parcels.fieldset import FieldSet + T_Dimensions = Literal["lon", "lat", "depth", "data"] + T_SanitizedFilenames = list[str] | dict[T_Dimensions, list[str]] + __all__ = ["Field", "NestedField", "VectorField"] @@ -426,22 +431,8 @@ def netcdf_engine(self): @classmethod @deprecated_made_private # TODO: Remove 6 months after v3.1.0 - def get_dim_filenames(cls, *args, **kwargs): - return cls._get_dim_filenames(*args, **kwargs) - - @classmethod - def _get_dim_filenames(cls, filenames, dim): - if isinstance(filenames, str) or not isinstance(filenames, collections.abc.Iterable): - return [filenames] - elif isinstance(filenames, dict): - assert dim in filenames.keys(), "filename dimension keys must be lon, lat, depth or data" - filename = filenames[dim] - if isinstance(filename, str): - return [filename] - else: - return filename - else: - return filenames + def get_dim_filenames(*args, **kwargs): + return _get_dim_filenames(*args, **kwargs) @staticmethod @deprecated_made_private # TODO: Remove 6 months after v3.1.0 @@ -498,7 +489,7 @@ def from_netcdf( time_periodic: TimePeriodic = False, deferred_load: bool = True, **kwargs, - ) -> "Field": + ) -> Field: """Create field from netCDF file. Parameters @@ -558,6 +549,8 @@ def from_netcdf( * `Timestamps <../examples/tutorial_timestamps.ipynb>`__ """ + filenames = _sanitize_field_filenames(filenames) + if kwargs.get("netcdf_decodewarning") is not None: _deprecated_param_netcdf_decodewarning() kwargs.pop("netcdf_decodewarning") @@ -598,20 +591,20 @@ def from_netcdf( len(variable) == 2 ), "The variable tuple must have length 2. Use FieldSet.from_netcdf() for multiple variables" - data_filenames = cls._get_dim_filenames(filenames, "data") - lonlat_filename = cls._get_dim_filenames(filenames, "lon") + data_filenames = _get_dim_filenames(filenames, "data") + lonlat_filename_lst = _get_dim_filenames(filenames, "lon") if isinstance(filenames, dict): - assert len(lonlat_filename) == 1 - if lonlat_filename != cls._get_dim_filenames(filenames, "lat"): + assert len(lonlat_filename_lst) == 1 + if lonlat_filename_lst != _get_dim_filenames(filenames, "lat"): raise NotImplementedError( "longitude and latitude dimensions are currently processed together from one single file" ) - lonlat_filename = lonlat_filename[0] + lonlat_filename = lonlat_filename_lst[0] if "depth" in dimensions: - depth_filename = cls._get_dim_filenames(filenames, "depth") - if isinstance(filenames, dict) and len(depth_filename) != 1: + depth_filename_lst = _get_dim_filenames(filenames, "depth") + if isinstance(filenames, dict) and len(depth_filename_lst) != 1: raise NotImplementedError("Vertically adaptive meshes not implemented for from_netcdf()") - depth_filename = depth_filename[0] + depth_filename = depth_filename_lst[0] netcdf_engine = kwargs.pop("netcdf_engine", "netcdf4") gridindexingtype = kwargs.get("gridindexingtype", "nemo") @@ -2584,3 +2577,67 @@ def __getitem__(self, key): else: pass return val + + +def _get_dim_filenames(filenames: T_SanitizedFilenames, dim: T_Dimensions) -> list[str]: + """Get's the relevant filenames for a given dimension.""" + if isinstance(filenames, list): + return filenames + + if isinstance(filenames, dict): + return filenames[dim] + + raise ValueError("Filenames must be a string, pathlib.Path, or a dictionary") + + +def _sanitize_field_filenames(filenames, *, recursed=False) -> T_SanitizedFilenames: + """The Field initializer can take `filenames` to be of various formats including: + + 1. a string or Path object. String can be a glob expression. + 2. a list of (a) + 3. a dictionary mapping with keys 'lon', 'lat', 'depth', 'data' and values of (1) or (2) + + This function sanitizes the inputs such that it returns, in the case of: + 1. A sorted list of strings with the expanded glob expression + 2. A sorted list of strings with the expanded glob expressions + 3. A dictionary with same keys but values as in (1) or (2). + + See tests for examples. + """ + allowed_dimension_keys = ("lon", "lat", "depth", "data") + + if isinstance(filenames, str) or not isinstance(filenames, Iterable): + return sorted(_expand_filename(filenames)) + + if isinstance(filenames, list): + files = [] + for f in filenames: + files.extend(_expand_filename(f)) + return sorted(files) + + if isinstance(filenames, dict): + if recursed: + raise ValueError("Invalid filenames format. Nested dictionary not allowed in dimension dictionary") + + for key in filenames: + if key not in allowed_dimension_keys: + raise ValueError( + f"Invalid key in filenames dimension dictionary. Must be one of {allowed_dimension_keys}" + ) + filenames[key] = _sanitize_field_filenames(filenames[key], recursed=True) + + return filenames + + raise ValueError("Filenames must be a string, pathlib.Path, list, or a dictionary") + + +def _expand_filename(filename: str | Path) -> list[str]: + """ + Converts a filename to a list of filenames if it is a glob expression. + + If a file is explicitly provided (i.e., not via glob), existence is only checked later. + """ + filename = str(filename) + if "*" in filename: + return glob(filename) + return [filename] diff --git a/parcels/fieldset.py b/parcels/fieldset.py index 5fa6b13c53..3b582c2a18 100644 --- a/parcels/fieldset.py +++ b/parcels/fieldset.py @@ -3,7 +3,6 @@ import sys import warnings from copy import deepcopy -from glob import glob import dask.array as da import numpy as np @@ -348,19 +347,9 @@ def check_velocityfields(U, V, W): @classmethod @deprecated_made_private # TODO: Remove 6 months after v3.1.0 def parse_wildcards(cls, *args, **kwargs): - return cls._parse_wildcards(*args, **kwargs) - - @classmethod - def _parse_wildcards(cls, paths, filenames, var): - if not isinstance(paths, list): - paths = sorted(glob(str(paths))) - if len(paths) == 0: - notfound_paths = filenames[var] if isinstance(filenames, dict) and var in filenames else filenames - raise OSError(f"FieldSet files not found for variable {var}: {notfound_paths}") - for fp in paths: - if not os.path.exists(fp): - raise OSError(f"FieldSet file not found: {fp}") - return paths + raise NotImplementedError( + "parse_wildcards was removed as a function as the internal implementation was no longer used." + ) @classmethod def from_netcdf( @@ -474,13 +463,11 @@ def from_netcdf( if "creation_log" not in kwargs.keys(): kwargs["creation_log"] = "from_netcdf" for var, name in variables.items(): - # Resolve all matching paths for the current variable - paths = filenames[var] if type(filenames) is dict and var in filenames else filenames - if type(paths) is not dict: - paths = cls._parse_wildcards(paths, filenames, var) + paths: list[str] + if isinstance(filenames, dict) and var in filenames: + paths = filenames[var] else: - for dim, p in paths.items(): - paths[dim] = cls._parse_wildcards(p, filenames, var) + paths = filenames # Use dimensions[var] and indices[var] if either of them is a dict of dicts dims = dimensions[var] if var in dimensions else dimensions diff --git a/tests/test_advection.py b/tests/test_advection.py index c9e441734e..77705f0576 100644 --- a/tests/test_advection.py +++ b/tests/test_advection.py @@ -73,7 +73,6 @@ def test_advection_zonal(lon, lat, depth, mode): } dimensions = {"lon": lon, "lat": lat} fieldset2D = FieldSet.from_data(data2D, dimensions, mesh="spherical", transpose=True) - assert fieldset2D.U._creation_log == "from_data" pset2D = ParticleSet(fieldset2D, pclass=ptype[mode], lon=np.zeros(npart) + 20.0, lat=np.linspace(0, 80, npart)) pset2D.execute(AdvectionRK4, runtime=timedelta(hours=2), dt=timedelta(seconds=30)) diff --git a/tests/test_deprecations.py b/tests/test_deprecations.py index f75b85b574..8666199f80 100644 --- a/tests/test_deprecations.py +++ b/tests/test_deprecations.py @@ -122,7 +122,7 @@ def test_testing_action_class(): Action("Field", "c_data_chunks", "make_private" ), Action("Field", "chunk_set", "make_private" ), Action("Field", "cell_edge_sizes", "read_only" ), - Action("Field", "get_dim_filenames()", "make_private" ), + Action("Field", "get_dim_filenames()", "make_private" , skip_reason="Moved underlying function."), Action("Field", "collect_timeslices()", "make_private" ), Action("Field", "reshape()", "make_private" ), Action("Field", "calc_cell_edge_sizes()", "make_private" ), @@ -148,7 +148,7 @@ def test_testing_action_class(): Action("FieldSet", "particlefile", "read_only" ), Action("FieldSet", "add_UVfield()", "make_private" ), Action("FieldSet", "check_complete()", "make_private" ), - Action("FieldSet", "parse_wildcards()", "make_private" ), + Action("FieldSet", "parse_wildcards()", "make_private" , skip_reason="Moved underlying function."), # 1713 Action("ParticleSet", "repeat_starttime", "make_private" ), diff --git a/tests/test_field.py b/tests/test_field.py new file mode 100644 index 0000000000..2478cea2c6 --- /dev/null +++ b/tests/test_field.py @@ -0,0 +1,142 @@ +from pathlib import Path + +import cftime +import numpy as np +import pytest +import xarray as xr + +from parcels import Field +from parcels.field import _expand_filename, _sanitize_field_filenames +from parcels.tools.converters import ( + _get_cftime_calendars, + _get_cftime_datetimes, +) +from tests.utils import TEST_DATA + + +def test_field_from_netcdf_variables(): + filename = str(TEST_DATA / "perlinfieldsU.nc") + dims = {"lon": "x", "lat": "y"} + + variable = "vozocrtx" + f1 = Field.from_netcdf(filename, variable, dims) + variable = ("U", "vozocrtx") + f2 = Field.from_netcdf(filename, variable, dims) + variable = {"U": "vozocrtx"} + f3 = Field.from_netcdf(filename, variable, dims) + + assert np.allclose(f1.data, f2.data, atol=1e-12) + assert np.allclose(f1.data, f3.data, atol=1e-12) + + with pytest.raises(AssertionError): + variable = {"U": "vozocrtx", "nav_lat": "nav_lat"} # multiple variables will fail + f3 = Field.from_netcdf(filename, variable, dims) + + +@pytest.mark.parametrize("with_timestamps", [True, False]) +def test_field_from_netcdf(with_timestamps): + filenames = { + "lon": str(TEST_DATA / "mask_nemo_cross_180lon.nc"), + "lat": str(TEST_DATA / "mask_nemo_cross_180lon.nc"), + "data": str(TEST_DATA / "Uu_eastward_nemo_cross_180lon.nc"), + } + variable = "U" + dimensions = {"lon": "glamf", "lat": "gphif"} + if with_timestamps: + timestamp_types = [[[2]], [[np.datetime64("2000-01-01")]]] + for timestamps in timestamp_types: + Field.from_netcdf(filenames, variable, dimensions, interp_method="cgrid_velocity", timestamps=timestamps) + else: + Field.from_netcdf(filenames, variable, dimensions, interp_method="cgrid_velocity") + + +@pytest.mark.parametrize( + "f", + [ + pytest.param(lambda x: x, id="Path"), + pytest.param(lambda x: str(x), id="str"), + ], +) +def test_from_netcdf_path_object(f): + filenames = { + "lon": f(TEST_DATA / "mask_nemo_cross_180lon.nc"), + "lat": f(TEST_DATA / "mask_nemo_cross_180lon.nc"), + "data": f(TEST_DATA / "Uu_eastward_nemo_cross_180lon.nc"), + } + variable = "U" + dimensions = {"lon": "glamf", "lat": "gphif"} + + Field.from_netcdf(filenames, variable, dimensions, interp_method="cgrid_velocity") + + +@pytest.mark.parametrize( + "calendar, cftime_datetime", zip(_get_cftime_calendars(), _get_cftime_datetimes(), strict=True) +) +def test_field_nonstandardtime(calendar, cftime_datetime, tmpdir): + xdim = 4 + ydim = 6 + filepath = tmpdir.join("test_nonstandardtime.nc") + dates = [getattr(cftime, cftime_datetime)(1, m, 1) for m in range(1, 13)] + da = xr.DataArray( + np.random.rand(12, xdim, ydim), coords=[dates, range(xdim), range(ydim)], dims=["time", "lon", "lat"], name="U" + ) + da.to_netcdf(str(filepath)) + + dims = {"lon": "lon", "lat": "lat", "time": "time"} + try: + field = Field.from_netcdf(filepath, "U", dims) + except NotImplementedError: + field = None + + if field is not None: + assert field.grid.time_origin.calendar == calendar + + +@pytest.mark.parametrize( + "input_,expected", + [ + pytest.param("file1.nc", ["file1.nc"], id="str"), + pytest.param(["file1.nc", "file2.nc"], ["file1.nc", "file2.nc"], id="list"), + pytest.param(["file2.nc", "file1.nc"], ["file1.nc", "file2.nc"], id="list-unsorted"), + pytest.param([Path("file1.nc"), Path("file2.nc")], ["file1.nc", "file2.nc"], id="list-Path"), + pytest.param( + { + "lon": "lon_file.nc", + "lat": ["lat_file1.nc", Path("lat_file2.nc")], + "depth": Path("depth_file.nc"), + "data": ["data_file1.nc", "data_file2.nc"], + }, + { + "lon": ["lon_file.nc"], + "lat": ["lat_file1.nc", "lat_file2.nc"], + "depth": ["depth_file.nc"], + "data": ["data_file1.nc", "data_file2.nc"], + }, + id="dict-mix", + ), + ], +) +def test_sanitize_field_filenames_cases(input_, expected): + assert _sanitize_field_filenames(input_) == expected + + +@pytest.mark.parametrize( + "input_,expected", + [ + pytest.param("file*.nc", [], id="glob-no-match"), + ], +) +def test_sanitize_field_filenames_glob(input_, expected): + assert _sanitize_field_filenames(input_) == expected + + +@pytest.mark.parametrize( + "input_,expected", + [ + pytest.param("test", ["test"], id="str"), + pytest.param(Path("test"), ["test"], id="Path"), + pytest.param("file*.nc", [], id="glob-no-match"), + ], +) +def test_expand_filename(input_, expected): + assert _expand_filename(input_) == expected diff --git a/tests/test_fieldset.py b/tests/test_fieldset.py index 7f54ca4d66..ef013046fe 100644 --- a/tests/test_fieldset.py +++ b/tests/test_fieldset.py @@ -4,7 +4,6 @@ import sys from datetime import timedelta -import cftime import dask import dask.array as da import numpy as np @@ -29,8 +28,6 @@ GeographicPolar, TimeConverter, UnitConverter, - _get_cftime_calendars, - _get_cftime_datetimes, ) from tests.common_kernels import DoNothing from tests.utils import TEST_DATA @@ -61,6 +58,7 @@ def test_fieldset_from_data(xdim, ydim): """Simple test for fieldset initialisation from data.""" data, dimensions = generate_fieldset_data(xdim, ydim) fieldset = FieldSet.from_data(data, dimensions) + assert fieldset.U._creation_log == "from_data" assert len(fieldset.U.data.shape) == 3 assert len(fieldset.V.data.shape) == 3 assert np.allclose(fieldset.U.data[0, :], data["U"], rtol=1e-12) @@ -139,65 +137,6 @@ def test_fieldset_from_parcels(xdim, ydim, tmpdir): assert np.allclose(fieldset.V.data[0, :], data["V"], rtol=1e-12) -def test_field_from_netcdf_variables(): - filename = str(TEST_DATA / "perlinfieldsU.nc") - dims = {"lon": "x", "lat": "y"} - - variable = "vozocrtx" - f1 = Field.from_netcdf(filename, variable, dims) - variable = ("U", "vozocrtx") - f2 = Field.from_netcdf(filename, variable, dims) - variable = {"U": "vozocrtx"} - f3 = Field.from_netcdf(filename, variable, dims) - - assert np.allclose(f1.data, f2.data, atol=1e-12) - assert np.allclose(f1.data, f3.data, atol=1e-12) - - with pytest.raises(AssertionError): - variable = {"U": "vozocrtx", "nav_lat": "nav_lat"} # multiple variables will fail - f3 = Field.from_netcdf(filename, variable, dims) - - -@pytest.mark.parametrize( - "calendar, cftime_datetime", zip(_get_cftime_calendars(), _get_cftime_datetimes(), strict=True) -) -def test_fieldset_nonstandardtime( - calendar, cftime_datetime, tmpdir, filename="test_nonstandardtime.nc", xdim=4, ydim=6 -): - filepath = tmpdir.join(filename) - dates = [getattr(cftime, cftime_datetime)(1, m, 1) for m in range(1, 13)] - da = xr.DataArray( - np.random.rand(12, xdim, ydim), coords=[dates, range(xdim), range(ydim)], dims=["time", "lon", "lat"], name="U" - ) - da.to_netcdf(str(filepath)) - - dims = {"lon": "lon", "lat": "lat", "time": "time"} - try: - field = Field.from_netcdf(filepath, "U", dims) - except NotImplementedError: - field = None - - if field is not None: - assert field.grid.time_origin.calendar == calendar - - -@pytest.mark.parametrize("with_timestamps", [True, False]) -def test_field_from_netcdf(with_timestamps): - filenames = { - "lon": str(TEST_DATA / "mask_nemo_cross_180lon.nc"), - "lat": str(TEST_DATA / "mask_nemo_cross_180lon.nc"), - "data": str(TEST_DATA / "Uu_eastward_nemo_cross_180lon.nc"), - } - variable = "U" - dimensions = {"lon": "glamf", "lat": "gphif"} - if with_timestamps: - timestamp_types = [[[2]], [[np.datetime64("2000-01-01")]]] - for timestamps in timestamp_types: - Field.from_netcdf(filenames, variable, dimensions, interp_method="cgrid_velocity", timestamps=timestamps) - else: - Field.from_netcdf(filenames, variable, dimensions, interp_method="cgrid_velocity") - - def test_fieldset_from_modulefile(): nemo_fname = str(TEST_DATA / "fieldset_nemo.py") nemo_error_fname = str(TEST_DATA / "fieldset_nemo_error.py") @@ -958,8 +897,7 @@ def test_fieldset_defer_loading_with_diff_time_origin(tmpdir, fail): @pytest.mark.parametrize("zdim", [2, 8]) @pytest.mark.parametrize("scale_fac", [0.2, 4, 1]) -def test_fieldset_defer_loading_function(zdim, scale_fac, tmpdir): - filepath = tmpdir.join("test_parcels_defer_loading") +def test_fieldset_defer_loading_function(zdim, scale_fac, tmp_path): data0, dims0 = generate_fieldset_data(3, 3, zdim, 10) data0["U"][:, 0, :, :] = ( np.nan @@ -967,9 +905,9 @@ def test_fieldset_defer_loading_function(zdim, scale_fac, tmpdir): dims0["time"] = np.arange(0, 10, 1) * 3600 dims0["depth"] = np.arange(0, zdim, 1) fieldset_out = FieldSet.from_data(data0, dims0) - fieldset_out.write(filepath) + fieldset_out.write(tmp_path) fieldset = FieldSet.from_parcels( - filepath, chunksize={"time": ("time_counter", 1), "depth": ("depthu", 1), "lat": ("y", 2), "lon": ("x", 2)} + tmp_path, chunksize={"time": ("time_counter", 1), "depth": ("depthu", 1), "lat": ("y", 2), "lon": ("x", 2)} ) # testing for combination of deferred-loaded and numpy Fields @@ -1093,6 +1031,25 @@ def test_fieldset_frompop(mode): pset.execute(AdvectionRK4, runtime=3, dt=1) +@pytest.mark.parametrize( + "f", + [ + pytest.param(lambda x: x, id="pathlib.Path"), + pytest.param(lambda x: str(x), id="str"), + ], +) +def test_fieldset_from_netcdf_path_type(f): + filenames = { + "lon": f(TEST_DATA / "mask_nemo_cross_180lon.nc"), + "lat": f(TEST_DATA / "mask_nemo_cross_180lon.nc"), + "data": f(TEST_DATA / "Uu_eastward_nemo_cross_180lon.nc"), + } + variable = "U" + dimensions = {"lon": "glamf", "lat": "gphif"} + + Field.from_netcdf(filenames, variable, dimensions, interp_method="cgrid_velocity") + + def test_fieldset_from_data_gridtypes(): """Simple test for fieldset initialisation from data.""" xdim, ydim, zdim = 20, 10, 4