|
8 | 8 | import openai |
9 | 9 | from openai import OpenAI |
10 | 10 | import base64 |
| 11 | +import cv2 |
| 12 | +import io |
11 | 13 |
|
12 | 14 | class LLM(AnalysisObject): |
13 | 15 | total_tokens = 0 |
@@ -76,11 +78,8 @@ def connect_gpt_oai_1(self, messages, **kwargs): |
76 | 78 | "model": self.gpt_model, |
77 | 79 | "messages": messages, |
78 | 80 | "max_tokens": self.max_tokens, |
79 | | - "stop": None, |
80 | | - "top_p": 1, |
81 | 81 | "temperature": 0.0, |
82 | 82 | } |
83 | | - |
84 | 83 | response = client.chat.completions.create(**json_data) |
85 | 84 |
|
86 | 85 | LLM.total_tokens = LLM.total_tokens + response.usage.prompt_tokens + response.usage.completion_tokens |
@@ -121,36 +120,32 @@ def update_history(self, role, content, encoded_image = None, replace=False): |
121 | 120 | self.context_window.append({"role": role, "content": content}) |
122 | 121 | else: |
123 | 122 |
|
| 123 | + if encoded_image is None: |
| 124 | + self.history.append({"role": role, "content": content}) |
| 125 | + num_AI_messages = (len(self.context_window) - 1) // 2 |
| 126 | + if num_AI_messages == self.keep_last_n_messages: |
| 127 | + print ("doing active forgetting") |
| 128 | + # we forget the oldest AI message and corresponding answer |
| 129 | + self.context_window.pop(1) |
| 130 | + self.context_window.pop(1) |
| 131 | + new_message = {"role": role, "content": content} |
| 132 | + else: |
| 133 | + new_message = {"role": "user", "content": [ |
| 134 | + {"type": "text", "text": ""}, |
| 135 | + {"type": "image_url", "image_url": { |
| 136 | + "url": f"data:image/jpeg;base64,{encoded_image}"} |
| 137 | + } |
| 138 | + ]} |
| 139 | + |
| 140 | + self.history.append(new_message) |
| 141 | + |
124 | 142 | if replace == True: |
125 | | - if len(self.history) == 2: |
126 | | - self.history[1]["content"] = content |
127 | | - self.context_window[1]["content"] = content |
| 143 | + if len(self.context_window) == 2: |
| 144 | + self.context_window[1] = new_message |
128 | 145 | else: |
129 | | - self.history.append({"role": role, "content": content}) |
130 | | - self.context_window.append({"role": role, "content": content}) |
| 146 | + self.context_window.append(new_message) |
| 147 | + |
131 | 148 |
|
132 | | - else: |
133 | | - if encoded_image is None: |
134 | | - self.history.append({"role": role, "content": content}) |
135 | | - num_AI_messages = (len(self.context_window) - 1) // 2 |
136 | | - if num_AI_messages == self.keep_last_n_messages: |
137 | | - print ("doing active forgetting") |
138 | | - # we forget the oldest AI message and corresponding answer |
139 | | - self.context_window.pop(1) |
140 | | - self.context_window.pop(1) |
141 | | - self.context_window.append({"role": role, "content": content}) |
142 | | - else: |
143 | | - message = { |
144 | | - "role": "user", "content": [ |
145 | | - {"type": "text", "text": content}, |
146 | | - {"type": "image_url", "image_url": { |
147 | | - "url": f"data:image/png;base64,{encoded_image}"} |
148 | | - }] |
149 | | - } |
150 | | - self.context_window.append(message) |
151 | | - |
152 | | - |
153 | | - |
154 | 149 |
|
155 | 150 | def clean_context_window(self): |
156 | 151 | while len(self.context_window) > 1: |
@@ -194,13 +189,26 @@ def speak(self, sandbox): |
194 | 189 |
|
195 | 190 | """ |
196 | 191 |
|
197 | | - from amadeusgpt.system_prompts.visual import _get_system_prompt |
| 192 | + from amadeusgpt.system_prompts.visual_llm import _get_system_prompt |
198 | 193 | self.system_prompt = _get_system_prompt() |
199 | 194 | analysis = sandbox.exec_namespace["behavior_analysis"] |
200 | 195 | scene_image = analysis.visual_manager.get_scene_image() |
201 | | - encoded_image = self.encode_image(scene_image) |
202 | | - self.update_history("user", encoded_image) |
203 | 196 |
|
| 197 | + result, buffer = cv2.imencode('.jpeg', scene_image) |
| 198 | + image_bytes = io.BytesIO(buffer) |
| 199 | + base64_image = base64.b64encode(image_bytes.getvalue()).decode('utf-8') |
| 200 | + self.update_history("system", self.system_prompt) |
| 201 | + self.update_history("user", "here is the image", encoded_image = base64_image, replace = True) |
| 202 | + response = self.connect_gpt(self.context_window, max_tokens=2000) |
| 203 | + text = response.choices[0].message.content.strip() |
| 204 | + print (text) |
| 205 | + pattern = r"```json(.*?)```" |
| 206 | + if len(re.findall(pattern, text, re.DOTALL)) == 0: |
| 207 | + raise ValueError("can't parse the json string correctly", text) |
| 208 | + else: |
| 209 | + json_string = re.findall(pattern, text, re.DOTALL)[0] |
| 210 | + json_obj = json.loads(json_string) |
| 211 | + return json_obj |
204 | 212 |
|
205 | 213 | class CodeGenerationLLM(LLM): |
206 | 214 | """ |
@@ -394,3 +402,14 @@ def speak(self, sandbox): |
394 | 402 | function_code = re.findall(pattern, text, re.DOTALL)[0] |
395 | 403 | qa_message["code"] = function_code |
396 | 404 | qa_message["chain_of_thought"] = thought_process |
| 405 | + |
| 406 | + |
| 407 | +if __name__ == "__main__": |
| 408 | + from amadeusgpt.config import Config |
| 409 | + from amadeusgpt.main import create_amadeus |
| 410 | + config = Config("amadeusgpt/configs/EPM_template.yaml") |
| 411 | + |
| 412 | + amadeus = create_amadeus(config) |
| 413 | + sandbox = amadeus.sandbox |
| 414 | + visualLLm = VisualLLM(config) |
| 415 | + visualLLm.speak(sandbox) |
0 commit comments