Skip to content

Commit 6c55285

Browse files
committed
added bsoid classifier wrapper
1 parent e212bbd commit 6c55285

File tree

4 files changed

+556
-14
lines changed

4 files changed

+556
-14
lines changed

experiments/custom/classifier.py

Lines changed: 96 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -44,9 +44,9 @@ def get_win_len(self):
4444
return self._win_len
4545

4646

47-
class SiMBAClassifier():
47+
class SiMBAClassifier:
4848
"""SiMBA base class for simple behavior classification trigger. Loads pretrained classifier, gets passed features
49-
from FeatureExtractor. Returns probability of prediction that can be incorporated into triggers."""
49+
from SimbaFeatureExtractor. Returns probability of prediction that can be incorporated into triggers."""
5050

5151
def __init__(self):
5252
self._classifier = self.load_classifier(PATH_TO_CLASSIFIER)
@@ -68,11 +68,49 @@ def classify(self,features):
6868
self.last_result = probability
6969
return probability
7070

71-
def get_last_result(self,skeleton_window: list):
71+
def get_last_result(self):
7272
"""Returns predicted last prediction"""
7373
return self.last_result
7474

7575

76+
class BsoidClassifier:
77+
"""BSOID base class for multiple behavior classification trigger. Loads pretrained classifier, gets passed features
78+
from SimbaFeatureExtractor. Returns probability of prediction that can be incorporated into triggers."""
79+
80+
def __init__(self):
81+
self._classifier = self.load_classifier(PATH_TO_CLASSIFIER)
82+
self.last_result = 0.0
83+
84+
@staticmethod
85+
def load_classifier(path_to_sav):
86+
"""Load saved classifier"""
87+
import joblib
88+
file = open(path_to_sav,'rb')
89+
[_, _, _, clf, _, predictions] = joblib.load(file)
90+
file.close()
91+
return clf
92+
93+
def classify(self, features):
94+
"""predicts motif probability from features :param feats: list, multiple feats (original feature space)
95+
:param clf: Obj, MLP classifier
96+
:return nonfs_labels: list, label/100ms
97+
Adapted from BSOID; https://github.com/YttriLab/B-SOID
98+
"""
99+
labels_fslow = []
100+
for i in range(0, len(features)):
101+
labels = self._classifier.predict(features[i].T)
102+
labels_fslow.append(labels)
103+
self.last_result = labels_fslow
104+
105+
return labels_fslow
106+
107+
def get_last_result(self):
108+
"""Returns predicted last prediction"""
109+
return self.last_result
110+
111+
112+
"""process protocols"""
113+
76114
def example_classifier_run(input_classification_q: mp.Queue,output_classification_q: mp.Queue):
77115
classifier = Classifier() # initialize classifier
78116
while True:
@@ -102,6 +140,22 @@ def simba_classifier_run(input_q: mp.Queue,output_q: mp.Queue):
102140
pass
103141

104142

143+
def bsoid_classifier_run(input_q: mp.Queue,output_q: mp.Queue):
144+
classifier = BsoidClassifier() # initialize classifier
145+
while True:
146+
features = None
147+
if input_q.full():
148+
features = input_q.get()
149+
if features is not None:
150+
start_time = time.time()
151+
last_prob = classifier.classify(features)
152+
output_q.put((last_prob))
153+
end_time = time.time()
154+
print("Classification time: {:.2f} msec".format((end_time-start_time)*1000))
155+
else:
156+
pass
157+
158+
105159
class ClassifierProcess:
106160
"""
107161
Class to help work with protocol function in multiprocessing
@@ -166,6 +220,15 @@ def __init__(self):
166220
self.output_queue))
167221

168222

223+
class BsoidClassifier_Process(ClassifierProcess):
224+
225+
def __init__(self):
226+
super().__init__()
227+
self.input_queue = mp.Queue(1)
228+
self.output_queue = mp.Queue(1)
229+
self._classification_process = mp.Process(target=bsoid_classifier_run,args=(self.input_queue,
230+
self.output_queue))
231+
169232
"""Processing pool for classification"""
170233

171234

@@ -203,6 +266,22 @@ def simba_classifier_pool_run(input_q: mp.Queue,output_q: mp.Queue):
203266
pass
204267

205268

269+
def bsoid_classifier_pool_run(input_q: mp.Queue,output_q: mp.Queue):
270+
classifier = BsoidClassifier() # initialize classifier
271+
while True:
272+
features = None
273+
feature_id = 0
274+
if input_q.full():
275+
features,feature_id = input_q.get()
276+
if features is not None:
277+
start_time = time.time()
278+
last_prob = classifier.classify(features)
279+
output_q.put((last_prob,feature_id))
280+
end_time = time.time()
281+
print("Classification time: {:.2f} msec".format((end_time-start_time)*1000))
282+
else:
283+
pass
284+
206285
class ClassifierProcessPool:
207286
"""
208287
Class to help work with protocol function in multiprocessing
@@ -309,3 +388,17 @@ def __init__(self,pool_size: int):
309388
"""
310389
super().__init__(pool_size)
311390
self._process_pool = super().initiate_pool(simba_classifier_pool_run,pool_size)
391+
392+
393+
class BsoidProcessPool(ClassifierProcessPool):
394+
"""
395+
Class to help work with protocol function in multiprocessing
396+
spawns a pool of processes that tackle the frame-by-frame issue.
397+
"""
398+
399+
def __init__(self,pool_size: int):
400+
"""
401+
Setting up the three queues and the process itself
402+
"""
403+
super().__init__(pool_size)
404+
self._process_pool = super().initiate_pool(bsoid_classifier_pool_run,pool_size)

experiments/custom/experiments.py

Lines changed: 112 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -11,11 +11,12 @@
1111
from functools import partial
1212
from collections import Counter
1313
from experiments.custom.stimulus_process import ClassicProtocolProcess, SimpleProtocolProcess,Timer, ExampleProtocolProcess
14-
from experiments.custom.triggers import ScreenTrigger, RegionTrigger, OutsideTrigger, DirectionTrigger, SpeedTrigger, SimbaThresholdBehaviorTriggerPool
14+
from experiments.custom.triggers import ScreenTrigger, RegionTrigger, OutsideTrigger, DirectionTrigger, SpeedTrigger,\
15+
SimbaThresholdBehaviorTriggerPool, BsoidClassBehaviorTriggerPool
1516
from utils.plotter import plot_triggers_response
1617
from utils.analysis import angle_between_vectors
1718
from experiments.custom.stimulation import show_visual_stim_img,laser_switch
18-
from experiments.custom.classifier import SimbaClassifier_Process, SimbaProcessPool
19+
from experiments.custom.classifier import SimbaClassifier_Process, SimbaProcessPool, BsoidProcessPool
1920

2021

2122
from utils.configloader import THRESHOLD, POOL_SIZE
@@ -125,6 +126,115 @@ def get_info(self):
125126
info = self._behaviortrigger.get_last_prob()
126127
return info
127128

129+
130+
131+
""" experimental classification experiment using BSOID trained classifiers in a pool"""
132+
133+
class BsoidBehaviorExperimentPool:
134+
"""
135+
Test experiment for Simba classification
136+
Simple class to contain all of the experiment properties and includes classification
137+
Uses multiprocess to ensure the best possible performance and
138+
to showcase that it is possible to work with any type of equipment, even timer-dependant
139+
"""
140+
141+
def __init__(self):
142+
"""Classifier process and initiation of behavior trigger"""
143+
self.experiment_finished = False
144+
self._process_pool = BsoidProcessPool(POOL_SIZE)
145+
#pass classifier to trigger, so that check_skeleton is the only function that passes skeleton
146+
#initiate in experiment, so that process can be started with start_experiment
147+
self._behaviortrigger = BsoidClassBehaviorTriggerPool(target_class= THRESHOLD,
148+
class_process_pool = self._process_pool)
149+
self._event = None
150+
#is not fully utilized in this experiment but is usefull to keep for further adaptation
151+
self._current_trial = None
152+
self._trial_count = {trial: 0 for trial in self._trials}
153+
self._trial_timers = {trial: Timer(10) for trial in self._trials}
154+
self._exp_timer = Timer(600)
155+
156+
def check_skeleton(self, frame, skeleton):
157+
"""
158+
Checking each passed animal skeleton for a pre-defined set of conditions
159+
Outputting the visual representation, if exist
160+
Advancing trials according to inherent logic of an experiment
161+
:param frame: frame, on which animal skeleton was found
162+
:param skeleton: skeleton, consisting of multiple joints of an animal
163+
"""
164+
self.check_exp_timer() # checking if experiment is still on
165+
for trial in self._trial_count:
166+
# checking if any trial hit a predefined cap
167+
if self._trial_count[trial] >= 10:
168+
self.stop_experiment()
169+
170+
if not self.experiment_finished:
171+
for trial in self._trials:
172+
# check for all trials if condition is met
173+
#this passes the skeleton to the trigger, where the feature extraction is done and the extracted features
174+
#are passed to the classifier process
175+
result, response = self._trials[trial]['trigger'](skeleton, target_class = self._trials[trial]['target_class'])
176+
plot_triggers_response(frame, response)
177+
#if the trigger is reporting back that the behavior is found: do something
178+
#currently nothing is done, expect counting the occurances
179+
if result:
180+
if self._current_trial is None:
181+
if not self._trial_timers[trial].check_timer():
182+
self._current_trial = trial
183+
self._trial_timers[trial].reset()
184+
self._trial_count[trial] += 1
185+
print(trial, self._trial_count[trial])
186+
else:
187+
if self._current_trial == trial:
188+
self._current_trial = None
189+
self._trial_timers[trial].start()
190+
@property
191+
def _trials(self):
192+
"""
193+
Defining the trials
194+
"""
195+
trials = {'BSOID1': dict(trigger=self._behaviortrigger.check_skeleton,
196+
target_class = None,
197+
count=0)}
198+
return trials
199+
200+
def check_exp_timer(self):
201+
"""
202+
Checking the experiment timer
203+
"""
204+
if not self._exp_timer.check_timer():
205+
print("Experiment is finished")
206+
print("Time ran out.")
207+
self.stop_experiment()
208+
209+
def start_experiment(self):
210+
"""
211+
Start the experiment
212+
"""
213+
self._process_pool.start()
214+
if not self.experiment_finished:
215+
self._exp_timer.start()
216+
217+
def stop_experiment(self):
218+
"""
219+
Stop the experiment and reset the timer
220+
"""
221+
self.experiment_finished = True
222+
self._process_pool.end()
223+
print('Experiment completed!')
224+
self._exp_timer.reset()
225+
226+
def get_trial(self):
227+
"""
228+
Check which trial is going on right now
229+
"""
230+
return self._event
231+
232+
def get_info(self):
233+
""" returns optional info"""
234+
info = self._behaviortrigger.get_last_prob()
235+
return info
236+
237+
128238
class ExampleExperiment:
129239
"""
130240
Simple class to contain all of the experiment properties

0 commit comments

Comments
 (0)