1+ import base64
2+ import io
13import json
24import os
35import re
46import time
57import traceback
6- from amadeusgpt . utils import AmadeusLogger
7- from . base import AnalysisObject
8+
9+ import cv2
810import openai
911from openai import OpenAI
10- import base64
11- import cv2
12- import io
12+
13+ from amadeusgpt .utils import AmadeusLogger
14+
15+ from .base import AnalysisObject
16+
1317
1418class LLM (AnalysisObject ):
1519 total_tokens = 0
@@ -26,7 +30,6 @@ def __init__(self, config):
2630 self .context_window = []
2731 # only for logging and long-term memory usage.
2832 self .history = []
29-
3033
3134 def encode_image (self , image_path ):
3235 with open (image_path , "rb" ) as image_file :
@@ -42,10 +45,10 @@ def connect_gpt(self, messages, **kwargs):
4245 # if openai version is less than 1
4346 return self .connect_gpt_oai_1 (messages , ** kwargs )
4447
45- def connect_gpt_oai_1 (self , messages , ** kwargs ):
48+ def connect_gpt_oai_1 (self , messages , ** kwargs ):
4649 """
4750 This is routed to openai > 1.0 interfaces
48- """
51+ """
4952
5053 if self .config .get ("use_streamlit" , False ):
5154 if "OPENAI_API_KEY" in os .environ :
@@ -82,10 +85,15 @@ def connect_gpt_oai_1(self, messages, **kwargs):
8285 }
8386 response = client .chat .completions .create (** json_data )
8487
85- LLM .total_tokens = LLM .total_tokens + response .usage .prompt_tokens + response .usage .completion_tokens
88+ LLM .total_tokens = (
89+ LLM .total_tokens
90+ + response .usage .prompt_tokens
91+ + response .usage .completion_tokens
92+ )
8693 LLM .total_cost += (
8794 LLM .prices [self .gpt_model ]["input" ] * response .usage .prompt_tokens
88- + LLM .prices [self .gpt_model ]["output" ] * response .usage .completion_tokens
95+ + LLM .prices [self .gpt_model ]["output" ]
96+ * response .usage .completion_tokens
8997 )
9098 print ("current total cost" , round (LLM .total_cost , 2 ), "$" )
9199 print ("current total tokens" , LLM .total_tokens )
@@ -110,7 +118,7 @@ def connect_gpt_oai_1(self, messages, **kwargs):
110118
111119 return response
112120
113- def update_history (self , role , content , encoded_image = None , replace = False ):
121+ def update_history (self , role , content , encoded_image = None , replace = False ):
114122 if role == "system" :
115123 if len (self .history ) > 0 :
116124 self .history [0 ]["content" ] = content
@@ -124,7 +132,7 @@ def update_history(self, role, content, encoded_image = None, replace=False):
124132 self .history .append ({"role" : role , "content" : content })
125133 num_AI_messages = (len (self .context_window ) - 1 ) // 2
126134 if num_AI_messages == self .keep_last_n_messages :
127- print ("doing active forgetting" )
135+ print ("doing active forgetting" )
128136 # we forget the oldest AI message and corresponding answer
129137 self .context_window .pop (1 )
130138 self .context_window .pop (1 )
@@ -134,23 +142,25 @@ def update_history(self, role, content, encoded_image = None, replace=False):
134142 self .history .append ({"role" : role , "content" : content })
135143 num_AI_messages = (len (self .context_window ) - 1 ) // 2
136144 if num_AI_messages == self .keep_last_n_messages :
137- print ("doing active forgetting" )
145+ print ("doing active forgetting" )
138146 # we forget the oldest AI message and corresponding answer
139147 self .context_window .pop (1 )
140148 self .context_window .pop (1 )
141149 self .context_window .append ({"role" : role , "content" : content })
142150 else :
143151 message = {
144- "role" : "user" , "content" : [
145- {"type" : "text" , "text" : content },
146- {"type" : "image_url" , "image_url" : {
147- "url" : f"data:image/png;base64,{ encoded_image } " }
148- }]
149- }
150- self .context_window .append (message )
151-
152-
153-
152+ "role" : "user" ,
153+ "content" : [
154+ {"type" : "text" , "text" : content },
155+ {
156+ "type" : "image_url" ,
157+ "image_url" : {
158+ "url" : f"data:image/png;base64,{ encoded_image } "
159+ },
160+ },
161+ ],
162+ }
163+ self .context_window .append (message )
154164
155165 def clean_context_window (self ):
156166 while len (self .context_window ) > 1 :
@@ -185,6 +195,7 @@ def parse_openai_response(cls, response):
185195class VisualLLM (LLM ):
186196 def __init__ (self , config ):
187197 super ().__init__ (config )
198+
188199 def speak (self , sandbox ):
189200 """
190201 Only to comment about one image
@@ -195,25 +206,29 @@ def speak(self, sandbox):
195206 """
196207
197208 from amadeusgpt .system_prompts .visual_llm import _get_system_prompt
209+
198210 self .system_prompt = _get_system_prompt ()
199211 analysis = sandbox .exec_namespace ["behavior_analysis" ]
200212 scene_image = analysis .visual_manager .get_scene_image ()
201- result , buffer = cv2 .imencode (' .jpeg' , scene_image )
213+ result , buffer = cv2 .imencode (" .jpeg" , scene_image )
202214 image_bytes = io .BytesIO (buffer )
203- base64_image = base64 .b64encode (image_bytes .getvalue ()).decode (' utf-8' )
215+ base64_image = base64 .b64encode (image_bytes .getvalue ()).decode (" utf-8" )
204216
205217 self .update_history ("system" , self .system_prompt )
206- self .update_history ("user" , "here is the image" , encoded_image = base64_image , replace = True )
207- response = self .connect_gpt (self .context_window , max_tokens = 2000 )
218+ self .update_history (
219+ "user" , "here is the image" , encoded_image = base64_image , replace = True
220+ )
221+ response = self .connect_gpt (self .context_window , max_tokens = 2000 )
208222 text = response .choices [0 ].message .content .strip ()
209- print (text )
223+ print (text )
210224 pattern = r"```json(.*?)```"
211225 if len (re .findall (pattern , text , re .DOTALL )) == 0 :
212226 raise ValueError ("can't parse the json string correctly" , text )
213227 else :
214228 json_string = re .findall (pattern , text , re .DOTALL )[0 ]
215229 json_obj = json .loads (json_string )
216- return json_obj
230+ return json_obj
231+
217232
218233class CodeGenerationLLM (LLM ):
219234 """
@@ -222,7 +237,6 @@ class CodeGenerationLLM(LLM):
222237
223238 def __init__ (self , config ):
224239 super ().__init__ (config )
225-
226240
227241 def speak (self , sandbox ):
228242 """
@@ -265,10 +279,8 @@ def update_system_prompt(self, sandbox):
265279 task_program_docs = sandbox .get_task_program_docs ()
266280 query_block = sandbox .get_query_block ()
267281
268- behavior_analysis = sandbox .exec_namespace [
269- "behavior_analysis"
270- ]
271-
282+ behavior_analysis = sandbox .exec_namespace ["behavior_analysis" ]
283+
272284 self .system_prompt = _get_system_prompt (
273285 query_block , core_api_docs , task_program_docs , behavior_analysis
274286 )
@@ -280,7 +292,6 @@ def update_system_prompt(self, sandbox):
280292class MutationLLM (LLM ):
281293 def __init__ (self , config ):
282294 super ().__init__ (config )
283-
284295
285296 def update_system_prompt (self , sandbox ):
286297 from amadeusgpt .system_prompts .mutation import _get_system_prompt
@@ -307,7 +318,6 @@ def speak(self, sandbox):
307318class BreedLLM (LLM ):
308319 def __init__ (self , config ):
309320 super ().__init__ (config )
310-
311321
312322 def update_system_prompt (self , sandbox ):
313323 from amadeusgpt .system_prompts .breed import _get_system_prompt
@@ -342,7 +352,6 @@ class DiagnosisLLM(LLM):
342352 """
343353 Resource management for testing and error handling
344354 """
345-
346355
347356 @classmethod
348357 def get_system_prompt (
@@ -410,8 +419,9 @@ def speak(self, sandbox):
410419
411420
412421if __name__ == "__main__" :
413- from amadeusgpt .config import Config
422+ from amadeusgpt .config import Config
414423 from amadeusgpt .main import create_amadeus
424+
415425 config = Config ("amadeusgpt/configs/EPM_template.yaml" )
416426
417427 amadeus = create_amadeus (config )
0 commit comments