Skip to content

Commit 7bc04ae

Browse files
committed
Merge branch 'shaokai/add_notebooks'
2 parents d0ca881 + 39a00e1 commit 7bc04ae

23 files changed

+3060
-122
lines changed

amadeusgpt/analysis_objects/llm.py

Lines changed: 52 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,8 @@
88
import openai
99
from openai import OpenAI
1010
import base64
11+
import cv2
12+
import io
1113

1214
class LLM(AnalysisObject):
1315
total_tokens = 0
@@ -76,11 +78,8 @@ def connect_gpt_oai_1(self, messages, **kwargs):
7678
"model": self.gpt_model,
7779
"messages": messages,
7880
"max_tokens": self.max_tokens,
79-
"stop": None,
80-
"top_p": 1,
8181
"temperature": 0.0,
8282
}
83-
8483
response = client.chat.completions.create(**json_data)
8584

8685
LLM.total_tokens = LLM.total_tokens + response.usage.prompt_tokens + response.usage.completion_tokens
@@ -121,36 +120,32 @@ def update_history(self, role, content, encoded_image = None, replace=False):
121120
self.context_window.append({"role": role, "content": content})
122121
else:
123122

123+
if encoded_image is None:
124+
self.history.append({"role": role, "content": content})
125+
num_AI_messages = (len(self.context_window) - 1) // 2
126+
if num_AI_messages == self.keep_last_n_messages:
127+
print ("doing active forgetting")
128+
# we forget the oldest AI message and corresponding answer
129+
self.context_window.pop(1)
130+
self.context_window.pop(1)
131+
new_message = {"role": role, "content": content}
132+
else:
133+
new_message = {"role": "user", "content": [
134+
{"type": "text", "text": ""},
135+
{"type": "image_url", "image_url": {
136+
"url": f"data:image/jpeg;base64,{encoded_image}"}
137+
}
138+
]}
139+
140+
self.history.append(new_message)
141+
124142
if replace == True:
125-
if len(self.history) == 2:
126-
self.history[1]["content"] = content
127-
self.context_window[1]["content"] = content
143+
if len(self.context_window) == 2:
144+
self.context_window[1] = new_message
128145
else:
129-
self.history.append({"role": role, "content": content})
130-
self.context_window.append({"role": role, "content": content})
146+
self.context_window.append(new_message)
147+
131148

132-
else:
133-
if encoded_image is None:
134-
self.history.append({"role": role, "content": content})
135-
num_AI_messages = (len(self.context_window) - 1) // 2
136-
if num_AI_messages == self.keep_last_n_messages:
137-
print ("doing active forgetting")
138-
# we forget the oldest AI message and corresponding answer
139-
self.context_window.pop(1)
140-
self.context_window.pop(1)
141-
self.context_window.append({"role": role, "content": content})
142-
else:
143-
message = {
144-
"role": "user", "content": [
145-
{"type": "text", "text": content},
146-
{"type": "image_url", "image_url": {
147-
"url": f"data:image/png;base64,{encoded_image}"}
148-
}]
149-
}
150-
self.context_window.append(message)
151-
152-
153-
154149

155150
def clean_context_window(self):
156151
while len(self.context_window) > 1:
@@ -194,13 +189,26 @@ def speak(self, sandbox):
194189
195190
"""
196191

197-
from amadeusgpt.system_prompts.visual import _get_system_prompt
192+
from amadeusgpt.system_prompts.visual_llm import _get_system_prompt
198193
self.system_prompt = _get_system_prompt()
199194
analysis = sandbox.exec_namespace["behavior_analysis"]
200195
scene_image = analysis.visual_manager.get_scene_image()
201-
encoded_image = self.encode_image(scene_image)
202-
self.update_history("user", encoded_image)
203196

197+
result, buffer = cv2.imencode('.jpeg', scene_image)
198+
image_bytes = io.BytesIO(buffer)
199+
base64_image = base64.b64encode(image_bytes.getvalue()).decode('utf-8')
200+
self.update_history("system", self.system_prompt)
201+
self.update_history("user", "here is the image", encoded_image = base64_image, replace = True)
202+
response = self.connect_gpt(self.context_window, max_tokens=2000)
203+
text = response.choices[0].message.content.strip()
204+
print (text)
205+
pattern = r"```json(.*?)```"
206+
if len(re.findall(pattern, text, re.DOTALL)) == 0:
207+
raise ValueError("can't parse the json string correctly", text)
208+
else:
209+
json_string = re.findall(pattern, text, re.DOTALL)[0]
210+
json_obj = json.loads(json_string)
211+
return json_obj
204212

205213
class CodeGenerationLLM(LLM):
206214
"""
@@ -394,3 +402,14 @@ def speak(self, sandbox):
394402
function_code = re.findall(pattern, text, re.DOTALL)[0]
395403
qa_message["code"] = function_code
396404
qa_message["chain_of_thought"] = thought_process
405+
406+
407+
if __name__ == "__main__":
408+
from amadeusgpt.config import Config
409+
from amadeusgpt.main import create_amadeus
410+
config = Config("amadeusgpt/configs/EPM_template.yaml")
411+
412+
amadeus = create_amadeus(config)
413+
sandbox = amadeus.sandbox
414+
visualLLm = VisualLLM(config)
415+
visualLLm.speak(sandbox)

amadeusgpt/app.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@ def main():
2323
st.session_state["exist_valid_openai_api_key"] = True
2424
else:
2525
st.session_state["exist_valid_openai_api_key"] = False
26+
2627

2728
example_to_page = {}
2829

amadeusgpt/app_utils.py

Lines changed: 26 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -548,25 +548,28 @@ def render_page_by_example(example):
548548

549549
if example == "Custom":
550550
st.markdown("Provide your own video and keypoint file (in pairs)")
551-
uploaded_keypoint_file = st.file_uploader(
552-
"Choose keypoint files",
553-
["h5"],
554-
accept_multiple_files=False,
555-
)
556-
uploaded_video_file = st.file_uploader(
557-
"Choose video files",
558-
VIDEO_EXTS,
559-
accept_multiple_files=False,
560-
)
561-
562-
if uploaded_keypoint_file is not None:
563-
save_dir = os.path.join("examples", example)
564-
config["keypoint_info"]["keypoint_file_path"] = save_uploaded_file(
565-
uploaded_keypoint_file, save_dir
551+
save_dir = os.path.join("examples", example)
552+
if "uploaded_keypoint_file" not in st.session_state:
553+
uploaded_keypoint_file = st.file_uploader(
554+
"Choose keypoint files",
555+
["h5"],
556+
accept_multiple_files=False,
566557
)
567-
config["video_info"]["video_file_path"] = save_uploaded_file(
568-
uploaded_video_file, save_dir
558+
if uploaded_keypoint_file is not None:
559+
path = save_uploaded_file(uploaded_keypoint_file, save_dir)
560+
st.session_state["uploaded_keypoint_file"] = path
561+
config["keypoint_info"]["keypoint_file_path"] = st.session_state["uploaded_keypoint_file"]
562+
563+
if "uploaded_video_file" not in st.session_state:
564+
uploaded_video_file = st.file_uploader(
565+
"Choose video files",
566+
VIDEO_EXTS,
567+
accept_multiple_files=False,
569568
)
569+
if uploaded_video_file is not None:
570+
path = save_uploaded_file(uploaded_video_file, save_dir)
571+
st.session_state["uploaded_video_file"] = uploaded_video_file
572+
config["video_info"]["video_file_path"] = st.session_state["uploaded_video_file"]
570573

571574
###### USER INPUT PANEL ######
572575
# get user input once getting the uploaded files
@@ -648,7 +651,6 @@ def render_page_by_example(example):
648651
st.session_state["example"] = example
649652

650653
scene_image_path = get_scene_image(config)
651-
video_file = config["video_info"]["video_file_path"]
652654

653655
if scene_image_path is not None:
654656
img_data = base64.b64decode(scene_image_path)
@@ -673,8 +675,12 @@ def render_page_by_example(example):
673675
st.caption("Raw video from Horse-30")
674676
else:
675677
st.caption("DeepLabCut-SuperAnimal tracked video")
676-
if video_file:
677-
st.video(video_file)
678+
if config["video_info"]["video_file_path"] and config["video_info"]["video_file_path"] is not None:
679+
st.video(config["video_info"]["video_file_path"])
680+
681+
if "uploaded_video_file" in st.session_state:
682+
st.video(st.session_state["uploaded_video_file"])
683+
678684
# we only show objects for MausHaus for demo
679685
# if sam_image is not None:
680686
# st.caption("SAM segmentation results")

amadeusgpt/main.py

Lines changed: 17 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@
1111
import os
1212

1313
from amadeusgpt.analysis_objects.llm import (CodeGenerationLLM, DiagnosisLLM,
14-
SelfDebugLLM)
14+
SelfDebugLLM, VisualLLM)
1515
from amadeusgpt.integration_module_hub import IntegrationModuleHub
1616

1717
amadeus_fac = {}
@@ -31,6 +31,7 @@ def __init__(self, config: Dict[str, Any]):
3131
self.code_generator_llm = CodeGenerationLLM(config.get("llm_info", {}))
3232
self.self_debug_llm = SelfDebugLLM(config.get("llm_info", {}))
3333
self.diagnosis_llm = DiagnosisLLM(config.get("llm_info", {}))
34+
self.visual_llm = VisualLLM(config.get("llm_info", {}))
3435
### fields that decide the behavior of the application
3536
self.use_self_debug = True
3637
self.use_diagnosis = False
@@ -47,11 +48,15 @@ def __init__(self, config: Dict[str, Any]):
4748

4849
## register the llm to the sandbox
4950
self.sandbox.register_llm("code_generator", self.code_generator_llm)
51+
self.sandbox.register_llm("visual_llm", self.visual_llm)
5052
if self.use_self_debug:
5153
self.sandbox.register_llm("self_debug", self.self_debug_llm)
5254
if self.use_diagnosis:
5355
self.sandbox.register_llm("diagnosis", self.diagnosis_llm)
5456

57+
# can only do this after the register process
58+
self.sandbox.configure_using_vlm()
59+
5560
def match_integration_module(self, user_query: str):
5661
"""
5762
Return a list of matched integration modules
@@ -80,21 +85,19 @@ def step(self, user_query):
8085
result = self.sandbox.llm_step(user_query)
8186
return result
8287

88+
def get_analysis(self):
89+
sandbox = self.sandbox
90+
analysis = sandbox.exec_namespace['behavior_analysis']
91+
return analysis
8392

84-
if __name__ == "__main__":
85-
config = Config("amadeusgpt/configs/MausHaus_template.yaml")
8693

87-
# amadeus = AMADEUS(config)
88-
# query = "Give me events when mice are close"
89-
# amadeus.step(query)
94+
if __name__ == "__main__":
95+
from amadeusgpt.config import Config
96+
from amadeusgpt.main import create_amadeus
97+
from amadeusgpt.analysis_objects.llm import VisualLLM
98+
config = Config("amadeusgpt/configs/EPM_template.yaml")
9099

91-
query = "Plot the trajectory with the keypoint butt"
92100
amadeus = create_amadeus(config)
93101
sandbox = amadeus.sandbox
94-
analysis = sandbox.exec_namespace["behavior_analysis"]
95-
96-
analysis.object_manager.load_roi_objects("temp_roi_objects.pickle")
97-
98-
from amadeusgpt.programs.sandbox import render_temp_message
99-
100-
render_temp_message(query, sandbox)
102+
visualLLm = VisualLLM(config)
103+
visualLLm.speak(sandbox)

amadeusgpt/managers/animal_manager.py

Lines changed: 29 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -80,9 +80,24 @@ def __init__(self, config: Dict[str, str], model_manager: ModelManager):
8080
self.model_manager = model_manager
8181
self.animals: List[AnimalSeq] = []
8282
self.full_keypoint_names = []
83-
83+
self.superanimal_predicted_video = None
8484
self.init_pose()
8585

86+
def configure_animal_from_meta(self, meta_info):
87+
"""
88+
Set the max individuals here
89+
Set the superanimal model here
90+
"""
91+
self.max_individuals = int(meta_info['individuals'])
92+
species = meta_info['species']
93+
if species == 'topview_mouse':
94+
self.superanimal_name = 'superanimal_topviewmouse_hrnetw32'
95+
elif species == 'sideview_quadruped':
96+
self.superanimal_name = 'superanimal_quadruped_hrnetw32'
97+
else:
98+
self.superanimal_name = None
99+
100+
86101
def init_pose(self):
87102
keypoint_info = self.config["keypoint_info"]
88103

@@ -224,18 +239,26 @@ def get_keypoints(self) -> ndarray:
224239
video_file_path = self.config['video_info']['video_file_path']
225240
if os.path.exists(video_file_path) and keypoint_file_path is None:
226241

242+
if self.superanimal_name is None:
243+
raise ValueError("Couldn't determine the species of the animal from the image. Change the scene index")
244+
245+
# only import here because people who choose the minimal installation might not have deeplabcut
227246
import deeplabcut
228-
from deeplabcut.modelzoo.video_inference import video_inference_superanimal
229-
superanimal_name = 'superanimal_topviewmouse_hrnetw32'
247+
from deeplabcut.modelzoo.video_inference import video_inference_superanimal
248+
video_suffix = Path(video_file_path).suffix
249+
250+
keypoint_file_path = video_file_path.replace(video_suffix, '_' + self.superanimal_name + '.h5')
251+
self.superanimal_predicted_video = keypoint_file_path.replace('.h5', '_labeled.mp4')
230252

231-
keypoint_file_path = video_file_path.replace('.mp4', '_' + superanimal_name + '.h5')
232253
if not os.path.exists(keypoint_file_path):
254+
print (f"going to inference video with {self.superanimal_name}")
233255
video_inference_superanimal(videos = [self.config['video_info']['video_file_path']],
234-
superanimal_name = superanimal_name,
256+
superanimal_name = self.superanimal_name,
257+
max_individuals=self.max_individuals,
235258
video_adapt = False)
259+
236260

237261
if os.path.exists(keypoint_file_path):
238-
239262
self.config['keypoint_info']['keypoint_file_path'] = keypoint_file_path
240263
self.init_pose()
241264

amadeusgpt/managers/gui_manager.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
register_core_api)
1212

1313
from .base import Manager
14+
from amadeusgpt.analysis_objects.object import ROIObject
1415
from .object_manager import ObjectManager
1516

1617

@@ -50,9 +51,9 @@ def onselect(self, vertices):
5051

5152
# Here you can add any further processing of the polygons
5253
self.object_manager.roi_objects = []
53-
self.object_manager.add_roi_object(self.paths)
54+
for idx, path in enumerate(self.paths):
55+
self.object_manager.add_roi_object(ROIObject(f'ROI{idx}', path))
5456

55-
print(len(self.object_manager.roi_objects))
5657
# Assuming the object_manager's add_roi_object is meant to handle the completed polygons
5758

5859

amadeusgpt/managers/object_manager.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -60,6 +60,7 @@ def load_objects_from_disk(self):
6060
pass
6161

6262
def get_roi_object_names(self) -> List[str]:
63+
6364
return [obj.name for obj in self.roi_objects]
6465

6566
def get_roi_objects(self) -> List[Object]:
@@ -80,9 +81,7 @@ def add_roi_object(self, object) -> None:
8081
self.roi_objects = self.filter_duplicates(self.roi_objects)
8182

8283
def save_roi_objects(self, path: str) -> None:
83-
roi_obects = self.get_roi_objects()
84-
for roi in roi_obects:
85-
print(roi.name)
84+
roi_obects = self.get_roi_objects()
8685
data = {}
8786
for obj in roi_obects:
8887
data[obj.name] = {"Path": obj.Path}

0 commit comments

Comments
 (0)