Skip to content

Commit 14db74f

Browse files
authored
Merge pull request #589 from ekhunter123/simple_initiator_measmodel
Make measurement model optional in measurement-based initiators
2 parents efb1ba2 + b6a52b1 commit 14db74f

File tree

2 files changed

+40
-6
lines changed

2 files changed

+40
-6
lines changed

stonesoup/initiator/simple.py

Lines changed: 18 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,10 @@ class SinglePointInitiator(GaussianInitiator):
2525
"""
2626

2727
prior_state: GaussianState = Property(doc="Prior state information")
28-
measurement_model: MeasurementModel = Property(doc="Measurement model")
28+
measurement_model: MeasurementModel = Property(
29+
default=None,
30+
doc="Measurement model. Can be left as None if all detections have a "
31+
"valid measurement model.")
2932

3033
def initiate(self, detections, timestamp, **kwargs):
3134
"""Initiates tracks given unassociated measurements
@@ -64,6 +67,8 @@ class SimpleMeasurementInitiator(GaussianInitiator):
6467
6568
This initiator utilises the :class:`~.MeasurementModel` matrix to convert
6669
:class:`~.Detection` state vector and model covariance into state space.
70+
It either takes the :class:`~.MeasurementModel` from the given detection
71+
or uses the :attr:`measurement_model`.
6772
6873
Utilises the ReversibleModel inverse function to convert
6974
non-linear spherical co-ordinates into Cartesian x/y co-ordinates
@@ -77,7 +82,10 @@ class SimpleMeasurementInitiator(GaussianInitiator):
7782
decompositions.
7883
"""
7984
prior_state: GaussianState = Property(doc="Prior state information")
80-
measurement_model: MeasurementModel = Property(doc="Measurement model")
85+
measurement_model: MeasurementModel = Property(
86+
default=None,
87+
doc="Measurement model. Can be left as None if all detections have a "
88+
"valid measurement model.")
8189
skip_non_reversible: bool = Property(default=False)
8290
diag_load: float = Property(default=0.0, doc="Positive float value for diagonal loading")
8391

@@ -94,7 +102,10 @@ def initiate(self, detections, timestamp, **kwargs):
94102
if detection.measurement_model is not None:
95103
measurement_model = detection.measurement_model
96104
else:
97-
measurement_model = self.measurement_model
105+
if self.measurement_model is None:
106+
raise ValueError("No measurement model specified")
107+
else:
108+
measurement_model = self.measurement_model
98109

99110
if isinstance(measurement_model, LinearModel):
100111
model_matrix = measurement_model.matrix()
@@ -155,12 +166,15 @@ class MultiMeasurementInitiator(GaussianInitiator):
155166
Does cause slight delay in initiation to tracker."""
156167

157168
prior_state: GaussianState = Property(doc="Prior state information")
158-
measurement_model: MeasurementModel = Property(doc="Measurement model")
159169
deleter: Deleter = Property(doc="Deleter used to delete the track.")
160170
data_associator: DataAssociator = Property(
161171
doc="Association algorithm to pair predictions to detections.")
162172
updater: Updater = Property(
163173
doc="Updater used to update the track object to the new state.")
174+
measurement_model: MeasurementModel = Property(
175+
default=None,
176+
doc="Measurement model. Can be left as None if all detections have a "
177+
"valid measurement model.")
164178
min_points: int = Property(
165179
default=2, doc="Minimum number of track points required to confirm a track.")
166180
updates_only: bool = Property(

stonesoup/initiator/tests/test_simple.py

Lines changed: 22 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@
1616
from ...hypothesiser.distance import DistanceHypothesiser
1717
from ...dataassociator.neighbour import NearestNeighbour
1818
from ...measures import Mahalanobis
19-
from ...types.detection import Detection
19+
from ...types.detection import Detection, TrueDetection
2020
from ...types.hypothesis import SingleHypothesis
2121
from ...types.prediction import Prediction
2222
from ...types.state import GaussianState
@@ -294,7 +294,8 @@ def test_multi_measurement(updates_only):
294294

295295
measurement_initiator = MultiMeasurementInitiator(
296296
GaussianState([[0], [0], [0], [0]], np.diag([0, 15, 0, 15])),
297-
measurement_model, deleter, data_associator, updater, updates_only=updates_only)
297+
deleter, data_associator, updater,
298+
measurement_model=measurement_model, updates_only=updates_only)
298299

299300
timestamp = datetime.datetime.now()
300301
first_detections = {Detection(np.array([[5], [2]]), timestamp),
@@ -318,6 +319,25 @@ def test_multi_measurement(updates_only):
318319
assert len(measurement_initiator.holding_tracks) == 0
319320

320321

322+
@pytest.mark.parametrize("initiator", [
323+
SinglePointInitiator(
324+
GaussianState(np.array([[0]]), np.array([[100]]))
325+
),
326+
SimpleMeasurementInitiator(
327+
GaussianState(np.array([[0]]), np.array([[100]]))
328+
),
329+
], ids=['SinglePoint', 'LinearMeasurement'])
330+
def test_measurement_model(initiator):
331+
timestamp = datetime.datetime.now()
332+
dummy_detection = TrueDetection(np.array([0, 0]), timestamp)
333+
# The SinglePointInitiator will raise an error when the ExtendedKalmanUpdater
334+
# is called and neither the detection nor the initiator has a measurement
335+
# model. The SimpleMeasurementInitiator will raise an error in the if/else
336+
# blocks.
337+
with pytest.raises(ValueError):
338+
_ = initiator.initiate({dummy_detection}, timestamp)
339+
340+
321341
@pytest.mark.parametrize("gaussian_initiator", [
322342
SinglePointInitiator(
323343
GaussianState(np.array([[0]]), np.array([[100]])),

0 commit comments

Comments
 (0)