Skip to content

Commit d64d896

Browse files
Using a dictionary of arrays for particleset data
1 parent 488e3fb commit d64d896

File tree

2 files changed

+20
-27
lines changed

2 files changed

+20
-27
lines changed

parcels/particle.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -115,7 +115,7 @@ def __init__(self, data: xr.Dataset, index: int):
115115
self._index = index
116116

117117
def __getattr__(self, name):
118-
return self._data[name].values[self._index]
118+
return self._data[name][self._index]
119119

120120
def __setattr__(self, name, value):
121121
if name in ["_data", "_index"]:

parcels/particleset.py

Lines changed: 19 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -135,36 +135,29 @@ def __init__(
135135
lon.size == kwargs[kwvar].size
136136
), f"{kwvar} and positions (lon, lat, depth) don't have the same lengths."
137137

138-
self._data = xr.Dataset(
139-
{
140-
"lon": (["trajectory"], lon.astype(lonlatdepth_dtype)),
141-
"lat": (["trajectory"], lat.astype(lonlatdepth_dtype)),
142-
"depth": (["trajectory"], depth.astype(lonlatdepth_dtype)),
143-
"time": (["trajectory"], time),
144-
"dt": (["trajectory"], np.timedelta64(1, "ns") * np.ones(len(trajectory_ids))),
145-
"ei": (["trajectory", "ngrid"], np.zeros((len(trajectory_ids), len(fieldset.gridset)), dtype=np.int32)),
146-
"state": (["trajectory"], np.zeros((len(trajectory_ids)), dtype=np.int32)),
147-
"lon_nextloop": (["trajectory"], lon.astype(lonlatdepth_dtype)),
148-
"lat_nextloop": (["trajectory"], lat.astype(lonlatdepth_dtype)),
149-
"depth_nextloop": (["trajectory"], depth.astype(lonlatdepth_dtype)),
150-
"time_nextloop": (["trajectory"], time),
151-
},
152-
coords={
153-
"trajectory": ("trajectory", trajectory_ids),
154-
},
155-
attrs={
156-
"ngrid": len(fieldset.gridset),
157-
"ptype": pclass.getPType(),
158-
},
159-
)
138+
self._data = {
139+
"lon": lon.astype(lonlatdepth_dtype),
140+
"lat": lat.astype(lonlatdepth_dtype),
141+
"depth": depth.astype(lonlatdepth_dtype),
142+
"time": time,
143+
"dt": np.timedelta64(1, "ns") * np.ones(len(trajectory_ids)),
144+
# "ei": (["trajectory", "ngrid"], np.zeros((len(trajectory_ids), len(fieldset.gridset)), dtype=np.int32)),
145+
"state": np.zeros((len(trajectory_ids)), dtype=np.int32),
146+
"lon_nextloop": lon.astype(lonlatdepth_dtype),
147+
"lat_nextloop": lat.astype(lonlatdepth_dtype),
148+
"depth_nextloop": depth.astype(lonlatdepth_dtype),
149+
"time_nextloop": time,
150+
"trajectory": trajectory_ids,
151+
}
152+
self.ptype = pclass.getPType()
160153
# add extra fields from the custom Particle class
161154
for v in pclass.__dict__.values():
162155
if isinstance(v, Variable):
163156
if isinstance(v.initial, attrgetter):
164157
initial = v.initial(self).values
165158
else:
166159
initial = v.initial * np.ones(len(trajectory_ids), dtype=v.dtype)
167-
self._data[v.name] = (["trajectory"], initial)
160+
self._data[v.name] = initial
168161

169162
# update initial values provided on ParticleSet creation
170163
for kwvar, kwval in kwargs.items():
@@ -591,19 +584,19 @@ def Kernel(self, pyfunc):
591584
if isinstance(pyfunc, list):
592585
return Kernel.from_list(
593586
self.fieldset,
594-
self._data.ptype,
587+
self.ptype,
595588
pyfunc,
596589
)
597590
return Kernel(
598591
self.fieldset,
599-
self._data.ptype,
592+
self.ptype,
600593
pyfunc=pyfunc,
601594
)
602595

603596
def InteractionKernel(self, pyfunc_inter):
604597
if pyfunc_inter is None:
605598
return None
606-
return InteractionKernel(self.fieldset, self._data.ptype, pyfunc=pyfunc_inter)
599+
return InteractionKernel(self.fieldset, self.ptype, pyfunc=pyfunc_inter)
607600

608601
def ParticleFile(self, *args, **kwargs):
609602
"""Wrapper method to initialise a :class:`parcels.particlefile.ParticleFile` object from the ParticleSet."""

0 commit comments

Comments
 (0)