Skip to content

Commit 3ee1c37

Browse files
implementing particleset adding and deleting methods for dicts
And fixing some unit tests
1 parent c99e29f commit 3ee1c37

File tree

4 files changed

+32
-17
lines changed

4 files changed

+32
-17
lines changed

parcels/particleset.py

Lines changed: 27 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -149,12 +149,12 @@ def __init__(
149149
"time_nextloop": time,
150150
"trajectory": trajectory_ids,
151151
}
152-
self.ptype = pclass.getPType()
152+
self._ptype = pclass.getPType()
153153
# add extra fields from the custom Particle class
154154
for v in pclass.__dict__.values():
155155
if isinstance(v, Variable):
156156
if isinstance(v.initial, attrgetter):
157-
initial = v.initial(self).values
157+
initial = v.initial(self)
158158
else:
159159
initial = v.initial * np.ones(len(trajectory_ids), dtype=v.dtype)
160160
self._data[v.name] = initial
@@ -231,13 +231,28 @@ def add(self, particles):
231231
The current ParticleSet
232232
233233
"""
234+
assert (
235+
particles is not None
236+
), f"Trying to add another {type(self)} to this one, but the other one is None - invalid operation."
237+
assert type(particles) is type(self)
238+
239+
if len(particles) == 0:
240+
return
241+
242+
if len(self) == 0:
243+
self._data = particles._data
244+
return
245+
234246
if isinstance(particles, type(self)):
235247
if len(self._data["trajectory"]) > 0:
236-
offset = self._data["trajectory"].values.max() + 1
248+
offset = self._data["trajectory"].max() + 1
237249
else:
238250
offset = 0
239-
particles._data["trajectory"] = particles._data["trajectory"].values + offset
240-
self._data = xr.concat([self._data, particles._data], dim="trajectory")
251+
particles._data["trajectory"] = particles._data["trajectory"] + offset
252+
253+
for d in self._data:
254+
self._data[d] = np.concatenate((self._data[d], particles._data[d]))
255+
241256
# Adding particles invalidates the neighbor search structure.
242257
self._dirty_neighbor = True
243258
return self
@@ -263,7 +278,8 @@ def __iadd__(self, particles):
263278

264279
def remove_indices(self, indices):
265280
"""Method to remove particles from the ParticleSet, based on their `indices`."""
266-
self._data = self._data.drop_sel(trajectory=indices)
281+
for d in self._data:
282+
self._data[d] = np.delete(self._data[d], indices, axis=0)
267283

268284
def _active_particles_mask(self, time, dt):
269285
active_indices = (time - self._data["time"]) / dt >= 0
@@ -584,19 +600,19 @@ def Kernel(self, pyfunc):
584600
if isinstance(pyfunc, list):
585601
return Kernel.from_list(
586602
self.fieldset,
587-
self.ptype,
603+
self._ptype,
588604
pyfunc,
589605
)
590606
return Kernel(
591607
self.fieldset,
592-
self.ptype,
608+
self._ptype,
593609
pyfunc=pyfunc,
594610
)
595611

596612
def InteractionKernel(self, pyfunc_inter):
597613
if pyfunc_inter is None:
598614
return None
599-
return InteractionKernel(self.fieldset, self.ptype, pyfunc=pyfunc_inter)
615+
return InteractionKernel(self.fieldset, self._ptype, pyfunc=pyfunc_inter)
600616

601617
def ParticleFile(self, *args, **kwargs):
602618
"""Wrapper method to initialise a :class:`parcels.particlefile.ParticleFile` object from the ParticleSet."""
@@ -740,9 +756,9 @@ def execute(
740756
else:
741757
if not np.isnat(self._data["time_nextloop"]).any():
742758
if sign_dt > 0:
743-
start_time = self._data["time_nextloop"].min().values
759+
start_time = self._data["time_nextloop"].min()
744760
else:
745-
start_time = self._data["time_nextloop"].max().values
761+
start_time = self._data["time_nextloop"].max()
746762
else:
747763
if sign_dt > 0:
748764
start_time = self.fieldset.time_interval.left

tests/v4/test_kernel.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -43,7 +43,7 @@ def test_unknown_var_in_kernel(fieldset):
4343
def ErrorKernel(particle, fieldset, time): # pragma: no cover
4444
particle.unknown_varname += 0.2
4545

46-
with pytest.raises(KeyError, match="No variable named 'unknown_varname'"):
46+
with pytest.raises(KeyError, match="'unknown_varname'"):
4747
pset.execute(ErrorKernel, runtime=np.timedelta64(2, "s"))
4848

4949

tests/v4/test_particleset.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -198,17 +198,17 @@ def test_pset_add_explicit(fieldset):
198198
assert len(pset) == npart
199199
assert np.allclose([p.lon for p in pset], lon, atol=1e-12)
200200
assert np.allclose([p.lat for p in pset], lat, atol=1e-12)
201-
assert np.allclose(np.diff(pset._data.trajectory), np.ones(pset._data.trajectory.size - 1), atol=1e-12)
201+
assert np.allclose(np.diff(pset._data["trajectory"]), np.ones(pset._data["trajectory"].size - 1), atol=1e-12)
202202

203203

204204
def test_pset_add_implicit(fieldset):
205205
pset = ParticleSet(fieldset, lon=np.zeros(3), lat=np.ones(3), pclass=Particle)
206206
pset += ParticleSet(fieldset, lon=np.ones(4), lat=np.zeros(4), pclass=Particle)
207207
assert len(pset) == 7
208-
assert np.allclose(np.diff(pset._data.trajectory), np.ones(6), atol=1e-12)
208+
assert np.allclose(np.diff(pset._data["trajectory"]), np.ones(6), atol=1e-12)
209209

210210

211-
def test_pset_add_implicit(fieldset, npart=10):
211+
def test_pset_add_implicit_in_loop(fieldset, npart=10):
212212
pset = ParticleSet(fieldset, lon=[], lat=[])
213213
for _ in range(npart):
214214
pset += ParticleSet(pclass=Particle, lon=0.1, lat=0.1, fieldset=fieldset)

tests/v4/test_particleset_execute.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -112,8 +112,7 @@ def PythonFail(particle, fieldset, time): # pragma: no cover
112112
with pytest.raises(RuntimeError):
113113
pset.execute(PythonFail, runtime=np.timedelta64(20, "s"), dt=np.timedelta64(2, "s"))
114114
assert len(pset) == npart
115-
assert pset.time[0] == fieldset.time_interval.left + np.timedelta64(10, "s")
116-
assert all([time == fieldset.time_interval.left + np.timedelta64(8, "s") for time in pset.time[1:]])
115+
assert all([time == fieldset.time_interval.left + np.timedelta64(10, "s") for time in pset.time])
117116

118117

119118
@pytest.mark.parametrize("verbose_progress", [True, False])

0 commit comments

Comments
 (0)