1010from .types import detection
1111from .models .base import LinearModel , Model
1212
13+ from enum import Enum
14+
15+
16+ class Dimension (Enum ):
17+ """Dimension Enum class for specifying plotting parameters in the Plotter class.
18+ Used to sanitize inputs for the dimension attribute of Plotter().
19+
20+ Attributes
21+ ----------
22+ TWO: str
23+ Specifies 2D plotting for Plotter object
24+ THREE: str
25+ Specifies 3D plotting for Plotter object
26+ """
27+ TWO = 2 # 2D plotting mode (original plotter.py functionality)
28+ THREE = 3 # 3D plotting mode
29+
1330
1431class Plotter :
1532 """Plotting class for building graphs of Stone Soup simulations
1633
1734 A plotting class which is used to simplify the process of plotting ground truths,
1835 measurements, clutter and tracks. Tracks can be plotted with uncertainty ellipses or
1936 particles if required. Legends are automatically generated with each plot.
37+ Three dimensional plots can be created using the optional dimension parameter.
38+
39+ Parameters
40+ ----------
41+ dimension: enum \' Dimension\'
42+ Optional parameter to specify 2D or 3D plotting. Default is 2D plotting.
2043
2144 Attributes
2245 ----------
2346 fig: matplotlib.figure.Figure
2447 Generated figure for graphs to be plotted on
2548 ax: matplotlib.axes.Axes
2649 Generated axes for graphs to be plotted on
27- handles_list: list of :class:`matplotlib.legend_handler.HandlerBase`
28- A list of generated legend handles
29- labels_list: list of str
30- A list of generated legend labels
50+ legend_dict: dict
51+ Dictionary of legend handles as :class:`matplotlib.legend_handler.HandlerBase`
52+ and labels as str
3153 """
3254
33- def __init__ (self ):
55+ def __init__ (self , dimension = Dimension .TWO ):
56+ if isinstance (dimension , type (Dimension .TWO )):
57+ self .dimension = dimension
58+ else :
59+ raise TypeError ("""%s is an unsupported type for \' dimension\' ;
60+ expected type %s""" % (type (dimension ), type (Dimension .TWO )))
3461 # Generate plot axes
3562 self .fig = plt .figure (figsize = (10 , 6 ))
36- self .ax = self .fig .add_subplot (1 , 1 , 1 )
63+ if self .dimension is Dimension .TWO : # 2D axes
64+ self .ax = self .fig .add_subplot (1 , 1 , 1 )
65+ self .ax .axis ('equal' )
66+ else : # 3D axes
67+ self .ax = self .fig .add_subplot (111 , projection = '3d' )
68+ self .ax .axis ('auto' )
69+ self .ax .set_zlabel ("$z$" )
3770 self .ax .set_xlabel ("$x$" )
3871 self .ax .set_ylabel ("$y$" )
39- self .ax .axis ('equal' )
4072
41- # Create empty lists for legend handles and labels
42- self .handles_list = []
43- self .labels_list = []
73+ # Create empty dictionary for legend handles and labels - dict used to
74+ # prevent multiple entries with the same label from displaying on legend
75+ # This is new compared to plotter.py
76+ self .legend_dict = {} # create an empty dictionary to hold legend entries
4477
4578 def plot_ground_truths (self , truths , mapping , truths_label = "Ground Truth" , ** kwargs ):
4679 """Plots ground truth(s)
@@ -58,7 +91,7 @@ def plot_ground_truths(self, truths, mapping, truths_label="Ground Truth", **kwa
5891 :class:`~.GroundTruthPath` type, the argument is modified to be a set to allow for
5992 iteration.
6093 mapping: list
61- List of 2 items specifying the mapping of the x and y components of the state space.
94+ List of items specifying the mapping of the position components of the state space.
6295 \\ *\\ *kwargs: dict
6396 Additional arguments to be passed to plot function. Default is ``linestyle="--"``.
6497 """
@@ -69,17 +102,22 @@ def plot_ground_truths(self, truths, mapping, truths_label="Ground Truth", **kwa
69102 truths = {truths } # Make a set of length 1
70103
71104 for truth in truths :
72- self .ax .plot ([state .state_vector [mapping [0 ]] for state in truth ],
73- [state .state_vector [mapping [1 ]] for state in truth ],
74- ** truths_kwargs )
75-
105+ if self .dimension is Dimension .TWO : # plots the ground truths in xy
106+ self .ax .plot ([state .state_vector [mapping [0 ]] for state in truth ],
107+ [state .state_vector [mapping [1 ]] for state in truth ],
108+ ** truths_kwargs )
109+ elif self .dimension is Dimension .THREE : # plots the ground truths in xyz
110+ self .ax .plot3D ([state .state_vector [mapping [0 ]] for state in truth ],
111+ [state .state_vector [mapping [1 ]] for state in truth ],
112+ [state .state_vector [mapping [2 ]] for state in truth ],
113+ ** truths_kwargs )
114+ else :
115+ raise NotImplementedError ('Unsupported dimension type for truth plotting' )
76116 # Generate legend items
77117 truths_handle = Line2D ([], [], linestyle = truths_kwargs ['linestyle' ], color = 'black' )
78- self .handles_list .append (truths_handle )
79- self .labels_list .append (truths_label )
80-
118+ self .legend_dict [truths_label ] = truths_handle
81119 # Generate legend
82- self .ax .legend (handles = self .handles_list , labels = self .labels_list )
120+ self .ax .legend (handles = self .legend_dict . values () , labels = self .legend_dict . keys () )
83121
84122 def plot_measurements (self , measurements , mapping , measurement_model = None ,
85123 measurements_label = "Measurements" , ** kwargs ):
@@ -97,7 +135,7 @@ def plot_measurements(self, measurements, mapping, measurement_model=None,
97135 measurements : list of :class:`~.Detection`
98136 Detections which will be plotted. If measurements is a set of lists it is flattened.
99137 mapping: list
100- List of 2 items specifying the mapping of the x and y components of the state space.
138+ List of items specifying the mapping of the position components of the state space.
101139 measurement_model : :class:`~.Model`, optional
102140 User-defined measurement model to be used in finding measurement state inverses if
103141 they cannot be found from the measurements themselves.
@@ -151,36 +189,38 @@ def plot_measurements(self, measurements, mapping, measurement_model=None,
151189
152190 if plot_detections :
153191 detection_array = np .array (plot_detections )
154- self .ax .scatter (detection_array [:, 0 ], detection_array [:, 1 ], ** measurement_kwargs )
192+ # *detection_array.T unpacks detection_array by coloumns
193+ # (same as passing in detection_array[:,0], detection_array[:,1], etc...)
194+ self .ax .scatter (* detection_array .T , ** measurement_kwargs )
155195 measurements_handle = Line2D ([], [], linestyle = '' , ** measurement_kwargs )
156196
157197 # Generate legend items for measurements
158- self .handles_list .append (measurements_handle )
159- self .labels_list .append (measurements_label )
198+ self .legend_dict [measurements_label ] = measurements_handle
160199
161200 if plot_clutter :
162201 clutter_array = np .array (plot_clutter )
163- self .ax .scatter (clutter_array [:, 0 ], clutter_array [:, 1 ] , color = 'y' , marker = '2' )
202+ self .ax .scatter (* clutter_array . T , color = 'y' , marker = '2' )
164203 clutter_handle = Line2D ([], [], linestyle = '' , marker = '2' , color = 'y' )
165204 clutter_label = "Clutter"
166205
167206 # Generate legend items for clutter
168- self .handles_list .append (clutter_handle )
169- self .labels_list .append (clutter_label )
207+ self .legend_dict [clutter_label ] = clutter_handle
170208
171209 # Generate legend
172- self .ax .legend (handles = self .handles_list , labels = self .labels_list )
210+ self .ax .legend (handles = self .legend_dict . values () , labels = self .legend_dict . keys () )
173211
174212 def plot_tracks (self , tracks , mapping , uncertainty = False , particle = False , track_label = "Track" ,
175- ** kwargs ):
213+ err_freq = 1 , ** kwargs ):
176214 """Plots track(s)
177215
178- Plots each track generated, generating a legend automatically. If ``uncertainty=True``,
179- uncertainty ellipses are plotted. If ``particle=True``, particles are plotted.
180- Tracks are plotted as solid lines with point markers and default colors.
181- Uncertainty ellipses are plotted with a default color which is the same for all tracks.
216+ Plots each track generated, generating a legend automatically. If ``uncertainty=True``
217+ and is being plotted in 2D, error elipses are plotted. If being plotted in
218+ 3D, uncertainty bars are plotted every :attr:`err_freq` measurement, default
219+ plots unceratinty bars at every track step. Tracks are plotted as solid
220+ lines with point markers and default colors. Uncertainty bars are plotted
221+ with a default color which is the same for all tracks.
182222
183- Users can change linestyle, color and marker using keyword arguments. Uncertainty ellipses
223+ Users can change linestyle, color and marker using keyword arguments. Uncertainty metrics
184224 will also be plotted with the user defined colour and any changes will apply to all tracks.
185225
186226 Parameters
@@ -189,13 +229,17 @@ def plot_tracks(self, tracks, mapping, uncertainty=False, particle=False, track_
189229 Set of tracks which will be plotted. If not a set, and instead a single
190230 :class:`~.Track` type, the argument is modified to be a set to allow for iteration.
191231 mapping: list
192- List of 2 items specifying the mapping of the x and y components of the state space.
232+ List of items specifying the mapping of the position
233+ components of the state space.
193234 uncertainty : bool
194- If True, function plots uncertainty ellipses.
235+ If True, function plots uncertainty ellipses or bars .
195236 particle : bool
196237 If True, function plots particles.
197238 track_label: str
198239 Label to apply to all tracks for legend.
240+ err_freq: int
241+ Frequency of error bar plotting on tracks. Default value is 1, meaning
242+ error bars are plotted at every track step.
199243 \\ *\\ *kwargs: dict
200244 Additional arguments to be passed to plot function. Defaults are ``linestyle="-"``,
201245 ``marker='.'`` and ``color=None``.
@@ -209,9 +253,15 @@ def plot_tracks(self, tracks, mapping, uncertainty=False, particle=False, track_
209253 # Plot tracks
210254 track_colors = {}
211255 for track in tracks :
212- line = self .ax .plot ([state .state_vector [mapping [0 ]] for state in track ],
213- [state .state_vector [mapping [1 ]] for state in track ],
214- ** tracks_kwargs )
256+ if self .dimension is Dimension .TWO :
257+ line = self .ax .plot ([state .state_vector [mapping [0 ]] for state in track ],
258+ [state .state_vector [mapping [1 ]] for state in track ],
259+ ** tracks_kwargs )
260+ else :
261+ line = self .ax .plot ([state .state_vector [mapping [0 ]] for state in track ],
262+ [state .state_vector [mapping [1 ]] for state in track ],
263+ [state .state_vector [mapping [2 ]] for state in track ],
264+ ** tracks_kwargs )
215265 track_colors [track ] = plt .getp (line [0 ], 'color' )
216266
217267 # Assuming a single track or all plotted as the same colour then the following will work.
@@ -221,55 +271,81 @@ def plot_tracks(self, tracks, mapping, uncertainty=False, particle=False, track_
221271 # Generate legend items for track
222272 track_handle = Line2D ([], [], linestyle = tracks_kwargs ['linestyle' ],
223273 marker = tracks_kwargs ['marker' ], color = tracks_kwargs ['color' ])
224- self .handles_list .append (track_handle )
225- self .labels_list .append (track_label )
226-
274+ self .legend_dict [track_label ] = track_handle
227275 if uncertainty :
228- # Plot uncertainty ellipses
229- for track in tracks :
230- HH = np .eye (track .ndim )[mapping , :] # Get position mapping matrix
231- for state in track :
232- w , v = np .linalg .eig (HH @ state .covar @ HH .T )
233- max_ind = np .argmax (w )
234- min_ind = np .argmin (w )
235- orient = np .arctan2 (v [1 , max_ind ], v [0 , max_ind ])
236- ellipse = Ellipse (xy = state .state_vector [mapping [:2 ], 0 ],
237- width = 2 * np .sqrt (w [max_ind ]),
238- height = 2 * np .sqrt (w [min_ind ]),
239- angle = np .rad2deg (orient ), alpha = 0.2 ,
240- color = track_colors [track ])
241- self .ax .add_artist (ellipse )
242-
243- # Generate legend items for uncertainty ellipses
244- ellipse_handle = Ellipse ((0.5 , 0.5 ), 0.5 , 0.5 , alpha = 0.2 , color = tracks_kwargs ['color' ])
245- ellipse_label = "Uncertainty"
246-
247- self .handles_list .append (ellipse_handle )
248- self .labels_list .append (ellipse_label )
249-
250- # Generate legend
251- self .ax .legend (handles = self .handles_list , labels = self .labels_list ,
252- handler_map = {Ellipse : _HandlerEllipse ()})
276+ if self .dimension is Dimension .TWO :
277+ # Plot uncertainty ellipses
278+ for track in tracks :
279+ HH = np .eye (track .ndim )[mapping , :] # Get position mapping matrix
280+ for state in track :
281+ w , v = np .linalg .eig (HH @ state .covar @ HH .T )
282+ max_ind = np .argmax (w )
283+ min_ind = np .argmin (w )
284+ orient = np .arctan2 (v [1 , max_ind ], v [0 , max_ind ])
285+ ellipse = Ellipse (xy = state .state_vector [mapping [:2 ], 0 ],
286+ width = 2 * np .sqrt (w [max_ind ]),
287+ height = 2 * np .sqrt (w [min_ind ]),
288+ angle = np .rad2deg (orient ), alpha = 0.2 ,
289+ color = track_colors [track ])
290+ self .ax .add_artist (ellipse )
291+
292+ # Generate legend items for uncertainty ellipses
293+ ellipse_handle = Ellipse ((0.5 , 0.5 ), 0.5 , 0.5 , alpha = 0.2 ,
294+ color = tracks_kwargs ['color' ])
295+ ellipse_label = "Uncertainty"
296+ self .legend_dict [ellipse_label ] = ellipse_handle
297+ # Generate legend
298+ self .ax .legend (handles = self .legend_dict .values (),
299+ labels = self .legend_dict .keys (),
300+ handler_map = {Ellipse : _HandlerEllipse ()})
301+ else :
302+ # Plot 3D error bars on tracks
303+ for track in tracks :
304+ HH = np .eye (track .ndim )[mapping , :] # Get position mapping matrix
305+ check = err_freq
306+ for state in track :
307+ if not check % err_freq :
308+ w , v = np .linalg .eig (HH @ state .covar @ HH .T )
309+
310+ xl = state .state_vector [mapping [0 ]]
311+ yl = state .state_vector [mapping [1 ]]
312+ zl = state .state_vector [mapping [2 ]]
313+
314+ x_err = w [0 ]
315+ y_err = w [1 ]
316+ z_err = w [2 ]
317+
318+ self .ax .plot3D ([xl + x_err , xl - x_err ], [yl , yl ], [zl , zl ],
319+ marker = "_" , color = tracks_kwargs ['color' ])
320+ self .ax .plot3D ([xl , xl ], [yl + y_err , yl - y_err ], [zl , zl ],
321+ marker = "_" , color = tracks_kwargs ['color' ])
322+ self .ax .plot3D ([xl , xl ], [yl , yl ], [zl + z_err , zl - z_err ],
323+ marker = "_" , color = tracks_kwargs ['color' ])
324+ check += 1
253325
254326 elif particle :
255- # Plot particles
256- for track in tracks :
257- for state in track :
258- data = state .particles .state_vector [mapping [:2 ], :]
259- self .ax .plot (data [0 ], data [1 ], linestyle = '' , marker = "." ,
260- markersize = 1 , alpha = 0.5 )
261-
262- # Generate legend items for particles
263- particle_handle = Line2D ([], [], linestyle = '' , color = "black" , marker = '.' , markersize = 1 )
264- particle_label = "Particles"
265- self .handles_list .append (particle_handle )
266- self .labels_list .append (particle_label )
267-
268- # Generate legend
269- self .ax .legend (handles = self .handles_list , labels = self .labels_list )
327+ if self .dimension is Dimension .TWO :
328+ # Plot particles
329+ for track in tracks :
330+ for state in track :
331+ data = state .particles .state_vector [mapping [:2 ], :]
332+ self .ax .plot (data [0 ], data [1 ], linestyle = '' , marker = "." ,
333+ markersize = 1 , alpha = 0.5 )
334+
335+ # Generate legend items for particles
336+ particle_handle = Line2D ([], [], linestyle = '' , color = "black" , marker = '.' ,
337+ markersize = 1 )
338+ particle_label = "Particles"
339+ self .legend_dict [particle_label ] = particle_handle
340+ # Generate legend
341+ self .ax .legend (handles = self .legend_dict .values (),
342+ labels = self .legend_dict .keys ()) # particle error legend
343+ else :
344+ raise NotImplementedError ("""Particle plotting is not currently supported for
345+ 3D visualization""" )
270346
271347 else :
272- self .ax .legend (handles = self .handles_list , labels = self .labels_list )
348+ self .ax .legend (handles = self .legend_dict . values () , labels = self .legend_dict . keys () )
273349
274350 # Ellipse legend patch (used in Tutorial 3)
275351 @staticmethod
0 commit comments