Skip to content

Commit 7571b73

Browse files
committed
ready for release (black-conform)
1 parent e90443a commit 7571b73

File tree

5 files changed

+534
-318
lines changed

5 files changed

+534
-318
lines changed

experiments/custom/classifier.py

Lines changed: 21 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,11 @@
1111
import pickle
1212

1313
from utils.configloader import PATH_TO_CLASSIFIER, TIME_WINDOW
14-
from experiments.custom.featureextraction import SimbaFeatureExtractor, SimbaFeatureExtractorStandard14bp, BsoidFeatureExtractor
14+
from experiments.custom.featureextraction import (
15+
SimbaFeatureExtractor,
16+
SimbaFeatureExtractorStandard14bp,
17+
BsoidFeatureExtractor,
18+
)
1519

1620

1721
class Classifier:
@@ -65,7 +69,7 @@ def load_classifier(path_to_sav):
6569
return classifier
6670

6771
def _check_pure(self):
68-
if 'pure' in str(self._classifier):
72+
if "pure" in str(self._classifier):
6973
return True
7074
else:
7175
return False
@@ -147,10 +151,9 @@ def example_feat_classifier_pool_run(input_q: mp.Queue, output_q: mp.Queue):
147151
pass
148152

149153

150-
151154
def simba_feat_classifier_pool_run(input_q: mp.Queue, output_q: mp.Queue):
152155
feature_extractor = SimbaFeatureExtractorStandard14bp(TIME_WINDOW)
153-
#feature_extractor = SimbaFeatureExtractor(TIME_WINDOW)
156+
# feature_extractor = SimbaFeatureExtractor(TIME_WINDOW)
154157
classifier = SiMBAClassifier() # initialize classifier
155158
while True:
156159
skel_time_window = None
@@ -211,7 +214,9 @@ def __init__(self, pool_size: int):
211214
"""
212215
self._running = False
213216
self._pool_size = pool_size
214-
self._process_pool = self.initiate_pool(example_feat_classifier_pool_run, pool_size)
217+
self._process_pool = self.initiate_pool(
218+
example_feat_classifier_pool_run, pool_size
219+
)
215220

216221
@staticmethod
217222
def initiate_pool(process_func, pool_size: int):
@@ -285,7 +290,11 @@ def pass_time_window(self, skel_time_window: tuple, debug: bool = False):
285290
elif process["input"].empty() and process["output"].full():
286291
process["input"].put(skel_time_window)
287292
if debug:
288-
print("Input", process["process"].name, "ID: " + str(skel_time_window[1]))
293+
print(
294+
"Input",
295+
process["process"].name,
296+
"ID: " + str(skel_time_window[1]),
297+
)
289298
break
290299

291300
def get_result(self, debug: bool = False):
@@ -316,7 +325,9 @@ def __init__(self, pool_size: int):
316325
Setting up the three queues and the process itself
317326
"""
318327
super().__init__(pool_size)
319-
self._process_pool = super().initiate_pool(simba_feat_classifier_pool_run, pool_size)
328+
self._process_pool = super().initiate_pool(
329+
simba_feat_classifier_pool_run, pool_size
330+
)
320331

321332

322333
class FeatBsoidProcessPool(FeatureExtractionClassifierProcessPool):
@@ -330,8 +341,9 @@ def __init__(self, pool_size: int):
330341
Setting up the three queues and the process itself
331342
"""
332343
super().__init__(pool_size)
333-
self._process_pool = super().initiate_pool(bsoid_feat_classifier_pool_run, pool_size)
334-
344+
self._process_pool = super().initiate_pool(
345+
bsoid_feat_classifier_pool_run, pool_size
346+
)
335347

336348

337349
"""Simple process protocols """
@@ -651,4 +663,3 @@ def __init__(self, pool_size: int):
651663
"""
652664
super().__init__(pool_size)
653665
self._process_pool = super().initiate_pool(bsoid_classifier_pool_run, pool_size)
654-

experiments/custom/experiments.py

Lines changed: 7 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -30,16 +30,15 @@
3030
from utils.plotter import plot_triggers_response
3131
from utils.analysis import angle_between_vectors
3232
from experiments.custom.stimulation import show_visual_stim_img, laser_switch
33-
from experiments.custom.classifier import (
34-
FeatBsoidProcessPool,
35-
FeatSimbaProcessPool
36-
)
33+
from experiments.custom.classifier import FeatBsoidProcessPool, FeatSimbaProcessPool
3734

3835

3936
from utils.configloader import THRESHOLD, POOL_SIZE, TRIGGER
4037

4138

4239
""" experimental classification experiment using Simba trained classifiers in a pool which are converted using the pure-predict package"""
40+
41+
4342
class SimbaBehaviorPoolExperiment:
4443
"""
4544
Test experiment for Simba classification
@@ -52,8 +51,8 @@ def __init__(self):
5251
"""Classifier process and initiation of behavior trigger"""
5352
self.experiment_finished = False
5453
self._process_experiment = ExampleProtocolProcess()
55-
#this process has feature extraction and classification in one process
56-
#simplifies everything if the feature extraction script is within the parallel process
54+
# this process has feature extraction and classification in one process
55+
# simplifies everything if the feature extraction script is within the parallel process
5756
self._process_pool = FeatSimbaProcessPool(POOL_SIZE)
5857
# pass classifier to trigger, so that check_skeleton is the only function that passes skeleton
5958
# initiate in experiment, so that process can be started with start_experiment
@@ -178,7 +177,7 @@ def __init__(self):
178177
# pass classifier to trigger, so that check_skeleton is the only function that passes skeleton
179178
# initiate in experiment, so that process can be started with start_experiment
180179
self._behaviortrigger = BsoidClassBehaviorPoolTrigger(
181-
target_class=TRIGGER, class_process_pool=self._process_pool, debug= False
180+
target_class=TRIGGER, class_process_pool=self._process_pool, debug=False
182181
)
183182
self._event = None
184183
# is not fully utilized in this experiment but is usefull to keep for further adaptation
@@ -226,7 +225,7 @@ def check_skeleton(self, frame, skeleton):
226225
self._process_experiment.set_trial(self._current_trial)
227226
else:
228227
pass
229-
return result,response
228+
return result, response
230229

231230
@property
232231
def _trials(self):

0 commit comments

Comments
 (0)