Skip to content

Commit cccbb36

Browse files
committed
conflict solved
2 parents c3a6747 + 76e139a commit cccbb36

File tree

13 files changed

+394
-288
lines changed

13 files changed

+394
-288
lines changed
Lines changed: 248 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,248 @@
1+
from amadeusgpt.analysis_objects.object import Object
2+
from numpy import ndarray
3+
from scipy.spatial import ConvexHull
4+
from typing import List, Dict, Any
5+
import matplotlib.path as mpath
6+
import numpy as np
7+
8+
9+
class Animal(Object):
10+
def get_keypoint_names(self):
11+
"""
12+
keypoint names should be the basic attributes
13+
"""
14+
pass
15+
16+
def summary(self):
17+
print(self.__class__.__name__)
18+
for attr_name in self.__dict__:
19+
print(f"{attr_name} has {self.__dict__[attr_name]}")
20+
21+
22+
class AnimalSeq(Animal):
23+
"""
24+
Because we support passing bodyparts indices for initializing an AnimalSeq object,
25+
body center, left, right, above, top are relative to the subset of keypoints.
26+
Attributes
27+
----------
28+
self._coords: arr potentially subset of keypoints
29+
self.wholebody: full set of keypoints. This is important for overlap relationship
30+
"""
31+
32+
def __init__(self, animal_name: str, keypoints: ndarray, keypoint_names: List[str]):
33+
self.name = animal_name
34+
self.whole_body: ndarray = keypoints
35+
self.keypoint_names = keypoint_names
36+
self._paths = []
37+
self.state = {}
38+
self.kinematics_types = ["speed", "acceleration"]
39+
self.bodypart_relation = ["bodypart_pairwise_distance"]
40+
self.support_body_orientation = False
41+
self.support_head_orientation = False
42+
# self.keypoints are updated by indices of keypoint names given
43+
keypoint_indices = [keypoint_names.index(name) for name in keypoint_names]
44+
self.keypoints = self.whole_body[:, keypoint_indices]
45+
self.center = np.nanmedian(self.whole_body, axis=1)
46+
47+
def update_roi_keypoint_by_names(self, keypoint_names: List[str]):
48+
# update self.keypoints based on keypoint names given
49+
if keypoint_names is None:
50+
return
51+
keypoint_indices = [self.keypoint_names.index(name) for name in keypoint_names]
52+
self.keypoints = self.whole_body[:, keypoint_indices]
53+
54+
def restore_roi_keypoint(self):
55+
self.keypoints = self.whole_body
56+
57+
def set_body_orientation_keypoints(
58+
self, body_orientation_keypoints: Dict[str, Any]
59+
):
60+
self.neck_name = body_orientation_keypoints["neck"]
61+
self.tail_base_name = body_orientation_keypoints["tail_base"]
62+
self.animal_center_name = body_orientation_keypoints["animal_center"]
63+
self.support_body_orientation = True
64+
65+
def set_head_orientation_keypoints(
66+
self, head_orientation_keypoints: Dict[str, Any]
67+
):
68+
self.nose_name = head_orientation_keypoints["nose"]
69+
self.neck_name = head_orientation_keypoints["neck"]
70+
self.support_head_orientation = True
71+
72+
# all the properties cannot be cached because update could happen
73+
def get_paths(self):
74+
paths = []
75+
for ind in range(self.whole_body.shape[0]):
76+
paths.append(self.get_path(ind))
77+
return paths
78+
79+
def get_path(self, ind):
80+
xy = self.whole_body[ind]
81+
xy = np.nan_to_num(xy)
82+
if np.all(xy == 0):
83+
return None
84+
85+
hull = ConvexHull(xy)
86+
vertices = hull.vertices
87+
path_data = []
88+
path_data.append((mpath.Path.MOVETO, xy[vertices[0]]))
89+
for point in xy[vertices[1:]]:
90+
path_data.append((mpath.Path.LINETO, point))
91+
path_data.append((mpath.Path.CLOSEPOLY, xy[vertices[0]]))
92+
codes, verts = zip(*path_data)
93+
return mpath.Path(verts, codes)
94+
95+
def get_keypoints(self) -> ndarray:
96+
# the shape should be (n_frames, n_keypoints, 2)
97+
# extending to 3D?
98+
assert len(self.keypoints.shape) == 3, f"keypoints shape is {self.keypoints.shape}"
99+
return self.keypoints
100+
101+
def get_center(self):
102+
"""
103+
median is more robust than mean
104+
"""
105+
return np.nanmedian(self.keypoints, axis=1).squeeze()
106+
107+
def get_xmin(self):
108+
return np.nanmin(self.keypoints[..., 0], axis=1)
109+
110+
def get_xmax(self):
111+
return np.nanmax(self.keypoints[..., 0], axis=1)
112+
113+
def get_ymin(self):
114+
return np.nanmin(self.keypoints[..., 1], axis=1)
115+
116+
def get_ymax(self):
117+
return np.nanmax(self.keypoints[..., 1], axis=1)
118+
119+
def get_keypoint_names(self):
120+
return self.keypoint_names
121+
122+
def query_states(self, query: str) -> ndarray:
123+
assert query in [
124+
"speed",
125+
"acceleration_mag",
126+
"bodypart_pairwise_distance",
127+
], f"{query} is not supported"
128+
129+
if query == "speed":
130+
self.state[query] = self.get_speed()
131+
elif query == "acceleration_mag":
132+
self.state[query] = self.get_acceleration_mag()
133+
elif query == "bodypart_pairwise_distance":
134+
self.state[query] = self.get_bodypart_wise_relation()
135+
136+
return self.state[query]
137+
138+
def get_velocity(self) -> ndarray:
139+
keypoints = self.get_keypoints()
140+
velocity = np.diff(keypoints, axis=0) / 30
141+
velocity = np.concatenate([np.zeros((1,) + velocity.shape[1:]), velocity])
142+
assert len(velocity.shape) == 3
143+
return velocity
144+
145+
def get_speed(self) -> ndarray:
146+
keypoints = self.get_keypoints()
147+
velocity = (
148+
np.diff(keypoints, axis=0) / 30
149+
) # divided by frame rate to get speed in pixels/second
150+
# Pad velocities to match the original shape
151+
velocity = np.concatenate([np.zeros((1,) + velocity.shape[1:]), velocity])
152+
# Compute the speed from the velocity
153+
speed = np.linalg.norm(velocity, axis=-1)
154+
speed = np.expand_dims(speed, axis=-1)
155+
assert len(speed.shape) == 3
156+
return speed
157+
158+
def get_acceleration(self) -> ndarray:
159+
velocities = self.get_velocity()
160+
accelerations = (
161+
np.diff(velocities, axis=0) / 30
162+
) # divided by frame rate to get acceleration in pixels/second^2
163+
# Pad accelerations to match the original shape
164+
accelerations = np.concatenate(
165+
[np.zeros((1,) + accelerations.shape[1:]), accelerations], axis=0
166+
)
167+
assert len(accelerations.shape) == 3
168+
return accelerations
169+
170+
def get_acceleration_mag(self) -> ndarray:
171+
"""
172+
Returns the magnitude of the acceleration vector
173+
"""
174+
accelerations = self.get_acceleration()
175+
acceleration_mag = np.linalg.norm(accelerations, axis=-1)
176+
acceleration_mag = np.expand_dims(acceleration_mag, axis=-1)
177+
assert len(acceleration_mag.shape) == 3
178+
return acceleration_mag
179+
180+
def get_bodypart_wise_relation(self):
181+
keypoints = self.get_keypoints()
182+
diff = keypoints[..., np.newaxis, :, :] - keypoints[..., :, np.newaxis, :]
183+
sq_dist = np.sum(diff**2, axis=-1)
184+
distances = np.sqrt(sq_dist)
185+
return distances
186+
187+
def get_body_cs(
188+
self,
189+
):
190+
# this only works for topview
191+
neck_index = self.keypoint_names.index(self.neck_name)
192+
tailbase_index = self.keypoint_names.index(self.tail_base_name)
193+
neck = self.whole_body[:, neck_index]
194+
tailbase = self.whole_body[:, tailbase_index]
195+
body_axis = neck - tailbase
196+
body_axis_norm = body_axis / np.linalg.norm(body_axis, axis=1, keepdims=True)
197+
# Get a normal vector pointing left
198+
mediolat_axis_norm = body_axis_norm[:, [1, 0]].copy()
199+
mediolat_axis_norm[:, 0] *= -1
200+
nrows = len(body_axis_norm)
201+
animal_cs = np.zeros((nrows, 3, 3))
202+
rot = np.stack((body_axis_norm, mediolat_axis_norm), axis=2)
203+
animal_cs[:, :2, :2] = rot
204+
animal_center_index = self.keypoint_names.index(self.animal_center_name)
205+
animal_cs[:, :, 2] = np.c_[
206+
self.whole_body[:, animal_center_index], np.ones(nrows)
207+
] # center back
208+
209+
return animal_cs
210+
211+
def calc_head_cs(self):
212+
nose_index = self.keypoint_names.index(self.nose_name)
213+
nose = self.whole_body[:, nose_index]
214+
neck_index = self.keypoint_names.index(self.neck_name)
215+
neck = self.whole_body[:, neck_index]
216+
head_axis = nose - neck
217+
head_axis_norm = head_axis / np.linalg.norm(head_axis, axis=1, keepdims=True)
218+
# Get a normal vector pointing left
219+
mediolat_axis_norm = head_axis_norm[:, [1, 0]].copy()
220+
mediolat_axis_norm[:, 0] *= -1
221+
nrows = len(head_axis_norm)
222+
mouse_cs = np.zeros((nrows, 3, 3))
223+
rot = np.stack((head_axis_norm, mediolat_axis_norm), axis=2)
224+
mouse_cs[:, :2, :2] = rot
225+
mouse_cs[:, :, 2] = np.c_[neck, np.ones(nrows)]
226+
return mouse_cs
227+
228+
229+
230+
if __name__ == "__main__":
231+
# unit testing the shape of kinematics data
232+
# acceleration, acceleration_mag, velocity, speed, and keypoints
233+
234+
from amadeusgpt.config import Config
235+
from amadeusgpt.main import AMADEUS
236+
config = Config("/Users/shaokaiye/AmadeusGPT-dev/amadeusgpt/configs/MausHaus_template.yaml")
237+
amadeus = AMADEUS(config)
238+
analysis = amadeus.get_analysis()
239+
# get an instance of animal
240+
animal = analysis.animal_manager.get_animals()[0]
241+
242+
print ("velocity shape", animal.get_velocity().shape)
243+
print ("speed shape", animal.get_speed().shape)
244+
print ("acceleration shape", animal.get_acceleration().shape)
245+
print ("acceleration_mag shape", animal.get_acceleration_mag().shape)
246+
247+
print(animal.query_states("acceleration_mag").shape)
248+

amadeusgpt/analysis_objects/event.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -384,6 +384,8 @@ def init_from_list(cls, events: List[BaseEvent]) -> "EventGraph":
384384

385385
return graph
386386

387+
388+
387389
@classmethod
388390
def init_from_mask(
389391
cls,

0 commit comments

Comments
 (0)