Skip to content

Commit eeedce8

Browse files
Merge pull request #2453 from Parcels-code/use_sgrid_for_simple_UV_dataset
Using from_sgrid_conventions for simple_UV_dataset throughout
2 parents 4ac7ebe + 3627866 commit eeedce8

15 files changed

+74
-157
lines changed

docs/user_guide/examples/tutorial_diffusion.ipynb

Lines changed: 2 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -192,7 +192,7 @@
192192
"source": [
193193
"from parcels._datasets.structured.generated import simple_UV_dataset\n",
194194
"\n",
195-
"ds = simple_UV_dataset(dims=(1, 1, Ny, 1), mesh=\"flat\").isel(time=0, depth=0)\n",
195+
"ds = simple_UV_dataset(dims=(1, 1, Ny, 1), mesh=\"flat\")\n",
196196
"ds[\"lat\"][:] = np.linspace(-0.01, 1.01, Ny)\n",
197197
"ds[\"lon\"][:] = np.ones(len(ds.XG))\n",
198198
"ds[\"Kh_meridional\"] = ([\"YG\", \"XG\"], Kh_meridional[:, None])\n",
@@ -205,20 +205,8 @@
205205
"metadata": {},
206206
"outputs": [],
207207
"source": [
208-
"grid = parcels.XGrid.from_dataset(ds, mesh=\"flat\")\n",
209-
"U = parcels.Field(\"U\", ds[\"U\"], grid, interp_method=parcels.interpolators.XLinear)\n",
210-
"V = parcels.Field(\"V\", ds[\"V\"], grid, interp_method=parcels.interpolators.XLinear)\n",
211-
"UV = parcels.VectorField(\"UV\", U, V)\n",
212-
"\n",
213-
"Kh_meridional_field = parcels.Field(\n",
214-
" \"Kh_meridional\",\n",
215-
" ds[\"Kh_meridional\"],\n",
216-
" grid,\n",
217-
" interp_method=parcels.interpolators.XLinear,\n",
218-
")\n",
219-
"fieldset = parcels.FieldSet([U, V, UV, Kh_meridional_field])\n",
208+
"fieldset = parcels.FieldSet.from_sgrid_conventions(ds, mesh=\"flat\")\n",
220209
"fieldset.add_constant_field(\"Kh_zonal\", 1, mesh=\"flat\")\n",
221-
"\n",
222210
"fieldset.add_constant(\"dres\", 0.00005)"
223211
]
224212
},

docs/user_guide/examples/tutorial_interpolation.ipynb

Lines changed: 4 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -46,20 +46,20 @@
4646
"source": [
4747
"from parcels._datasets.structured.generated import simple_UV_dataset\n",
4848
"\n",
49-
"ds = simple_UV_dataset(dims=(1, 1, 5, 4), mesh=\"flat\").isel(time=0, depth=0)\n",
49+
"ds = simple_UV_dataset(dims=(1, 1, 5, 4), mesh=\"flat\")\n",
5050
"ds[\"lat\"][:] = np.linspace(0.0, 1.0, len(ds.YG))\n",
5151
"ds[\"lon\"][:] = np.linspace(0.0, 1.0, len(ds.XG))\n",
5252
"dx, dy = 1.0 / len(ds.XG), 1.0 / len(ds.YG)\n",
5353
"ds[\"P\"] = ds[\"U\"] + np.random.rand(5, 4) + 0.1\n",
54-
"ds[\"P\"][1, 1] = 0\n",
54+
"ds[\"P\"][:, :, 1, 1] = 0\n",
5555
"ds"
5656
]
5757
},
5858
{
5959
"cell_type": "markdown",
6060
"metadata": {},
6161
"source": [
62-
"From this dataset we create a {py:obj}`parcels.FieldSet`. Parcels requires an interpolation method to be set for each {py:obj}`parcels.Field`, which we will later adapt to see the effects of the different interpolators. A common interpolator for fields on structured grids is (tri)linear, implemented in {py:obj}`parcels.interpolators.XLinear`."
62+
"From this dataset we create a {py:obj}`parcels.FieldSet` using the {py:meth}`parcels.FieldSet.from_sgrid_conventions` constructor, which automatically sets up the grid and fields according to Parcels' s-grid conventions."
6363
]
6464
},
6565
{
@@ -68,12 +68,7 @@
6868
"metadata": {},
6969
"outputs": [],
7070
"source": [
71-
"grid = parcels.XGrid.from_dataset(ds, mesh=\"flat\")\n",
72-
"U = parcels.Field(\"U\", ds[\"U\"], grid, interp_method=parcels.interpolators.XLinear)\n",
73-
"V = parcels.Field(\"V\", ds[\"V\"], grid, interp_method=parcels.interpolators.XLinear)\n",
74-
"UV = parcels.VectorField(\"UV\", U, V)\n",
75-
"P = parcels.Field(\"P\", ds[\"P\"], grid, interp_method=parcels.interpolators.XLinear)\n",
76-
"fieldset = parcels.FieldSet([U, V, UV, P])"
71+
"fieldset = parcels.FieldSet.from_sgrid_conventions(ds, mesh=\"flat\")"
7772
]
7873
},
7974
{

docs/user_guide/examples/tutorial_nestedgrids.ipynb

Lines changed: 3 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -440,12 +440,7 @@
440440
"source": [
441441
"fields = [GridID]\n",
442442
"for i, ds in enumerate(ds_in):\n",
443-
" # TODO : use FieldSet.from_sgrid_convetion here once #2437 is merged\n",
444-
" grid = parcels.XGrid.from_dataset(ds, mesh=\"spherical\")\n",
445-
" U = parcels.Field(\"U\", ds[\"U\"], grid, interp_method=parcels.interpolators.XLinear)\n",
446-
" V = parcels.Field(\"V\", ds[\"V\"], grid, interp_method=parcels.interpolators.XLinear)\n",
447-
" UV = parcels.VectorField(\"UV\", U, V)\n",
448-
" fset = parcels.FieldSet([U, V, UV])\n",
443+
" fset = parcels.FieldSet.from_sgrid_conventions(ds, mesh=\"spherical\")\n",
449444
"\n",
450445
" for fld in fset.fields.values():\n",
451446
" fld.name = f\"{fld.name}{i}\"\n",
@@ -469,32 +464,19 @@
469464
"outputs": [],
470465
"source": [
471466
"def AdvectEE_NestedGrids(particles, fieldset):\n",
472-
" particles.gridID = fieldset.GridID[particles]\n",
473-
"\n",
474-
" # TODO because of KernelParticle bug (GH #2143), we need to copy lon/lat/time to local variables\n",
475-
" time = particles.time\n",
476-
" z = particles.z\n",
477-
" lat = particles.lat\n",
478-
" lon = particles.lon\n",
479467
" u = np.zeros_like(particles.lon)\n",
480468
" v = np.zeros_like(particles.lat)\n",
481469
"\n",
470+
" particles.gridID = fieldset.GridID[particles]\n",
482471
" unique_ids = np.unique(particles.gridID)\n",
483472
" for gid in unique_ids:\n",
484473
" mask = particles.gridID == gid\n",
485474
" UVField = getattr(fieldset, f\"UV{gid}\")\n",
486-
" (u[mask], v[mask]) = UVField[time[mask], z[mask], lat[mask], lon[mask]]\n",
475+
" (u[mask], v[mask]) = UVField[particles[mask]]\n",
487476
"\n",
488477
" particles.dlon += u * particles.dt\n",
489478
" particles.dlat += v * particles.dt\n",
490479
"\n",
491-
" # TODO particle states have to be updated manually because UVField is not called with `particles` argument (becaise of GH #2143)\n",
492-
" particles.state = np.where(\n",
493-
" np.isnan(u) | np.isnan(v),\n",
494-
" parcels.StatusCode.ErrorInterpolation,\n",
495-
" particles.state,\n",
496-
" )\n",
497-
"\n",
498480
"\n",
499481
"lat = np.linspace(-17, 35, 10)\n",
500482
"lon = np.full(len(lat), -5)\n",

docs/user_guide/examples/tutorial_statuscodes.md

Lines changed: 2 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -63,19 +63,14 @@ Let's add the `KeepInOcean` Kernel to an particle simulation where particles mov
6363
import numpy as np
6464
from parcels._datasets.structured.generated import simple_UV_dataset
6565
66-
ds = simple_UV_dataset(dims=(1, 2, 5, 4), mesh="flat").isel(time=0)
66+
ds = simple_UV_dataset(dims=(1, 2, 5, 4), mesh="flat")
6767
6868
dx, dy = 1.0 / len(ds.XG), 1.0 / len(ds.YG)
6969
7070
# Add W velocity that pushes through surface
7171
ds["W"] = ds["U"] - 0.1 # 0.1 m/s towards the surface
7272
73-
grid = parcels.XGrid.from_dataset(ds, mesh="flat")
74-
U = parcels.Field("U", ds["U"], grid, interp_method=parcels.interpolators.XLinear)
75-
V = parcels.Field("V", ds["V"], grid, interp_method=parcels.interpolators.XLinear)
76-
W = parcels.Field("W", ds["W"], grid, interp_method=parcels.interpolators.XLinear)
77-
UVW = parcels.VectorField("UVW", U, V, W)
78-
fieldset = parcels.FieldSet([U, V, W, UVW])
73+
fieldset = parcels.FieldSet.from_sgrid_conventions(ds, mesh="flat")
7974
```
8075

8176
If we advect particles with the `AdvectionRK2_3D` kernel, Parcels will raise a `FieldOutOfBoundSurfaceError`:

docs/user_guide/examples/tutorial_unitconverters.ipynb

Lines changed: 6 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -49,7 +49,7 @@
4949
"\n",
5050
"nlat = 10\n",
5151
"nlon = 18\n",
52-
"ds = simple_UV_dataset(dims=(1, 1, nlat, nlon), mesh=\"spherical\").isel(time=0, depth=0)\n",
52+
"ds = simple_UV_dataset(dims=(1, 1, nlat, nlon), mesh=\"spherical\")\n",
5353
"ds[\"temperature\"] = ds[\"U\"] + 20 # add temperature field of 20 deg\n",
5454
"ds[\"U\"].data[:] = 1.0 # set U to 1 m/s\n",
5555
"ds[\"V\"].data[:] = 1.0 # set V to 1 m/s\n",
@@ -61,7 +61,7 @@
6161
"cell_type": "markdown",
6262
"metadata": {},
6363
"source": [
64-
"To create a `parcels.FieldSet` object, we define the `parcels.Field`s and the structured grid (`parcels.XGrid`) the fields are defined on. We add the argument `mesh='spherical'` to the `parcels.XGrid` to signal that all longitudes and latitudes are in degrees.\n",
64+
"To create a `parcels.FieldSet` object, we use the `parcels.FieldSet.from_sgrid_conventions` constructor. We add the argument `mesh='spherical'` to signal that all longitudes and latitudes are in degrees.\n",
6565
"\n",
6666
"```{note}\n",
6767
"When using a `FieldSet` method for a specific dataset, such as `from_copernicusmarine()`, the grid information is known and parsed by Parcels, so we do not have to add the `mesh` argument.\n",
@@ -76,14 +76,7 @@
7676
"metadata": {},
7777
"outputs": [],
7878
"source": [
79-
"grid = parcels.XGrid.from_dataset(ds, mesh=\"spherical\")\n",
80-
"U = parcels.Field(\"U\", ds[\"U\"], grid, interp_method=parcels.interpolators.XLinear)\n",
81-
"V = parcels.Field(\"V\", ds[\"V\"], grid, interp_method=parcels.interpolators.XLinear)\n",
82-
"UV = parcels.VectorField(\"UV\", U, V)\n",
83-
"temperature = parcels.Field(\n",
84-
" \"temperature\", ds[\"temperature\"], grid, interp_method=parcels.interpolators.XLinear\n",
85-
")\n",
86-
"fieldset = parcels.FieldSet([U, V, UV, temperature])\n",
79+
"fieldset = parcels.FieldSet.from_sgrid_conventions(ds, mesh=\"spherical\")\n",
8780
"\n",
8881
"plt.pcolormesh(\n",
8982
" fieldset.U.grid.lon,\n",
@@ -199,7 +192,7 @@
199192
"cell_type": "markdown",
200193
"metadata": {},
201194
"source": [
202-
"If longitudes and latitudes are given in meters, rather than degrees, simply add `mesh='flat'` when creating the XGrid object.\n"
195+
"If longitudes and latitudes are given in meters, rather than degrees, simply add `mesh='flat'` when creating the `FieldSet` object.\n"
203196
]
204197
},
205198
{
@@ -208,21 +201,11 @@
208201
"metadata": {},
209202
"outputs": [],
210203
"source": [
211-
"ds_flat = simple_UV_dataset(dims=(1, 1, nlat, nlon), mesh=\"flat\").isel(time=0, depth=0)\n",
204+
"ds_flat = simple_UV_dataset(dims=(1, 1, nlat, nlon), mesh=\"flat\")\n",
212205
"ds_flat[\"temperature\"] = ds_flat[\"U\"] + 20 # add temperature field of 20 deg\n",
213206
"ds_flat[\"U\"].data[:] = 1.0 # set U to 1 m/s\n",
214207
"ds_flat[\"V\"].data[:] = 1.0 # set V to 1 m/s\n",
215-
"grid = parcels.XGrid.from_dataset(ds_flat, mesh=\"flat\")\n",
216-
"U = parcels.Field(\"U\", ds_flat[\"U\"], grid, interp_method=parcels.interpolators.XLinear)\n",
217-
"V = parcels.Field(\"V\", ds_flat[\"V\"], grid, interp_method=parcels.interpolators.XLinear)\n",
218-
"UV = parcels.VectorField(\"UV\", U, V)\n",
219-
"temperature = parcels.Field(\n",
220-
" \"temperature\",\n",
221-
" ds_flat[\"temperature\"],\n",
222-
" grid,\n",
223-
" interp_method=parcels.interpolators.XLinear,\n",
224-
")\n",
225-
"fieldset_flat = parcels.FieldSet([U, V, UV, temperature])\n",
208+
"fieldset_flat = parcels.FieldSet.from_sgrid_conventions(ds_flat, mesh=\"flat\")\n",
226209
"\n",
227210
"plt.pcolormesh(\n",
228211
" fieldset_flat.U.grid.lon,\n",

src/parcels/_core/fieldset.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -351,12 +351,11 @@ def from_sgrid_conventions(
351351
if "U" in ds.data_vars and "V" in ds.data_vars:
352352
fields["U"] = Field("U", ds["U"], grid, XLinear)
353353
fields["V"] = Field("V", ds["V"], grid, XLinear)
354+
fields["UV"] = VectorField("UV", fields["U"], fields["V"])
354355

355356
if "W" in ds.data_vars:
356357
fields["W"] = Field("W", ds["W"], grid, XLinear)
357358
fields["UVW"] = VectorField("UVW", fields["U"], fields["V"], fields["W"])
358-
else:
359-
fields["UV"] = VectorField("UV", fields["U"], fields["V"])
360359

361360
for varname in set(ds.data_vars) - set(fields.keys()) - skip_vars:
362361
fields[varname] = Field(varname, ds[varname], grid, XLinear)

src/parcels/_datasets/structured/generated.py

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,13 @@
33
import numpy as np
44
import xarray as xr
55

6+
from parcels._core.utils.sgrid import (
7+
DimDimPadding,
8+
Grid2DMetadata,
9+
Padding,
10+
)
611
from parcels._core.utils.time import timedelta_to_float
12+
from parcels._datasets.utils import _attach_sgrid_metadata
713

814

915
def simple_UV_dataset(dims=(360, 2, 30, 4), maxdepth=1, mesh="spherical"):
@@ -21,6 +27,18 @@ def simple_UV_dataset(dims=(360, 2, 30, 4), maxdepth=1, mesh="spherical"):
2127
"lat": (["YG"], np.linspace(-90, 90, dims[2]), {"axis": "Y", "c_grid_axis_shift": 0.5}),
2228
"lon": (["XG"], np.linspace(-max_lon, max_lon, dims[3]), {"axis": "X", "c_grid_axis_shift": -0.5}),
2329
},
30+
).pipe(
31+
_attach_sgrid_metadata,
32+
Grid2DMetadata(
33+
cf_role="grid_topology",
34+
topology_dimension=2,
35+
node_dimensions=("XG", "YG"),
36+
face_dimensions=(
37+
DimDimPadding("XC", "XG", Padding.LOW),
38+
DimDimPadding("YC", "YG", Padding.LOW),
39+
),
40+
vertical_dimensions=(DimDimPadding("ZC", "depth", Padding.BOTH),),
41+
),
2442
)
2543

2644

src/parcels/_datasets/structured/generic.py

Lines changed: 1 addition & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -4,12 +4,12 @@
44
from parcels._core.utils.sgrid import (
55
DimDimPadding,
66
Grid2DMetadata,
7-
Grid3DMetadata,
87
Padding,
98
)
109
from parcels._core.utils.sgrid import (
1110
rename_dims as sgrid_rename_dims,
1211
)
12+
from parcels._datasets.utils import _attach_sgrid_metadata
1313

1414
from . import T, X, Y, Z
1515

@@ -18,18 +18,6 @@
1818
TIME = xr.date_range("2000", "2001", T)
1919

2020

21-
def _attach_sgrid_metadata(ds, grid: Grid2DMetadata | Grid3DMetadata):
22-
"""Copies the dataset and attaches the SGRID metadata in 'grid' variable. Modifies 'conventions' attribute."""
23-
ds = ds.copy()
24-
ds["grid"] = (
25-
[],
26-
0,
27-
grid.to_attrs(),
28-
)
29-
ds.attrs["Conventions"] = "SGRID"
30-
return ds
31-
32-
3321
def _rotated_curvilinear_grid():
3422
XG = np.arange(X)
3523
YG = np.arange(Y)

src/parcels/_datasets/utils.py

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,9 +4,26 @@
44
import numpy as np
55
import xarray as xr
66

7+
from parcels._core.utils.sgrid import (
8+
Grid2DMetadata,
9+
Grid3DMetadata,
10+
)
11+
712
_SUPPORTED_ATTR_TYPES = int | float | str | np.ndarray
813

914

15+
def _attach_sgrid_metadata(ds, grid: Grid2DMetadata | Grid3DMetadata):
16+
"""Copies the dataset and attaches the SGRID metadata in 'grid' variable. Modifies 'conventions' attribute."""
17+
ds = ds.copy()
18+
ds["grid"] = (
19+
[],
20+
0,
21+
grid.to_attrs(),
22+
)
23+
ds.attrs["Conventions"] = "SGRID"
24+
return ds
25+
26+
1027
def _print_mismatched_keys(d1: dict[Any, Any], d2: dict[Any, Any]) -> None:
1128
k1 = set(d1.keys())
1229
k2 = set(d2.keys())

tests/test_advection.py

Lines changed: 9 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -33,11 +33,7 @@ def test_advection_zonal(mesh, npart=10):
3333
"""Particles at high latitude move geographically faster due to the pole correction in `GeographicPolar`."""
3434
ds = simple_UV_dataset(mesh=mesh)
3535
ds["U"].data[:] = 1.0
36-
grid = XGrid.from_dataset(ds, mesh=mesh)
37-
U = Field("U", ds["U"], grid, interp_method=XLinear)
38-
V = Field("V", ds["V"], grid, interp_method=XLinear)
39-
UV = VectorField("UV", U, V)
40-
fieldset = FieldSet([U, V, UV])
36+
fieldset = FieldSet.from_sgrid_conventions(ds, mesh=mesh)
4137

4238
pset = ParticleSet(fieldset, lon=np.zeros(npart) + 20.0, lat=np.linspace(0, 80, npart))
4339
pset.execute(AdvectionRK4, runtime=np.timedelta64(2, "h"), dt=np.timedelta64(15, "m"))
@@ -53,11 +49,7 @@ def test_advection_zonal_with_particlefile(tmp_store):
5349
npart = 10
5450
ds = simple_UV_dataset(mesh="flat")
5551
ds["U"].data[:] = 1.0
56-
grid = XGrid.from_dataset(ds, mesh="flat")
57-
U = Field("U", ds["U"], grid, interp_method=XLinear)
58-
V = Field("V", ds["V"], grid, interp_method=XLinear)
59-
UV = VectorField("UV", U, V)
60-
fieldset = FieldSet([U, V, UV])
52+
fieldset = FieldSet.from_sgrid_conventions(ds, mesh="flat")
6153

6254
pset = ParticleSet(fieldset, lon=np.zeros(npart) + 20.0, lat=np.linspace(0, 80, npart))
6355
pfile = ParticleFile(tmp_store, outputdt=np.timedelta64(30, "m"))
@@ -85,11 +77,7 @@ def test_advection_zonal_periodic():
8577
halo.XG.values = ds.XG.values[1] + 2
8678
ds = xr.concat([ds, halo], dim="XG")
8779

88-
grid = XGrid.from_dataset(ds, mesh="flat")
89-
U = Field("U", ds["U"], grid, interp_method=XLinear)
90-
V = Field("V", ds["V"], grid, interp_method=XLinear)
91-
UV = VectorField("UV", U, V)
92-
fieldset = FieldSet([U, V, UV])
80+
fieldset = FieldSet.from_sgrid_conventions(ds, mesh="flat")
9381

9482
PeriodicParticle = Particle.add_variable(Variable("total_dlon", initial=0))
9583
startlon = np.array([0.5, 0.4])
@@ -104,12 +92,8 @@ def test_horizontal_advection_in_3D_flow(npart=10):
10492
"""Flat 2D zonal flow that increases linearly with z from 0 m/s to 1 m/s."""
10593
ds = simple_UV_dataset(mesh="flat")
10694
ds["U"].data[:] = 1.0
107-
grid = XGrid.from_dataset(ds, mesh="flat")
108-
U = Field("U", ds["U"], grid, interp_method=XLinear)
109-
U.data[:, 0, :, :] = 0.0 # Set U to 0 at the surface
110-
V = Field("V", ds["V"], grid, interp_method=XLinear)
111-
UV = VectorField("UV", U, V)
112-
fieldset = FieldSet([U, V, UV])
95+
ds["U"].data[:, 0, :, :] = 0.0 # Set U to 0 at the surface
96+
fieldset = FieldSet.from_sgrid_conventions(ds, mesh="flat")
11397

11498
pset = ParticleSet(fieldset, lon=np.zeros(npart), lat=np.zeros(npart), z=np.linspace(0.1, 0.9, npart))
11599
pset.execute(AdvectionRK4, runtime=np.timedelta64(2, "h"), dt=np.timedelta64(15, "m"))
@@ -122,15 +106,10 @@ def test_horizontal_advection_in_3D_flow(npart=10):
122106
@pytest.mark.parametrize("wErrorThroughSurface", [True, False])
123107
def test_advection_3D_outofbounds(direction, wErrorThroughSurface):
124108
ds = simple_UV_dataset(mesh="flat")
125-
grid = XGrid.from_dataset(ds, mesh="flat")
126-
U = Field("U", ds["U"], grid, interp_method=XLinear)
127-
U.data[:] = 0.01 # Set U to small value (to avoid horizontal out of bounds)
128-
V = Field("V", ds["V"], grid, interp_method=XLinear)
129-
W = Field("W", ds["V"], grid, interp_method=XLinear) # Use V as W for testing
130-
W.data[:] = -1.0 if direction == "up" else 1.0
131-
UVW = VectorField("UVW", U, V, W)
132-
UV = VectorField("UV", U, V)
133-
fieldset = FieldSet([U, V, W, UVW, UV])
109+
ds["W"] = ds["V"].copy() # Just to have W field present
110+
ds["U"].data[:] = 0.01 # Set U to small value (to avoid horizontal out of bounds)
111+
ds["W"].data[:] = -1.0 if direction == "up" else 1.0
112+
fieldset = FieldSet.from_sgrid_conventions(ds, mesh="flat")
134113

135114
def DeleteParticle(particles, fieldset): # pragma: no cover
136115
particles.state = np.where(particles.state == StatusCode.ErrorOutOfBounds, StatusCode.Delete, particles.state)

0 commit comments

Comments
 (0)