Skip to content

Commit b4bc285

Browse files
committed
separated animal from object
1 parent dbdcf05 commit b4bc285

File tree

7 files changed

+298
-238
lines changed

7 files changed

+298
-238
lines changed
Lines changed: 224 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,224 @@
1+
from amadeusgpt.analysis_objects.object import Object
2+
from numpy import ndarray
3+
from scipy.spatial import ConvexHull
4+
from typing import List, Dict, Any
5+
import matplotlib.path as mpath
6+
import numpy as np
7+
8+
9+
class Animal(Object):
10+
def get_keypoint_names(self):
11+
"""
12+
keypoint names should be the basic attributes
13+
"""
14+
pass
15+
16+
def summary(self):
17+
print(self.__class__.__name__)
18+
for attr_name in self.__dict__:
19+
print(f"{attr_name} has {self.__dict__[attr_name]}")
20+
21+
22+
class AnimalSeq(Animal):
23+
"""
24+
Because we support passing bodyparts indices for initializing an AnimalSeq object,
25+
body center, left, right, above, top are relative to the subset of keypoints.
26+
Attributes
27+
----------
28+
self._coords: arr potentially subset of keypoints
29+
self.wholebody: full set of keypoints. This is important for overlap relationship
30+
"""
31+
32+
def __init__(self, animal_name: str, keypoints: ndarray, keypoint_names: List[str]):
33+
self.name = animal_name
34+
self.whole_body: ndarray = keypoints
35+
self.keypoint_names = keypoint_names
36+
self._paths = []
37+
self.state = {}
38+
self.kinematics_types = ["speed", "acceleration"]
39+
self.bodypart_relation = ["bodypart_pairwise_distance"]
40+
self.support_body_orientation = False
41+
self.support_head_orientation = False
42+
# self.keypoints are updated by indices of keypoint names given
43+
keypoint_indices = [keypoint_names.index(name) for name in keypoint_names]
44+
self.keypoints = self.whole_body[:, keypoint_indices]
45+
self.center = np.nanmedian(self.whole_body, axis=1)
46+
47+
def update_roi_keypoint_by_names(self, keypoint_names: List[str]):
48+
# update self.keypoints based on keypoint names given
49+
if keypoint_names is None:
50+
return
51+
keypoint_indices = [self.keypoint_names.index(name) for name in keypoint_names]
52+
self.keypoints = self.whole_body[:, keypoint_indices]
53+
54+
def restore_roi_keypoint(self):
55+
self.keypoints = self.whole_body
56+
57+
def set_body_orientation_keypoints(
58+
self, body_orientation_keypoints: Dict[str, Any]
59+
):
60+
self.neck_name = body_orientation_keypoints["neck"]
61+
self.tail_base_name = body_orientation_keypoints["tail_base"]
62+
self.animal_center_name = body_orientation_keypoints["animal_center"]
63+
self.support_body_orientation = True
64+
65+
def set_head_orientation_keypoints(
66+
self, head_orientation_keypoints: Dict[str, Any]
67+
):
68+
self.nose_name = head_orientation_keypoints["nose"]
69+
self.neck_name = head_orientation_keypoints["neck"]
70+
self.support_head_orientation = True
71+
72+
# all the properties cannot be cached because update could happen
73+
def get_paths(self):
74+
paths = []
75+
for ind in range(self.whole_body.shape[0]):
76+
paths.append(self.get_path(ind))
77+
return paths
78+
79+
def get_path(self, ind):
80+
xy = self.whole_body[ind]
81+
xy = np.nan_to_num(xy)
82+
if np.all(xy == 0):
83+
return None
84+
85+
hull = ConvexHull(xy)
86+
vertices = hull.vertices
87+
path_data = []
88+
path_data.append((mpath.Path.MOVETO, xy[vertices[0]]))
89+
for point in xy[vertices[1:]]:
90+
path_data.append((mpath.Path.LINETO, point))
91+
path_data.append((mpath.Path.CLOSEPOLY, xy[vertices[0]]))
92+
codes, verts = zip(*path_data)
93+
return mpath.Path(verts, codes)
94+
95+
def get_keypoints(self, average_keypoints=False) -> ndarray:
96+
if average_keypoints:
97+
return np.nanmedian(self.keypoints, axis=1)
98+
return self.keypoints
99+
100+
def get_center(self):
101+
return np.nanmedian(self.keypoints, axis=1).squeeze()
102+
103+
def get_xmin(self):
104+
return np.nanmin(self.keypoints[..., 0], axis=1)
105+
106+
def get_xmax(self):
107+
return np.nanmax(self.keypoints[..., 0], axis=1)
108+
109+
def get_ymin(self):
110+
return np.nanmin(self.keypoints[..., 1], axis=1)
111+
112+
def get_ymax(self):
113+
return np.nanmax(self.keypoints[..., 1], axis=1)
114+
115+
def get_keypoint_names(self):
116+
return self.keypoint_names
117+
118+
def query_states(self, query: str) -> ndarray:
119+
assert query in [
120+
"speed",
121+
"acceleration",
122+
"bodypart_pairwise_distance",
123+
], f"{query} is not supported"
124+
125+
if query == "speed":
126+
self.state[query] = self.get_speed()
127+
elif query == "acceleration_mag":
128+
self.state[query] = self.get_acceleration_mag()
129+
elif query == "bodypart_pairwise_distance":
130+
self.state[query] = self.get_bodypart_wise_relation()
131+
132+
return self.state[query]
133+
134+
def get_velocity(self) -> ndarray:
135+
keypoints = self.get_keypoints()
136+
velocity = np.diff(keypoints, axis=0) / 30
137+
velocity = np.concatenate([np.zeros((1,) + velocity.shape[1:]), velocity])
138+
assert len(velocity.shape) == 3
139+
return velocity
140+
141+
def get_speed(self) -> ndarray:
142+
keypoints = self.get_keypoints()
143+
velocity = (
144+
np.diff(keypoints, axis=0) / 30
145+
) # divided by frame rate to get speed in pixels/second
146+
# Pad velocities to match the original shape
147+
velocity = np.concatenate([np.zeros((1,) + velocity.shape[1:]), velocity])
148+
# Compute the speed from the velocity
149+
speed = np.linalg.norm(velocity, axis=-1)
150+
speed = np.expand_dims(speed, axis=-1)
151+
assert len(speed.shape) == 3
152+
return speed
153+
154+
def get_acceleration(self) -> ndarray:
155+
velocities = self.get_velocity()
156+
accelerations = (
157+
np.diff(velocities, axis=0) / 30
158+
) # divided by frame rate to get acceleration in pixels/second^2
159+
# Pad accelerations to match the original shape
160+
accelerations = np.concatenate(
161+
[np.zeros((1,) + accelerations.shape[1:]), accelerations], axis=0
162+
)
163+
assert len(accelerations.shape) == 3
164+
return accelerations
165+
166+
def get_acceleration_mag(self) -> ndarray:
167+
"""
168+
Returns the magnitude of the acceleration vector
169+
"""
170+
accelerations = self.get_acceleration()
171+
acceleration_mag = np.linalg.norm(accelerations, axis=-1)
172+
assert len(acceleration_mag.shape) == 2
173+
return acceleration_mag
174+
175+
def get_bodypart_wise_relation(self):
176+
keypoints = self.get_keypoints()
177+
diff = keypoints[..., np.newaxis, :, :] - keypoints[..., :, np.newaxis, :]
178+
sq_dist = np.sum(diff**2, axis=-1)
179+
distances = np.sqrt(sq_dist)
180+
return distances
181+
182+
def get_body_cs(
183+
self,
184+
):
185+
# this only works for topview
186+
neck_index = self.keypoint_names.index(self.neck_name)
187+
tailbase_index = self.keypoint_names.index(self.tail_base_name)
188+
neck = self.whole_body[:, neck_index]
189+
tailbase = self.whole_body[:, tailbase_index]
190+
body_axis = neck - tailbase
191+
body_axis_norm = body_axis / np.linalg.norm(body_axis, axis=1, keepdims=True)
192+
# Get a normal vector pointing left
193+
mediolat_axis_norm = body_axis_norm[:, [1, 0]].copy()
194+
mediolat_axis_norm[:, 0] *= -1
195+
nrows = len(body_axis_norm)
196+
animal_cs = np.zeros((nrows, 3, 3))
197+
rot = np.stack((body_axis_norm, mediolat_axis_norm), axis=2)
198+
animal_cs[:, :2, :2] = rot
199+
animal_center_index = self.keypoint_names.index(self.animal_center_name)
200+
animal_cs[:, :, 2] = np.c_[
201+
self.whole_body[:, animal_center_index], np.ones(nrows)
202+
] # center back
203+
204+
return animal_cs
205+
206+
def calc_head_cs(self):
207+
nose_index = self.keypoint_names.index(self.nose_name)
208+
nose = self.whole_body[:, nose_index]
209+
neck_index = self.keypoint_names.index(self.neck_name)
210+
neck = self.whole_body[:, neck_index]
211+
head_axis = nose - neck
212+
head_axis_norm = head_axis / np.linalg.norm(head_axis, axis=1, keepdims=True)
213+
# Get a normal vector pointing left
214+
mediolat_axis_norm = head_axis_norm[:, [1, 0]].copy()
215+
mediolat_axis_norm[:, 0] *= -1
216+
nrows = len(head_axis_norm)
217+
mouse_cs = np.zeros((nrows, 3, 3))
218+
rot = np.stack((head_axis_norm, mediolat_axis_norm), axis=2)
219+
mouse_cs[:, :2, :2] = rot
220+
mouse_cs[:, :, 2] = np.c_[neck, np.ones(nrows)]
221+
return mouse_cs
222+
223+
224+

amadeusgpt/analysis_objects/event.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -384,6 +384,8 @@ def init_from_list(cls, events: List[BaseEvent]) -> "EventGraph":
384384

385385
return graph
386386

387+
388+
387389
@classmethod
388390
def init_from_mask(
389391
cls,

amadeusgpt/analysis_objects/llm.py

Lines changed: 8 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -64,7 +64,7 @@ def connect_gpt_oai_1(self, messages, **kwargs):
6464
if "gpt_model" in st.session_state:
6565
self.gpt_model = st.session_state["gpt_model"]
6666

67-
configurable_params = ["gpt_model", "max_tokens"]
67+
configurable_params = ["gpt_model", "max_tokens", "temperature"]
6868

6969
for param in configurable_params:
7070
if param in kwargs:
@@ -281,19 +281,17 @@ def speak(self, sandbox):
281281

282282
qa_message["chain_of_thought"] = thought_process
283283

284-
def update_system_prompt(self, sandbox):
284+
def get_system_prompt(self, sandbox):
285285
from amadeusgpt.system_prompts.code_generator import _get_system_prompt
286286

287-
# get the formatted docs / blocks from the sandbox
288-
core_api_docs = sandbox.get_core_api_docs()
289-
task_program_docs = sandbox.get_task_program_docs()
290-
query_block = sandbox.get_query_block()
287+
return _get_system_prompt(
288+
sandbox
289+
)
291290

292-
behavior_analysis = sandbox.exec_namespace["behavior_analysis"]
291+
def update_system_prompt(self, sandbox):
293292

294-
self.system_prompt = _get_system_prompt(
295-
query_block, core_api_docs, task_program_docs, behavior_analysis
296-
)
293+
# get the formatted docs / blocks from the sandbox
294+
self.system_prompt = self.get_system_prompt(sandbox)
297295

298296
# update both history and context window
299297
self.update_history("system", self.system_prompt)

0 commit comments

Comments
 (0)