-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathwindow.py
More file actions
216 lines (144 loc) · 4.92 KB
/
window.py
File metadata and controls
216 lines (144 loc) · 4.92 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
import numpy as np
import stats_module
import plot_module as pm
class window:
def __init__(self,idx,data):
self.idx=idx
self.x=data.x[idx]
self.y=data.y[idx]
self.d=data.d[idx]
self.dt=data.dt[idx]
self.v=data.v[idx]
self.angles=np.array(data.angles)[idx[:-1]]
def computeAllStats(self):
"""Computes all key stats.
Computes mean and std for distances, velocities, time duration, and movement angles.
"""
self.meand,self.stdd=self.computeStats(self.d)
self.meandt,self.stddt=self.computeStats(self.dt)
self.meanv,self.stdv=self.computeStats(self.v)
self.meanangles,self.stdangles=self.computeStats(self.angles)
if np.isnan(self.getKeyStats()).any():
print self.idx
raw_input()
def computeStats(self,x):
"""Computes basic stats, that is mean and std, of variable.
Args:
x (numpy.ndarray): Array of variable.
Returns:
tuple: Tuple containing:
* float: Mean of variable.
* float: std of variable.
"""
return np.mean(x), np.std(x)
def getKeyStats(self):
"""Returns key stats of window."""
return self.meand,self.meandt,self.meanv,self.stdangles
def plotCartTraj(self,ax=None,vel=[],showVel=False,centered=False,color='b',center=[0,0,0],showCenter=False):
"""Plots trajectory of gps data in window.
Keyword Args:
ax (matplotlib.axes): Axes to plot in.
vel (list): List of velocities or other values.
showVel (bool): Show velocities.
centered (bool): Show centered coordinates.
color (str): Color of plot.
center (list): Center of plot.
showCenter (bool): Display center.
Returns:
matplotlib.axes: Modified axes.
"""
ax=pm.plotCartTraj(self.x,self.y,ax=ax,vel=vel,showVel=showVel,centered=centered,showCenter=showCenter,center=center,color=color)
return ax
class windowSet:
def __init__(self,windows,idxs,data):
self.windows=windows
self.idxs=idxs
self.data=data
self.computeAllStats()
self.collectKeyStats()
def computeAllStats(self):
"""Computes all key stats for each window."""
for w in self.windows:
w.computeAllStats()
def collectKeyStats(self):
"""Collects all key stats from each window and stores them in nunmpy array."""
stats=[]
for w in self.windows:
stats.append(w.getKeyStats())
self.stats=np.array(stats)
def performKMeans(self,nbouts=3):
"""Performs kmeans clustering into nbouts bouts of the data.
Keyword Args:
nbouts (int): Number of clusters.
"""
x=stats_module.performKMeans(self.stats,nbouts)
self.kMeans=x
self.kMeansScore=self.kMeans.inertia_
self.kMeansLabels=self.kMeans.labels_
def plotClusters(self,ax=None,var='dta',showLabels=False,cpick='jet',alg='kMeans',nbouts=3):
"""Shows scatter plot of gps data binned in windows.
If algorithm is 'kMeans', performs additional kMeans and labels windows by bout.
``var`` defines which variables are shown on x/y/z-axis in the order they are specified.
* 'd': distances
* 't': deltaT
* 'a': angles
* 'v': velocities
Keyword Args:
ax (matplotlib.axes): Axes to plot in.
var (str): Define which variables are plotted.
showLabels (bool): Show labels.
cpick (str): Colormap used for coloring.
alg (str): Algorithm used for clustering.
nbouts (int): Number of bouts to cluster data in.
Returns:
matplotlib.axes: Modified axes.
"""
idxs=[]
# Figure out what to plot
if 'd' in var:
idxs.append(0)
if 't' in var:
idxs.append(1)
if 'v' in var:
idxs.append(2)
if 'a' in var:
idxs.append(3)
# Cluster
if alg=='none':
labels=np.zeros(np.shape(self.stats)[0])
if alg=='kMeans':
self.performKMeans(nbouts=nbouts)
labels=self.kMeansLabels
# Create axes if necessary
if ax==None:
fig,axes=pm.makeSubplot([1,1],proj=['3d'])
ax=axes[0]
# Plot
pm.labeledScatter(self.stats[:,idxs[0]],self.stats[:,idxs[1]],self.stats[:,idxs[2]],labels,ax=ax,cmap=cpick)
# Labels
ax.set_xlabel(var[0])
ax.set_ylabel(var[1])
ax.set_zlabel(var[2])
ax.get_figure().canvas.draw()
return ax
def showClustersOnTracks(self,ax=None,centered=True,nbouts=3,showCenter=False):
"""Plots all trajectory of gps data colorized by kmeans-clustered groups.
Keyword Args:
ax (matplotlib.axes): Axes to plot in.
nbouts (int): Number of bouts to cluster data in.
centered (bool): Show centered coordinates.
center (list): Center of plot.
showCenter (bool): Display center.
Returns:
matplotlib.axes: Modified axes.
"""
# Create axes if necessary
if ax==None:
fig,axes=pm.makeSubplot([1,1])
ax=axes[0]
#self.performKMeans(nbouts=nbouts)
labels=self.kMeansLabels
print labels
colors=pm.getColors(len(np.unique(labels)))
for i,w in enumerate(self.windows):
w.plotCartTraj(ax=ax,vel=[],showVel=False,centered=centered,color=colors[labels[i]],center=[0,0,0],showCenter=showCenter)