Skip to content

Commit d648229

Browse files
authored
Merge pull request #415 from dstl/iterable_trackers
Trackers are Iterables instead of Generators
2 parents 071abc8 + 5358490 commit d648229

File tree

9 files changed

+196
-159
lines changed

9 files changed

+196
-159
lines changed

docs/demos/AIS_Solent_Tracker.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -161,7 +161,7 @@
161161
# :class:`set` we can simply update this with `current_tracks` at each timestep, not worrying about
162162
# duplicates.
163163
tracks = set()
164-
for step, (time, current_tracks) in enumerate(tracker.tracks_gen(), 1):
164+
for step, (time, current_tracks) in enumerate(tracker, 1):
165165
tracks.update(current_tracks)
166166
if not step % 10:
167167
print("Step: {} Time: {}".format(step, time))

docs/demos/OpenSky_Demo.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -238,7 +238,7 @@
238238
)
239239

240240
tracks = set()
241-
for step, (time, current_tracks) in enumerate(kalman_tracker.tracks_gen(), 1):
241+
for step, (time, current_tracks) in enumerate(kalman_tracker, 1):
242242
tracks.update(current_tracks)
243243

244244
# %%

docs/examples/Metrics.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -146,7 +146,7 @@
146146
# With this basic tracker built and metrics ready, we'll now run the tracker, adding the sets of
147147
# :class:`~.GroundTruthPath`, :class:`~.Detection` and :class:`~.Track` objects: to the metric
148148
# manager.
149-
for time, tracks in tracker.tracks_gen():
149+
for time, tracks in tracker:
150150
metric_manager.add_data(
151151
groundtruth_sim.groundtruth_paths, tracks, detection_sim.detections,
152152
overwrite=False, # Don't overwrite, instead add above as additional data

docs/examples/Sensor_Platform_Simulation.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -346,7 +346,7 @@ def initiate(self, detections, timestamp, **kwargs):
346346
groundtruth_paths = {} # Store for plotting later
347347
detections = [] # Store for plotting later
348348

349-
for time, ctracks in kalman_tracker.tracks_gen():
349+
for time, ctracks in kalman_tracker:
350350
for track in ctracks:
351351
loc = (track.state_vector[0], track.state_vector[2])
352352
if track not in kalman_tracks:

stonesoup/reader/yaml.py

Lines changed: 20 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -74,16 +74,23 @@ class YAMLTrackReader(YAMLReader, Tracker):
7474
def data_gen(self):
7575
yield from super().data_gen()
7676

77-
@BufferedGenerator.generator_method
78-
def tracks_gen(self):
79-
tracks = dict()
80-
for time, document in self.data_gen():
81-
updated_tracks = set()
82-
for track in document.get('tracks', set()):
83-
if track.id in tracks:
84-
tracks[track.id].states = track.states
85-
else:
86-
tracks[track.id] = track
87-
updated_tracks.add(tracks[track.id])
88-
89-
yield time, updated_tracks
77+
def __iter__(self):
78+
self.data_iter = iter(self.data_gen())
79+
self._tracks = dict()
80+
return super().__iter__()
81+
82+
@property
83+
def tracks(self):
84+
return self._tracks
85+
86+
def __next__(self):
87+
time, document = next(self.data_iter)
88+
updated_tracks = set()
89+
for track in document.get('tracks', set()):
90+
if track.id in self.tracks:
91+
self._tracks[track.id].states = track.states
92+
else:
93+
self._tracks[track.id] = track
94+
updated_tracks.add(self.tracks[track.id])
95+
96+
return time, updated_tracks

stonesoup/tracker/base.py

Lines changed: 10 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -2,23 +2,24 @@
22
from abc import abstractmethod
33

44
from ..base import Base
5-
from ..buffered_generator import BufferedGenerator
65

76

8-
class Tracker(Base, BufferedGenerator):
7+
class Tracker(Base):
98
"""Tracker base class"""
109

1110
@property
11+
@abstractmethod
1212
def tracks(self):
13-
return self.current[1]
13+
raise NotImplementedError
1414

15-
@abstractmethod
16-
@BufferedGenerator.generator_method
17-
def tracks_gen(self):
18-
"""Returns a generator of tracks for each time step.
15+
def __iter__(self):
16+
return self
1917

20-
Yields
21-
------
18+
@abstractmethod
19+
def __next__(self):
20+
"""
21+
Returns
22+
-------
2223
: :class:`datetime.datetime`
2324
Datetime of current time step
2425
: set of :class:`~.Track`

stonesoup/tracker/pointprocess.py

Lines changed: 24 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,6 @@
99
from ..updater import Updater
1010
from ..hypothesiser.gaussianmixture import GaussianMixtureHypothesiser
1111
from ..mixturereducer.gaussianmixture import GaussianMixtureReducer
12-
from ..buffered_generator import BufferedGenerator
1312

1413

1514
class PointProcessMultiTargetTracker(Tracker):
@@ -46,6 +45,10 @@ def tracks(self):
4645
tracks.add(track)
4746
return tracks
4847

48+
def __iter__(self):
49+
self.detector_iter = iter(self.detector)
50+
return super().__iter__()
51+
4952
def update_tracks(self):
5053
"""
5154
Updates the tracks (:class:`Track`) associated with the filter.
@@ -74,27 +77,26 @@ def update_tracks(self):
7477
self.extraction_threshold:
7578
self.target_tracks[tag] = Track([component], id=tag)
7679

77-
@BufferedGenerator.generator_method
78-
def tracks_gen(self):
79-
for time, detections in self.detector:
80-
# Add birth component
81-
self.birth_component.timestamp = time
82-
self.gaussian_mixture.append(self.birth_component)
83-
# Perform GM Prediction and generate hypotheses
84-
hypotheses = self.hypothesiser.hypothesise(
85-
self.gaussian_mixture.components,
86-
detections,
87-
time
88-
)
89-
# Perform GM Update
90-
self.gaussian_mixture = self.updater.update(hypotheses)
91-
# Reduce mixture - Pruning and Merging
92-
self.gaussian_mixture.components = \
93-
self.reducer.reduce(self.gaussian_mixture.components)
94-
# Update the tracks
95-
self.update_tracks()
96-
self.end_tracks()
97-
yield time, self.tracks
80+
def __next__(self):
81+
time, detections = next(self.detector_iter)
82+
# Add birth component
83+
self.birth_component.timestamp = time
84+
self.gaussian_mixture.append(self.birth_component)
85+
# Perform GM Prediction and generate hypotheses
86+
hypotheses = self.hypothesiser.hypothesise(
87+
self.gaussian_mixture.components,
88+
detections,
89+
time
90+
)
91+
# Perform GM Update
92+
self.gaussian_mixture = self.updater.update(hypotheses)
93+
# Reduce mixture - Pruning and Merging
94+
self.gaussian_mixture.components = \
95+
self.reducer.reduce(self.gaussian_mixture.components)
96+
# Update the tracks
97+
self.update_tracks()
98+
self.end_tracks()
99+
return time, self.tracks
98100

99101
def end_tracks(self):
100102
"""

0 commit comments

Comments
 (0)