Skip to content

Commit e22d746

Browse files
committed
Fixed some sort of rebase bug
1 parent a1ab4c3 commit e22d746

File tree

10 files changed

+61
-86
lines changed

10 files changed

+61
-86
lines changed

amadeusgpt/analysis_objects/llm.py

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -198,17 +198,22 @@ def speak(self, sandbox):
198198
self.system_prompt = _get_system_prompt()
199199
analysis = sandbox.exec_namespace["behavior_analysis"]
200200
scene_image = analysis.visual_manager.get_scene_image()
201-
encoded_image = self.encode_image(scene_image)
202-
self.update_history("user", encoded_image)
201+
result, buffer = cv2.imencode('.jpeg', scene_image)
202+
image_bytes = io.BytesIO(buffer)
203+
base64_image = base64.b64encode(image_bytes.getvalue()).decode('utf-8')
203204

205+
self.update_history("system", self.system_prompt)
206+
self.update_history("user", "here is the image", encoded_image = base64_image, replace = True)
207+
response = self.connect_gpt(self.context_window, max_tokens=2000)
208+
text = response.choices[0].message.content.strip()
204209
print (text)
205210
pattern = r"```json(.*?)```"
206211
if len(re.findall(pattern, text, re.DOTALL)) == 0:
207212
raise ValueError("can't parse the json string correctly", text)
208213
else:
209214
json_string = re.findall(pattern, text, re.DOTALL)[0]
210215
json_obj = json.loads(json_string)
211-
return json_obj
216+
return json_obj
212217

213218
class CodeGenerationLLM(LLM):
214219
"""

amadeusgpt/analysis_objects/object.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -298,7 +298,7 @@ def set_body_orientation_keypoints(
298298

299299
def set_head_orientation_keypoints(
300300
self, head_orientation_keypoints: Dict[str, Any]
301-
):
301+
):
302302
self.nose_name = head_orientation_keypoints["nose"]
303303
self.neck_name = head_orientation_keypoints["neck"]
304304
self.support_head_orientation = True

amadeusgpt/analysis_objects/relationship.py

Lines changed: 16 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -92,7 +92,6 @@ def get_name(self):
9292
return self.__name__
9393

9494
def query_relationship(self, query_name: str) -> ndarray:
95-
9695
ret = self.data[query_name]
9796
return ret
9897

@@ -244,16 +243,25 @@ def _animal_animal_relationship(
244243
head_cs_inv, receiver_animal.get_center()
245244
)
246245

247-
relative_velocity = (
248-
sender_animal.get_velocity() - receiver_animal.get_velocity()
249-
)
250-
relative_velocity_magnitude = np.linalg.norm(relative_velocity, axis=2)
251-
# Then, average these magnitudes over all keypoints for each frame
252-
relative_speed = np.nanmean(relative_velocity_magnitude, axis=1)
246+
# relative_velocity = (
247+
# sender_animal.get_velocity() - receiver_animal.get_velocity()
248+
# )
249+
# relative_velocity_magnitude = np.linalg.norm(relative_velocity, axis=2)
250+
# # Then, average these magnitudes over all keypoints for each frame
251+
# relative_speed = np.nanmean(relative_velocity_magnitude, axis=1)
252+
253+
sender_pos = sender_animal.get_center()
254+
receiver_pos = receiver_animal.get_center()
255+
direction_vector = receiver_pos - sender_pos
256+
sender_velocity = np.nanmean(sender_animal.get_velocity(), axis = 1)
257+
norm_direction_vector = direction_vector / np.linalg.norm(direction_vector)
258+
relative_speed = np.einsum('ij,ij->i', sender_velocity, norm_direction_vector)
259+
253260
closest_distance = np.nanmin(
254261
get_pairwise_distance(sender_animal.keypoints, receiver_animal.keypoints),
255262
axis=(1, 2),
256263
)
264+
print ('relative_speed', relative_speed.mean())
257265
ret = {
258266
"distance": distance,
259267
"overlap": overlap,
@@ -267,6 +275,6 @@ def _animal_animal_relationship(
267275
if angles is not None:
268276
ret["relative_angle"] = angles
269277
if orientation is not None:
270-
ret["orientation"] = orientation
278+
ret["orientation"] = orientation
271279

272280
return ret

amadeusgpt/main.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -89,6 +89,9 @@ def get_analysis(self):
8989
sandbox = self.sandbox
9090
analysis = sandbox.exec_namespace['behavior_analysis']
9191
return analysis
92+
93+
def run_task_program(self, task_program_name: str):
94+
return self.sandbox.run_task_program(task_program_name)
9295

9396

9497
if __name__ == "__main__":

amadeusgpt/managers/animal_manager.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -123,7 +123,8 @@ def init_pose(self):
123123
animalseq.set_body_orientation_keypoints(
124124
self.config["keypoint_info"]["body_orientation_keypoints"]
125125
)
126-
if "head_orientation_keypoints" in self.config["keypoint_info"]:
126+
127+
if "head_orientation_keypoints" in self.config["keypoint_info"]:
127128
animalseq.set_head_orientation_keypoints(
128129
self.config["keypoint_info"]["head_orientation_keypoints"]
129130
)

amadeusgpt/managers/event_manager.py

Lines changed: 6 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,10 @@
11
import re
22
from typing import Any, Dict, List, Literal, Optional, Set, Union
3-
43
import numpy as np
5-
6-
from amadeusgpt.analysis_objects import event
74
from amadeusgpt.analysis_objects.event import BaseEvent, Event, EventGraph
85
from amadeusgpt.analysis_objects.relationship import Orientation, Relationship
96
from amadeusgpt.programs.api_registry import (register_class_methods,
107
register_core_api)
11-
from amadeusgpt.utils import timer_decorator
128

139
from .animal_manager import AnimalManager
1410
from .base import Manager, cache_decorator
@@ -182,8 +178,9 @@ def get_animals_state_events(
182178
self.animal_manager.update_roi_keypoint_by_names(bodypart_names)
183179
ret_events = []
184180
pattern = r"(==|<=|>=|<|>)"
185-
comparison_operator = re.findall(pattern, query)[0]
186-
query_name = query.split(comparison_operator)[0]
181+
# note we need to strip off the spaces
182+
comparison_operator = re.findall(pattern, query)[0].strip()
183+
query_name = query.split(comparison_operator)[0].strip()
187184
comparison = comparison_operator + "".join(query.split(comparison_operator)[1:])
188185

189186
for sender_animal_name in self.animal_manager.get_animal_names():
@@ -233,7 +230,6 @@ def get_events_from_relationship(
233230

234231
mask = relationship.query_relationship(relation_query)
235232

236-
#print ('mask', mask)
237233
# determine whether the mask is a numpy of float or numpy of boolean
238234

239235
if mask.dtype != bool:
@@ -318,9 +314,10 @@ def get_animals_animals_events(
318314
for query in cross_animal_query_list:
319315
# assert that query must contain one of the operators
320316
# find the operator
317+
# note we need to strip the spaces
318+
comparison_operator = re.findall(pattern, query)[0].strip()
319+
_query = query.split(comparison_operator)[0].strip()
321320

322-
comparison_operator = re.findall(pattern, query)[0]
323-
_query = query.split(comparison_operator)[0]
324321
_comparison = comparison_operator + "".join(
325322
query.split(comparison_operator)[1:]
326323
)

amadeusgpt/managers/visual_manager.py

Lines changed: 1 addition & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,15 +1,11 @@
11
import glob
22
import os
3-
from calendar import c
4-
from sched import Event
5-
from typing import Any, Dict, List, Optional, Union
3+
from typing import Any, Dict, List, Optional
64

75
import cv2
86
import matplotlib.pyplot as plt
97
import numpy as np
10-
import streamlit as st
118
from matplotlib.patches import Wedge
12-
139
from amadeusgpt.analysis_objects.event import BaseEvent
1410
from amadeusgpt.analysis_objects.object import AnimalSeq
1511
from amadeusgpt.analysis_objects.visualization import (EventVisualization,

amadeusgpt/programs/sandbox.py

Lines changed: 13 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -6,16 +6,12 @@
66
import traceback
77
import typing
88
from functools import wraps
9-
109
import matplotlib.pyplot as plt
1110
import numpy as np
12-
1311
from amadeusgpt.analysis_objects.analysis_factory import create_analysis
1412
from amadeusgpt.analysis_objects.event import BaseEvent
1513
from amadeusgpt.analysis_objects.relationship import Orientation
1614
from amadeusgpt.config import Config
17-
from amadeusgpt.implementation import AnimalBehaviorAnalysis
18-
from amadeusgpt.managers import visual_manager
1915
from amadeusgpt.programs.api_registry import (CORE_API_REGISTRY,
2016
INTEGRATION_API_REGISTRY)
2117
from amadeusgpt.programs.task_program_registry import (TaskProgram,
@@ -462,12 +458,24 @@ def llm_step(self, user_query):
462458

463459

464460
return qa_message
461+
462+
def run_task_program(self, task_program_name):
463+
"""
464+
Sandbox is also responsible for running task program
465+
"""
466+
task_program = self.task_program_library[task_program_name]
467+
self.query = "run the task program"
468+
qa_message = create_message(self.query, self)
469+
qa_message["code"] = task_program["source_code"]
470+
self.messages.append(qa_message)
471+
self.code_execution(qa_message)
472+
qa_message = self.render_qa_message(qa_message)
473+
return qa_message
465474

466475
def step(self, user_query, number_of_debugs=1):
467476
qa_message = create_message(user_query, self)
468477
self.messages.append(qa_message)
469478

470-
post_process_llm = ["self_debug"]
471479
self.query = user_query
472480
self.llms["code_generator"].speak(self)
473481
# all these llms collectively compose a amadeus_answer

amadeusgpt/programs/task_program_registry.py

Lines changed: 10 additions & 52 deletions
Original file line numberDiff line numberDiff line change
@@ -1,23 +1,8 @@
11
import ast
22
import json
3-
import os
4-
import typing
53
from collections import defaultdict
6-
from typing import Any, Callable, Dict, List
7-
8-
from numpy import ndarray
9-
10-
from amadeusgpt.config import Config
11-
from amadeusgpt.implementation import AnimalBehaviorAnalysis
4+
from typing import Any, Callable
125
from amadeusgpt.utils import func2json
13-
14-
required_classes = {"AnimalBehaviorAnalysis": AnimalBehaviorAnalysis, "Config": Config}
15-
required_types = {
16-
name: getattr(typing, name) for name in dir(typing) if not name.startswith("_")
17-
}
18-
required_types.update({"ndarray": ndarray})
19-
20-
216
class TaskProgram:
227
"""
238
The task program in the system should be uniquely tracked by the id
@@ -46,9 +31,7 @@ def task_program_name(config) -> List[BaseEvent]:
4631
__call__(): should take the context and run the program in a sandbox.
4732
In the future we use docker container to run it
4833
49-
"""
50-
51-
exec_namespace = None
34+
"""
5235
cache = defaultdict(dict)
5336

5437
def __init__(
@@ -76,10 +59,17 @@ def __init__(
7659
self.json_obj["parents"] = parents
7760
self.json_obj["mutation_from"] = mutation_from
7861
self.json_obj["generation"] = generation
62+
7963

8064
def __setitem__(self, key, value):
8165
self.json_obj[key] = value
8266

67+
def display(self):
68+
print (self.json_obj['name'],
69+
self.json_obj['source_code'],
70+
self.json_obj['description'])
71+
72+
8373
def __getitem__(self, key):
8474
"""
8575
{
@@ -130,40 +120,15 @@ def deserialize(self):
130120

131121
def validate(self):
132122
pass
133-
134-
def __call__(self, config) -> Any:
135-
namespace = TaskProgram.exec_namespace
136-
function_name = self.json_obj["name"]
137-
keypoint_file = config["keypoint_info"]["keypoint_file_path"]
138-
keypoint_type = keypoint_file.split(".")[-1]
139-
videoname = keypoint_file.split("/")[-1].replace(keypoint_type, "")
140-
if self.json_obj["source_code"] is not None:
141-
exec(self.json_obj["source_code"], namespace)
142-
function = namespace[function_name]
143-
else:
144-
assert self.json_obj["func_pointer"] is not None
145-
function = self.json_obj["func_pointer"]
146-
namespace[function_name] = function
147-
call_str = f"{function_name}(config)"
148-
if videoname in TaskProgram.cache[function_name]:
149-
return TaskProgram.cache[function_name][videoname]
150-
else:
151-
exec(f"result = {call_str}", namespace)
152-
result = namespace["result"]
153-
154-
return result
155-
123+
156124

157125
class TaskProgramLibrary:
158126
"""
159127
Keep track of the task programs
160128
There are following types of task programs:
161129
1) Custom task programs that are created by the user (can be loaded from disk)
162130
2) Task programs that are created by LLMs
163-
164-
165131
"""
166-
167132
LIBRARY = {}
168133

169134
@classmethod
@@ -217,13 +182,6 @@ def get_task_programs(cls):
217182
"""
218183
return cls.LIBRARY
219184

220-
@classmethod
221-
def bind_exec_namespace(cls, exec_namespace):
222-
"""
223-
For task programs to execute, we need to bind the namespace
224-
"""
225-
TaskProgram.exec_namespace = exec_namespace
226-
227185
@classmethod
228186
def save(cls, out_path):
229187
ret = []

notebooks/EPM_demo.ipynb

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -14,8 +14,7 @@
1414
"import amadeusgpt\n",
1515
"from pathlib import Path\n",
1616
"import matplotlib.pyplot as plt\n",
17-
"import cv2\n",
18-
"from amadeusgpt.managers.gui_manager import ROISelector"
17+
"import cv2"
1918
]
2019
},
2120
{

0 commit comments

Comments
 (0)