Skip to content

Commit abbca4f

Browse files
committed
llm generated function automatically registered to task program
1 parent 5396098 commit abbca4f

File tree

3 files changed

+362
-130
lines changed

3 files changed

+362
-130
lines changed

amadeusgpt/main.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,8 @@
2020
from collections import defaultdict
2121
import pickle
2222

23+
from amadeusgpt.programs.task_program_registry import TaskProgramLibrary
24+
2325
class AMADEUS:
2426
def __init__(self, config: Dict[str, Any]):
2527
self.config = config
@@ -128,6 +130,9 @@ def load_results(self, result_folder: str | None = None ):
128130

129131
def get_results(self):
130132
return self.sandbox.result_cache
133+
134+
def get_task_programs(self):
135+
return TaskProgramLibrary.get_task_programs()
131136

132137
if __name__ == "__main__":
133138
from amadeusgpt.analysis_objects.llm import VisualLLM

amadeusgpt/programs/sandbox.py

Lines changed: 27 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -506,8 +506,15 @@ def llm_step(self, user_query):
506506
self.messages.append(qa_message)
507507
# there might be better way to set this
508508
self.query = user_query
509-
self.llms["code_generator"].speak(self)
509+
self.llms["code_generator"].speak(self)
510+
# cache the resulted qa message for future use
510511
self.result_cache[user_query][self.config['video_info']['video_file_path']] = qa_message
512+
513+
# task program that is written by llm is automatically registered to be used in the future
514+
515+
if qa_message['code'] is not None:
516+
TaskProgramLibrary.register_task_program(creator="llm")(qa_message['code'])
517+
511518
return qa_message
512519

513520
def run_task_program(self, config: Config, task_program_name: str):
@@ -532,30 +539,30 @@ def run_task_program(self, config: Config, task_program_name: str):
532539
self.result_cache[task_program_name][config['video_info']['video_file_path']] = qa_message
533540
return qa_message
534541

535-
def step(self, user_query, number_of_debugs=1):
536-
"""
537-
Currently not used. We tried to seperate LLM inference and code execution
538-
"""
539-
qa_message = create_message(user_query, self)
542+
# def step(self, user_query, number_of_debugs=1):
543+
# """
544+
# Currently not used. We tried to seperate LLM inference and code execution
545+
# """
546+
# qa_message = create_message(user_query, self)
540547

541-
if self.meta_info is not None:
542-
qa_message["meta_info"] = self.meta_info
548+
# if self.meta_info is not None:
549+
# qa_message["meta_info"] = self.meta_info
543550

544-
self.messages.append(qa_message)
551+
# self.messages.append(qa_message)
545552

546-
self.query = user_query
547-
self.llms["code_generator"].speak(self)
548-
# all these llms collectively compose a amadeus_answer
549-
qa_message = self.code_execution(qa_message)
553+
# self.query = user_query
554+
# self.llms["code_generator"].speak(self)
555+
# # all these llms collectively compose a amadeus_answer
556+
# qa_message = self.code_execution(qa_message)
550557

551-
if qa_message["error_message"] is not None:
552-
for i in range(number_of_debugs):
553-
self.llms["self_debug"].speak(self)
554-
qa_message = self.code_execution(qa_message)
558+
# if qa_message["error_message"] is not None:
559+
# for i in range(number_of_debugs):
560+
# self.llms["self_debug"].speak(self)
561+
# qa_message = self.code_execution(qa_message)
555562

556-
qa_message = self.render_qa_message(qa_message)
557-
self.result_cache[user_query][self.config['video_info']['video_file_path']] = qa_message
558-
return qa_message
563+
# qa_message = self.render_qa_message(qa_message)
564+
# self.result_cache[user_query][self.config['video_info']['video_file_path']] = qa_message
565+
# return qa_message
559566

560567

561568
def save_figure_to_tempfile(fig):

notebooks/MABe_demo.ipynb

Lines changed: 330 additions & 110 deletions
Large diffs are not rendered by default.

0 commit comments

Comments
 (0)