Skip to content

Commit 0dd509a

Browse files
authored
Merge pull request #29 from AdaptiveMotorControlLab/shaokai/fix_issues
Shaokai/fix issues
2 parents 17de779 + ac9bbbf commit 0dd509a

File tree

4 files changed

+59
-25
lines changed

4 files changed

+59
-25
lines changed

Makefile

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
11
export streamlit_app=True
22
app:
33

4-
streamlit run amadeusgpt/app.py --server.fileWatcherType none
4+
streamlit run amadeusgpt/app.py --server.fileWatcherType none --server.maxUploadSize 1000

amadeusgpt/app_utils.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -581,7 +581,6 @@ def update_df_data(new_item, index_to_update, df, csv_file):
581581
return csv_file, df
582582

583583

584-
@st.cache_data(persist="disk")
585584
def get_scene_image(example):
586585
if AnimalBehaviorAnalysis.get_video_file_path() is not None:
587586
scene_image = Scene.get_scene_frame()
@@ -599,7 +598,6 @@ def get_scene_image(example):
599598
return get_scene_image(example)
600599

601600

602-
@st.cache_data(persist="disk")
603601
def get_sam_image(example):
604602
if AnimalBehaviorAnalysis.get_video_file_path():
605603
seg_objects = AnimalBehaviorAnalysis.get_seg_objects()

amadeusgpt/implementation.py

Lines changed: 30 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -60,6 +60,7 @@
6060
scene_frame_number = 0
6161

6262

63+
6364
class Database:
6465
"""
6566
A singleton that stores all data. Should be easy to integrate with a Nonsql database
@@ -1218,9 +1219,19 @@ def get_objects(self, video_file_path):
12181219
else:
12191220
return self.pickledata
12201221

1222+
1223+
12211224

12221225
class AnimalBehaviorAnalysis:
1226+
"""
1227+
This class holds methods and objects that are useful for analyzing animal behavior.
1228+
It no longer holds the states of objects directly. Instead, it references to the Database
1229+
singleton object. This is to make the class more stateless and easier to use in a web app.
1230+
"""
1231+
1232+
# to be deprecated
12231233
task_programs = {}
1234+
# to be deprecated
12241235
task_program_results = {}
12251236
# if a function has a parameter, it assumes the result_buffer has it
12261237
# special dataset flags set to be False
@@ -1251,7 +1262,7 @@ def result_buffer(cls):
12511262
@classmethod
12521263
def release_cache_objects(cls):
12531264
"""
1254-
For web app, switching from one example to the other requires a release of cached objects
1265+
For web app, switching from one example to the another requires a release of cached objects
12551266
"""
12561267
if Database.exist(cls.__name__, "animal_objects"):
12571268
Database.delete(cls.__name__, "animal_objects")
@@ -1914,21 +1925,32 @@ def reject_outlier_keypoints(cls, keypoints, threshold_in_stds=2):
19141925
return temp
19151926

19161927
@classmethod
1917-
def ast_fillna_2d(cls, arr):
1928+
def ast_fillna_2d(cls, arr: np.ndarray) -> np.ndarray:
1929+
"""
1930+
Fills NaN values in a 4D keypoints array using linear interpolation.
1931+
1932+
Parameters:
1933+
arr (np.ndarray): A 4D numpy array of shape (n_frames, n_individuals, n_kpts, n_dims).
1934+
1935+
Returns:
1936+
np.ndarray: The 4D array with NaN values filled.
1937+
"""
19181938
n_frames, n_individuals, n_kpts, n_dims = arr.shape
19191939
arr_reshaped = arr.reshape(n_frames, -1)
19201940
x = np.arange(n_frames)
19211941
for i in range(arr_reshaped.shape[1]):
19221942
valid_mask = ~np.isnan(arr_reshaped[:, i])
19231943
if np.all(valid_mask):
19241944
continue
1925-
arr_reshaped[:, i] = np.interp(
1926-
x, x[valid_mask], arr_reshaped[valid_mask, i]
1927-
)
1928-
# Reshape the array back to 4D
1929-
arr = arr_reshaped.reshape(n_frames, n_individuals, n_kpts, n_dims)
1945+
elif np.any(valid_mask):
1946+
# Perform interpolation when there are some valid points
1947+
arr_reshaped[:, i] = np.interp(x, x[valid_mask], arr_reshaped[valid_mask, i])
1948+
else:
1949+
# Handle the case where all values are NaN
1950+
# Replace with a default value or another suitable handling
1951+
arr_reshaped[:, i].fill(0) # Example: filling with 0
19301952

1931-
return arr
1953+
return arr_reshaped.reshape(n_frames, n_individuals, n_kpts, n_dims)
19321954

19331955
@classmethod
19341956
@timer_decorator

amadeusgpt/main.py

Lines changed: 28 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -100,9 +100,11 @@ class AMADEUS:
100100
code_generator_brain.enforce_prompt = ""
101101
usage = 0
102102
behavior_modules_in_context = True
103-
# to save the behavior module strings for context window use
103+
# load the integration modules to context
104104
smart_loading = True
105+
# number of topk integration modules to load
105106
load_module_top_k = 3
107+
module_threshold = 0.7
106108
context_window_dict = {}
107109
plot = False
108110
use_rephraser = True
@@ -124,6 +126,7 @@ def release_cache_objects(cls):
124126

125127
@classmethod
126128
def load_module_smartly(cls, user_input):
129+
# TODO: need to improve the module matching by vector database
127130
sorted_query_results = match_module(user_input)
128131
if len(sorted_query_results) == 0:
129132
return None
@@ -134,7 +137,7 @@ def load_module_smartly(cls, user_input):
134137
query_module = query_result[0]
135138
query_score = query_result[1][0][0]
136139

137-
if query_score > 0.7:
140+
if query_score > cls.module_threshold:
138141
modules.append(query_module)
139142
# parse the query result by loading active loading
140143
module_path = os.sep.join(query_module.split(os.sep)[-2:]).replace(
@@ -158,7 +161,7 @@ def magic_command(cls, user_input):
158161
AmadeusLogger.info(result.stdout.decode("utf-8"))
159162

160163
@classmethod
161-
def save_state(cls):
164+
def save_state(cls, output_path = 'soul.pickle'):
162165
# save the class attributes of all classes that are under state_list.
163166
def get_class_variables(_class):
164167
return {
@@ -171,15 +174,14 @@ def get_class_variables(_class):
171174

172175
state = {k.__name__: get_class_variables(k) for k in cls.state_list}
173176

174-
output_filename = "soul.pickle"
175-
with open(output_filename, "wb") as f:
177+
with open(output_path, "wb") as f:
176178
pickle.dump(state, f)
177-
AmadeusLogger.info(f"memory saved to {output_filename}")
179+
AmadeusLogger.info(f"memory saved to {output_path}")
178180

179181
@classmethod
180-
def load_state(cls):
182+
def load_state(cls, ckpt_path = 'soul.pickle'):
181183
# load the class variables into 3 class
182-
memory_filename = "soul.pickle"
184+
memory_filename = ckpt_path
183185
AmadeusLogger.info(f"loading memory from {memory_filename}")
184186
with open(memory_filename, "rb") as f:
185187
state = pickle.load(f)
@@ -296,6 +298,7 @@ def chat(
296298
cls.interface_str, cls.behavior_modules_str
297299
)
298300
cls.code_generator_brain.update_history("user", rephrased_user_msg)
301+
299302
response = cls.code_generator_brain.connect_gpt(
300303
cls.code_generator_brain.context_window, max_tokens=700, functions=functions
301304
)
@@ -307,10 +310,12 @@ def chat(
307310
thought_process,
308311
) = cls.code_generator_brain.parse_openai_response(response)
309312

313+
# write down the task program for offline processing
310314
with open("temp_for_debug.json", "w") as f:
311315
out = {'function_code': function_code,
312316
'query': rephrased_user_msg}
313317
json.dump(out, f, indent=4)
318+
314319
# handle_function_codes gives the answer with function outputs
315320
amadeus_answer = cls.core_loop(
316321
rephrased_user_msg, text, function_code, thought_process
@@ -321,16 +326,19 @@ def chat(
321326
original_user_msg, amadeus_answer.function_code, code_output
322327
)
323328

324-
# if there is an error or the function code is empty, we want to make sure we prevent ChatGPT to learn to output nothing from few-shot learning
325-
# is this used anymore?
329+
330+
# Could be used for in context feedback learning. Costly
326331
if amadeus_answer.has_error:
327332
cls.code_generator_brain.context_window[-1][
328333
"content"
329334
] += "\n While executing the code above, there was error so it is not correct answer\n"
330335

331-
elif amadeus_answer.has_error:
332-
cls.code_generator_brain.context_window.pop()
333-
cls.code_generator_brain.history.pop()
336+
337+
# if there is an error or the function code is empty, we want to make sure we prevent ChatGPT to learn to output nothing from few-shot learning
338+
#elif amadeus_answer.has_error:
339+
# cls.code_generator_brain.context_window.pop()
340+
# cls.code_generator_brain.history.pop()
341+
334342
else:
335343
# needs to manage memory of Amadeus for context window management and state restore etc.
336344
# we have it remember user's original question instead of the rephrased one for better
@@ -351,10 +359,13 @@ def execute_python_function(
351359
exec(function_code, globals())
352360
if "task_program" not in globals():
353361
return None
362+
363+
# TODO: to serialize and support different function arguments
354364
func_sigs = inspect.signature(task_program)
355365
if not func_sigs.parameters:
356366
result = task_program()
357367
else:
368+
# TODO: We don't do this anymore. But in the future, Is passing result buffer from each function sustainable?
358369
result_buffer = AnimalBehaviorAnalysis.result_buffer
359370
AmadeusLogger.info(f"result_buffer: {result_buffer}")
360371
if isinstance(result_buffer, tuple):
@@ -368,9 +379,11 @@ def execute_python_function(
368379
@classmethod
369380
def contribute(cls, program_name):
370381
"""
382+
Deprecated
371383
Takes the program from the task program registry and write it into contribution folder
372384
TODO: split the task program into implementation and api
373-
"""
385+
"""
386+
374387
AmadeusLogger.info(f"contributing {program_name}")
375388
task_program = AnimalBehaviorAnalysis.task_programs[program_name]
376389
# removing add_symbol or add_task_program line
@@ -390,6 +403,7 @@ def update_behavior_modules_str(cls):
390403
Called during loading behavior modules from disk or when task program is updated
391404
"""
392405
modules_str = []
406+
# context_window_dict is where integration modules are stored in current AMADEUS class
393407
for name, task_program in cls.context_window_dict.items():
394408
modules_str.append(task_program)
395409
modules_str = modules_str[-cls.load_module_top_k :]

0 commit comments

Comments
 (0)