Skip to content

Commit bf196b4

Browse files
committed
black and isort
1 parent 153382f commit bf196b4

File tree

16 files changed

+194
-143
lines changed

16 files changed

+194
-143
lines changed

amadeusgpt/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,10 +5,10 @@
55
Apache-2.0 license
66
"""
77

8-
from amadeusgpt.integration_modules import *
98
from matplotlib import pyplot as plt
109

1110
from amadeusgpt.implementation import AnimalBehaviorAnalysis
11+
from amadeusgpt.integration_modules import *
1212
from amadeusgpt.main import AMADEUS
1313
from amadeusgpt.version import VERSION, __version__
1414

amadeusgpt/analysis_objects/event.py

Lines changed: 7 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -172,11 +172,13 @@ def event_negate(cls, events: List[BaseEvent]) -> List[BaseEvent]:
172172
mask |= event.generate_mask()
173173
negate_mask = ~mask
174174

175-
176-
177175
negate_events = Event.mask2events(
178-
negate_mask, video_file_path, sender_animal_name,
179-
set(), set(), smooth_window_size = 1
176+
negate_mask,
177+
video_file_path,
178+
sender_animal_name,
179+
set(),
180+
set(),
181+
smooth_window_size=1,
180182
)
181183

182184
return negate_events
@@ -540,7 +542,7 @@ def fuse_subgraph_by_kvs(
540542
For example, if there are two conditions to be met in the masks we look for locations that have overlap as 2
541543
"""
542544
# retrieve all events that satisfy the conditions (k=v)
543-
events = graph.traverse_by_kvs(merge_kvs)
545+
events = graph.traverse_by_kvs(merge_kvs)
544546
if not allow_more_than_2_overlap:
545547
assert (
546548
Event.check_max_in_sum(events) <= number_of_overlap_for_fusion

amadeusgpt/analysis_objects/llm.py

Lines changed: 48 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -1,15 +1,19 @@
1+
import base64
2+
import io
13
import json
24
import os
35
import re
46
import time
57
import traceback
6-
from amadeusgpt.utils import AmadeusLogger
7-
from .base import AnalysisObject
8+
9+
import cv2
810
import openai
911
from openai import OpenAI
10-
import base64
11-
import cv2
12-
import io
12+
13+
from amadeusgpt.utils import AmadeusLogger
14+
15+
from .base import AnalysisObject
16+
1317

1418
class LLM(AnalysisObject):
1519
total_tokens = 0
@@ -26,7 +30,6 @@ def __init__(self, config):
2630
self.context_window = []
2731
# only for logging and long-term memory usage.
2832
self.history = []
29-
3033

3134
def encode_image(self, image_path):
3235
with open(image_path, "rb") as image_file:
@@ -42,10 +45,10 @@ def connect_gpt(self, messages, **kwargs):
4245
# if openai version is less than 1
4346
return self.connect_gpt_oai_1(messages, **kwargs)
4447

45-
def connect_gpt_oai_1(self, messages, **kwargs):
48+
def connect_gpt_oai_1(self, messages, **kwargs):
4649
"""
4750
This is routed to openai > 1.0 interfaces
48-
"""
51+
"""
4952

5053
if self.config.get("use_streamlit", False):
5154
if "OPENAI_API_KEY" in os.environ:
@@ -82,10 +85,15 @@ def connect_gpt_oai_1(self, messages, **kwargs):
8285
}
8386
response = client.chat.completions.create(**json_data)
8487

85-
LLM.total_tokens = LLM.total_tokens + response.usage.prompt_tokens + response.usage.completion_tokens
88+
LLM.total_tokens = (
89+
LLM.total_tokens
90+
+ response.usage.prompt_tokens
91+
+ response.usage.completion_tokens
92+
)
8693
LLM.total_cost += (
8794
LLM.prices[self.gpt_model]["input"] * response.usage.prompt_tokens
88-
+ LLM.prices[self.gpt_model]["output"] * response.usage.completion_tokens
95+
+ LLM.prices[self.gpt_model]["output"]
96+
* response.usage.completion_tokens
8997
)
9098
print("current total cost", round(LLM.total_cost, 2), "$")
9199
print("current total tokens", LLM.total_tokens)
@@ -110,7 +118,7 @@ def connect_gpt_oai_1(self, messages, **kwargs):
110118

111119
return response
112120

113-
def update_history(self, role, content, encoded_image = None, replace=False):
121+
def update_history(self, role, content, encoded_image=None, replace=False):
114122
if role == "system":
115123
if len(self.history) > 0:
116124
self.history[0]["content"] = content
@@ -124,7 +132,7 @@ def update_history(self, role, content, encoded_image = None, replace=False):
124132
self.history.append({"role": role, "content": content})
125133
num_AI_messages = (len(self.context_window) - 1) // 2
126134
if num_AI_messages == self.keep_last_n_messages:
127-
print ("doing active forgetting")
135+
print("doing active forgetting")
128136
# we forget the oldest AI message and corresponding answer
129137
self.context_window.pop(1)
130138
self.context_window.pop(1)
@@ -134,23 +142,25 @@ def update_history(self, role, content, encoded_image = None, replace=False):
134142
self.history.append({"role": role, "content": content})
135143
num_AI_messages = (len(self.context_window) - 1) // 2
136144
if num_AI_messages == self.keep_last_n_messages:
137-
print ("doing active forgetting")
145+
print("doing active forgetting")
138146
# we forget the oldest AI message and corresponding answer
139147
self.context_window.pop(1)
140148
self.context_window.pop(1)
141149
self.context_window.append({"role": role, "content": content})
142150
else:
143151
message = {
144-
"role": "user", "content": [
145-
{"type": "text", "text": content},
146-
{"type": "image_url", "image_url": {
147-
"url": f"data:image/png;base64,{encoded_image}"}
148-
}]
149-
}
150-
self.context_window.append(message)
151-
152-
153-
152+
"role": "user",
153+
"content": [
154+
{"type": "text", "text": content},
155+
{
156+
"type": "image_url",
157+
"image_url": {
158+
"url": f"data:image/png;base64,{encoded_image}"
159+
},
160+
},
161+
],
162+
}
163+
self.context_window.append(message)
154164

155165
def clean_context_window(self):
156166
while len(self.context_window) > 1:
@@ -185,6 +195,7 @@ def parse_openai_response(cls, response):
185195
class VisualLLM(LLM):
186196
def __init__(self, config):
187197
super().__init__(config)
198+
188199
def speak(self, sandbox):
189200
"""
190201
Only to comment about one image
@@ -195,25 +206,29 @@ def speak(self, sandbox):
195206
"""
196207

197208
from amadeusgpt.system_prompts.visual_llm import _get_system_prompt
209+
198210
self.system_prompt = _get_system_prompt()
199211
analysis = sandbox.exec_namespace["behavior_analysis"]
200212
scene_image = analysis.visual_manager.get_scene_image()
201-
result, buffer = cv2.imencode('.jpeg', scene_image)
213+
result, buffer = cv2.imencode(".jpeg", scene_image)
202214
image_bytes = io.BytesIO(buffer)
203-
base64_image = base64.b64encode(image_bytes.getvalue()).decode('utf-8')
215+
base64_image = base64.b64encode(image_bytes.getvalue()).decode("utf-8")
204216

205217
self.update_history("system", self.system_prompt)
206-
self.update_history("user", "here is the image", encoded_image = base64_image, replace = True)
207-
response = self.connect_gpt(self.context_window, max_tokens=2000)
218+
self.update_history(
219+
"user", "here is the image", encoded_image=base64_image, replace=True
220+
)
221+
response = self.connect_gpt(self.context_window, max_tokens=2000)
208222
text = response.choices[0].message.content.strip()
209-
print (text)
223+
print(text)
210224
pattern = r"```json(.*?)```"
211225
if len(re.findall(pattern, text, re.DOTALL)) == 0:
212226
raise ValueError("can't parse the json string correctly", text)
213227
else:
214228
json_string = re.findall(pattern, text, re.DOTALL)[0]
215229
json_obj = json.loads(json_string)
216-
return json_obj
230+
return json_obj
231+
217232

218233
class CodeGenerationLLM(LLM):
219234
"""
@@ -222,7 +237,6 @@ class CodeGenerationLLM(LLM):
222237

223238
def __init__(self, config):
224239
super().__init__(config)
225-
226240

227241
def speak(self, sandbox):
228242
"""
@@ -265,10 +279,8 @@ def update_system_prompt(self, sandbox):
265279
task_program_docs = sandbox.get_task_program_docs()
266280
query_block = sandbox.get_query_block()
267281

268-
behavior_analysis = sandbox.exec_namespace[
269-
"behavior_analysis"
270-
]
271-
282+
behavior_analysis = sandbox.exec_namespace["behavior_analysis"]
283+
272284
self.system_prompt = _get_system_prompt(
273285
query_block, core_api_docs, task_program_docs, behavior_analysis
274286
)
@@ -280,7 +292,6 @@ def update_system_prompt(self, sandbox):
280292
class MutationLLM(LLM):
281293
def __init__(self, config):
282294
super().__init__(config)
283-
284295

285296
def update_system_prompt(self, sandbox):
286297
from amadeusgpt.system_prompts.mutation import _get_system_prompt
@@ -307,7 +318,6 @@ def speak(self, sandbox):
307318
class BreedLLM(LLM):
308319
def __init__(self, config):
309320
super().__init__(config)
310-
311321

312322
def update_system_prompt(self, sandbox):
313323
from amadeusgpt.system_prompts.breed import _get_system_prompt
@@ -342,7 +352,6 @@ class DiagnosisLLM(LLM):
342352
"""
343353
Resource management for testing and error handling
344354
"""
345-
346355

347356
@classmethod
348357
def get_system_prompt(
@@ -410,8 +419,9 @@ def speak(self, sandbox):
410419

411420

412421
if __name__ == "__main__":
413-
from amadeusgpt.config import Config
422+
from amadeusgpt.config import Config
414423
from amadeusgpt.main import create_amadeus
424+
415425
config = Config("amadeusgpt/configs/EPM_template.yaml")
416426

417427
amadeus = create_amadeus(config)

amadeusgpt/analysis_objects/object.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -298,7 +298,7 @@ def set_body_orientation_keypoints(
298298

299299
def set_head_orientation_keypoints(
300300
self, head_orientation_keypoints: Dict[str, Any]
301-
):
301+
):
302302
self.nose_name = head_orientation_keypoints["nose"]
303303
self.neck_name = head_orientation_keypoints["neck"]
304304
self.support_head_orientation = True

amadeusgpt/analysis_objects/relationship.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -71,7 +71,7 @@ def calc_angle_in_egocentric_animal(mouse_cs_inv, p):
7171
p_in_mouse[:, 1], p_in_mouse[:, 0]
7272
) # relative angle between the object and the mouse body axis
7373
theta = np.rad2deg(theta % (2 * np.pi))
74-
74+
7575
return theta
7676

7777

@@ -253,9 +253,9 @@ def _animal_animal_relationship(
253253
sender_pos = sender_animal.get_center()
254254
receiver_pos = receiver_animal.get_center()
255255
direction_vector = receiver_pos - sender_pos
256-
sender_velocity = np.nanmean(sender_animal.get_velocity(), axis = 1)
256+
sender_velocity = np.nanmean(sender_animal.get_velocity(), axis=1)
257257
norm_direction_vector = direction_vector / np.linalg.norm(direction_vector)
258-
relative_speed = np.einsum('ij,ij->i', sender_velocity, norm_direction_vector)
258+
relative_speed = np.einsum("ij,ij->i", sender_velocity, norm_direction_vector)
259259

260260
closest_distance = np.nanmin(
261261
get_pairwise_distance(sender_animal.keypoints, receiver_animal.keypoints),
@@ -274,6 +274,6 @@ def _animal_animal_relationship(
274274
if angles is not None:
275275
ret["relative_angle"] = angles
276276
if orientation is not None:
277-
ret["orientation"] = orientation
277+
ret["orientation"] = orientation
278278

279279
return ret

amadeusgpt/app.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,6 @@ def main():
2323
st.session_state["exist_valid_openai_api_key"] = True
2424
else:
2525
st.session_state["exist_valid_openai_api_key"] = False
26-
2726

2827
example_to_page = {}
2928

amadeusgpt/app_utils.py

Lines changed: 14 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -112,7 +112,7 @@ def __init__(self, amadeus_answer=None, json_entry=None):
112112
amadeus_answer = amadeus_answer.to_dict()
113113
self.data.update(amadeus_answer)
114114

115-
def render(self, debug = False):
115+
def render(self, debug=False):
116116
"""
117117
We use the getter for better encapsulation
118118
overall structure of what to be rendered
@@ -181,9 +181,9 @@ def render(self, debug = False):
181181
# Remind users we are fixing the error by self debugging
182182
st.markdown(f"Let me try to fix the error by self-debugging\n ")
183183
if not debug:
184-
sandbox.llms["self_debug"].speak(sandbox)
184+
sandbox.llms["self_debug"].speak(sandbox)
185185
qa_message = sandbox.code_execution(qa_message)
186-
self.render(debug = True)
186+
self.render(debug=True)
187187
# do not need to execute the block one more time
188188
if not self.rendered:
189189
self.rendered = True
@@ -558,7 +558,9 @@ def render_page_by_example(example):
558558
if uploaded_keypoint_file is not None:
559559
path = save_uploaded_file(uploaded_keypoint_file, save_dir)
560560
st.session_state["uploaded_keypoint_file"] = path
561-
config["keypoint_info"]["keypoint_file_path"] = st.session_state["uploaded_keypoint_file"]
561+
config["keypoint_info"]["keypoint_file_path"] = st.session_state[
562+
"uploaded_keypoint_file"
563+
]
562564

563565
if "uploaded_video_file" not in st.session_state:
564566
uploaded_video_file = st.file_uploader(
@@ -569,7 +571,9 @@ def render_page_by_example(example):
569571
if uploaded_video_file is not None:
570572
path = save_uploaded_file(uploaded_video_file, save_dir)
571573
st.session_state["uploaded_video_file"] = uploaded_video_file
572-
config["video_info"]["video_file_path"] = st.session_state["uploaded_video_file"]
574+
config["video_info"]["video_file_path"] = st.session_state[
575+
"uploaded_video_file"
576+
]
573577

574578
###### USER INPUT PANEL ######
575579
# get user input once getting the uploaded files
@@ -675,12 +679,15 @@ def render_page_by_example(example):
675679
st.caption("Raw video from Horse-30")
676680
else:
677681
st.caption("DeepLabCut-SuperAnimal tracked video")
678-
if config["video_info"]["video_file_path"] and config["video_info"]["video_file_path"] is not None:
682+
if (
683+
config["video_info"]["video_file_path"]
684+
and config["video_info"]["video_file_path"] is not None
685+
):
679686
st.video(config["video_info"]["video_file_path"])
680687

681688
if "uploaded_video_file" in st.session_state:
682689
st.video(st.session_state["uploaded_video_file"])
683-
690+
684691
# we only show objects for MausHaus for demo
685692
# if sam_image is not None:
686693
# st.caption("SAM segmentation results")

amadeusgpt/main.py

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -56,7 +56,7 @@ def __init__(self, config: Dict[str, Any]):
5656

5757
# can only do this after the register process
5858
self.sandbox.configure_using_vlm()
59-
59+
6060
def match_integration_module(self, user_query: str):
6161
"""
6262
Return a list of matched integration modules
@@ -87,17 +87,18 @@ def step(self, user_query):
8787

8888
def get_analysis(self):
8989
sandbox = self.sandbox
90-
analysis = sandbox.exec_namespace['behavior_analysis']
90+
analysis = sandbox.exec_namespace["behavior_analysis"]
9191
return analysis
92-
92+
9393
def run_task_program(self, task_program_name: str):
9494
return self.sandbox.run_task_program(task_program_name)
9595

9696

9797
if __name__ == "__main__":
98-
from amadeusgpt.config import Config
99-
from amadeusgpt.main import create_amadeus
10098
from amadeusgpt.analysis_objects.llm import VisualLLM
99+
from amadeusgpt.config import Config
100+
from amadeusgpt.main import create_amadeus
101+
101102
config = Config("amadeusgpt/configs/EPM_template.yaml")
102103

103104
amadeus = create_amadeus(config)

0 commit comments

Comments
 (0)