Skip to content

Commit ad13918

Browse files
committed
Better support for retreiving results and save / load results
1 parent e9873d2 commit ad13918

File tree

9 files changed

+403
-306
lines changed

9 files changed

+403
-306
lines changed
Lines changed: 3 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,16 +1,11 @@
11
keypoint_info:
2-
keypoint_file_path:
3-
body_orientation_keypoints:
4-
neck: "nose"
5-
tail_base: "tail_base"
6-
animal_center: "neck"
7-
head_orientation_keypoints:
8-
nose: "nose"
9-
neck: "neck"
2+
keypoint_file_path:
103
llm_info:
114
keep_last_n_messages: 2
125
object_info:
136
load_objects_from_disk: false
147
video_info:
158
scene_frame_number: 100
169
video_file_path:
10+
result_info:
11+
result_folder: "./results"

amadeusgpt/integration_module_hub.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,3 @@
1-
import glob
21
import os
32
import pickle
43
from typing import Dict, List, Optional

amadeusgpt/main.py

Lines changed: 36 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,8 @@
1818
SelfDebugLLM, VisualLLM)
1919
from amadeusgpt.integration_module_hub import IntegrationModuleHub
2020

21+
import pickle
22+
2123
class AMADEUS:
2224
def __init__(self, config: Dict[str, Any]):
2325
self.config = config
@@ -81,16 +83,42 @@ def get_analysis(self):
8183

8284
def run_task_program(self, task_program_name: str):
8385
return self.sandbox.run_task_program(task_program_name)
84-
86+
87+
def save_results(self, out_folder: str| None = None):
88+
"""
89+
Save the results of the qa message (since it has all the information needed)
90+
"""
91+
if out_folder is None:
92+
result_folder = self.sandbox.result_folder
93+
else:
94+
result_folder = out_folder
95+
# make sure it exists
96+
os.makedirs(result_folder, exist_ok=True)
97+
results = self.sandbox.result_cache
98+
99+
ret = {}
100+
for query in results:
101+
ret[query] = results[query].get_serializable()
102+
103+
# save results to a pickle file
104+
with open (os.path.join(result_folder, "results.pickle"), "wb") as f:
105+
pickle.dump(ret, f)
106+
107+
def load_results(self, result_folder: str | None = None ):
108+
if result_folder is None:
109+
result_folder = self.sandbox.result_folder
110+
else:
111+
result_folder = result_folder
112+
113+
with open (os.path.join(result_folder, "results.pickle"), "rb") as f:
114+
results = pickle.load(f)
115+
self.sandbox.result_cache = results
116+
117+
def get_results(self):
118+
return self.sandbox.result_cache
85119

86120
if __name__ == "__main__":
87121
from amadeusgpt.analysis_objects.llm import VisualLLM
88122
from amadeusgpt.config import Config
89-
90-
91123
config = Config("amadeusgpt/configs/EPM_template.yaml")
92-
93-
amadeus = AMADEUS(config)
94-
sandbox = amadeus.sandbox
95-
visualLLm = VisualLLM(config)
96-
visualLLm.speak(sandbox)
124+
amadeus = AMADEUS(config)

amadeusgpt/programs/mabe_social_mabe.py

Lines changed: 0 additions & 156 deletions
This file was deleted.

amadeusgpt/programs/sandbox.py

Lines changed: 49 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -6,10 +6,8 @@
66
import traceback
77
import typing
88
from functools import wraps
9-
109
import matplotlib.pyplot as plt
1110
import numpy as np
12-
1311
from amadeusgpt.analysis_objects.analysis_factory import create_analysis
1412
from amadeusgpt.analysis_objects.event import BaseEvent
1513
from amadeusgpt.analysis_objects.relationship import Orientation
@@ -18,10 +16,42 @@
1816
INTEGRATION_API_REGISTRY)
1917
from amadeusgpt.programs.task_program_registry import (TaskProgram,
2018
TaskProgramLibrary)
19+
from pathlib import Path
20+
21+
22+
class QA_Message(dict):
23+
def __init__(self, *args, **kwargs):
24+
super(QA_Message, self).__init__(*args, **kwargs)
25+
26+
def get_masks(self):
27+
function_rets = self["function_rets"]
28+
# if function_ret is a list of events
29+
if (
30+
isinstance(function_rets, list)
31+
and len(function_rets) > 0
32+
and isinstance(function_rets[0], BaseEvent)
33+
):
34+
events = function_rets
35+
masks = []
36+
for event in events:
37+
masks.append(event.generate_mask())
38+
return np.array(masks)
39+
else:
40+
raise ValueError("No events found in the function_rets")
41+
42+
def get_serializable(self):
43+
"""
44+
Only part of qa messages are serializable.
45+
"""
46+
selected_keys = ['query', 'code', 'chain_of_thought', 'function_rets', 'meta_info']
47+
ret = {}
48+
for key in selected_keys:
49+
ret[key] = self[key]
50+
return ret
2151

2252

2353
def create_message(query, sandbox):
24-
return {
54+
return QA_Message({
2555
"query": query,
2656
"code": None,
2757
"chain_of_thought": None,
@@ -32,7 +62,7 @@ def create_message(query, sandbox):
3262
"out_videos": None,
3363
"pose_video": None,
3464
"meta_info": None,
35-
}
65+
})
3666

3767

3868
class SandboxBase:
@@ -192,16 +222,21 @@ def __init__(self, config):
192222
self.task_program_library = TaskProgramLibrary().get_task_programs()
193223
self.config = config
194224
self.messages = []
225+
# initialize the code execution namespace with builtins
195226
self.exec_namespace = {"__builtins__": __builtins__}
196227
# update_namespace initializes behavior analysis
197228
self.update_namespace()
198229
# then we can configure behavior analysis using vlm
199230
self.meta_info = None
200-
self.visual_cache = {}
231+
# where llms are stored
201232
self.llms = {}
202233
# just easier to pass this around
203234
self.query = None
204235
self.matched_modules = []
236+
# result cache keeps the qa_message using the query as the key:
237+
self.result_cache = {}
238+
# configure how to save the results to a result folder
239+
self.result_folder = Path(self.config["result_info"].get("result_folder", "./results"))
205240

206241
def configure_using_vlm(self):
207242
# example meta_info:
@@ -322,9 +357,6 @@ def update_namespace(self):
322357
# to allow the program to access existing task programs
323358
self.exec_namespace["task_programs"] = TaskProgramLibrary.get_task_programs()
324359

325-
def parse_function_results(self, function_rets):
326-
pass
327-
328360
def code_execution(self, qa_message):
329361
# add main function into the namespace
330362
self.update_namespace()
@@ -388,7 +420,8 @@ def register_llm(self, name, llm):
388420
def events_to_videos(self, events, function_name):
389421
behavior_analysis = self.exec_namespace["behavior_analysis"]
390422
visual_manager = behavior_analysis.visual_manager
391-
out_folder = "event_clips"
423+
# save video clips to the result folder
424+
out_folder = str(self.result_folder)
392425
os.makedirs(out_folder, exist_ok=True)
393426
behavior_name = "_".join(function_name.split(" "))
394427
video_file = self.config["video_info"]["video_file_path"]
@@ -449,9 +482,7 @@ def render_qa_message(self, qa_message):
449482
qa_message["out_videos"] = self.events_to_videos(
450483
function_rets, self.get_function_name_from_string(qa_message["code"])
451484
)
452-
453-
else:
454-
pass
485+
455486
qa_message["plots"].extend(plots)
456487
return qa_message
457488

@@ -463,22 +494,25 @@ def llm_step(self, user_query):
463494
qa_message["meta_info"] = self.meta_info
464495

465496
self.messages.append(qa_message)
497+
# there might be better way to set this
466498
self.query = user_query
467499
self.llms["code_generator"].speak(self)
468-
500+
self.result_cache[user_query] = qa_message
469501
return qa_message
470502

471503
def run_task_program(self, task_program_name):
472504
"""
473505
Sandbox is also responsible for running task program
474506
"""
475507
task_program = self.task_program_library[task_program_name]
476-
self.query = "run the task program"
508+
# there might be better way to set this
509+
self.query = task_program_name
477510
qa_message = create_message(self.query, self)
478511
qa_message["code"] = task_program["source_code"]
479512
self.messages.append(qa_message)
480513
self.code_execution(qa_message)
481514
qa_message = self.render_qa_message(qa_message)
515+
self.result_cache[task_program_name] = qa_message
482516
return qa_message
483517

484518
def step(self, user_query, number_of_debugs=1):
@@ -503,7 +537,7 @@ def step(self, user_query, number_of_debugs=1):
503537
qa_message = self.code_execution(qa_message)
504538

505539
qa_message = self.render_qa_message(qa_message)
506-
540+
self.result_cache[user_query] = qa_message
507541
return qa_message
508542

509543

0 commit comments

Comments
 (0)