66import traceback
77import typing
88from functools import wraps
9-
109import matplotlib .pyplot as plt
1110import numpy as np
12-
1311from amadeusgpt .analysis_objects .analysis_factory import create_analysis
1412from amadeusgpt .analysis_objects .event import BaseEvent
1513from amadeusgpt .analysis_objects .relationship import Orientation
1816 INTEGRATION_API_REGISTRY )
1917from amadeusgpt .programs .task_program_registry import (TaskProgram ,
2018 TaskProgramLibrary )
19+ from pathlib import Path
20+
21+
22+ class QA_Message (dict ):
23+ def __init__ (self , * args , ** kwargs ):
24+ super (QA_Message , self ).__init__ (* args , ** kwargs )
25+
26+ def get_masks (self ):
27+ function_rets = self ["function_rets" ]
28+ # if function_ret is a list of events
29+ if (
30+ isinstance (function_rets , list )
31+ and len (function_rets ) > 0
32+ and isinstance (function_rets [0 ], BaseEvent )
33+ ):
34+ events = function_rets
35+ masks = []
36+ for event in events :
37+ masks .append (event .generate_mask ())
38+ return np .array (masks )
39+ else :
40+ raise ValueError ("No events found in the function_rets" )
41+
42+ def get_serializable (self ):
43+ """
44+ Only part of qa messages are serializable.
45+ """
46+ selected_keys = ['query' , 'code' , 'chain_of_thought' , 'function_rets' , 'meta_info' ]
47+ ret = {}
48+ for key in selected_keys :
49+ ret [key ] = self [key ]
50+ return ret
2151
2252
2353def create_message (query , sandbox ):
24- return {
54+ return QA_Message ( {
2555 "query" : query ,
2656 "code" : None ,
2757 "chain_of_thought" : None ,
@@ -32,7 +62,7 @@ def create_message(query, sandbox):
3262 "out_videos" : None ,
3363 "pose_video" : None ,
3464 "meta_info" : None ,
35- }
65+ })
3666
3767
3868class SandboxBase :
@@ -192,16 +222,21 @@ def __init__(self, config):
192222 self .task_program_library = TaskProgramLibrary ().get_task_programs ()
193223 self .config = config
194224 self .messages = []
225+ # initialize the code execution namespace with builtins
195226 self .exec_namespace = {"__builtins__" : __builtins__ }
196227 # update_namespace initializes behavior analysis
197228 self .update_namespace ()
198229 # then we can configure behavior analysis using vlm
199230 self .meta_info = None
200- self . visual_cache = {}
231+ # where llms are stored
201232 self .llms = {}
202233 # just easier to pass this around
203234 self .query = None
204235 self .matched_modules = []
236+ # result cache keeps the qa_message using the query as the key:
237+ self .result_cache = {}
238+ # configure how to save the results to a result folder
239+ self .result_folder = Path (self .config ["result_info" ].get ("result_folder" , "./results" ))
205240
206241 def configure_using_vlm (self ):
207242 # example meta_info:
@@ -322,9 +357,6 @@ def update_namespace(self):
322357 # to allow the program to access existing task programs
323358 self .exec_namespace ["task_programs" ] = TaskProgramLibrary .get_task_programs ()
324359
325- def parse_function_results (self , function_rets ):
326- pass
327-
328360 def code_execution (self , qa_message ):
329361 # add main function into the namespace
330362 self .update_namespace ()
@@ -388,7 +420,8 @@ def register_llm(self, name, llm):
388420 def events_to_videos (self , events , function_name ):
389421 behavior_analysis = self .exec_namespace ["behavior_analysis" ]
390422 visual_manager = behavior_analysis .visual_manager
391- out_folder = "event_clips"
423+ # save video clips to the result folder
424+ out_folder = str (self .result_folder )
392425 os .makedirs (out_folder , exist_ok = True )
393426 behavior_name = "_" .join (function_name .split (" " ))
394427 video_file = self .config ["video_info" ]["video_file_path" ]
@@ -449,9 +482,7 @@ def render_qa_message(self, qa_message):
449482 qa_message ["out_videos" ] = self .events_to_videos (
450483 function_rets , self .get_function_name_from_string (qa_message ["code" ])
451484 )
452-
453- else :
454- pass
485+
455486 qa_message ["plots" ].extend (plots )
456487 return qa_message
457488
@@ -463,22 +494,25 @@ def llm_step(self, user_query):
463494 qa_message ["meta_info" ] = self .meta_info
464495
465496 self .messages .append (qa_message )
497+ # there might be better way to set this
466498 self .query = user_query
467499 self .llms ["code_generator" ].speak (self )
468-
500+ self . result_cache [ user_query ] = qa_message
469501 return qa_message
470502
471503 def run_task_program (self , task_program_name ):
472504 """
473505 Sandbox is also responsible for running task program
474506 """
475507 task_program = self .task_program_library [task_program_name ]
476- self .query = "run the task program"
508+ # there might be better way to set this
509+ self .query = task_program_name
477510 qa_message = create_message (self .query , self )
478511 qa_message ["code" ] = task_program ["source_code" ]
479512 self .messages .append (qa_message )
480513 self .code_execution (qa_message )
481514 qa_message = self .render_qa_message (qa_message )
515+ self .result_cache [task_program_name ] = qa_message
482516 return qa_message
483517
484518 def step (self , user_query , number_of_debugs = 1 ):
@@ -503,7 +537,7 @@ def step(self, user_query, number_of_debugs=1):
503537 qa_message = self .code_execution (qa_message )
504538
505539 qa_message = self .render_qa_message (qa_message )
506-
540+ self . result_cache [ user_query ] = qa_message
507541 return qa_message
508542
509543
0 commit comments