Skip to content

Commit 6860338

Browse files
Shaokai/release 0.1.2 (#52)
* corrected broken links and broken references * Update setup.cfg - bump to stable v1 and v0.1.1 * Update pyproject.toml * Update version.py * Corrected typo. Added config yamls in setup * Removed config files that are no longer needed * changed work from to pull the repo from git * Added comments to remind people to pay attentino to data folder in the demo notebooks * fixed pypi typo * Fixed a bug in create_project. Changed default use_vlm to False. Updated demo notebooks * removed WIP 3d keypoints * Fixed one more * WIP * enforcing the use of create_project in demo notebooks and modified the test * 3D supported. Better tests. More flexible identifier * black and isort * added dlc to test requirement * Made test use stronger gpt. Added dummy video * easier superanimal test * Better 3D prompt and fixed self-debug * preventing infinite loop * better prompt for 3D * better prompt for 3D * better prompt * updates * fixed serialization * extension to support animation. Made self-debugging work with bigger output. Allowing to skip code execution in parse result * better interpolation and corrected x,y,z convention * incorporated suggestions --------- Co-authored-by: Mackenzie Mathis <[email protected]>
1 parent af4a0a5 commit 6860338

34 files changed

+609
-331
lines changed

.github/workflows/pytest.yml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,7 @@ jobs:
3636
run: |
3737
python -m pip install --upgrade pip
3838
pip install pytest numpy==1.23.5 tables==3.8.0
39+
pip install deeplabcut==3.0.0rc4
3940
pip install pytest
4041
pip install pytest-timeout
4142
if [ -f requirements.txt ]; then pip install -r requirements.txt; fi

amadeusgpt/__init__.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,8 +9,9 @@
99

1010
from amadeusgpt.integration_modules import *
1111
from amadeusgpt.main import AMADEUS
12-
from amadeusgpt.version import VERSION, __version__
1312
from amadeusgpt.project import create_project
13+
from amadeusgpt.version import VERSION, __version__
14+
1415
params = {
1516
"axes.labelsize": 10,
1617
"legend.fontsize": 10,

amadeusgpt/analysis_objects/animal.py

Lines changed: 9 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,6 @@
44
import numpy as np
55
from numpy import ndarray
66
from scipy.spatial import ConvexHull
7-
87
from amadeusgpt.analysis_objects.object import Object
98

109

@@ -27,8 +26,8 @@ class AnimalSeq(Animal):
2726
body center, left, right, above, top are relative to the subset of keypoints.
2827
Attributes
2928
----------
30-
self._coords: arr potentially subset of keypoints
31-
self.wholebody: full set of keypoints. This is important for overlap relationship
29+
self.wholebody: np.ndarray of keypoints of all bodyparts
30+
self.keypoint
3231
"""
3332

3433
def __init__(self, animal_name: str, keypoints: ndarray, keypoint_names: List[str]):
@@ -95,8 +94,6 @@ def get_path(self, ind):
9594
return mpath.Path(verts, codes)
9695

9796
def get_keypoints(self, average_keypoints=False) -> ndarray:
98-
# the shape should be (n_frames, n_keypoints, 2)
99-
# extending to 3D?
10097
assert (
10198
len(self.keypoints.shape) == 3
10299
), f"keypoints shape is {self.keypoints.shape}"
@@ -123,8 +120,15 @@ def get_ymin(self):
123120
def get_ymax(self):
124121
return np.nanmax(self.keypoints[..., 1], axis=1)
125122

123+
def get_zmin(self):
124+
return np.nanmin(self.keypoints[..., 2], axis=1)
125+
126+
def get_zmax(self):
127+
return np.nanmax(self.keypoints[..., 2], axis=1)
128+
126129
def get_keypoint_names(self):
127130
return self.keypoint_names
131+
128132

129133
def query_states(self, query: str) -> ndarray:
130134
assert query in [

amadeusgpt/analysis_objects/llm.py

Lines changed: 22 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -27,8 +27,9 @@ class LLM(AnalysisObject):
2727

2828
def __init__(self, config):
2929
self.config = config
30-
self.max_tokens = config.get("max_tokens", 4096)
31-
self.gpt_model = config.get("gpt_model", "gpt-4o-mini")
30+
31+
self.max_tokens = config["llm_info"].get("max_tokens", 4096)
32+
self.gpt_model = config["llm_info"].get("gpt_model", "gpt-4o-mini")
3233
self.keep_last_n_messages = config.get("keep_last_n_messages", 2)
3334

3435
# the list that is actually sent to gpt
@@ -261,8 +262,8 @@ def speak(self, sandbox: Sandbox, image: np.ndarray):
261262
response = self.connect_gpt(self.context_window, max_tokens=2000)
262263
text = response.choices[0].message.content.strip()
263264

264-
print ('description of the image frame provided')
265-
print (text)
265+
print("description of the image frame provided")
266+
print(text)
266267

267268
pattern = r"```json(.*?)```"
268269
if len(re.findall(pattern, text, re.DOTALL)) == 0:
@@ -293,24 +294,26 @@ def speak(
293294
task_program_docs = sandbox.get_task_program_docs()
294295

295296
if share_video_file:
296-
video_file_path = sandbox.video_file_paths[0]
297+
identifier = sandbox.identifiers[0]
297298
else:
298299
raise NotImplementedError("This is not implemented yet")
299300

300-
behavior_analysis = sandbox.analysis_dict[video_file_path]
301+
behavior_analysis = sandbox.analysis_dict[identifier]
301302
scene_image = behavior_analysis.visual_manager.get_scene_image()
302303
keypoint_names = behavior_analysis.animal_manager.get_keypoint_names()
303304
object_names = behavior_analysis.object_manager.get_object_names()
304-
animal_names = behavior_analysis.animal_manager.get_animal_names()
305-
305+
animal_names = behavior_analysis.animal_manager.get_animal_names()
306+
use_3d = sandbox.config['keypoint_info'].get('use_3d', False)
307+
306308
self.system_prompt = _get_system_prompt(
307309
core_api_docs,
308310
task_program_docs,
309311
scene_image,
310312
keypoint_names,
311313
object_names,
312314
animal_names,
313-
)
315+
use_3d=use_3d,
316+
)
314317

315318
self.update_history("system", self.system_prompt)
316319

@@ -338,6 +341,13 @@ def speak(
338341
with open("temp_answer.json", "w") as f:
339342
obj = {}
340343
obj["chain_of_thought"] = text
344+
obj["code"] = function_code
345+
obj["video_file_paths"] = sandbox.video_file_paths
346+
obj["keypoint_file_paths"] = sandbox.keypoint_file_paths
347+
if not isinstance(sandbox.config, dict):
348+
obj["config"] = sandbox.config.to_dict()
349+
else:
350+
obj["config"] = sandbox.config
341351
json.dump(obj, f, indent=4)
342352

343353
return qa_message
@@ -361,21 +371,18 @@ def speak(self, qa_message):
361371
query = f""" The code that caused error was {code}
362372
And the error message was {error_message}.
363373
All the modules were already imported so you don't need to import them again.
364-
Can you correct the code?
374+
Can you correct the code? Make sure you only write one function which is the updated function.
365375
"""
366376
self.update_history("user", query)
367-
response = self.connect_gpt(self.context_window, max_tokens=700)
377+
response = self.connect_gpt(self.context_window, max_tokens=4096)
368378
text = response.choices[0].message.content.strip()
369-
370379
print(text)
371-
372380
pattern = r"```python(.*?)```"
373381
function_code = re.findall(pattern, text, re.DOTALL)[0]
374-
375382
qa_message.code = function_code
376-
377383
qa_message.chain_of_thought = text
378384

385+
return qa_message
379386

380387
if __name__ == "__main__":
381388
from amadeusgpt.config import Config

amadeusgpt/analysis_objects/object.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
import matplotlib.path as mpath
22
import numpy as np
3+
34
from .base import AnalysisObject
45

56

@@ -141,6 +142,7 @@ def __init__(self, name: str, masks: dict):
141142
_seg: dict = self.masks.get("segmentation")
142143
# this is rle format
143144
from pycocotools import mask as mask_decoder
145+
144146
if "counts" in _seg:
145147
self.segmentation = mask_decoder.decode(_seg)
146148
else:

amadeusgpt/analysis_objects/relationship.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -48,13 +48,14 @@ def calc_angle_between_2d_coordinate_systems(cs1, cs2):
4848
return np.rad2deg(np.arccos(dot))
4949

5050

51-
def get_pairwise_distance(arr1, arr2):
51+
def get_pairwise_distance(arr1: np.ndarray, arr2: np.ndarray):
5252
# we want to make sure this uses a fast implementation
53-
# (n_frame, n_kpts, 2)
53+
# arr: (n_frame, n_kpts, 2)
5454
assert len(arr1.shape) == 3 and len(arr2.shape) == 3
5555
# pariwise distance (n_frames, n_kpts, n_kpts)
5656
pairwise_distances = np.ones((arr1.shape[0], arr1.shape[1], arr2.shape[1])) * 100000
5757
for frame_id in range(arr1.shape[0]):
58+
# should we use the mean of all keypoints for the distance?
5859
pairwise_distances[frame_id] = cdist(arr1[frame_id], arr2[frame_id])
5960

6061
return pairwise_distances

amadeusgpt/analysis_objects/visualization.py

Lines changed: 42 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
from matplotlib.figure import Figure
1313
from matplotlib.ticker import FuncFormatter
1414
from mpl_toolkits.axes_grid1 import make_axes_locatable
15+
from mpl_toolkits.mplot3d import Axes3D
1516
from PIL import Image
1617
from scipy.signal import medfilt
1718

@@ -125,7 +126,8 @@ def draw(self, **kwargs) -> None:
125126

126127
self._draw_seg_objects()
127128
self._draw_roi_objects()
128-
self.axs.imshow(self.scene_frame)
129+
if self.scene_frame is not None:
130+
self.axs.imshow(self.scene_frame)
129131

130132

131133
class KeypointVisualization(MatplotlibVisualization):
@@ -284,37 +286,46 @@ def _event_plot_trajectory(self, **kwargs):
284286
masked_data = medfilt(masked_data, kernel_size=(k, 1))
285287
if masked_data.shape[0] == 0:
286288
continue
287-
x, y = masked_data[:, 0], masked_data[:, 1]
288-
x = x[x.nonzero()]
289-
y = y[y.nonzero()]
290-
if len(x) < 1:
291-
continue
292289

293-
scatter = self.axs.plot(
294-
x,
295-
y,
296-
label=f"event{event_id}",
297-
color=line_colors[event_id],
298-
alpha=0.5,
299-
)
300-
scatter = self.axs.scatter(
301-
x[0],
302-
y[0],
303-
marker="*",
304-
s=100,
305-
color=line_colors[event_id],
306-
alpha=0.5,
307-
**kwargs,
308-
)
309-
self.axs.scatter(
310-
x[-1],
311-
y[-1],
312-
marker="x",
313-
s=100,
314-
color=line_colors[event_id],
315-
alpha=0.5,
316-
**kwargs,
317-
)
290+
if not kwargs.get("use_3d", False):
291+
x, y = masked_data[:, 0], masked_data[:, 1]
292+
_mask = (x != 0) & (y != 0)
293+
294+
x = x[_mask]
295+
y = y[_mask]
296+
if len(x) < 1:
297+
continue
298+
299+
scatter = self.axs.plot(
300+
x,
301+
y,
302+
label=f"event{event_id}",
303+
color=line_colors[event_id],
304+
alpha=0.5,
305+
)
306+
scatter = self.axs.scatter(
307+
x[0],
308+
y[0],
309+
marker="*",
310+
s=100,
311+
color=line_colors[event_id],
312+
alpha=0.5,
313+
**kwargs,
314+
)
315+
self.axs.scatter(
316+
x[-1],
317+
y[-1],
318+
marker="x",
319+
s=100,
320+
color=line_colors[event_id],
321+
alpha=0.5,
322+
**kwargs,
323+
)
324+
else:
325+
# TODO
326+
# implement 3d event plot
327+
pass
328+
318329
return self.axs
319330

320331
def display(self):

amadeusgpt/behavior_analysis/analysis_factory.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,6 @@
77

88
def create_analysis(identifier: Identifier):
99

10-
if str(identifier) not in analysis_fac:
11-
analysis_fac[str(identifier)] = AnimalBehaviorAnalysis(identifier)
12-
return analysis_fac[str(identifier)]
10+
if identifier not in analysis_fac:
11+
analysis_fac[identifier] = AnimalBehaviorAnalysis(identifier)
12+
return analysis_fac[identifier]

amadeusgpt/behavior_analysis/identifier.py

Lines changed: 21 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -11,17 +11,34 @@ class Identifier:
1111
Can be more in the future
1212
"""
1313

14-
def __init__(self, config: Config, video_file_path: str, keypoint_file_path: str):
14+
def __init__(
15+
self, config: Config | dict, video_file_path: str, keypoint_file_path: str
16+
):
1517

1618
self.config = config
1719
self.video_file_path = video_file_path
1820
self.keypoint_file_path = keypoint_file_path
1921

2022
def __str__(self):
21-
return os.path.abspath(self.video_file_path)
23+
return f"""------
24+
video_file_path: {self.video_file_path}
25+
keypoint_file_path: {self.keypoint_file_path}
26+
config: {self.config}
27+
------
28+
"""
2229

2330
def __eq__(self, other):
24-
return self.video_file_path == other.video_file_path
31+
if os.path.exists(self.video_file_path):
32+
return os.path.abspath(self.video_file_path) == os.path.abspath(
33+
other.video_file_path
34+
)
35+
else:
36+
return os.path.abspath(self.keypoint_file_path) == os.path.abspath(
37+
other.keypoint_file_path
38+
)
2539

2640
def __hash__(self):
27-
return hash(self.video_file_path)
41+
if os.path.exists(self.video_file_path):
42+
return hash(os.path.abspath(self.video_file_path))
43+
else:
44+
return hash(os.path.abspath(self.keypoint_file_path))

amadeusgpt/config.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,9 @@ def __repr__(self):
2222
def __setitem__(self, key, value):
2323
self.data[key] = value
2424

25+
def to_dict(self):
26+
return self.data
27+
2528
def load_config(self):
2629
# Load the YAML config file
2730
if os.path.exists(self.config_file_path):

0 commit comments

Comments
 (0)