Skip to content

Commit ea44e88

Browse files
committed
new SocialInteractionExperiment
1 parent f64aa7b commit ea44e88

File tree

7 files changed

+257
-25
lines changed

7 files changed

+257
-25
lines changed

DeepLabStream.py

Lines changed: 4 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -19,10 +19,10 @@
1919

2020
from utils.generic import VideoManager, WebCamManager, GenericManager
2121
from utils.configloader import RESOLUTION, FRAMERATE, OUT_DIR, MODEL_NAME, MULTI_CAM, STACK_FRAMES, \
22-
ANIMALS_NUMBER, STREAMS, STREAMING_SOURCE, MODEL_ORIGIN, CROP, CROP_X, CROP_Y
22+
ANIMALS_NUMBER, FLATTEN_MA, STREAMS, STREAMING_SOURCE, MODEL_ORIGIN, CROP, CROP_X, CROP_Y
2323
from utils.plotter import plot_bodyparts,plot_metadata_frame
24-
from utils.poser import load_deeplabcut,load_dpk,load_dlc_live,get_pose,calculate_skeletons, \
25-
find_local_peaks_new,get_ma_pose,load_sleap
24+
from utils.poser import load_deeplabcut,load_dpk,load_dlc_live,load_sleap, get_pose,calculate_skeletons, \
25+
find_local_peaks_new,get_ma_pose
2626

2727
def create_video_files(directory, devices, resolution, framerate, codec):
2828
"""
@@ -439,9 +439,7 @@ def get_analysed_frames(self) -> tuple:
439439
self._experiment_running = False
440440

441441
if self._experiment_running and not self._experiment.experiment_finished:
442-
#TODO: Update to work for multiple animal and single animal experiments
443-
#Shift responsibility to experiments
444-
if ANIMALS_NUMBER > 1:
442+
if ANIMALS_NUMBER > 1 and not FLATTEN_MA:
445443
self._experiment.check_skeleton(analysed_image,skeletons)
446444
else:
447445
for skeleton in skeletons:

experiments/custom/experiments.py

Lines changed: 145 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -10,33 +10,41 @@
1010
import time
1111
from functools import partial
1212
from collections import Counter
13-
from experiments.custom.stimulus_process import ClassicProtocolProcess, SimpleProtocolProcess,Timer, ExampleProtocolProcess
14-
from experiments.custom.triggers import ScreenTrigger, RegionTrigger, OutsideTrigger, DirectionTrigger, SpeedTrigger
13+
from experiments.custom.stimulus_process import ClassicProtocolProcess, SimpleProtocolProcess,Timer\
14+
, ExampleProtocolProcess
15+
from experiments.custom.triggers import ScreenTrigger, RegionTrigger, OutsideTrigger, DirectionTrigger\
16+
, SpeedTrigger, SocialInteractionTrigger
1517
from utils.plotter import plot_triggers_response
1618
from utils.analysis import angle_between_vectors
1719
from experiments.custom.stimulation import show_visual_stim_img,laser_switch
1820

1921

20-
class ExampleSocialExperiment:
22+
"""Social or multiple animal experiments in combination with SLEAP or non-flattened maDLC pose estimation"""
23+
24+
class ExampleSocialInteractionExperiment:
2125
"""
26+
In this experiment the skeleton/instance of each animal will be considers for the trigger,
27+
any animal can trigger the stimulation (the first one to result in TRUE).
28+
2229
Simple class to contain all of the experiment properties
2330
Uses multiprocess to ensure the best possible performance and
2431
to showcase that it is possible to work with any type of equipment, even timer-dependent
2532
"""
2633
def __init__(self):
2734
self.experiment_finished = False
2835
self._process = ExampleProtocolProcess()
29-
self._green_point = (313, 552)
30-
self._radius = 80
36+
self._proximity_threshold = 30
37+
self._min_animals = 2
3138
self._event = None
3239
self._current_trial = None
40+
self._max_reps = 999
3341
self._trial_count = {trial: 0 for trial in self._trials}
3442
self._trial_timers = {trial: Timer(10) for trial in self._trials}
3543
self._exp_timer = Timer(600)
3644

3745
def check_skeleton(self, frame, skeletons):
3846
"""
39-
Checking each passed animal skeleton for a pre-defined set of conditions
47+
Checking passed animal skeletons for a pre-defined set of conditions
4048
Outputting the visual representation, if exist
4149
Advancing trials according to inherent logic of an experiment
4250
:param frame: frame, on which animal skeleton was found
@@ -45,40 +53,161 @@ def check_skeleton(self, frame, skeletons):
4553
self.check_exp_timer() # checking if experiment is still on
4654
for trial in self._trial_count:
4755
# checking if any trial hit a predefined cap
48-
if self._trial_count[trial] >= 10:
56+
if self._trial_count[trial] >= self._max_reps:
4957
self.stop_experiment()
5058

5159
if not self.experiment_finished:
5260
result, response = False, None
61+
#checking if enough animals were detected
62+
if len(skeletons) >= self._min_animals:
63+
for trial in self._trials:
64+
# check if social interaction trigger is true
65+
result, response = self._trials[trial]['trigger'](skeletons=skeletons)
66+
plot_triggers_response(frame, response)
67+
if result:
68+
if self._current_trial is None:
69+
if not self._trial_timers[trial].check_timer():
70+
self._current_trial = trial
71+
self._trial_timers[trial].reset()
72+
self._trial_count[trial] += 1
73+
print(trial, self._trial_count[trial])
74+
else:
75+
if self._current_trial == trial:
76+
self._current_trial = None
77+
self._trial_timers[trial].start()
78+
79+
self._process.set_trial(self._current_trial)
80+
else:
81+
pass
82+
return result, response
83+
84+
@property
85+
def _trials(self):
86+
"""
87+
Defining the trials
88+
"""
89+
identification_dict = dict(active={'animal': 1
90+
, 'bp': ['bp0']
91+
}
92+
,passive = {'animal': 0
93+
, 'bp': ['bp2']
94+
}
95+
)
96+
97+
interaction_trigger = SocialInteractionTrigger(threshold= self._proximity_threshold
98+
, identification_dict = identification_dict
99+
, interaction_type = 'proximity'
100+
, debug=True
101+
)
102+
103+
trials = {'DLStream_test': dict(trigger=interaction_trigger.check_skeleton,
104+
count=0)}
105+
return trials
106+
107+
def check_exp_timer(self):
108+
"""
109+
Checking the experiment timer
110+
"""
111+
if not self._exp_timer.check_timer():
112+
print("Experiment is finished")
113+
print("Time ran out.")
114+
self.stop_experiment()
115+
116+
def start_experiment(self):
117+
"""
118+
Start the experiment
119+
"""
120+
self._process.start()
121+
if not self.experiment_finished:
122+
self._exp_timer.start()
123+
124+
def stop_experiment(self):
125+
"""
126+
Stop the experiment and reset the timer
127+
"""
128+
self.experiment_finished = True
129+
print('Experiment completed!')
130+
self._exp_timer.reset()
131+
# don't forget to end the process!
132+
self._process.end()
133+
134+
def get_trial(self):
135+
"""
136+
Check which trial is going on right now
137+
"""
138+
return self._current_trial
139+
140+
141+
class ExampleMultipleAnimalExperiment:
142+
"""
143+
In this experiment the skeleton/instance of each animal will be considers for the trigger,
144+
any animal can trigger the stimulation (the first one to result in TRUE).
145+
146+
Simple class to contain all of the experiment properties
147+
Uses multiprocess to ensure the best possible performance and
148+
to showcase that it is possible to work with any type of equipment, even timer-dependent
149+
"""
150+
151+
def __init__(self):
152+
self.experiment_finished = False
153+
self._process = ExampleProtocolProcess()
154+
self._green_point = (550, 163)
155+
self._radius = 40
156+
self._dist_threshold = 80
157+
self._event = None
158+
self._current_trial = None
159+
self._max_reps = 10
160+
self._trial_count = {trial: 0 for trial in self._trials}
161+
self._trial_timers = {trial: Timer(10) for trial in self._trials}
162+
self._exp_timer = Timer(600)
163+
164+
def check_skeleton(self,frame,skeletons):
165+
"""
166+
Checking each passed animal skeleton for a pre-defined set of conditions
167+
Outputting the visual representation, if exist
168+
Advancing trials according to inherent logic of an experiment
169+
:param frame: frame, on which animal skeleton was found
170+
:param skeletons: skeletons, consisting of multiple joints of an animal
171+
"""
172+
self.check_exp_timer() # checking if experiment is still on
173+
for trial in self._trial_count:
174+
# checking if any trial hit a predefined cap
175+
if self._trial_count[trial] >= self._max_reps:
176+
self.stop_experiment()
177+
178+
if not self.experiment_finished:
179+
result,response = False,None
53180
for trial in self._trials:
54181
# check for all trials if condition is met
55182
result_list = []
56183
for skeleton in skeletons:
57-
result, response = self._trials[trial]['trigger'](skeleton=skeleton)
184+
# checking each skeleton for trigger success
185+
result,response = self._trials[trial]['trigger'](skeleton=skeleton)
186+
# if one of the triggers is true, break the loop and continue (the first True)
58187
if result:
59188
break
60-
plot_triggers_response(frame, response)
189+
plot_triggers_response(frame,response)
61190
if result:
62191
if self._current_trial is None:
63192
if not self._trial_timers[trial].check_timer():
64193
self._current_trial = trial
65194
self._trial_timers[trial].reset()
66195
self._trial_count[trial] += 1
67-
print(trial, self._trial_count[trial])
196+
print(trial,self._trial_count[trial])
68197
else:
69198
if self._current_trial == trial:
70199
self._current_trial = None
71200
self._trial_timers[trial].start()
72201

73202
self._process.set_trial(self._current_trial)
74-
return result, response
203+
return result,response
75204

76205
@property
77206
def _trials(self):
78207
"""
79208
Defining the trials
80209
"""
81-
green_roi = RegionTrigger('circle', self._green_point, self._radius * 2 + 7.5, 'bp1')
210+
green_roi = RegionTrigger('circle',self._green_point,self._radius * 2 + 7.5,'bp1')
82211
trials = {'Greenbar_whiteback': dict(trigger=green_roi.check_skeleton,
83212
count=0)}
84213
return trials
@@ -117,6 +246,10 @@ def get_trial(self):
117246
return self._current_trial
118247

119248

249+
"""Single animal or flattened multi animal pose estimation experiments (e.g. different fur color)
250+
or by use of the FLATTEN_MA parameter in advanced settings"""
251+
252+
120253
class ExampleExperiment:
121254
"""
122255
Simple class to contain all of the experiment properties

experiments/custom/stimulation.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -51,7 +51,8 @@ def show_visual_stim_img(type='background', name='vistim'):
5151
# Show image when called
5252
visual = {'background': dict(path=r"./experiments/src/whiteback_1920_1080.png"),
5353
'Greenbar_whiteback': dict(path=r"./experiments/src/greenbar_whiteback_1920_1080.png"),
54-
'Bluebar_whiteback': dict(path=r"./experiments/src/bluebar_whiteback_1920_1080.png")}
54+
'Bluebar_whiteback': dict(path=r"./experiments/src/bluebar_whiteback_1920_1080.png"),
55+
'DLStream_test': dict(path=r"./experiments/src/stuckinaloop.jpg")}
5556
# load image unchanged (-1), greyscale (0) or color (1)
5657
img = cv2.imread(visual[type]['path'], -1)
5758
converted_image = np.uint8(img)

experiments/custom/stimulus_process.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -73,10 +73,10 @@ def example_protocol_run(condition_q: mp.Queue):
7373
if condition_q.full():
7474
current_trial = condition_q.get()
7575
if current_trial is not None:
76-
show_visual_stim_img(type=current_trial, name='inside')
76+
show_visual_stim_img(type=current_trial, name='DlStream')
7777
#dmod_device.toggle()
7878
else:
79-
show_visual_stim_img(name='inside')
79+
show_visual_stim_img(name='DlStream')
8080
#dmod_device.turn_off()
8181

8282
if cv2.waitKey(1) & 0xFF == ord('q'):

experiments/custom/triggers.py

Lines changed: 101 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,106 @@
1414

1515
import time
1616

17-
"""Single posture triggers"""
17+
"""Multiple Animal, Social Interaction Triggers
18+
can only be used with multiple animals/skeletons (e.g. with SLEAP)
19+
Note that currently"""
20+
21+
22+
class SocialInteractionTrigger:
23+
"""Trigger to check if one animal's body part is close to (Proximity trigger) far away from (distance trigger) another animal's body part
24+
This trigger can be easily adapted to incorporate more complicated social interaction matching.
25+
Note that this Trigger can be used without realtime identity tracking (e.g. with maDLC or DeepPoseKit) but is less effective
26+
because no distinction between active/passive interaction partners can be made."""
27+
28+
def __init__(self, threshold: float
29+
, identification_dict: dict
30+
, interaction_type: str = 'proximity', debug: bool = False):
31+
"""
32+
Initialising trigger with following parameters:
33+
:param float threshold: minimum distance between selected body parts for trigger to activate
34+
:param dict identification_dict: nested dictionary that specifies which animal/instance is taken as "active"/acting interaction partner and which is "passive"/recieving.
35+
Identification of animal ('name') has to match DLStream naming parameters for calculate skeleton (default is 'Animal1', 'Animal2', etc.).
36+
Idenification of relevant bodyparts ('bp') has to match Dlstream autonaming (default is 'bp1', 'bp2', etc.) or model specific bodypart naming,
37+
whichever applies. This can be a single body part (e.g. ['nose']) or a list of body parts (e.g. ['nose', 'paw_left', 'paw_right']
38+
39+
Example: identification_dict = dict(
40+
active = {'animal': 0,
41+
'bp': ['nose']]
42+
}
43+
,passive = {'animal': 1,
44+
'bp': ['center','tail_root']
45+
}
46+
)
47+
48+
:param str interaction_type: Type of interaction ('proximity' or 'distance'). Proximity is checking for distances lower than threshold,
49+
while distance is checking for distances higher than threshold. Default: 'proximity'
50+
:param debug: Not used in this trigger
51+
"""
52+
53+
self._threshold = threshold
54+
self._identification_dict = identification_dict
55+
#for easier use
56+
self._active_animal = self._identification_dict['active']['animal']
57+
self._active_bp = self._identification_dict['active']['bp']
58+
self._passive_animal = self._identification_dict['passive']['animal']
59+
self._passive_bp = self._identification_dict['passive']['bp']
60+
61+
self._interaction_type = interaction_type
62+
self._debug = debug
63+
64+
def check_skeleton(self, skeletons: dict):
65+
"""
66+
Checking skeletons for trigger
67+
:param skeletons: a skeleton dictionary, returned by calculate_skeletons() from poser file
68+
:return: response, a tuple of result (bool) and response body
69+
Response body is used for plotting and outputting results to trials dataframes
70+
"""
71+
72+
results = []
73+
for active_bp in self._identification_dict['active']['bp']:
74+
active_coords = skeletons[self._active_animal][active_bp]
75+
for passive_bp in self._identification_dict['passive']['bp']:
76+
passive_coords = skeletons[self._passive_animal][passive_bp]
77+
#calculate distance for all combinations
78+
distance = calculate_distance(active_coords, passive_coords)
79+
temp_result = False
80+
if distance >= self._threshold and self._interaction_type == 'distance':
81+
temp_result = True
82+
elif distance < self._threshold and self._interaction_type == 'proximity':
83+
temp_result = True
84+
else:
85+
pass
86+
results.append(temp_result)
87+
88+
result = any(results)
89+
90+
color = (0, 255, 0) if result else (0, 0, 255)
91+
92+
if self._debug:
93+
active_point_x, active_point_y = skeletons[self._active_animal][self._active_bp[0]]
94+
passive_point_x, passive_point_y= skeletons[self._passive_animal][self._passive_bp[0]]
95+
96+
response_body = {'plot': {'text': dict(text=f'{self._interaction_type}: {result}',
97+
org= (20 , 20),
98+
color= color),
99+
'line': dict(pt1=(int(active_point_x), int(active_point_y)),
100+
pt2=(int(passive_point_x), int(passive_point_y)),
101+
color=color),
102+
}
103+
}
104+
else:
105+
response_body = {'plot': {'text': dict(text=f'{self._interaction_type}: {result}',
106+
org=(20 , 20),
107+
color= color)
108+
}
109+
}
110+
response = (result, response_body)
111+
return response
112+
113+
114+
"""Single posture triggers
115+
can be used in single animal or multiple animal experiments
116+
"""
18117

19118
class HeaddirectionROITrigger:
20119
"""Trigger to check if animal is turning its head in a specific angle to a reference point (center of the region of interest)
@@ -390,6 +489,7 @@ def check_skeleton(self, skeleton: dict):
390489

391490
return response
392491

492+
393493
class SpeedTrigger:
394494
"""
395495
Trigger to check if animal is moving above a certain speed

0 commit comments

Comments
 (0)