Skip to content

Commit e3da986

Browse files
authored
Merge pull request #584 from PACarniglia/main
Added 3D Plotting Functionality to plotter.py
2 parents d648229 + f190be0 commit e3da986

File tree

2 files changed

+266
-82
lines changed

2 files changed

+266
-82
lines changed

stonesoup/plotter.py

Lines changed: 158 additions & 82 deletions
Original file line numberDiff line numberDiff line change
@@ -10,37 +10,70 @@
1010
from .types import detection
1111
from .models.base import LinearModel, Model
1212

13+
from enum import Enum
14+
15+
16+
class Dimension(Enum):
17+
"""Dimension Enum class for specifying plotting parameters in the Plotter class.
18+
Used to sanitize inputs for the dimension attribute of Plotter().
19+
20+
Attributes
21+
----------
22+
TWO: str
23+
Specifies 2D plotting for Plotter object
24+
THREE: str
25+
Specifies 3D plotting for Plotter object
26+
"""
27+
TWO = 2 # 2D plotting mode (original plotter.py functionality)
28+
THREE = 3 # 3D plotting mode
29+
1330

1431
class Plotter:
1532
"""Plotting class for building graphs of Stone Soup simulations
1633
1734
A plotting class which is used to simplify the process of plotting ground truths,
1835
measurements, clutter and tracks. Tracks can be plotted with uncertainty ellipses or
1936
particles if required. Legends are automatically generated with each plot.
37+
Three dimensional plots can be created using the optional dimension parameter.
38+
39+
Parameters
40+
----------
41+
dimension: enum \'Dimension\'
42+
Optional parameter to specify 2D or 3D plotting. Default is 2D plotting.
2043
2144
Attributes
2245
----------
2346
fig: matplotlib.figure.Figure
2447
Generated figure for graphs to be plotted on
2548
ax: matplotlib.axes.Axes
2649
Generated axes for graphs to be plotted on
27-
handles_list: list of :class:`matplotlib.legend_handler.HandlerBase`
28-
A list of generated legend handles
29-
labels_list: list of str
30-
A list of generated legend labels
50+
legend_dict: dict
51+
Dictionary of legend handles as :class:`matplotlib.legend_handler.HandlerBase`
52+
and labels as str
3153
"""
3254

33-
def __init__(self):
55+
def __init__(self, dimension=Dimension.TWO):
56+
if isinstance(dimension, type(Dimension.TWO)):
57+
self.dimension = dimension
58+
else:
59+
raise TypeError("""%s is an unsupported type for \'dimension\';
60+
expected type %s""" % (type(dimension), type(Dimension.TWO)))
3461
# Generate plot axes
3562
self.fig = plt.figure(figsize=(10, 6))
36-
self.ax = self.fig.add_subplot(1, 1, 1)
63+
if self.dimension is Dimension.TWO: # 2D axes
64+
self.ax = self.fig.add_subplot(1, 1, 1)
65+
self.ax.axis('equal')
66+
else: # 3D axes
67+
self.ax = self.fig.add_subplot(111, projection='3d')
68+
self.ax.axis('auto')
69+
self.ax.set_zlabel("$z$")
3770
self.ax.set_xlabel("$x$")
3871
self.ax.set_ylabel("$y$")
39-
self.ax.axis('equal')
4072

41-
# Create empty lists for legend handles and labels
42-
self.handles_list = []
43-
self.labels_list = []
73+
# Create empty dictionary for legend handles and labels - dict used to
74+
# prevent multiple entries with the same label from displaying on legend
75+
# This is new compared to plotter.py
76+
self.legend_dict = {} # create an empty dictionary to hold legend entries
4477

4578
def plot_ground_truths(self, truths, mapping, truths_label="Ground Truth", **kwargs):
4679
"""Plots ground truth(s)
@@ -58,7 +91,7 @@ def plot_ground_truths(self, truths, mapping, truths_label="Ground Truth", **kwa
5891
:class:`~.GroundTruthPath` type, the argument is modified to be a set to allow for
5992
iteration.
6093
mapping: list
61-
List of 2 items specifying the mapping of the x and y components of the state space.
94+
List of items specifying the mapping of the position components of the state space.
6295
\\*\\*kwargs: dict
6396
Additional arguments to be passed to plot function. Default is ``linestyle="--"``.
6497
"""
@@ -69,17 +102,22 @@ def plot_ground_truths(self, truths, mapping, truths_label="Ground Truth", **kwa
69102
truths = {truths} # Make a set of length 1
70103

71104
for truth in truths:
72-
self.ax.plot([state.state_vector[mapping[0]] for state in truth],
73-
[state.state_vector[mapping[1]] for state in truth],
74-
**truths_kwargs)
75-
105+
if self.dimension is Dimension.TWO: # plots the ground truths in xy
106+
self.ax.plot([state.state_vector[mapping[0]] for state in truth],
107+
[state.state_vector[mapping[1]] for state in truth],
108+
**truths_kwargs)
109+
elif self.dimension is Dimension.THREE: # plots the ground truths in xyz
110+
self.ax.plot3D([state.state_vector[mapping[0]] for state in truth],
111+
[state.state_vector[mapping[1]] for state in truth],
112+
[state.state_vector[mapping[2]] for state in truth],
113+
**truths_kwargs)
114+
else:
115+
raise NotImplementedError('Unsupported dimension type for truth plotting')
76116
# Generate legend items
77117
truths_handle = Line2D([], [], linestyle=truths_kwargs['linestyle'], color='black')
78-
self.handles_list.append(truths_handle)
79-
self.labels_list.append(truths_label)
80-
118+
self.legend_dict[truths_label] = truths_handle
81119
# Generate legend
82-
self.ax.legend(handles=self.handles_list, labels=self.labels_list)
120+
self.ax.legend(handles=self.legend_dict.values(), labels=self.legend_dict.keys())
83121

84122
def plot_measurements(self, measurements, mapping, measurement_model=None,
85123
measurements_label="Measurements", **kwargs):
@@ -97,7 +135,7 @@ def plot_measurements(self, measurements, mapping, measurement_model=None,
97135
measurements : list of :class:`~.Detection`
98136
Detections which will be plotted. If measurements is a set of lists it is flattened.
99137
mapping: list
100-
List of 2 items specifying the mapping of the x and y components of the state space.
138+
List of items specifying the mapping of the position components of the state space.
101139
measurement_model : :class:`~.Model`, optional
102140
User-defined measurement model to be used in finding measurement state inverses if
103141
they cannot be found from the measurements themselves.
@@ -151,36 +189,38 @@ def plot_measurements(self, measurements, mapping, measurement_model=None,
151189

152190
if plot_detections:
153191
detection_array = np.array(plot_detections)
154-
self.ax.scatter(detection_array[:, 0], detection_array[:, 1], **measurement_kwargs)
192+
# *detection_array.T unpacks detection_array by coloumns
193+
# (same as passing in detection_array[:,0], detection_array[:,1], etc...)
194+
self.ax.scatter(*detection_array.T, **measurement_kwargs)
155195
measurements_handle = Line2D([], [], linestyle='', **measurement_kwargs)
156196

157197
# Generate legend items for measurements
158-
self.handles_list.append(measurements_handle)
159-
self.labels_list.append(measurements_label)
198+
self.legend_dict[measurements_label] = measurements_handle
160199

161200
if plot_clutter:
162201
clutter_array = np.array(plot_clutter)
163-
self.ax.scatter(clutter_array[:, 0], clutter_array[:, 1], color='y', marker='2')
202+
self.ax.scatter(*clutter_array.T, color='y', marker='2')
164203
clutter_handle = Line2D([], [], linestyle='', marker='2', color='y')
165204
clutter_label = "Clutter"
166205

167206
# Generate legend items for clutter
168-
self.handles_list.append(clutter_handle)
169-
self.labels_list.append(clutter_label)
207+
self.legend_dict[clutter_label] = clutter_handle
170208

171209
# Generate legend
172-
self.ax.legend(handles=self.handles_list, labels=self.labels_list)
210+
self.ax.legend(handles=self.legend_dict.values(), labels=self.legend_dict.keys())
173211

174212
def plot_tracks(self, tracks, mapping, uncertainty=False, particle=False, track_label="Track",
175-
**kwargs):
213+
err_freq=1, **kwargs):
176214
"""Plots track(s)
177215
178-
Plots each track generated, generating a legend automatically. If ``uncertainty=True``,
179-
uncertainty ellipses are plotted. If ``particle=True``, particles are plotted.
180-
Tracks are plotted as solid lines with point markers and default colors.
181-
Uncertainty ellipses are plotted with a default color which is the same for all tracks.
216+
Plots each track generated, generating a legend automatically. If ``uncertainty=True``
217+
and is being plotted in 2D, error elipses are plotted. If being plotted in
218+
3D, uncertainty bars are plotted every :attr:`err_freq` measurement, default
219+
plots unceratinty bars at every track step. Tracks are plotted as solid
220+
lines with point markers and default colors. Uncertainty bars are plotted
221+
with a default color which is the same for all tracks.
182222
183-
Users can change linestyle, color and marker using keyword arguments. Uncertainty ellipses
223+
Users can change linestyle, color and marker using keyword arguments. Uncertainty metrics
184224
will also be plotted with the user defined colour and any changes will apply to all tracks.
185225
186226
Parameters
@@ -189,13 +229,17 @@ def plot_tracks(self, tracks, mapping, uncertainty=False, particle=False, track_
189229
Set of tracks which will be plotted. If not a set, and instead a single
190230
:class:`~.Track` type, the argument is modified to be a set to allow for iteration.
191231
mapping: list
192-
List of 2 items specifying the mapping of the x and y components of the state space.
232+
List of items specifying the mapping of the position
233+
components of the state space.
193234
uncertainty : bool
194-
If True, function plots uncertainty ellipses.
235+
If True, function plots uncertainty ellipses or bars.
195236
particle : bool
196237
If True, function plots particles.
197238
track_label: str
198239
Label to apply to all tracks for legend.
240+
err_freq: int
241+
Frequency of error bar plotting on tracks. Default value is 1, meaning
242+
error bars are plotted at every track step.
199243
\\*\\*kwargs: dict
200244
Additional arguments to be passed to plot function. Defaults are ``linestyle="-"``,
201245
``marker='.'`` and ``color=None``.
@@ -209,9 +253,15 @@ def plot_tracks(self, tracks, mapping, uncertainty=False, particle=False, track_
209253
# Plot tracks
210254
track_colors = {}
211255
for track in tracks:
212-
line = self.ax.plot([state.state_vector[mapping[0]] for state in track],
213-
[state.state_vector[mapping[1]] for state in track],
214-
**tracks_kwargs)
256+
if self.dimension is Dimension.TWO:
257+
line = self.ax.plot([state.state_vector[mapping[0]] for state in track],
258+
[state.state_vector[mapping[1]] for state in track],
259+
**tracks_kwargs)
260+
else:
261+
line = self.ax.plot([state.state_vector[mapping[0]] for state in track],
262+
[state.state_vector[mapping[1]] for state in track],
263+
[state.state_vector[mapping[2]] for state in track],
264+
**tracks_kwargs)
215265
track_colors[track] = plt.getp(line[0], 'color')
216266

217267
# Assuming a single track or all plotted as the same colour then the following will work.
@@ -221,55 +271,81 @@ def plot_tracks(self, tracks, mapping, uncertainty=False, particle=False, track_
221271
# Generate legend items for track
222272
track_handle = Line2D([], [], linestyle=tracks_kwargs['linestyle'],
223273
marker=tracks_kwargs['marker'], color=tracks_kwargs['color'])
224-
self.handles_list.append(track_handle)
225-
self.labels_list.append(track_label)
226-
274+
self.legend_dict[track_label] = track_handle
227275
if uncertainty:
228-
# Plot uncertainty ellipses
229-
for track in tracks:
230-
HH = np.eye(track.ndim)[mapping, :] # Get position mapping matrix
231-
for state in track:
232-
w, v = np.linalg.eig(HH @ state.covar @ HH.T)
233-
max_ind = np.argmax(w)
234-
min_ind = np.argmin(w)
235-
orient = np.arctan2(v[1, max_ind], v[0, max_ind])
236-
ellipse = Ellipse(xy=state.state_vector[mapping[:2], 0],
237-
width=2 * np.sqrt(w[max_ind]),
238-
height=2 * np.sqrt(w[min_ind]),
239-
angle=np.rad2deg(orient), alpha=0.2,
240-
color=track_colors[track])
241-
self.ax.add_artist(ellipse)
242-
243-
# Generate legend items for uncertainty ellipses
244-
ellipse_handle = Ellipse((0.5, 0.5), 0.5, 0.5, alpha=0.2, color=tracks_kwargs['color'])
245-
ellipse_label = "Uncertainty"
246-
247-
self.handles_list.append(ellipse_handle)
248-
self.labels_list.append(ellipse_label)
249-
250-
# Generate legend
251-
self.ax.legend(handles=self.handles_list, labels=self.labels_list,
252-
handler_map={Ellipse: _HandlerEllipse()})
276+
if self.dimension is Dimension.TWO:
277+
# Plot uncertainty ellipses
278+
for track in tracks:
279+
HH = np.eye(track.ndim)[mapping, :] # Get position mapping matrix
280+
for state in track:
281+
w, v = np.linalg.eig(HH @ state.covar @ HH.T)
282+
max_ind = np.argmax(w)
283+
min_ind = np.argmin(w)
284+
orient = np.arctan2(v[1, max_ind], v[0, max_ind])
285+
ellipse = Ellipse(xy=state.state_vector[mapping[:2], 0],
286+
width=2 * np.sqrt(w[max_ind]),
287+
height=2 * np.sqrt(w[min_ind]),
288+
angle=np.rad2deg(orient), alpha=0.2,
289+
color=track_colors[track])
290+
self.ax.add_artist(ellipse)
291+
292+
# Generate legend items for uncertainty ellipses
293+
ellipse_handle = Ellipse((0.5, 0.5), 0.5, 0.5, alpha=0.2,
294+
color=tracks_kwargs['color'])
295+
ellipse_label = "Uncertainty"
296+
self.legend_dict[ellipse_label] = ellipse_handle
297+
# Generate legend
298+
self.ax.legend(handles=self.legend_dict.values(),
299+
labels=self.legend_dict.keys(),
300+
handler_map={Ellipse: _HandlerEllipse()})
301+
else:
302+
# Plot 3D error bars on tracks
303+
for track in tracks:
304+
HH = np.eye(track.ndim)[mapping, :] # Get position mapping matrix
305+
check = err_freq
306+
for state in track:
307+
if not check % err_freq:
308+
w, v = np.linalg.eig(HH @ state.covar @ HH.T)
309+
310+
xl = state.state_vector[mapping[0]]
311+
yl = state.state_vector[mapping[1]]
312+
zl = state.state_vector[mapping[2]]
313+
314+
x_err = w[0]
315+
y_err = w[1]
316+
z_err = w[2]
317+
318+
self.ax.plot3D([xl+x_err, xl-x_err], [yl, yl], [zl, zl],
319+
marker="_", color=tracks_kwargs['color'])
320+
self.ax.plot3D([xl, xl], [yl+y_err, yl-y_err], [zl, zl],
321+
marker="_", color=tracks_kwargs['color'])
322+
self.ax.plot3D([xl, xl], [yl, yl], [zl+z_err, zl-z_err],
323+
marker="_", color=tracks_kwargs['color'])
324+
check += 1
253325

254326
elif particle:
255-
# Plot particles
256-
for track in tracks:
257-
for state in track:
258-
data = state.particles.state_vector[mapping[:2], :]
259-
self.ax.plot(data[0], data[1], linestyle='', marker=".",
260-
markersize=1, alpha=0.5)
261-
262-
# Generate legend items for particles
263-
particle_handle = Line2D([], [], linestyle='', color="black", marker='.', markersize=1)
264-
particle_label = "Particles"
265-
self.handles_list.append(particle_handle)
266-
self.labels_list.append(particle_label)
267-
268-
# Generate legend
269-
self.ax.legend(handles=self.handles_list, labels=self.labels_list)
327+
if self.dimension is Dimension.TWO:
328+
# Plot particles
329+
for track in tracks:
330+
for state in track:
331+
data = state.particles.state_vector[mapping[:2], :]
332+
self.ax.plot(data[0], data[1], linestyle='', marker=".",
333+
markersize=1, alpha=0.5)
334+
335+
# Generate legend items for particles
336+
particle_handle = Line2D([], [], linestyle='', color="black", marker='.',
337+
markersize=1)
338+
particle_label = "Particles"
339+
self.legend_dict[particle_label] = particle_handle
340+
# Generate legend
341+
self.ax.legend(handles=self.legend_dict.values(),
342+
labels=self.legend_dict.keys()) # particle error legend
343+
else:
344+
raise NotImplementedError("""Particle plotting is not currently supported for
345+
3D visualization""")
270346

271347
else:
272-
self.ax.legend(handles=self.handles_list, labels=self.labels_list)
348+
self.ax.legend(handles=self.legend_dict.values(), labels=self.legend_dict.keys())
273349

274350
# Ellipse legend patch (used in Tutorial 3)
275351
@staticmethod

0 commit comments

Comments
 (0)