1717from amadeusgpt .programs .task_program_registry import (TaskProgram ,
1818 TaskProgramLibrary )
1919from pathlib import Path
20-
20+ from collections import defaultdict
2121
2222class QA_Message (dict ):
2323 def __init__ (self , * args , ** kwargs ):
@@ -233,8 +233,17 @@ def __init__(self, config):
233233 # just easier to pass this around
234234 self .query = None
235235 self .matched_modules = []
236- # result cache keeps the qa_message using the query as the key:
237- self .result_cache = {}
236+ # example result_cahe
237+ """
238+ {'query' :
239+ {
240+ 'file1.mp4': QA_Message(),
241+ 'file2.mp4': QA_Message(),
242+ }
243+ }
244+ """
245+
246+ self .result_cache = defaultdict (dict )
238247 # configure how to save the results to a result folder
239248 self .result_folder = Path (self .config ["result_info" ].get ("result_folder" , "./results" ))
240249
@@ -497,24 +506,30 @@ def llm_step(self, user_query):
497506 self .messages .append (qa_message )
498507 # there might be better way to set this
499508 self .query = user_query
500- self .llms ["code_generator" ].speak (self )
501- self .result_cache [user_query ] = qa_message
509+ self .llms ["code_generator" ].speak (self )
510+ self .result_cache [user_query ][ self . config [ 'video_info' ][ 'video_file_path' ]] = qa_message
502511 return qa_message
503512
504- def run_task_program (self , task_program_name ):
513+ def run_task_program (self , config : Config , task_program_name : str ):
505514 """
506515 1) sandbox is also responsible for running task program
507516 2) self.task_program_library references to a singleton so a different sandbox still has reference to the task program
508517 """
518+ # update the config
519+ self .config = config
520+
509521 task_program = self .task_program_library [task_program_name ]
510522 # there might be better way to set this
511523 self .query = task_program_name
512524 qa_message = create_message (self .query , self )
513525 qa_message ["code" ] = task_program ["source_code" ]
514526 self .messages .append (qa_message )
527+
528+ # code execution will use the latest config, if updated
515529 self .code_execution (qa_message )
530+
516531 qa_message = self .render_qa_message (qa_message )
517- self .result_cache [task_program_name ] = qa_message
532+ self .result_cache [task_program_name ][ config [ 'video_info' ][ 'video_file_path' ]] = qa_message
518533 return qa_message
519534
520535 def step (self , user_query , number_of_debugs = 1 ):
@@ -539,7 +554,7 @@ def step(self, user_query, number_of_debugs=1):
539554 qa_message = self .code_execution (qa_message )
540555
541556 qa_message = self .render_qa_message (qa_message )
542- self .result_cache [user_query ] = qa_message
557+ self .result_cache [user_query ][ self . config [ 'video_info' ][ 'video_file_path' ]] = qa_message
543558 return qa_message
544559
545560
0 commit comments