33import re
44import time
55import traceback
6- from amadeusgpt .utils import AmadeusLogger , search_generated_func
6+ from amadeusgpt .utils import AmadeusLogger
77from .base import AnalysisObject
88import openai
99from openai import OpenAI
10+ import base64
1011
1112class LLM (AnalysisObject ):
1213 total_tokens = 0
@@ -23,13 +24,11 @@ def __init__(self, config):
2324 self .context_window = []
2425 # only for logging and long-term memory usage.
2526 self .history = []
27+
2628
27- def whetehr_speak (self ):
28- """
29- Handcrafted rules to decide whether to speak
30- 1) If there is a error in the current chat channel
31- """
32- return False
29+ def encode_image (self , image_path ):
30+ with open (image_path , "rb" ) as image_file :
31+ return base64 .b64encode (image_file .read ()).decode ("utf-8" )
3332
3433 def speak (self ):
3534 """
@@ -41,7 +40,10 @@ def connect_gpt(self, messages, **kwargs):
4140 # if openai version is less than 1
4241 return self .connect_gpt_oai_1 (messages , ** kwargs )
4342
44- def connect_gpt_oai_1 (self , messages , ** kwargs ):
43+ def connect_gpt_oai_1 (self , messages , ** kwargs ):
44+ """
45+ This is routed to openai > 1.0 interfaces
46+ """
4547
4648 if self .config .get ("use_streamlit" , False ):
4749 if "OPENAI_API_KEY" in os .environ :
@@ -68,10 +70,6 @@ def connect_gpt_oai_1(self, messages, **kwargs):
6870 # the usage was recorded from the last run. However, since we have many LLMs that
6971 # share the call of this function, we will need to store usage and retrieve them from the database class
7072 num_retries = 3
71- print ('number of messages to send' , len (messages ))
72- print ('print the message' )
73- for message in messages :
74- print (message )
7573 for _ in range (num_retries ):
7674 try :
7775 json_data = {
@@ -113,7 +111,7 @@ def connect_gpt_oai_1(self, messages, **kwargs):
113111
114112 return response
115113
116- def update_history (self , role , content , replace = False ):
114+ def update_history (self , role , content , encoded_image = None , replace = False ):
117115 if role == "system" :
118116 if len (self .history ) > 0 :
119117 self .history [0 ]["content" ] = content
@@ -132,17 +130,27 @@ def update_history(self, role, content, replace=False):
132130 self .context_window .append ({"role" : role , "content" : content })
133131
134132 else :
135-
136- self .history .append ({"role" : role , "content" : content })
137-
138- num_AI_messages = (len (self .context_window ) - 1 ) // 2
139- if num_AI_messages == self .keep_last_n_messages :
140- print ("doing active forgetting" )
141- # we forget the oldest AI message and corresponding answer
142- self .context_window .pop (1 )
143- self .context_window .pop (1 )
144-
145- self .context_window .append ({"role" : role , "content" : content })
133+ if encoded_image is None :
134+ self .history .append ({"role" : role , "content" : content })
135+ num_AI_messages = (len (self .context_window ) - 1 ) // 2
136+ if num_AI_messages == self .keep_last_n_messages :
137+ print ("doing active forgetting" )
138+ # we forget the oldest AI message and corresponding answer
139+ self .context_window .pop (1 )
140+ self .context_window .pop (1 )
141+ self .context_window .append ({"role" : role , "content" : content })
142+ else :
143+ message = {
144+ "role" : "user" , "content" : [
145+ {"type" : "text" , "text" : content },
146+ {"type" : "image_url" , "image_url" : {
147+ "url" : f"data:image/png;base64,{ encoded_image } " }
148+ }]
149+ }
150+ self .context_window .append (message )
151+
152+
153+
146154
147155 def clean_context_window (self ):
148156 while len (self .context_window ) > 1 :
@@ -174,19 +182,34 @@ def parse_openai_response(cls, response):
174182 return text , function_code , thought_process
175183
176184
185+ class VisualLLM (LLM ):
186+ def __init__ (self , config ):
187+ super ().__init__ (config )
188+ def speak (self , sandbox ):
189+ """
190+ Only to comment about one image
191+ #1) What animal is there, how many and what superanimal model we should use
192+ #2) report the background object list
193+ #3) We format them in json format
194+
195+ """
196+
197+ from amadeusgpt .system_prompts .visual import _get_system_prompt
198+ self .system_prompt = _get_system_prompt ()
199+ analysis = sandbox .exec_namespace ["behavior_analysis" ]
200+ scene_image = analysis .visual_manager .get_scene_image ()
201+ encoded_image = self .encode_image (scene_image )
202+ self .update_history ("user" , encoded_image )
203+
204+
177205class CodeGenerationLLM (LLM ):
178206 """
179207 Resource management for the behavior analysis part of the system
180208 """
181209
182210 def __init__ (self , config ):
183211 super ().__init__ (config )
184-
185- def whether_speak (self , sandbox ):
186- """
187- 1) if there is a error from last iteration, don't speak
188- """
189- return True
212+
190213
191214 def speak (self , sandbox ):
192215 """
@@ -244,17 +267,7 @@ def update_system_prompt(self, sandbox):
244267class MutationLLM (LLM ):
245268 def __init__ (self , config ):
246269 super ().__init__ (config )
247-
248- def whether_speak (self , chat_channel ):
249- """
250- 1) if there is a error from last iteration, don't speak
251- """
252-
253- error = chat_channel .get_last_message ().get ("error" , None )
254- if error is not None :
255- return False
256- else :
257- return True
270+
258271
259272 def update_system_prompt (self , sandbox ):
260273 from amadeusgpt .system_prompts .mutation import _get_system_prompt
@@ -281,17 +294,7 @@ def speak(self, sandbox):
281294class BreedLLM (LLM ):
282295 def __init__ (self , config ):
283296 super ().__init__ (config )
284-
285- def whether_speak (self , chat_channel ):
286- """
287- 1) if there is a error from last iteration, don't speak
288- """
289-
290- error = chat_channel .get_last_message ().get ("error" , None )
291- if error is not None :
292- return False
293- else :
294- return True
297+
295298
296299 def update_system_prompt (self , sandbox ):
297300 from amadeusgpt .system_prompts .breed import _get_system_prompt
@@ -326,18 +329,7 @@ class DiagnosisLLM(LLM):
326329 """
327330 Resource management for testing and error handling
328331 """
329-
330- def whether_speak (self , chat_channel ):
331- """
332- Handcrafted rules to decide whether to speak
333- 1) If there is a error in the current chat channel
334- """
335- if chat_channel .get_last_message () is None :
336- return False
337- else :
338- error = chat_channel .get_last_message ().get ("error" , None )
339-
340- return error is None
332+
341333
342334 @classmethod
343335 def get_system_prompt (
0 commit comments