From 8a9b2c0239b9a01844346cfd26e03d069fd8a867 Mon Sep 17 00:00:00 2001 From: Vecko <36369090+VeckoTheGecko@users.noreply.github.com> Date: Wed, 4 Dec 2024 15:39:52 +0100 Subject: [PATCH 01/15] update test to tmp_path --- tests/test_fieldset.py | 7 +++---- 1 file changed, 3 insertions(+), 4 deletions(-) diff --git a/tests/test_fieldset.py b/tests/test_fieldset.py index 7f54ca4d66..d86dff7382 100644 --- a/tests/test_fieldset.py +++ b/tests/test_fieldset.py @@ -958,8 +958,7 @@ def test_fieldset_defer_loading_with_diff_time_origin(tmpdir, fail): @pytest.mark.parametrize("zdim", [2, 8]) @pytest.mark.parametrize("scale_fac", [0.2, 4, 1]) -def test_fieldset_defer_loading_function(zdim, scale_fac, tmpdir): - filepath = tmpdir.join("test_parcels_defer_loading") +def test_fieldset_defer_loading_function(zdim, scale_fac, tmp_path): data0, dims0 = generate_fieldset_data(3, 3, zdim, 10) data0["U"][:, 0, :, :] = ( np.nan @@ -967,9 +966,9 @@ def test_fieldset_defer_loading_function(zdim, scale_fac, tmpdir): dims0["time"] = np.arange(0, 10, 1) * 3600 dims0["depth"] = np.arange(0, zdim, 1) fieldset_out = FieldSet.from_data(data0, dims0) - fieldset_out.write(filepath) + fieldset_out.write(tmp_path) fieldset = FieldSet.from_parcels( - filepath, chunksize={"time": ("time_counter", 1), "depth": ("depthu", 1), "lat": ("y", 2), "lon": ("x", 2)} + tmp_path, chunksize={"time": ("time_counter", 1), "depth": ("depthu", 1), "lat": ("y", 2), "lon": ("x", 2)} ) # testing for combination of deferred-loaded and numpy Fields From 9b6720d2b989aff89861c17548924c9bba58136c Mon Sep 17 00:00:00 2001 From: Vecko <36369090+VeckoTheGecko@users.noreply.github.com> Date: Wed, 4 Dec 2024 15:49:21 +0100 Subject: [PATCH 02/15] Add filename type test --- tests/test_fieldset.py | 19 +++++++++++++++++++ 1 file changed, 19 insertions(+) diff --git a/tests/test_fieldset.py b/tests/test_fieldset.py index d86dff7382..f7c63737c5 100644 --- a/tests/test_fieldset.py +++ b/tests/test_fieldset.py @@ -1092,6 +1092,25 @@ def test_fieldset_frompop(mode): pset.execute(AdvectionRK4, runtime=3, dt=1) +@pytest.mark.parametrize( + "f", + [ + pytest.param(lambda x: x, id="pathlib.Path"), + pytest.param(lambda x: str(x), id="str"), + ], +) +def test_fieldset_from_netcdf_path_type(f): + filenames = { + "lon": f(TEST_DATA / "mask_nemo_cross_180lon.nc"), + "lat": f(TEST_DATA / "mask_nemo_cross_180lon.nc"), + "data": f(TEST_DATA / "Uu_eastward_nemo_cross_180lon.nc"), + } + variable = "U" + dimensions = {"lon": "glamf", "lat": "gphif"} + + Field.from_netcdf(filenames, variable, dimensions, interp_method="cgrid_velocity") + + def test_fieldset_from_data_gridtypes(): """Simple test for fieldset initialisation from data.""" xdim, ydim, zdim = 20, 10, 4 From 4e2710f477ddd3c6a8427f06e4b187bf4805c894 Mon Sep 17 00:00:00 2001 From: Vecko <36369090+VeckoTheGecko@users.noreply.github.com> Date: Wed, 4 Dec 2024 15:50:37 +0100 Subject: [PATCH 03/15] Move _get_dim_filenames --- parcels/field.py | 40 +++++++++++++++++++------------------- tests/test_deprecations.py | 2 +- 2 files changed, 21 insertions(+), 21 deletions(-) diff --git a/parcels/field.py b/parcels/field.py index 159c7c299f..204ceded53 100644 --- a/parcels/field.py +++ b/parcels/field.py @@ -426,22 +426,8 @@ def netcdf_engine(self): @classmethod @deprecated_made_private # TODO: Remove 6 months after v3.1.0 - def get_dim_filenames(cls, *args, **kwargs): - return cls._get_dim_filenames(*args, **kwargs) - - @classmethod - def _get_dim_filenames(cls, filenames, dim): - if isinstance(filenames, str) or not isinstance(filenames, collections.abc.Iterable): - return [filenames] - elif isinstance(filenames, dict): - assert dim in filenames.keys(), "filename dimension keys must be lon, lat, depth or data" - filename = filenames[dim] - if isinstance(filename, str): - return [filename] - else: - return filename - else: - return filenames + def get_dim_filenames(*args, **kwargs): + return _get_dim_filenames(*args, **kwargs) @staticmethod @deprecated_made_private # TODO: Remove 6 months after v3.1.0 @@ -598,17 +584,17 @@ def from_netcdf( len(variable) == 2 ), "The variable tuple must have length 2. Use FieldSet.from_netcdf() for multiple variables" - data_filenames = cls._get_dim_filenames(filenames, "data") - lonlat_filename = cls._get_dim_filenames(filenames, "lon") + data_filenames = _get_dim_filenames(filenames, "data") + lonlat_filename = _get_dim_filenames(filenames, "lon") if isinstance(filenames, dict): assert len(lonlat_filename) == 1 - if lonlat_filename != cls._get_dim_filenames(filenames, "lat"): + if lonlat_filename != _get_dim_filenames(filenames, "lat"): raise NotImplementedError( "longitude and latitude dimensions are currently processed together from one single file" ) lonlat_filename = lonlat_filename[0] if "depth" in dimensions: - depth_filename = cls._get_dim_filenames(filenames, "depth") + depth_filename = _get_dim_filenames(filenames, "depth") if isinstance(filenames, dict) and len(depth_filename) != 1: raise NotImplementedError("Vertically adaptive meshes not implemented for from_netcdf()") depth_filename = depth_filename[0] @@ -2584,3 +2570,17 @@ def __getitem__(self, key): else: pass return val + + +def _get_dim_filenames(filenames, dim): + if isinstance(filenames, str) or not isinstance(filenames, collections.abc.Iterable): + return [filenames] + elif isinstance(filenames, dict): + assert dim in filenames.keys(), "filename dimension keys must be lon, lat, depth or data" + filename = filenames[dim] + if isinstance(filename, str): + return [filename] + else: + return filename + else: + raise ValueError("Filenames must be a string, pathlib.Path, a list or a dictionary") diff --git a/tests/test_deprecations.py b/tests/test_deprecations.py index f75b85b574..e408b12788 100644 --- a/tests/test_deprecations.py +++ b/tests/test_deprecations.py @@ -122,7 +122,7 @@ def test_testing_action_class(): Action("Field", "c_data_chunks", "make_private" ), Action("Field", "chunk_set", "make_private" ), Action("Field", "cell_edge_sizes", "read_only" ), - Action("Field", "get_dim_filenames()", "make_private" ), + Action("Field", "get_dim_filenames()", "make_private" , skip_reason="Moved underlying function."), Action("Field", "collect_timeslices()", "make_private" ), Action("Field", "reshape()", "make_private" ), Action("Field", "calc_cell_edge_sizes()", "make_private" ), From 4aedbdc85e070c395f0f5f8f201ceffa58dbab78 Mon Sep 17 00:00:00 2001 From: Vecko <36369090+VeckoTheGecko@users.noreply.github.com> Date: Wed, 4 Dec 2024 15:57:02 +0100 Subject: [PATCH 04/15] Add typing for _get_dim_filenames --- parcels/field.py | 11 +++++------ 1 file changed, 5 insertions(+), 6 deletions(-) diff --git a/parcels/field.py b/parcels/field.py index 204ceded53..96f2857552 100644 --- a/parcels/field.py +++ b/parcels/field.py @@ -1,10 +1,9 @@ -import collections import math import warnings from collections.abc import Iterable from ctypes import POINTER, Structure, c_float, c_int, pointer from pathlib import Path -from typing import TYPE_CHECKING, Literal +from typing import TYPE_CHECKING, Any, Literal import dask.array as da import numpy as np @@ -2572,8 +2571,8 @@ def __getitem__(self, key): return val -def _get_dim_filenames(filenames, dim): - if isinstance(filenames, str) or not isinstance(filenames, collections.abc.Iterable): +def _get_dim_filenames(filenames: str | Path | Any | dict[str, str | Any], dim: str) -> Any: + if isinstance(filenames, str) or not isinstance(filenames, Iterable): return [filenames] elif isinstance(filenames, dict): assert dim in filenames.keys(), "filename dimension keys must be lon, lat, depth or data" @@ -2582,5 +2581,5 @@ def _get_dim_filenames(filenames, dim): return [filename] else: return filename - else: - raise ValueError("Filenames must be a string, pathlib.Path, a list or a dictionary") + + raise ValueError("Filenames must be a string, pathlib.Path, or a dictionary") From 7470597eb41f03d3059ab4abb32cc0536407ba0d Mon Sep 17 00:00:00 2001 From: Vecko <36369090+VeckoTheGecko@users.noreply.github.com> Date: Wed, 4 Dec 2024 15:57:45 +0100 Subject: [PATCH 05/15] Bugfix _get_dim_filenames not working with path objects --- parcels/field.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/parcels/field.py b/parcels/field.py index 96f2857552..f3e016294f 100644 --- a/parcels/field.py +++ b/parcels/field.py @@ -2577,7 +2577,7 @@ def _get_dim_filenames(filenames: str | Path | Any | dict[str, str | Any], dim: elif isinstance(filenames, dict): assert dim in filenames.keys(), "filename dimension keys must be lon, lat, depth or data" filename = filenames[dim] - if isinstance(filename, str): + if not isinstance(filename, Iterable): return [filename] else: return filename From f4adcb50b61fb8cff872e83823ecfeae849b11ad Mon Sep 17 00:00:00 2001 From: Vecko <36369090+VeckoTheGecko@users.noreply.github.com> Date: Thu, 21 Nov 2024 19:37:11 +0800 Subject: [PATCH 06/15] move field tests to file --- tests/test_advection.py | 1 - tests/test_field.py | 70 +++++++++++++++++++++++++++++++++++++++++ tests/test_fieldset.py | 63 +------------------------------------ 3 files changed, 71 insertions(+), 63 deletions(-) create mode 100644 tests/test_field.py diff --git a/tests/test_advection.py b/tests/test_advection.py index c9e441734e..77705f0576 100644 --- a/tests/test_advection.py +++ b/tests/test_advection.py @@ -73,7 +73,6 @@ def test_advection_zonal(lon, lat, depth, mode): } dimensions = {"lon": lon, "lat": lat} fieldset2D = FieldSet.from_data(data2D, dimensions, mesh="spherical", transpose=True) - assert fieldset2D.U._creation_log == "from_data" pset2D = ParticleSet(fieldset2D, pclass=ptype[mode], lon=np.zeros(npart) + 20.0, lat=np.linspace(0, 80, npart)) pset2D.execute(AdvectionRK4, runtime=timedelta(hours=2), dt=timedelta(seconds=30)) diff --git a/tests/test_field.py b/tests/test_field.py new file mode 100644 index 0000000000..106622705c --- /dev/null +++ b/tests/test_field.py @@ -0,0 +1,70 @@ +import cftime +import numpy as np +import pytest +import xarray as xr + +from parcels import Field +from parcels.tools.converters import ( + _get_cftime_calendars, + _get_cftime_datetimes, +) +from tests.utils import TEST_DATA + + +def test_field_from_netcdf_variables(): + filename = str(TEST_DATA / "perlinfieldsU.nc") + dims = {"lon": "x", "lat": "y"} + + variable = "vozocrtx" + f1 = Field.from_netcdf(filename, variable, dims) + variable = ("U", "vozocrtx") + f2 = Field.from_netcdf(filename, variable, dims) + variable = {"U": "vozocrtx"} + f3 = Field.from_netcdf(filename, variable, dims) + + assert np.allclose(f1.data, f2.data, atol=1e-12) + assert np.allclose(f1.data, f3.data, atol=1e-12) + + with pytest.raises(AssertionError): + variable = {"U": "vozocrtx", "nav_lat": "nav_lat"} # multiple variables will fail + f3 = Field.from_netcdf(filename, variable, dims) + + +@pytest.mark.parametrize("with_timestamps", [True, False]) +def test_field_from_netcdf(with_timestamps): + filenames = { + "lon": str(TEST_DATA / "mask_nemo_cross_180lon.nc"), + "lat": str(TEST_DATA / "mask_nemo_cross_180lon.nc"), + "data": str(TEST_DATA / "Uu_eastward_nemo_cross_180lon.nc"), + } + variable = "U" + dimensions = {"lon": "glamf", "lat": "gphif"} + if with_timestamps: + timestamp_types = [[[2]], [[np.datetime64("2000-01-01")]]] + for timestamps in timestamp_types: + Field.from_netcdf(filenames, variable, dimensions, interp_method="cgrid_velocity", timestamps=timestamps) + else: + Field.from_netcdf(filenames, variable, dimensions, interp_method="cgrid_velocity") + + +@pytest.mark.parametrize( + "calendar, cftime_datetime", zip(_get_cftime_calendars(), _get_cftime_datetimes(), strict=True) +) +def test_field_nonstandardtime(calendar, cftime_datetime, tmpdir): + xdim = 4 + ydim = 6 + filepath = tmpdir.join("test_nonstandardtime.nc") + dates = [getattr(cftime, cftime_datetime)(1, m, 1) for m in range(1, 13)] + da = xr.DataArray( + np.random.rand(12, xdim, ydim), coords=[dates, range(xdim), range(ydim)], dims=["time", "lon", "lat"], name="U" + ) + da.to_netcdf(str(filepath)) + + dims = {"lon": "lon", "lat": "lat", "time": "time"} + try: + field = Field.from_netcdf(filepath, "U", dims) + except NotImplementedError: + field = None + + if field is not None: + assert field.grid.time_origin.calendar == calendar diff --git a/tests/test_fieldset.py b/tests/test_fieldset.py index f7c63737c5..ef013046fe 100644 --- a/tests/test_fieldset.py +++ b/tests/test_fieldset.py @@ -4,7 +4,6 @@ import sys from datetime import timedelta -import cftime import dask import dask.array as da import numpy as np @@ -29,8 +28,6 @@ GeographicPolar, TimeConverter, UnitConverter, - _get_cftime_calendars, - _get_cftime_datetimes, ) from tests.common_kernels import DoNothing from tests.utils import TEST_DATA @@ -61,6 +58,7 @@ def test_fieldset_from_data(xdim, ydim): """Simple test for fieldset initialisation from data.""" data, dimensions = generate_fieldset_data(xdim, ydim) fieldset = FieldSet.from_data(data, dimensions) + assert fieldset.U._creation_log == "from_data" assert len(fieldset.U.data.shape) == 3 assert len(fieldset.V.data.shape) == 3 assert np.allclose(fieldset.U.data[0, :], data["U"], rtol=1e-12) @@ -139,65 +137,6 @@ def test_fieldset_from_parcels(xdim, ydim, tmpdir): assert np.allclose(fieldset.V.data[0, :], data["V"], rtol=1e-12) -def test_field_from_netcdf_variables(): - filename = str(TEST_DATA / "perlinfieldsU.nc") - dims = {"lon": "x", "lat": "y"} - - variable = "vozocrtx" - f1 = Field.from_netcdf(filename, variable, dims) - variable = ("U", "vozocrtx") - f2 = Field.from_netcdf(filename, variable, dims) - variable = {"U": "vozocrtx"} - f3 = Field.from_netcdf(filename, variable, dims) - - assert np.allclose(f1.data, f2.data, atol=1e-12) - assert np.allclose(f1.data, f3.data, atol=1e-12) - - with pytest.raises(AssertionError): - variable = {"U": "vozocrtx", "nav_lat": "nav_lat"} # multiple variables will fail - f3 = Field.from_netcdf(filename, variable, dims) - - -@pytest.mark.parametrize( - "calendar, cftime_datetime", zip(_get_cftime_calendars(), _get_cftime_datetimes(), strict=True) -) -def test_fieldset_nonstandardtime( - calendar, cftime_datetime, tmpdir, filename="test_nonstandardtime.nc", xdim=4, ydim=6 -): - filepath = tmpdir.join(filename) - dates = [getattr(cftime, cftime_datetime)(1, m, 1) for m in range(1, 13)] - da = xr.DataArray( - np.random.rand(12, xdim, ydim), coords=[dates, range(xdim), range(ydim)], dims=["time", "lon", "lat"], name="U" - ) - da.to_netcdf(str(filepath)) - - dims = {"lon": "lon", "lat": "lat", "time": "time"} - try: - field = Field.from_netcdf(filepath, "U", dims) - except NotImplementedError: - field = None - - if field is not None: - assert field.grid.time_origin.calendar == calendar - - -@pytest.mark.parametrize("with_timestamps", [True, False]) -def test_field_from_netcdf(with_timestamps): - filenames = { - "lon": str(TEST_DATA / "mask_nemo_cross_180lon.nc"), - "lat": str(TEST_DATA / "mask_nemo_cross_180lon.nc"), - "data": str(TEST_DATA / "Uu_eastward_nemo_cross_180lon.nc"), - } - variable = "U" - dimensions = {"lon": "glamf", "lat": "gphif"} - if with_timestamps: - timestamp_types = [[[2]], [[np.datetime64("2000-01-01")]]] - for timestamps in timestamp_types: - Field.from_netcdf(filenames, variable, dimensions, interp_method="cgrid_velocity", timestamps=timestamps) - else: - Field.from_netcdf(filenames, variable, dimensions, interp_method="cgrid_velocity") - - def test_fieldset_from_modulefile(): nemo_fname = str(TEST_DATA / "fieldset_nemo.py") nemo_error_fname = str(TEST_DATA / "fieldset_nemo_error.py") From c795ba4be04f054259d53c6e2e331d837efbfade Mon Sep 17 00:00:00 2001 From: Vecko <36369090+VeckoTheGecko@users.noreply.github.com> Date: Wed, 4 Dec 2024 16:30:29 +0100 Subject: [PATCH 07/15] Move _parse_wildcards --- parcels/fieldset.py | 30 +++++++++++++++--------------- tests/test_deprecations.py | 2 +- 2 files changed, 16 insertions(+), 16 deletions(-) diff --git a/parcels/fieldset.py b/parcels/fieldset.py index 5fa6b13c53..2cd74e7ff8 100644 --- a/parcels/fieldset.py +++ b/parcels/fieldset.py @@ -348,19 +348,7 @@ def check_velocityfields(U, V, W): @classmethod @deprecated_made_private # TODO: Remove 6 months after v3.1.0 def parse_wildcards(cls, *args, **kwargs): - return cls._parse_wildcards(*args, **kwargs) - - @classmethod - def _parse_wildcards(cls, paths, filenames, var): - if not isinstance(paths, list): - paths = sorted(glob(str(paths))) - if len(paths) == 0: - notfound_paths = filenames[var] if isinstance(filenames, dict) and var in filenames else filenames - raise OSError(f"FieldSet files not found for variable {var}: {notfound_paths}") - for fp in paths: - if not os.path.exists(fp): - raise OSError(f"FieldSet file not found: {fp}") - return paths + return _parse_wildcards(*args, **kwargs) @classmethod def from_netcdf( @@ -477,10 +465,10 @@ def from_netcdf( # Resolve all matching paths for the current variable paths = filenames[var] if type(filenames) is dict and var in filenames else filenames if type(paths) is not dict: - paths = cls._parse_wildcards(paths, filenames, var) + paths = _parse_wildcards(paths, filenames, var) else: for dim, p in paths.items(): - paths[dim] = cls._parse_wildcards(p, filenames, var) + paths[dim] = _parse_wildcards(p, filenames, var) # Use dimensions[var] and indices[var] if either of them is a dict of dicts dims = dimensions[var] if var in dimensions else dimensions @@ -1689,3 +1677,15 @@ def computeTimeChunk(self, time=0.0, dt=1): return nextTime else: return time + nSteps * dt + + +def _parse_wildcards(paths, filenames, var): + if not isinstance(paths, list): + paths = sorted(glob(str(paths))) + if len(paths) == 0: + notfound_paths = filenames[var] if isinstance(filenames, dict) and var in filenames else filenames + raise OSError(f"FieldSet files not found for variable {var}: {notfound_paths}") + for fp in paths: + if not os.path.exists(fp): + raise OSError(f"FieldSet file not found: {fp}") + return paths diff --git a/tests/test_deprecations.py b/tests/test_deprecations.py index e408b12788..8666199f80 100644 --- a/tests/test_deprecations.py +++ b/tests/test_deprecations.py @@ -148,7 +148,7 @@ def test_testing_action_class(): Action("FieldSet", "particlefile", "read_only" ), Action("FieldSet", "add_UVfield()", "make_private" ), Action("FieldSet", "check_complete()", "make_private" ), - Action("FieldSet", "parse_wildcards()", "make_private" ), + Action("FieldSet", "parse_wildcards()", "make_private" , skip_reason="Moved underlying function."), # 1713 Action("ParticleSet", "repeat_starttime", "make_private" ), From dc6d4f864cd935292cf66995e0011150b7b1f94c Mon Sep 17 00:00:00 2001 From: Vecko <36369090+VeckoTheGecko@users.noreply.github.com> Date: Wed, 4 Dec 2024 16:32:33 +0100 Subject: [PATCH 08/15] Update error type --- parcels/fieldset.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/parcels/fieldset.py b/parcels/fieldset.py index 2cd74e7ff8..257431d434 100644 --- a/parcels/fieldset.py +++ b/parcels/fieldset.py @@ -1684,8 +1684,8 @@ def _parse_wildcards(paths, filenames, var): paths = sorted(glob(str(paths))) if len(paths) == 0: notfound_paths = filenames[var] if isinstance(filenames, dict) and var in filenames else filenames - raise OSError(f"FieldSet files not found for variable {var}: {notfound_paths}") + raise FileNotFoundError(f"FieldSet files not found for variable {var}: {notfound_paths}") for fp in paths: if not os.path.exists(fp): - raise OSError(f"FieldSet file not found: {fp}") + raise FileNotFoundError(f"FieldSet file not found: {fp}") return paths From a6b46d49acb81ec64d956d91e6a9b37d77f1332d Mon Sep 17 00:00:00 2001 From: Vecko <36369090+VeckoTheGecko@users.noreply.github.com> Date: Wed, 4 Dec 2024 18:01:10 +0100 Subject: [PATCH 09/15] Add _sanitize_field_filenames and associated tests --- parcels/field.py | 57 +++++++++++++++++++++++++++++++++++++++++ tests/test_field.py | 62 +++++++++++++++++++++++++++++++++++++++++++++ 2 files changed, 119 insertions(+) diff --git a/parcels/field.py b/parcels/field.py index f3e016294f..28ee1b1911 100644 --- a/parcels/field.py +++ b/parcels/field.py @@ -2,6 +2,7 @@ import warnings from collections.abc import Iterable from ctypes import POINTER, Structure, c_float, c_int, pointer +from glob import glob from pathlib import Path from typing import TYPE_CHECKING, Any, Literal @@ -543,6 +544,8 @@ def from_netcdf( * `Timestamps <../examples/tutorial_timestamps.ipynb>`__ """ + filenames = _sanitize_field_filenames(filenames) + if kwargs.get("netcdf_decodewarning") is not None: _deprecated_param_netcdf_decodewarning() kwargs.pop("netcdf_decodewarning") @@ -2572,6 +2575,7 @@ def __getitem__(self, key): def _get_dim_filenames(filenames: str | Path | Any | dict[str, str | Any], dim: str) -> Any: + """Get's the relevant filenames for a given dimension.""" if isinstance(filenames, str) or not isinstance(filenames, Iterable): return [filenames] elif isinstance(filenames, dict): @@ -2583,3 +2587,56 @@ def _get_dim_filenames(filenames: str | Path | Any | dict[str, str | Any], dim: return filename raise ValueError("Filenames must be a string, pathlib.Path, or a dictionary") + + +def _sanitize_field_filenames(filenames, *, recursed=False): + """The Field initializer can take `filenames` to be of various formats including: + + 1. a string or Path object. String can be a glob expression. + 2. a list of (a) + 3. a dictionary mapping with keys 'lon', 'lat', 'depth', 'data' and values of (1) or (2) + + This function sanitizes the inputs such that it returns, in the case of: + 1. A sorted list of strings with the expanded glob expression + 2. A sorted list of strings with the expanded glob expressions + 3. A dictionary with same keys but values as in (1) or (2). + + See tests for examples. + """ + allowed_dimension_keys = ("lon", "lat", "depth", "data") + + if isinstance(filenames, str) or not isinstance(filenames, Iterable): + return sorted(_expand_filename(filenames)) + + if isinstance(filenames, list): + files = [] + for f in filenames: + files.extend(_expand_filename(f)) + return sorted(files) + + if isinstance(filenames, dict): + if recursed: + raise ValueError("Invalid filenames format. Nested dictionary not allowed in dimension dictionary") + + for key in filenames: + if key not in allowed_dimension_keys: + raise ValueError( + f"Invalid key in filenames dimension dictionary. Must be one of {allowed_dimension_keys}" + ) + filenames[key] = _sanitize_field_filenames(filenames[key], recursed=True) + + return filenames + + raise ValueError("Filenames must be a string, pathlib.Path, list, or a dictionary") + + +def _expand_filename(filename: str | Path) -> list[str]: + """ + Converts a filename to a list of filenames if it is a glob expression. + + If a file is explicitly provided (i.e., not via glob), existence is only checked later. + """ + filename = str(filename) + if "*" in filename: + return glob(filename) + return [filename] diff --git a/tests/test_field.py b/tests/test_field.py index 106622705c..f171dd3525 100644 --- a/tests/test_field.py +++ b/tests/test_field.py @@ -1,9 +1,13 @@ +import glob +from pathlib import Path + import cftime import numpy as np import pytest import xarray as xr from parcels import Field +from parcels.field import _expand_filename, _sanitize_field_filenames from parcels.tools.converters import ( _get_cftime_calendars, _get_cftime_datetimes, @@ -68,3 +72,61 @@ def test_field_nonstandardtime(calendar, cftime_datetime, tmpdir): if field is not None: assert field.grid.time_origin.calendar == calendar + + +@pytest.mark.parametrize( + "input_,expected", + [ + pytest.param("file1.nc", ["file1.nc"], id="str"), + pytest.param(["file1.nc", "file2.nc"], ["file1.nc", "file2.nc"], id="list"), + pytest.param(["file2.nc", "file1.nc"], ["file1.nc", "file2.nc"], id="list-unsorted"), + pytest.param([Path("file1.nc"), Path("file2.nc")], ["file1.nc", "file2.nc"], id="list-Path"), + pytest.param( + { + "lon": "lon_file.nc", + "lat": ["lat_file1.nc", Path("lat_file2.nc")], + "depth": Path("depth_file.nc"), + "data": ["data_file1.nc", "data_file2.nc"], + }, + { + "lon": ["lon_file.nc"], + "lat": ["lat_file1.nc", "lat_file2.nc"], + "depth": ["depth_file.nc"], + "data": ["data_file1.nc", "data_file2.nc"], + }, + id="dict-mix", + ), + ], +) +def test_sanitize_field_filenames_cases(input_, expected): + assert _sanitize_field_filenames(input_) == expected + + +@pytest.mark.parametrize( + "input_,expected", + [ + ("file*.nc", ["file0.nc", "file1.nc", "file2.nc"]), + ], +) +def test_sanitize_field_filenames_glob(input_, expected, tmp_path, monkeypatch): + def monkey_glob(pattern): + return glob.glob(pattern, root_dir=tmp_path) + + monkeypatch.setattr(_sanitize_field_filenames, "glob", monkey_glob) + + for f in expected: + Path(tmp_path / f).touch() + + assert _sanitize_field_filenames(input_, tmp_path) == expected + + +@pytest.mark.parametrize( + "input_,expected", + [ + pytest.param("test", ["test"], id="str"), + pytest.param(Path("test"), ["test"], id="Path"), + pytest.param("file*.nc", [], id="glob-no-match"), + ], +) +def test_expand_filename(input_, expected): + assert _expand_filename(input_) == expected From 8a7de2a0a3231b22c6f676c050f058f0b49f3693 Mon Sep 17 00:00:00 2001 From: Vecko <36369090+VeckoTheGecko@users.noreply.github.com> Date: Wed, 4 Dec 2024 18:04:28 +0100 Subject: [PATCH 10/15] Update _get_dim_filenames Now that filenames is sanitized --- parcels/field.py | 14 +++++--------- 1 file changed, 5 insertions(+), 9 deletions(-) diff --git a/parcels/field.py b/parcels/field.py index 28ee1b1911..d9095ccc4e 100644 --- a/parcels/field.py +++ b/parcels/field.py @@ -2576,15 +2576,11 @@ def __getitem__(self, key): def _get_dim_filenames(filenames: str | Path | Any | dict[str, str | Any], dim: str) -> Any: """Get's the relevant filenames for a given dimension.""" - if isinstance(filenames, str) or not isinstance(filenames, Iterable): - return [filenames] - elif isinstance(filenames, dict): - assert dim in filenames.keys(), "filename dimension keys must be lon, lat, depth or data" - filename = filenames[dim] - if not isinstance(filename, Iterable): - return [filename] - else: - return filename + if isinstance(filenames, list): + return filenames + + if isinstance(filenames, dict): + return filenames[dim] raise ValueError("Filenames must be a string, pathlib.Path, or a dictionary") From ed60b4b45c42ba816308b713e7ba49df8f8449b5 Mon Sep 17 00:00:00 2001 From: Vecko <36369090+VeckoTheGecko@users.noreply.github.com> Date: Wed, 4 Dec 2024 18:10:57 +0100 Subject: [PATCH 11/15] Remove _parse_wildcards --- parcels/fieldset.py | 27 +++++++-------------------- 1 file changed, 7 insertions(+), 20 deletions(-) diff --git a/parcels/fieldset.py b/parcels/fieldset.py index 257431d434..3b582c2a18 100644 --- a/parcels/fieldset.py +++ b/parcels/fieldset.py @@ -3,7 +3,6 @@ import sys import warnings from copy import deepcopy -from glob import glob import dask.array as da import numpy as np @@ -348,7 +347,9 @@ def check_velocityfields(U, V, W): @classmethod @deprecated_made_private # TODO: Remove 6 months after v3.1.0 def parse_wildcards(cls, *args, **kwargs): - return _parse_wildcards(*args, **kwargs) + raise NotImplementedError( + "parse_wildcards was removed as a function as the internal implementation was no longer used." + ) @classmethod def from_netcdf( @@ -462,13 +463,11 @@ def from_netcdf( if "creation_log" not in kwargs.keys(): kwargs["creation_log"] = "from_netcdf" for var, name in variables.items(): - # Resolve all matching paths for the current variable - paths = filenames[var] if type(filenames) is dict and var in filenames else filenames - if type(paths) is not dict: - paths = _parse_wildcards(paths, filenames, var) + paths: list[str] + if isinstance(filenames, dict) and var in filenames: + paths = filenames[var] else: - for dim, p in paths.items(): - paths[dim] = _parse_wildcards(p, filenames, var) + paths = filenames # Use dimensions[var] and indices[var] if either of them is a dict of dicts dims = dimensions[var] if var in dimensions else dimensions @@ -1677,15 +1676,3 @@ def computeTimeChunk(self, time=0.0, dt=1): return nextTime else: return time + nSteps * dt - - -def _parse_wildcards(paths, filenames, var): - if not isinstance(paths, list): - paths = sorted(glob(str(paths))) - if len(paths) == 0: - notfound_paths = filenames[var] if isinstance(filenames, dict) and var in filenames else filenames - raise FileNotFoundError(f"FieldSet files not found for variable {var}: {notfound_paths}") - for fp in paths: - if not os.path.exists(fp): - raise FileNotFoundError(f"FieldSet file not found: {fp}") - return paths From 132183e9c38f7fd99dc491f90e9e1cbfb14af661 Mon Sep 17 00:00:00 2001 From: Vecko <36369090+VeckoTheGecko@users.noreply.github.com> Date: Wed, 4 Dec 2024 18:16:09 +0100 Subject: [PATCH 12/15] update test --- tests/test_field.py | 15 +++------------ 1 file changed, 3 insertions(+), 12 deletions(-) diff --git a/tests/test_field.py b/tests/test_field.py index f171dd3525..da9dff12ee 100644 --- a/tests/test_field.py +++ b/tests/test_field.py @@ -1,4 +1,3 @@ -import glob from pathlib import Path import cftime @@ -105,19 +104,11 @@ def test_sanitize_field_filenames_cases(input_, expected): @pytest.mark.parametrize( "input_,expected", [ - ("file*.nc", ["file0.nc", "file1.nc", "file2.nc"]), + pytest.param("file*.nc", [], id="glob-no-match"), ], ) -def test_sanitize_field_filenames_glob(input_, expected, tmp_path, monkeypatch): - def monkey_glob(pattern): - return glob.glob(pattern, root_dir=tmp_path) - - monkeypatch.setattr(_sanitize_field_filenames, "glob", monkey_glob) - - for f in expected: - Path(tmp_path / f).touch() - - assert _sanitize_field_filenames(input_, tmp_path) == expected +def test_sanitize_field_filenames_glob(input_, expected): + assert _sanitize_field_filenames(input_) == expected @pytest.mark.parametrize( From d8713c845c2300f2d3c257f1b144e20e264ebddb Mon Sep 17 00:00:00 2001 From: Vecko <36369090+VeckoTheGecko@users.noreply.github.com> Date: Wed, 4 Dec 2024 18:54:16 +0100 Subject: [PATCH 13/15] typing --- parcels/field.py | 20 ++++++++++++-------- 1 file changed, 12 insertions(+), 8 deletions(-) diff --git a/parcels/field.py b/parcels/field.py index d9095ccc4e..5d97ad6357 100644 --- a/parcels/field.py +++ b/parcels/field.py @@ -1,10 +1,12 @@ +from __future__ import annotations + import math import warnings from collections.abc import Iterable from ctypes import POINTER, Structure, c_float, c_int, pointer from glob import glob from pathlib import Path -from typing import TYPE_CHECKING, Any, Literal +from typing import TYPE_CHECKING, Literal import dask.array as da import numpy as np @@ -50,6 +52,8 @@ from parcels.fieldset import FieldSet + T_SanitizedFilenames = list[str] | dict[str, list[str]] + __all__ = ["Field", "NestedField", "VectorField"] @@ -484,7 +488,7 @@ def from_netcdf( time_periodic: TimePeriodic = False, deferred_load: bool = True, **kwargs, - ) -> "Field": + ) -> Field: """Create field from netCDF file. Parameters @@ -587,14 +591,14 @@ def from_netcdf( ), "The variable tuple must have length 2. Use FieldSet.from_netcdf() for multiple variables" data_filenames = _get_dim_filenames(filenames, "data") - lonlat_filename = _get_dim_filenames(filenames, "lon") + lonlat_filename_lst = _get_dim_filenames(filenames, "lon") if isinstance(filenames, dict): - assert len(lonlat_filename) == 1 - if lonlat_filename != _get_dim_filenames(filenames, "lat"): + assert len(lonlat_filename_lst) == 1 + if lonlat_filename_lst != _get_dim_filenames(filenames, "lat"): raise NotImplementedError( "longitude and latitude dimensions are currently processed together from one single file" ) - lonlat_filename = lonlat_filename[0] + lonlat_filename = lonlat_filename_lst[0] if "depth" in dimensions: depth_filename = _get_dim_filenames(filenames, "depth") if isinstance(filenames, dict) and len(depth_filename) != 1: @@ -2574,7 +2578,7 @@ def __getitem__(self, key): return val -def _get_dim_filenames(filenames: str | Path | Any | dict[str, str | Any], dim: str) -> Any: +def _get_dim_filenames(filenames: T_SanitizedFilenames, dim: str) -> list[str]: """Get's the relevant filenames for a given dimension.""" if isinstance(filenames, list): return filenames @@ -2585,7 +2589,7 @@ def _get_dim_filenames(filenames: str | Path | Any | dict[str, str | Any], dim: raise ValueError("Filenames must be a string, pathlib.Path, or a dictionary") -def _sanitize_field_filenames(filenames, *, recursed=False): +def _sanitize_field_filenames(filenames, *, recursed=False) -> T_SanitizedFilenames: """The Field initializer can take `filenames` to be of various formats including: 1. a string or Path object. String can be a glob expression. From 9013062da7edc6cf97499207253c8b2ade123092 Mon Sep 17 00:00:00 2001 From: Vecko <36369090+VeckoTheGecko@users.noreply.github.com> Date: Wed, 4 Dec 2024 19:22:39 +0100 Subject: [PATCH 14/15] Add from_netcdf path object test Contributes to #1706 --- tests/test_field.py | 19 +++++++++++++++++++ 1 file changed, 19 insertions(+) diff --git a/tests/test_field.py b/tests/test_field.py index da9dff12ee..2478cea2c6 100644 --- a/tests/test_field.py +++ b/tests/test_field.py @@ -50,6 +50,25 @@ def test_field_from_netcdf(with_timestamps): Field.from_netcdf(filenames, variable, dimensions, interp_method="cgrid_velocity") +@pytest.mark.parametrize( + "f", + [ + pytest.param(lambda x: x, id="Path"), + pytest.param(lambda x: str(x), id="str"), + ], +) +def test_from_netcdf_path_object(f): + filenames = { + "lon": f(TEST_DATA / "mask_nemo_cross_180lon.nc"), + "lat": f(TEST_DATA / "mask_nemo_cross_180lon.nc"), + "data": f(TEST_DATA / "Uu_eastward_nemo_cross_180lon.nc"), + } + variable = "U" + dimensions = {"lon": "glamf", "lat": "gphif"} + + Field.from_netcdf(filenames, variable, dimensions, interp_method="cgrid_velocity") + + @pytest.mark.parametrize( "calendar, cftime_datetime", zip(_get_cftime_calendars(), _get_cftime_datetimes(), strict=True) ) From 508006d8681f59e72f261b6a3835d37c5519d4e5 Mon Sep 17 00:00:00 2001 From: Vecko <36369090+VeckoTheGecko@users.noreply.github.com> Date: Wed, 4 Dec 2024 19:34:16 +0100 Subject: [PATCH 15/15] typing --- parcels/field.py | 11 ++++++----- 1 file changed, 6 insertions(+), 5 deletions(-) diff --git a/parcels/field.py b/parcels/field.py index 5d97ad6357..048de22759 100644 --- a/parcels/field.py +++ b/parcels/field.py @@ -52,7 +52,8 @@ from parcels.fieldset import FieldSet - T_SanitizedFilenames = list[str] | dict[str, list[str]] + T_Dimensions = Literal["lon", "lat", "depth", "data"] + T_SanitizedFilenames = list[str] | dict[T_Dimensions, list[str]] __all__ = ["Field", "NestedField", "VectorField"] @@ -600,10 +601,10 @@ def from_netcdf( ) lonlat_filename = lonlat_filename_lst[0] if "depth" in dimensions: - depth_filename = _get_dim_filenames(filenames, "depth") - if isinstance(filenames, dict) and len(depth_filename) != 1: + depth_filename_lst = _get_dim_filenames(filenames, "depth") + if isinstance(filenames, dict) and len(depth_filename_lst) != 1: raise NotImplementedError("Vertically adaptive meshes not implemented for from_netcdf()") - depth_filename = depth_filename[0] + depth_filename = depth_filename_lst[0] netcdf_engine = kwargs.pop("netcdf_engine", "netcdf4") gridindexingtype = kwargs.get("gridindexingtype", "nemo") @@ -2578,7 +2579,7 @@ def __getitem__(self, key): return val -def _get_dim_filenames(filenames: T_SanitizedFilenames, dim: str) -> list[str]: +def _get_dim_filenames(filenames: T_SanitizedFilenames, dim: T_Dimensions) -> list[str]: """Get's the relevant filenames for a given dimension.""" if isinstance(filenames, list): return filenames