Skip to content

Commit d6eeba6

Browse files
committed
some clean up and some docs
1 parent d5c1d70 commit d6eeba6

File tree

5 files changed

+31
-188
lines changed

5 files changed

+31
-188
lines changed

amadeusgpt/analysis_objects/analysis_factory.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,6 @@
22

33
analysis_fac = {}
44

5-
65
def create_analysis(config):
76
if str(config) not in analysis_fac:
87
analysis_fac[str(config)] = AnimalBehaviorAnalysis(config)

amadeusgpt/main.py

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -77,11 +77,17 @@ def step(self, user_query):
7777
return result
7878

7979
def get_analysis(self):
80-
sandbox = self.sandbox
81-
analysis = sandbox.exec_namespace["behavior_analysis"]
80+
"""
81+
Every sandbox stores a unique "behavior analysis" instance in its namespace
82+
Therefore, get analysis gets the current sandbox's analysis.
83+
"""
84+
analysis = self.sandbox.exec_namespace["behavior_analysis"]
8285
return analysis
8386

8487
def run_task_program(self, task_program_name: str):
88+
"""
89+
Execute the task program on the currently holding sandbox
90+
"""
8591
return self.sandbox.run_task_program(task_program_name)
8692

8793
def save_results(self, out_folder: str| None = None):

amadeusgpt/managers/visual_manager.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -678,14 +678,15 @@ def write_video(self, out_folder, video_file_path, out_name, events):
678678
return out_videos
679679

680680
def generate_video_clips_from_events(
681-
self, out_folder, video_file, events: List[BaseEvent], behavior_name
681+
self, out_folder, events: List[BaseEvent], behavior_name
682682
):
683683
"""
684684
This function takes a list of events and generates video clips from the events
685685
1) For the same events, we first group events based on the video
686686
2) For the same event on the same video, we plot the animal name and the "sender" of the event
687687
3) Then we write those videos to the disk
688688
"""
689+
video_file = self.config["video_info"]["video_file_path"]
689690

690691
videoname = video_file.split("/")[-1].replace(".mp4", "").replace(".avi", "")
691692
video_name = f"{videoname}_{behavior_name}_video.mp4"

amadeusgpt/programs/sandbox.py

Lines changed: 21 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -305,15 +305,6 @@ def update_config(self, config):
305305
def copy(self):
306306
return Sandbox(self.config, self.api_registry)
307307

308-
def visual_validate(self, video_file, events, behavior_name):
309-
# change video and keypoint file
310-
analysis = create_analysis(self.config)
311-
out_folder = os.path.join(self.config["evo_info"]["data_folder"], "inspection")
312-
discovered_behaviors = []
313-
for name, task_program in self.task_program_library.items():
314-
if task_program["creator"] != "human":
315-
discovered_behaviors.append(name)
316-
317308
def update_matched_integration_modules(self, matched_modules):
318309
self.matched_modules = matched_modules
319310

@@ -358,7 +349,8 @@ def update_namespace(self):
358349
self.exec_namespace["task_programs"] = TaskProgramLibrary.get_task_programs()
359350

360351
def code_execution(self, qa_message):
361-
# add main function into the namespace
352+
# update the namespace in the beginning of code execution makes sure that
353+
# if there is a change in the config, we always use the newest config
362354
self.update_namespace()
363355
code = qa_message["code"]
364356
# not need to do further if there was no code found
@@ -424,12 +416,16 @@ def events_to_videos(self, events, function_name):
424416
out_folder = str(self.result_folder)
425417
os.makedirs(out_folder, exist_ok=True)
426418
behavior_name = "_".join(function_name.split(" "))
427-
video_file = self.config["video_info"]["video_file_path"]
428419
return visual_manager.generate_video_clips_from_events(
429-
out_folder, video_file, events, behavior_name
420+
out_folder, events, behavior_name
430421
)
431422

432423
def render_qa_message(self, qa_message):
424+
"""
425+
To be called after code execution.
426+
If the function returns a list of events, we visualize those events to keypoint plot, ethogram plot and videos
427+
if the function returns is a tuple of axe and figure, we put them into the plots filed
428+
"""
433429
function_rets = qa_message["function_rets"]
434430
behavior_analysis = self.exec_namespace["behavior_analysis"]
435431
bodypart_names = behavior_analysis.animal_manager.get_keypoint_names()
@@ -487,6 +483,11 @@ def render_qa_message(self, qa_message):
487483
return qa_message
488484

489485
def llm_step(self, user_query):
486+
"""
487+
1) We first use gpt-4o to create meta_info describing the scene
488+
2) We then ask LLM to generate code based on the query
489+
3) We also cache the qa_message for future reference
490+
"""
490491
qa_message = create_message(user_query, self)
491492

492493
# so that the frontend can display it too
@@ -502,7 +503,8 @@ def llm_step(self, user_query):
502503

503504
def run_task_program(self, task_program_name):
504505
"""
505-
Sandbox is also responsible for running task program
506+
1) sandbox is also responsible for running task program
507+
2) self.task_program_library references to a singleton so a different sandbox still has reference to the task program
506508
"""
507509
task_program = self.task_program_library[task_program_name]
508510
# there might be better way to set this
@@ -542,6 +544,9 @@ def step(self, user_query, number_of_debugs=1):
542544

543545

544546
def save_figure_to_tempfile(fig):
547+
"""
548+
Only used for debug
549+
"""
545550
import tempfile
546551

547552
# save the figure
@@ -566,6 +571,9 @@ def save_figure_to_tempfile(fig):
566571

567572

568573
def render_temp_message(query, sandbox):
574+
"""
575+
Only used for debug
576+
"""
569577
import streamlit as st
570578

571579
qa_message = create_message("random query", sandbox)

amadeusgpt/utils.py

Lines changed: 0 additions & 171 deletions
Original file line numberDiff line numberDiff line change
@@ -1,21 +1,14 @@
11
import ast
2-
import copy
32
import inspect
4-
import re
53
import sys
64
import time
75
import traceback
86
from itertools import groupby
97
from operator import itemgetter
10-
from pydoc import doc
118
from typing import Any, Dict, Sequence
12-
139
import cv2
14-
import matplotlib.pyplot as plt
1510
import numpy as np
16-
from numpy import ndarray
1711
from scipy.ndimage.filters import uniform_filter1d
18-
1912
from amadeusgpt.logger import AmadeusLogger
2013

2114

@@ -63,23 +56,6 @@ def moving_average(x: Sequence, window_size: int, pos: str = "centered"):
6356
return uniform_filter1d(x, window_size, mode="constant", origin=origin)
6457

6558

66-
def moving_variance(x, window_size):
67-
"""
68-
Blazing fast implementation of a moving variance.
69-
:param x: ndarray, 1D input
70-
:param window_size: int, window length
71-
:return: 1D ndarray of length len(x)-window+1
72-
"""
73-
nrows = x.size - window_size + 1
74-
n = x.strides[0]
75-
mat = np.lib.stride_tricks.as_strided(
76-
x,
77-
shape=(nrows, window_size),
78-
strides=(n, n),
79-
)
80-
return np.var(mat, axis=1)
81-
82-
8359
def smooth_boolean_mask(x: Sequence, window_size: int):
8460
# `window_size` should be at least twice as large as the
8561
# minimal number of consecutive frames to be smoothed out.
@@ -109,136 +85,6 @@ def get_video_length(video_path):
10985
return int(n_frames)
11086

11187

112-
def frame_number_to_minute_seconds(frame_number, video_path):
113-
fps = get_fps(video_path)
114-
temp = frame_number / fps
115-
minutes = int(temp // 60)
116-
seconds = int(temp % 60)
117-
ret = f"{minutes:02d}:{seconds:02d}"
118-
return ret
119-
120-
121-
def search_generated_func(text):
122-
functions = []
123-
lines = text.split("\n")
124-
func_names = []
125-
i = 0
126-
127-
while i < len(lines):
128-
line = lines[i]
129-
func_signature = "def task_program"
130-
if line.startswith(func_signature):
131-
start = line.index("def ") + 4
132-
end = line.index("(")
133-
func_name = line[start:end]
134-
func_names.append(func_name)
135-
function_lines = [line]
136-
nesting_level = 0
137-
i += 1
138-
139-
while i < len(lines):
140-
line = lines[i]
141-
# Check for nested function definitions
142-
if line.lstrip().startswith("def "):
143-
nesting_level += 1
144-
elif line.lstrip().startswith("return") and nesting_level > 0:
145-
nesting_level -= 1
146-
elif line.lstrip().startswith("return") and nesting_level == 0:
147-
function_lines.append(line)
148-
break
149-
150-
function_lines.append(line)
151-
i += 1
152-
153-
functions.append("\n".join(function_lines))
154-
i += 1
155-
156-
return functions, func_names
157-
158-
159-
def search_external_module_for_context_window(text):
160-
"""
161-
just include everything
162-
"""
163-
functions = []
164-
i = 0
165-
lines = text.split("\n")
166-
func_names = []
167-
while i < len(lines):
168-
line = lines[i].strip()
169-
func_signature = "def "
170-
if line.strip() == "":
171-
i += 1
172-
continue
173-
if line.startswith(func_signature):
174-
start = line.index("def ") + 4
175-
end = line.index("(")
176-
func_name = line[start:end]
177-
func_names.append(func_name)
178-
function_lines = [line]
179-
in_function = True
180-
while in_function:
181-
i += 1
182-
if i == len(lines):
183-
break
184-
next_line = lines[i].rstrip()
185-
if not next_line.startswith((" ", "\t")):
186-
in_function = False
187-
continue
188-
function_lines.append(next_line)
189-
functions.append("\n".join(function_lines))
190-
else:
191-
i += 1
192-
return functions, func_names
193-
194-
195-
def search_external_module_for_task_program_table(text):
196-
"""
197-
in this case, just include everything
198-
"""
199-
functions = []
200-
i = 0
201-
lines = text.split("\n")
202-
lines_copy = copy.deepcopy(lines)
203-
func_names = []
204-
example_indentation = " " * 4
205-
while i < len(lines):
206-
if lines[i].startswith("def "):
207-
start = lines[i].index("def ") + 4
208-
end = lines[i].index("(")
209-
func_name = lines[i][start:end]
210-
func_names.append(func_name)
211-
if "" not in lines[i]:
212-
i += 1
213-
continue
214-
else:
215-
lines[i] = lines[i].replace("", "").strip()
216-
line = lines[i].strip()
217-
func_signature = "def "
218-
if line.strip() == "":
219-
i += 1
220-
continue
221-
if line.startswith("def "):
222-
function_lines = [line]
223-
in_function = True
224-
while in_function:
225-
i += 1
226-
if i == len(lines):
227-
break
228-
next_line = lines[i].rstrip()
229-
if "" not in lines_copy[i]:
230-
in_function = False
231-
continue
232-
function_lines.append(
233-
next_line.replace("", "").replace(example_indentation, "", 1)
234-
)
235-
functions.append("\n".join(function_lines))
236-
else:
237-
i += 1
238-
239-
return functions, func_names
240-
241-
24288
def filter_kwargs_for_function(func, kwargs):
24389
sig = inspect.signature(func)
24490
return {k: v for k, v in kwargs.items() if k in sig.parameters}
@@ -385,20 +231,3 @@ def func2json(func):
385231
}
386232
return json_obj
387233

388-
389-
def get_func_name_from_func_string(function_string: str):
390-
import ast
391-
392-
# Parse the string into an AST
393-
parsed_ast = ast.parse(function_string)
394-
395-
# Initialize a variable to hold the function name
396-
function_name = None
397-
398-
# Traverse the AST
399-
for node in ast.walk(parsed_ast):
400-
if isinstance(node, ast.FunctionDef):
401-
function_name = node.name
402-
break
403-
404-
return function_name

0 commit comments

Comments
 (0)