Skip to content

Commit 5396098

Browse files
committed
Change result cache to defaultdict(dict) and added a new notebook for task program library
1 parent d6eeba6 commit 5396098

File tree

6 files changed

+392
-16
lines changed

6 files changed

+392
-16
lines changed

amadeusgpt/analysis_objects/analysis_factory.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
from amadeusgpt.implementation import AnimalBehaviorAnalysis
22

3+
34
analysis_fac = {}
45

56
def create_analysis(config):

amadeusgpt/configs/MABe_template.yaml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,5 +12,5 @@ llm_info:
1212
object_info:
1313
load_objects_from_disk: false
1414
video_info:
15-
scene_frame_number: 100
15+
scene_frame_number: 1400
1616
video_file_path: "examples/MABe/EGS8X2MN4SSUGFWAV976.mp4"

amadeusgpt/main.py

Lines changed: 11 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@
1717
from amadeusgpt.analysis_objects.llm import (CodeGenerationLLM, DiagnosisLLM,
1818
SelfDebugLLM, VisualLLM)
1919
from amadeusgpt.integration_module_hub import IntegrationModuleHub
20-
20+
from collections import defaultdict
2121
import pickle
2222

2323
class AMADEUS:
@@ -84,11 +84,16 @@ def get_analysis(self):
8484
analysis = self.sandbox.exec_namespace["behavior_analysis"]
8585
return analysis
8686

87-
def run_task_program(self, task_program_name: str):
87+
def run_task_program(self, config: Config, task_program_name: str):
8888
"""
8989
Execute the task program on the currently holding sandbox
90+
Parameters
91+
-----------
92+
config: a config specifies the movie file and the keypoint file to run task program
93+
task_program_name: the name of the task program to run
94+
9095
"""
91-
return self.sandbox.run_task_program(task_program_name)
96+
return self.sandbox.run_task_program(config, task_program_name)
9297

9398
def save_results(self, out_folder: str| None = None):
9499
"""
@@ -102,9 +107,10 @@ def save_results(self, out_folder: str| None = None):
102107
os.makedirs(result_folder, exist_ok=True)
103108
results = self.sandbox.result_cache
104109

105-
ret = {}
110+
ret = defaultdict(dict)
106111
for query in results:
107-
ret[query] = results[query].get_serializable()
112+
for video_file_path in results[query]:
113+
ret[query][video_file_path] = results[query][video_file_path].get_serializable()
108114

109115
# save results to a pickle file
110116
with open (os.path.join(result_folder, "results.pickle"), "wb") as f:

amadeusgpt/managers/visual_manager.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -673,8 +673,6 @@ def write_video(self, out_folder, video_file_path, out_name, events):
673673
# Release everything when job is finished
674674
cap.release()
675675
cv2.destroyAllWindows()
676-
print("out videos" * 10)
677-
print(out_videos)
678676
return out_videos
679677

680678
def generate_video_clips_from_events(

amadeusgpt/programs/sandbox.py

Lines changed: 23 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@
1717
from amadeusgpt.programs.task_program_registry import (TaskProgram,
1818
TaskProgramLibrary)
1919
from pathlib import Path
20-
20+
from collections import defaultdict
2121

2222
class QA_Message(dict):
2323
def __init__(self, *args, **kwargs):
@@ -233,8 +233,17 @@ def __init__(self, config):
233233
# just easier to pass this around
234234
self.query = None
235235
self.matched_modules = []
236-
# result cache keeps the qa_message using the query as the key:
237-
self.result_cache = {}
236+
# example result_cahe
237+
"""
238+
{'query' :
239+
{
240+
'file1.mp4': QA_Message(),
241+
'file2.mp4': QA_Message(),
242+
}
243+
}
244+
"""
245+
246+
self.result_cache = defaultdict(dict)
238247
# configure how to save the results to a result folder
239248
self.result_folder = Path(self.config["result_info"].get("result_folder", "./results"))
240249

@@ -497,24 +506,30 @@ def llm_step(self, user_query):
497506
self.messages.append(qa_message)
498507
# there might be better way to set this
499508
self.query = user_query
500-
self.llms["code_generator"].speak(self)
501-
self.result_cache[user_query] = qa_message
509+
self.llms["code_generator"].speak(self)
510+
self.result_cache[user_query][self.config['video_info']['video_file_path']] = qa_message
502511
return qa_message
503512

504-
def run_task_program(self, task_program_name):
513+
def run_task_program(self, config: Config, task_program_name: str):
505514
"""
506515
1) sandbox is also responsible for running task program
507516
2) self.task_program_library references to a singleton so a different sandbox still has reference to the task program
508517
"""
518+
# update the config
519+
self.config = config
520+
509521
task_program = self.task_program_library[task_program_name]
510522
# there might be better way to set this
511523
self.query = task_program_name
512524
qa_message = create_message(self.query, self)
513525
qa_message["code"] = task_program["source_code"]
514526
self.messages.append(qa_message)
527+
528+
# code execution will use the latest config, if updated
515529
self.code_execution(qa_message)
530+
516531
qa_message = self.render_qa_message(qa_message)
517-
self.result_cache[task_program_name] = qa_message
532+
self.result_cache[task_program_name][config['video_info']['video_file_path']] = qa_message
518533
return qa_message
519534

520535
def step(self, user_query, number_of_debugs=1):
@@ -539,7 +554,7 @@ def step(self, user_query, number_of_debugs=1):
539554
qa_message = self.code_execution(qa_message)
540555

541556
qa_message = self.render_qa_message(qa_message)
542-
self.result_cache[user_query] = qa_message
557+
self.result_cache[user_query][self.config['video_info']['video_file_path']] = qa_message
543558
return qa_message
544559

545560

notebooks/Use_Task_Program.ipynb

Lines changed: 356 additions & 0 deletions
Large diffs are not rendered by default.

0 commit comments

Comments
 (0)