Skip to content

Commit dc34c18

Browse files
committed
updated experiments.py, triggers.py, configloader.py to match new release
1 parent 79fbf35 commit dc34c18

File tree

3 files changed

+37
-19
lines changed

3 files changed

+37
-19
lines changed

experiments/custom/experiments.py

Lines changed: 21 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -37,7 +37,7 @@
3737
)
3838

3939

40-
from utils.configloader import THRESHOLD, POOL_SIZE
40+
from utils.configloader import THRESHOLD, POOL_SIZE, TRIGGER
4141

4242

4343
""" experimental classification experiment using Simba trained classifiers in a pool which are converted using the pure-predict package"""
@@ -62,7 +62,7 @@ def __init__(self):
6262
prob_threshold=THRESHOLD, class_process_pool=self._process_pool, debug=False
6363
)
6464
self._event = None
65-
# is not fully utilized in this experiment but is usefull to keep for further adaptation
65+
# is not fully utilized in this experiment but is useful to keep for further adaptation
6666
self._current_trial = None
6767
self._max_reps = 999
6868
self._trial_count = {trial: 0 for trial in self._trials}
@@ -173,7 +173,7 @@ class SimbaBehaviorPoolExperiment:
173173

174174
def __init__(self):
175175
"""Classifier process and initiation of behavior trigger"""
176-
self.experiment_finished = False
176+
self._process_experiment = ExampleProtocolProcess()
177177
self._process_pool = SimbaProcessPool(POOL_SIZE)
178178
# pass classifier to trigger, so that check_skeleton is the only function that passes skeleton
179179
# initiate in experiment, so that process can be started with start_experiment
@@ -223,14 +223,18 @@ def check_skeleton(self, frame, skeleton):
223223
if self._current_trial == trial:
224224
self._current_trial = None
225225
self._trial_timers[trial].start()
226+
self._process_experiment.set_trial(self._current_trial)
227+
else:
228+
pass
229+
return result,response
226230

227231
@property
228232
def _trials(self):
229233
"""
230234
Defining the trials
231235
"""
232236
trials = {
233-
"SimBA1": dict(
237+
"DLStream_test": dict(
234238
trigger=self._behaviortrigger.check_skeleton, target_prob=None, count=0
235239
)
236240
}
@@ -249,6 +253,7 @@ def start_experiment(self):
249253
"""
250254
Start the experiment
251255
"""
256+
self._process_experiment.start()
252257
self._process_pool.start()
253258
if not self.experiment_finished:
254259
self._exp_timer.start()
@@ -258,6 +263,7 @@ def stop_experiment(self):
258263
Stop the experiment and reset the timer
259264
"""
260265
self.experiment_finished = True
266+
self._process_experiment.end()
261267
self._process_pool.end()
262268
print("Experiment completed!")
263269
self._exp_timer.reset()
@@ -288,11 +294,12 @@ class BsoidBehaviorPoolExperiment:
288294
def __init__(self):
289295
"""Classifier process and initiation of behavior trigger"""
290296
self.experiment_finished = False
297+
self._process_experiment = ExampleProtocolProcess()
291298
self._process_pool = BsoidProcessPool(POOL_SIZE)
292299
# pass classifier to trigger, so that check_skeleton is the only function that passes skeleton
293300
# initiate in experiment, so that process can be started with start_experiment
294301
self._behaviortrigger = BsoidClassBehaviorPoolTrigger(
295-
target_class=THRESHOLD, class_process_pool=self._process_pool
302+
target_class=TRIGGER, class_process_pool=self._process_pool, debug= False
296303
)
297304
self._event = None
298305
# is not fully utilized in this experiment but is usefull to keep for further adaptation
@@ -337,14 +344,19 @@ def check_skeleton(self, frame, skeleton):
337344
if self._current_trial == trial:
338345
self._current_trial = None
339346
self._trial_timers[trial].start()
347+
self._process_experiment.set_trial(self._current_trial)
348+
else:
349+
pass
350+
return result,response
340351

341352
@property
342353
def _trials(self):
343354
"""
344355
Defining the trials
356+
Target class is the cluster of interest and can be changed with every call of check_skeleton
345357
"""
346358
trials = {
347-
"BSOID1": dict(
359+
"DLStream_test": dict(
348360
trigger=self._behaviortrigger.check_skeleton, target_class=None, count=0
349361
)
350362
}
@@ -364,6 +376,7 @@ def start_experiment(self):
364376
Start the experiment
365377
"""
366378
self._process_pool.start()
379+
self._process_experiment.start()
367380
if not self.experiment_finished:
368381
self._exp_timer.start()
369382

@@ -372,6 +385,7 @@ def stop_experiment(self):
372385
Stop the experiment and reset the timer
373386
"""
374387
self.experiment_finished = True
388+
self._process_experiment.end()
375389
self._process_pool.end()
376390
print("Experiment completed!")
377391
self._exp_timer.reset()
@@ -388,6 +402,7 @@ def get_info(self):
388402
return info
389403

390404

405+
391406
"""Social or multiple animal experiments in combination with SLEAP or non-flattened maDLC pose estimation"""
392407

393408

experiments/custom/triggers.py

Lines changed: 15 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@
1313
EllipseROI,
1414
RectangleROI,
1515
)
16-
from utils.configloader import RESOLUTION, TIME_WINDOW
16+
from utils.configloader import RESOLUTION, TIME_WINDOW, FRAMERATE
1717
from collections import deque
1818
from experiments.custom.featureextraction import (
1919
SimbaFeatureExtractor,
@@ -804,12 +804,12 @@ def __init__(self, target_class: int, path_to_sav: str, debug: bool = False):
804804
self._trigger = target_class
805805
self._last_result = [0]
806806
self._center = None
807-
self._debug = debug # not used in this trigger
807+
self._debug = debug
808808
self._skeleton = None
809809
self._classifier, self._time_window_len = self._init_classifier(
810810
path_to_sav
811811
) # initialize classifier
812-
self.feat_extractor = BsoidFeatureExtractor(self._time_window_len, fps=30)
812+
self.feat_extractor = BsoidFeatureExtractor(self._time_window_len, fps=FRAMERATE)
813813
self._time_window = deque(maxlen=self._time_window_len)
814814

815815
@staticmethod
@@ -818,7 +818,7 @@ def _init_classifier(path_to_sav):
818818

819819
"""Put your classifier of choice in here"""
820820
classifier = BsoidClassifier(path_to_clf=path_to_sav)
821-
win_len = classifier.ge()
821+
win_len = classifier.get_win_len()
822822
return classifier, win_len
823823

824824
def fill_time_window(self, skeleton):
@@ -847,11 +847,12 @@ def check_skeleton(self, skeleton, trigger: float = None):
847847
start_time = time.time()
848848
f_extract_output = self.feat_extractor.extract_features(self._time_window)
849849
end_time = time.time()
850-
print(
851-
"Feature extraction time: {:.2f} msec".format(
852-
(end_time - start_time) * 1000
850+
if self._debug:
851+
print(
852+
"Feature extraction time: {:.2f} msec".format(
853+
(end_time - start_time) * 1000
854+
)
853855
)
854-
)
855856
if f_extract_output is not None:
856857
self._last_result, _, _ = self._classifier.classify(f_extract_output)
857858
else:
@@ -906,7 +907,7 @@ def __init__(self, target_class: int, class_process_pool, debug: bool = False):
906907
self._last_result = [0]
907908
self._feature_id = 0
908909
self._center = None
909-
self._debug = debug # not used in this trigger
910+
self._debug = debug
910911
self._skeleton = None
911912
self._time_window_len = TIME_WINDOW
912913
self.feat_extractor = BsoidFeatureExtractor(self._time_window_len)
@@ -938,11 +939,12 @@ def check_skeleton(self, skeleton, target_class: int = None):
938939
start_time = time.time()
939940
f_extract_output = self.feat_extractor.extract_features(self._time_window)
940941
end_time = time.time()
941-
print(
942-
"Feature extraction time: {:.2f} msec".format(
943-
(end_time - start_time) * 1000
942+
if self._debug:
943+
print(
944+
"Feature extraction time: {:.2f} msec".format(
945+
(end_time - start_time) * 1000
946+
)
944947
)
945-
)
946948
if f_extract_output is not None:
947949
self._feature_id += 1
948950
self._process_pool.pass_features(

utils/configloader.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -78,6 +78,7 @@ def get_script_path():
7878
PATH_TO_CLASSIFIER = dsc_config["Classification"].get("PATH_TO_CLASSIFIER")
7979
PIXPERMM = dsc_config["Classification"].getfloat("PIXPERMM")
8080
THRESHOLD = dsc_config["Classification"].getfloat("THRESHOLD")
81+
TRIGGER = dsc_config["Classification"].getfloat("TRIGGER")
8182
POOL_SIZE = dsc_config["Classification"].getint("POOL_SIZE")
8283
TIME_WINDOW = dsc_config["Classification"].getint("TIME_WINDOW")
8384

0 commit comments

Comments
 (0)