|
| 1 | +from amadeusgpt.analysis_objects.object import Object |
| 2 | +from numpy import ndarray |
| 3 | +from scipy.spatial import ConvexHull |
| 4 | +from typing import List, Dict, Any |
| 5 | +import matplotlib.path as mpath |
| 6 | +import numpy as np |
| 7 | + |
| 8 | + |
| 9 | +class Animal(Object): |
| 10 | + def get_keypoint_names(self): |
| 11 | + """ |
| 12 | + keypoint names should be the basic attributes |
| 13 | + """ |
| 14 | + pass |
| 15 | + |
| 16 | + def summary(self): |
| 17 | + print(self.__class__.__name__) |
| 18 | + for attr_name in self.__dict__: |
| 19 | + print(f"{attr_name} has {self.__dict__[attr_name]}") |
| 20 | + |
| 21 | + |
| 22 | +class AnimalSeq(Animal): |
| 23 | + """ |
| 24 | + Because we support passing bodyparts indices for initializing an AnimalSeq object, |
| 25 | + body center, left, right, above, top are relative to the subset of keypoints. |
| 26 | + Attributes |
| 27 | + ---------- |
| 28 | + self._coords: arr potentially subset of keypoints |
| 29 | + self.wholebody: full set of keypoints. This is important for overlap relationship |
| 30 | + """ |
| 31 | + |
| 32 | + def __init__(self, animal_name: str, keypoints: ndarray, keypoint_names: List[str]): |
| 33 | + self.name = animal_name |
| 34 | + self.whole_body: ndarray = keypoints |
| 35 | + self.keypoint_names = keypoint_names |
| 36 | + self._paths = [] |
| 37 | + self.state = {} |
| 38 | + self.kinematics_types = ["speed", "acceleration"] |
| 39 | + self.bodypart_relation = ["bodypart_pairwise_distance"] |
| 40 | + self.support_body_orientation = False |
| 41 | + self.support_head_orientation = False |
| 42 | + # self.keypoints are updated by indices of keypoint names given |
| 43 | + keypoint_indices = [keypoint_names.index(name) for name in keypoint_names] |
| 44 | + self.keypoints = self.whole_body[:, keypoint_indices] |
| 45 | + self.center = np.nanmedian(self.whole_body, axis=1) |
| 46 | + |
| 47 | + def update_roi_keypoint_by_names(self, keypoint_names: List[str]): |
| 48 | + # update self.keypoints based on keypoint names given |
| 49 | + if keypoint_names is None: |
| 50 | + return |
| 51 | + keypoint_indices = [self.keypoint_names.index(name) for name in keypoint_names] |
| 52 | + self.keypoints = self.whole_body[:, keypoint_indices] |
| 53 | + |
| 54 | + def restore_roi_keypoint(self): |
| 55 | + self.keypoints = self.whole_body |
| 56 | + |
| 57 | + def set_body_orientation_keypoints( |
| 58 | + self, body_orientation_keypoints: Dict[str, Any] |
| 59 | + ): |
| 60 | + self.neck_name = body_orientation_keypoints["neck"] |
| 61 | + self.tail_base_name = body_orientation_keypoints["tail_base"] |
| 62 | + self.animal_center_name = body_orientation_keypoints["animal_center"] |
| 63 | + self.support_body_orientation = True |
| 64 | + |
| 65 | + def set_head_orientation_keypoints( |
| 66 | + self, head_orientation_keypoints: Dict[str, Any] |
| 67 | + ): |
| 68 | + self.nose_name = head_orientation_keypoints["nose"] |
| 69 | + self.neck_name = head_orientation_keypoints["neck"] |
| 70 | + self.support_head_orientation = True |
| 71 | + |
| 72 | + # all the properties cannot be cached because update could happen |
| 73 | + def get_paths(self): |
| 74 | + paths = [] |
| 75 | + for ind in range(self.whole_body.shape[0]): |
| 76 | + paths.append(self.get_path(ind)) |
| 77 | + return paths |
| 78 | + |
| 79 | + def get_path(self, ind): |
| 80 | + xy = self.whole_body[ind] |
| 81 | + xy = np.nan_to_num(xy) |
| 82 | + if np.all(xy == 0): |
| 83 | + return None |
| 84 | + |
| 85 | + hull = ConvexHull(xy) |
| 86 | + vertices = hull.vertices |
| 87 | + path_data = [] |
| 88 | + path_data.append((mpath.Path.MOVETO, xy[vertices[0]])) |
| 89 | + for point in xy[vertices[1:]]: |
| 90 | + path_data.append((mpath.Path.LINETO, point)) |
| 91 | + path_data.append((mpath.Path.CLOSEPOLY, xy[vertices[0]])) |
| 92 | + codes, verts = zip(*path_data) |
| 93 | + return mpath.Path(verts, codes) |
| 94 | + |
| 95 | + def get_keypoints(self) -> ndarray: |
| 96 | + # the shape should be (n_frames, n_keypoints, 2) |
| 97 | + # extending to 3D? |
| 98 | + assert len(self.keypoints.shape) == 3, f"keypoints shape is {self.keypoints.shape}" |
| 99 | + return self.keypoints |
| 100 | + |
| 101 | + def get_center(self): |
| 102 | + """ |
| 103 | + median is more robust than mean |
| 104 | + """ |
| 105 | + return np.nanmedian(self.keypoints, axis=1).squeeze() |
| 106 | + |
| 107 | + def get_xmin(self): |
| 108 | + return np.nanmin(self.keypoints[..., 0], axis=1) |
| 109 | + |
| 110 | + def get_xmax(self): |
| 111 | + return np.nanmax(self.keypoints[..., 0], axis=1) |
| 112 | + |
| 113 | + def get_ymin(self): |
| 114 | + return np.nanmin(self.keypoints[..., 1], axis=1) |
| 115 | + |
| 116 | + def get_ymax(self): |
| 117 | + return np.nanmax(self.keypoints[..., 1], axis=1) |
| 118 | + |
| 119 | + def get_keypoint_names(self): |
| 120 | + return self.keypoint_names |
| 121 | + |
| 122 | + def query_states(self, query: str) -> ndarray: |
| 123 | + assert query in [ |
| 124 | + "speed", |
| 125 | + "acceleration_mag", |
| 126 | + "bodypart_pairwise_distance", |
| 127 | + ], f"{query} is not supported" |
| 128 | + |
| 129 | + if query == "speed": |
| 130 | + self.state[query] = self.get_speed() |
| 131 | + elif query == "acceleration_mag": |
| 132 | + self.state[query] = self.get_acceleration_mag() |
| 133 | + elif query == "bodypart_pairwise_distance": |
| 134 | + self.state[query] = self.get_bodypart_wise_relation() |
| 135 | + |
| 136 | + return self.state[query] |
| 137 | + |
| 138 | + def get_velocity(self) -> ndarray: |
| 139 | + keypoints = self.get_keypoints() |
| 140 | + velocity = np.diff(keypoints, axis=0) / 30 |
| 141 | + velocity = np.concatenate([np.zeros((1,) + velocity.shape[1:]), velocity]) |
| 142 | + assert len(velocity.shape) == 3 |
| 143 | + return velocity |
| 144 | + |
| 145 | + def get_speed(self) -> ndarray: |
| 146 | + keypoints = self.get_keypoints() |
| 147 | + velocity = ( |
| 148 | + np.diff(keypoints, axis=0) / 30 |
| 149 | + ) # divided by frame rate to get speed in pixels/second |
| 150 | + # Pad velocities to match the original shape |
| 151 | + velocity = np.concatenate([np.zeros((1,) + velocity.shape[1:]), velocity]) |
| 152 | + # Compute the speed from the velocity |
| 153 | + speed = np.linalg.norm(velocity, axis=-1) |
| 154 | + speed = np.expand_dims(speed, axis=-1) |
| 155 | + assert len(speed.shape) == 3 |
| 156 | + return speed |
| 157 | + |
| 158 | + def get_acceleration(self) -> ndarray: |
| 159 | + velocities = self.get_velocity() |
| 160 | + accelerations = ( |
| 161 | + np.diff(velocities, axis=0) / 30 |
| 162 | + ) # divided by frame rate to get acceleration in pixels/second^2 |
| 163 | + # Pad accelerations to match the original shape |
| 164 | + accelerations = np.concatenate( |
| 165 | + [np.zeros((1,) + accelerations.shape[1:]), accelerations], axis=0 |
| 166 | + ) |
| 167 | + assert len(accelerations.shape) == 3 |
| 168 | + return accelerations |
| 169 | + |
| 170 | + def get_acceleration_mag(self) -> ndarray: |
| 171 | + """ |
| 172 | + Returns the magnitude of the acceleration vector |
| 173 | + """ |
| 174 | + accelerations = self.get_acceleration() |
| 175 | + acceleration_mag = np.linalg.norm(accelerations, axis=-1) |
| 176 | + acceleration_mag = np.expand_dims(acceleration_mag, axis=-1) |
| 177 | + assert len(acceleration_mag.shape) == 3 |
| 178 | + return acceleration_mag |
| 179 | + |
| 180 | + def get_bodypart_wise_relation(self): |
| 181 | + keypoints = self.get_keypoints() |
| 182 | + diff = keypoints[..., np.newaxis, :, :] - keypoints[..., :, np.newaxis, :] |
| 183 | + sq_dist = np.sum(diff**2, axis=-1) |
| 184 | + distances = np.sqrt(sq_dist) |
| 185 | + return distances |
| 186 | + |
| 187 | + def get_body_cs( |
| 188 | + self, |
| 189 | + ): |
| 190 | + # this only works for topview |
| 191 | + neck_index = self.keypoint_names.index(self.neck_name) |
| 192 | + tailbase_index = self.keypoint_names.index(self.tail_base_name) |
| 193 | + neck = self.whole_body[:, neck_index] |
| 194 | + tailbase = self.whole_body[:, tailbase_index] |
| 195 | + body_axis = neck - tailbase |
| 196 | + body_axis_norm = body_axis / np.linalg.norm(body_axis, axis=1, keepdims=True) |
| 197 | + # Get a normal vector pointing left |
| 198 | + mediolat_axis_norm = body_axis_norm[:, [1, 0]].copy() |
| 199 | + mediolat_axis_norm[:, 0] *= -1 |
| 200 | + nrows = len(body_axis_norm) |
| 201 | + animal_cs = np.zeros((nrows, 3, 3)) |
| 202 | + rot = np.stack((body_axis_norm, mediolat_axis_norm), axis=2) |
| 203 | + animal_cs[:, :2, :2] = rot |
| 204 | + animal_center_index = self.keypoint_names.index(self.animal_center_name) |
| 205 | + animal_cs[:, :, 2] = np.c_[ |
| 206 | + self.whole_body[:, animal_center_index], np.ones(nrows) |
| 207 | + ] # center back |
| 208 | + |
| 209 | + return animal_cs |
| 210 | + |
| 211 | + def calc_head_cs(self): |
| 212 | + nose_index = self.keypoint_names.index(self.nose_name) |
| 213 | + nose = self.whole_body[:, nose_index] |
| 214 | + neck_index = self.keypoint_names.index(self.neck_name) |
| 215 | + neck = self.whole_body[:, neck_index] |
| 216 | + head_axis = nose - neck |
| 217 | + head_axis_norm = head_axis / np.linalg.norm(head_axis, axis=1, keepdims=True) |
| 218 | + # Get a normal vector pointing left |
| 219 | + mediolat_axis_norm = head_axis_norm[:, [1, 0]].copy() |
| 220 | + mediolat_axis_norm[:, 0] *= -1 |
| 221 | + nrows = len(head_axis_norm) |
| 222 | + mouse_cs = np.zeros((nrows, 3, 3)) |
| 223 | + rot = np.stack((head_axis_norm, mediolat_axis_norm), axis=2) |
| 224 | + mouse_cs[:, :2, :2] = rot |
| 225 | + mouse_cs[:, :, 2] = np.c_[neck, np.ones(nrows)] |
| 226 | + return mouse_cs |
| 227 | + |
| 228 | + |
| 229 | + |
| 230 | +if __name__ == "__main__": |
| 231 | + # unit testing the shape of kinematics data |
| 232 | + # acceleration, acceleration_mag, velocity, speed, and keypoints |
| 233 | + |
| 234 | + from amadeusgpt.config import Config |
| 235 | + from amadeusgpt.main import AMADEUS |
| 236 | + config = Config("/Users/shaokaiye/AmadeusGPT-dev/amadeusgpt/configs/MausHaus_template.yaml") |
| 237 | + amadeus = AMADEUS(config) |
| 238 | + analysis = amadeus.get_analysis() |
| 239 | + # get an instance of animal |
| 240 | + animal = analysis.animal_manager.get_animals()[0] |
| 241 | + |
| 242 | + print ("velocity shape", animal.get_velocity().shape) |
| 243 | + print ("speed shape", animal.get_speed().shape) |
| 244 | + print ("acceleration shape", animal.get_acceleration().shape) |
| 245 | + print ("acceleration_mag shape", animal.get_acceleration_mag().shape) |
| 246 | + |
| 247 | + print(animal.query_states("acceleration_mag").shape) |
| 248 | + |
0 commit comments