Skip to content

Commit aef5a90

Browse files
committed
3 different conda yaml and installation scripts
1 parent d160a91 commit aef5a90

File tree

11 files changed

+140
-80
lines changed

11 files changed

+140
-80
lines changed

amadeusgpt/analysis_objects/event.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -540,8 +540,7 @@ def fuse_subgraph_by_kvs(
540540
For example, if there are two conditions to be met in the masks we look for locations that have overlap as 2
541541
"""
542542
# retrieve all events that satisfy the conditions (k=v)
543-
events = graph.traverse_by_kvs(merge_kvs)
544-
543+
events = graph.traverse_by_kvs(merge_kvs)
545544
if not allow_more_than_2_overlap:
546545
assert (
547546
Event.check_max_in_sum(events) <= number_of_overlap_for_fusion

amadeusgpt/app_utils.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -112,7 +112,7 @@ def __init__(self, amadeus_answer=None, json_entry=None):
112112
amadeus_answer = amadeus_answer.to_dict()
113113
self.data.update(amadeus_answer)
114114

115-
def render(self):
115+
def render(self, debug = False):
116116
"""
117117
We use the getter for better encapsulation
118118
overall structure of what to be rendered
@@ -180,9 +180,10 @@ def render(self):
180180
st.markdown(f"Error: {qa_message['error_message']}\n ")
181181
# Remind users we are fixing the error by self debuging
182182
st.markdown(f"Let me try to fix the error by self-debugging\n ")
183-
for i in range(1):
184-
sandbox.llms["self_debug"].speak(sandbox)
183+
if not debug:
184+
sandbox.llms["self_debug"].speak(sandbox)
185185
qa_message = sandbox.code_execution(qa_message)
186+
self.render(debug = True)
186187
# do not need to execute the block one more time
187188
if not self.rendered:
188189
self.rendered = True

amadeusgpt/managers/animal_manager.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -212,6 +212,16 @@ def get_keypoints(self) -> ndarray:
212212
"""
213213
Get the keypoints of animals. The shape is of shape n_frames, n_individuals, n_kpts, n_dims
214214
"""
215+
use_superanimal = False
216+
if use_superanimal:
217+
import deeplabcut
218+
from deeplabcut.modelzoo.video_inference import video_inference_superanimal
219+
superanimal_name = 'superanimal_topviewmouse_hrnetw32'
220+
video_inference_superanimal(videos = [self.config['video_info']['video_file_path']],
221+
superanimal_name = superanimal_name,
222+
video_adapt = False,
223+
dest_folder = 'temp_pose')
224+
215225
ret = np.stack([animal.get_keypoints() for animal in self.animals], axis=1)
216226
return ret
217227

amadeusgpt/managers/event_manager.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -440,10 +440,12 @@ def get_composite_events(
440440
graphs.append(animal_animal_subgraph)
441441
# we then fuse events from different task programs that involve animal-object interactions
442442
for object_name in self.object_manager.get_object_names():
443-
443+
# if we strictly require the object needs to match, then EPM head dipping example won't work.
444+
# so we cannot require objects to match. This causes some ambiguity.
444445
animal_object_subgraph = EventGraph.fuse_subgraph_by_kvs(
445446
graph,
446-
{"sender_animal_name": animal_name, "object_names": object_name},
447+
{"sender_animal_name": animal_name,},
448+
#"object_names": object_name},
447449
number_of_overlap_for_fusion=2,
448450
)
449451
graphs.append(animal_object_subgraph)

amadeusgpt/programs/sandbox.py

Lines changed: 78 additions & 71 deletions
Original file line numberDiff line numberDiff line change
@@ -389,15 +389,14 @@ def render_qa_message(self, qa_message):
389389
n_animals = behavior_analysis.animal_manager.get_n_individuals()
390390
bodypart_names = behavior_analysis.animal_manager.get_keypoint_names()
391391
visual_manager = behavior_analysis.visual_manager
392-
plots = []
393-
392+
plots = []
394393
if isinstance(function_rets, tuple):
395394
# could be plotting tuple
396395
if isinstance(function_rets[0], plt.Figure):
397396
# this is for "return fig, ax"
398397
plots.append(function_rets)
399398

400-
else:
399+
else:
401400
for e in function_rets:
402401
if isinstance(e, list) and len(e) > 0 and isinstance(e[0], BaseEvent):
403402
# here we need to understand what we do with the events
@@ -521,8 +520,8 @@ def render_temp_message(query, sandbox):
521520
print("after code execution")
522521
print(len(qa_message["function_rets"]))
523522
events = qa_message["function_rets"]
524-
for event in events:
525-
print (event)
523+
524+
sandbox.render_qa_message(qa_message)
526525

527526
if qa_message["function_rets"] is not None:
528527
st.markdown(qa_message["function_rets"])
@@ -543,72 +542,80 @@ def render_temp_message(query, sandbox):
543542
# testing qa message
544543
from amadeusgpt.analysis_objects.object import ROIObject
545544
from amadeusgpt.main import create_amadeus
546-
547-
config = Config("amadeusgpt/configs/mabe_template.yaml")
545+
import pickle
546+
config = Config("amadeusgpt/configs/EPM_template.yaml")
548547

549548
amadeus = create_amadeus(config)
550549
sandbox = amadeus.sandbox
551-
552-
def get_chases_events(config: Config):
553-
'''
554-
Parameters:
555-
----------
556-
config: Config
557-
'''
558-
# create_analysis returns an instance of AnimalBehaviorAnalysis
559-
analysis = create_analysis(config)
560-
561-
# Get events where the closest distance between animals is less than 40 pixels
562-
closest_distance_events = analysis.get_animals_animals_events(
563-
cross_animal_query_list=['closest_distance<40'],
564-
bodypart_names=None,
565-
otheranimal_bodypart_names=None,
566-
min_window=1,
567-
max_window=100000
568-
)
569-
570-
# Get events where the angle between animals is less than 30 degrees
571-
angle_events = analysis.get_animals_animals_events(
572-
cross_animal_query_list=['relative_angle<30'],
573-
bodypart_names=None,
574-
otheranimal_bodypart_names=None,
575-
min_window=1,
576-
max_window=100000
577-
)
578-
579-
print ('angle_events')
580-
for event in angle_events:
581-
print (event)
582-
return angle_events
583-
584-
# Get events where the animal's speed is greater than 0.2
585-
speed_events = analysis.get_animals_state_events(
586-
query='speed>0.2',
587-
bodypart_names=None,
588-
min_window=1,
589-
max_window=100000
590-
)
591-
592-
# Combine the closest distance and angle events using logical AND
593-
distance_angle_events = analysis.get_composite_events(
594-
events_A=closest_distance_events,
595-
events_B=angle_events,
596-
composition_type='logical_and',
597-
max_interval_between_sequential_events=0,
598-
min_window=1,
599-
max_window=100000
600-
)
601-
602-
# Combine the result with the speed events using logical AND
603-
chases_events = analysis.get_composite_events(
604-
events_A=distance_angle_events,
605-
events_B=speed_events,
606-
composition_type='logical_and',
607-
max_interval_between_sequential_events=0,
608-
min_window=1,
609-
max_window=100000
610-
)
611-
612-
return chases_events
613-
614-
get_chases_events(config)
550+
analysis = sandbox.exec_namespace["behavior_analysis"]
551+
with open("temp_roi_objects.pickle", "rb") as f:
552+
roi_objects = pickle.load(f)
553+
554+
for name, roi_object in roi_objects.items():
555+
analysis.object_manager.add_roi_object(ROIObject(name, roi_object["Path"]))
556+
557+
render_temp_message("random query", sandbox)
558+
559+
# def get_head_dips_events(config: Config):
560+
# """
561+
# Identify and count the number of head_dips events.
562+
563+
# Parameters:
564+
# ----------
565+
# config: Config
566+
567+
# Returns:
568+
# -------
569+
# head_dips_events: List[BaseEvent]
570+
# List of events where head_dips behavior occurs.
571+
# num_bouts: int
572+
# Number of bouts for head_dips behavior.
573+
# """
574+
# # Create an instance of AnimalBehaviorAnalysis
575+
# analysis = create_analysis(config)
576+
577+
# # Get events where mouse_center and neck are inside ROI0
578+
# mouse_center_neck_in_ROI0_events = analysis.get_animals_object_events(
579+
# object_name='ROI0',
580+
# query='overlap == True',
581+
# bodypart_names=['mouse_center', 'neck'],
582+
# min_window=1,
583+
# max_window=100000,
584+
# negate=False
585+
# )
586+
# # print ("mouse center neck in ROI0")
587+
# # print (len(mouse_center_neck_in_ROI0_events))
588+
# # for event in mouse_center_neck_in_ROI0_events:
589+
# # print (event)
590+
591+
# # Get events where head_midpoint is outside ROI1
592+
# head_midpoint_outside_ROI1_events = analysis.get_animals_object_events(
593+
# object_name='ROI1',
594+
# query='overlap == True',
595+
# bodypart_names=['head_midpoint'],
596+
# min_window=1,
597+
# max_window=100000,
598+
# negate=True
599+
# )
600+
# # print ('mouse head not in ROI1')
601+
# # print (len(head_midpoint_outside_ROI1_events))
602+
# # for event in head_midpoint_outside_ROI1_events:
603+
# # print (event)
604+
605+
# # Combine the events to define head_dips behavior
606+
# head_dips_events = analysis.get_composite_events(
607+
# events_A=mouse_center_neck_in_ROI0_events,
608+
# events_B=head_midpoint_outside_ROI1_events,
609+
# composition_type='logical_and',
610+
# max_interval_between_sequential_events=0,
611+
# min_window=1,
612+
# max_window=100000
613+
# )
614+
# print ('head dips events', len(head_dips_events))
615+
616+
# # Count the number of bouts for head_dips behavior
617+
# num_bouts = len(head_dips_events)
618+
619+
# return head_dips_events, num_bouts
620+
621+
# get_head_dips_events(config)

conda/amadesuGPT-cpu.yml

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,10 @@
1+
# This environment can be used to install amadeusGPT
2+
name: amadeusgpt-cpu
3+
channels:
4+
- defaults
5+
dependencies:
6+
- python<3.10
7+
- pytables==3.8.0
8+
- hdf5
9+
- pip
10+
- jupyter

conda/amadesuGPT-gpu.yml

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,10 @@
1+
# This environment can be used to install amadeusGPT
2+
name: amadeusgpt-gpu
3+
channels:
4+
- defaults
5+
dependencies:
6+
- python==3.10
7+
- pytables==3.8.0
8+
- hdf5
9+
- pip
10+
- jupyter
Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,5 @@
11
# This environment can be used to install amadeusGPT
2-
3-
name: amadeusgpt
2+
name: amadeusgpt-minimal
43
channels:
54
- defaults
65
dependencies:

install_cpu.sh

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,7 @@
1+
#!/bin/bash
2+
source /Users/shaokaiye/miniforge3/bin/activate
3+
conda env create -f conda/amadesuGPT-cpu.yml
4+
conda activate amadeusgpt-cpu
5+
conda install pytorch cpuonly -c pytorch
6+
7+
pip install -e .[streamlit]

install_cpu_minimal.sh

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,6 @@
1+
#!/bin/bash
2+
source /Users/shaokaiye/miniforge3/bin/activate
3+
conda env create -f conda/amadesuGPT-minimal.yml
4+
conda activate amadeusgpt-minimal
5+
6+
pip install -e .[streamlit]

0 commit comments

Comments
 (0)