Skip to content

Commit db6e777

Browse files
committed
updated classification with pure-predict
1 parent fae313d commit db6e777

File tree

8 files changed

+227
-13
lines changed

8 files changed

+227
-13
lines changed

DeepLabStream.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -289,9 +289,10 @@ def get_pose_mp(input_q, output_q):
289289
start_time = time.time()
290290
if MODEL_ORIGIN == 'DLC':
291291
scmap, locref, pose = get_pose(frame, config, sess, inputs, outputs)
292-
peaks = find_local_peaks_new(scmap, locref, ANIMALS_NUMBER, config)
292+
#peaks = find_local_peaks_new(scmap, locref, ANIMALS_NUMBER, config)
293293
#Use the line below to use raw DLC output rather then DLStream optimization
294-
# peaks = pose
294+
#TODO: return to original
295+
peaks = pose
295296
if MODEL_ORIGIN == 'MADLC':
296297
peaks = get_ma_pose(frame, config, sess, inputs, outputs)
297298
analysis_time = time.time() - start_time

convert_classifier.py

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,27 @@
1+
import pickle
2+
import os
3+
from pure_sklearn.map import convert_estimator
4+
5+
6+
def load_classifier(path_to_sav):
7+
"""Load saved classifier"""
8+
file = open(path_to_sav,'rb')
9+
classifier = pickle.load(file)
10+
file.close()
11+
return classifier
12+
13+
def convert_classifier(path):
14+
# convert to pure python estimator
15+
print('Loading classifier...')
16+
clf = load_classifier(path)
17+
dir_path = os.path.dirname(path)
18+
filename = os.path.basename(path)
19+
filename, _ = filename.split('.')
20+
clf_pure_predict = convert_estimator(clf)
21+
with open(dir_path + '/'+ filename + "_pure.sav", "wb") as f:
22+
pickle.dump(clf_pure_predict, f)
23+
print(f'Converted Classifier {filename}')
24+
25+
if __name__ == '__main__':
26+
path_to_classifier = r"D:\SimBa\Jens_models\pursuit_prediction_11.sav"
27+
convert_classifier(path_to_classifier)

experiments/custom/classifier.py

Lines changed: 69 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88

99
import multiprocessing as mp
1010
import time
11+
import pickle
1112

1213
from utils.configloader import PATH_TO_CLASSIFIER
1314

@@ -24,7 +25,7 @@ def __init__(self,win_len: int = 1):
2425
@staticmethod
2526
def load_classifier(path_to_sav):
2627
"""Load saved classifier"""
27-
import pickle
28+
#import pickle
2829
file = open(path_to_sav,'rb')
2930
classifier = pickle.load(file)
3031
file.close()
@@ -44,6 +45,37 @@ def get_win_len(self):
4445
return self._win_len
4546

4647

48+
49+
class PureSiMBAClassifier:
50+
"""SiMBA base class for simple behavior classification trigger. Loads pretrained classifier, gets passed features
51+
from SimbaFeatureExtractor. Returns probability of prediction that can be incorporated into triggers."""
52+
53+
def __init__(self):
54+
self._classifier = self.load_classifier(PATH_TO_CLASSIFIER)
55+
self.last_result = 0.0
56+
57+
@staticmethod
58+
def load_classifier(path_to_sav):
59+
"""Load saved classifier"""
60+
# load pickled pure-predict model
61+
with open(path_to_sav,"rb") as f:
62+
classifier = pickle.load(f)
63+
return classifier
64+
65+
def classify(self, features):
66+
"""predicts motif probability from features"""
67+
#pure-predict needs a list instead of a numpy array
68+
prediction = self._classifier.predict_proba(list(features))
69+
#pure-predict returns a nested list
70+
probability = prediction[0][1]
71+
self.last_result = probability
72+
return probability
73+
74+
def get_last_result(self):
75+
"""Returns predicted last prediction"""
76+
return self.last_result
77+
78+
4779
class SiMBAClassifier:
4880
"""SiMBA base class for simple behavior classification trigger. Loads pretrained classifier, gets passed features
4981
from SimbaFeatureExtractor. Returns probability of prediction that can be incorporated into triggers."""
@@ -55,7 +87,7 @@ def __init__(self):
5587
@staticmethod
5688
def load_classifier(path_to_sav):
5789
"""Load saved classifier"""
58-
import pickle
90+
#import pickle
5991
file = open(path_to_sav,'rb')
6092
classifier = pickle.load(file)
6193
file.close()
@@ -261,7 +293,24 @@ def simba_classifier_pool_run(input_q: mp.Queue,output_q: mp.Queue):
261293
last_prob = classifier.classify(features)
262294
output_q.put((last_prob,feature_id))
263295
end_time = time.time()
264-
# print("Classification time: {:.2f} msec".format((end_time-start_time)*1000))
296+
#print("Classification time: {:.2f} msec".format((end_time-start_time)*1000))
297+
else:
298+
pass
299+
300+
301+
def pure_simba_classifier_pool_run(input_q: mp.Queue,output_q: mp.Queue):
302+
classifier = PureSiMBAClassifier() # initialize classifier
303+
while True:
304+
features = None
305+
feature_id = 0
306+
if input_q.full():
307+
features,feature_id = input_q.get()
308+
if features is not None:
309+
start_time = time.time()
310+
last_prob = classifier.classify(features)
311+
output_q.put((last_prob,feature_id))
312+
end_time = time.time()
313+
print("Classification time: {:.2f} msec".format((end_time-start_time)*1000))
265314
else:
266315
pass
267316

@@ -278,8 +327,8 @@ def bsoid_classifier_pool_run(input_q: mp.Queue,output_q: mp.Queue):
278327
last_prob = classifier.classify(features)
279328
output_q.put((last_prob,feature_id))
280329
end_time = time.time()
281-
print("Classification time: {:.2f} msec".format((end_time-start_time)*1000))
282-
print("Feature ID: "+ feature_id)
330+
#print("Classification time: {:.2f} msec".format((end_time-start_time)*1000))
331+
#print("Feature ID: "+ feature_id)
283332

284333
else:
285334
pass
@@ -378,6 +427,21 @@ def get_result(self,debug: bool = False):
378427
return result
379428

380429

430+
class PureSimbaProcessPool(ClassifierProcessPool):
431+
"""
432+
Class to help work with protocol function in multiprocessing
433+
spawns a pool of processes that tackle the frame-by-frame issue.
434+
"""
435+
436+
def __init__(self,pool_size: int):
437+
"""
438+
Setting up the three queues and the process itself
439+
"""
440+
super().__init__(pool_size)
441+
self._process_pool = super().initiate_pool(pure_simba_classifier_pool_run, pool_size)
442+
443+
444+
381445
class SimbaProcessPool(ClassifierProcessPool):
382446
"""
383447
Class to help work with protocol function in multiprocessing

experiments/custom/experiments.py

Lines changed: 119 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -17,12 +17,128 @@
1717
from utils.plotter import plot_triggers_response
1818
from utils.analysis import angle_between_vectors
1919
from experiments.custom.stimulation import show_visual_stim_img,laser_switch
20-
from experiments.custom.classifier import SimbaProcessPool, BsoidProcessPool
20+
from experiments.custom.classifier import SimbaProcessPool, BsoidProcessPool, PureSimbaProcessPool
2121

2222

2323
from utils.configloader import THRESHOLD, POOL_SIZE
2424

2525

26+
""" experimental classification experiment using Simba trained classifiers in a pool which are converted using the pure-predict package"""
27+
28+
class PureSimbaBehaviorPoolExperiment:
29+
"""
30+
Test experiment for Simba classification with pure-predict conversion
31+
Simple class to contain all of the experiment properties and includes classification
32+
Uses multiprocess to ensure the best possible performance and
33+
to showcase that it is possible to work with any type of equipment, even timer-dependant
34+
"""
35+
36+
def __init__(self):
37+
"""Classifier process and initiation of behavior trigger"""
38+
self.experiment_finished = False
39+
self._process_experiment = ExampleProtocolProcess()
40+
self._process_pool = PureSimbaProcessPool(POOL_SIZE)
41+
#pass classifier to trigger, so that check_skeleton is the only function that passes skeleton
42+
#initiate in experiment, so that process can be started with start_experiment
43+
self._behaviortrigger = SimbaThresholdBehaviorPoolTrigger(prob_threshold= THRESHOLD,
44+
class_process_pool = self._process_pool,
45+
debug=False)
46+
self._event = None
47+
#is not fully utilized in this experiment but is usefull to keep for further adaptation
48+
self._current_trial = None
49+
self._max_reps = 999
50+
self._trial_count = {trial: 0 for trial in self._trials}
51+
self._trial_timers = {trial: Timer(0) for trial in self._trials}
52+
self._exp_timer = Timer(9999)
53+
54+
def check_skeleton(self, frame, skeleton):
55+
"""
56+
Checking each passed animal skeleton for a pre-defined set of conditions
57+
Outputting the visual representation, if exist
58+
Advancing trials according to inherent logic of an experiment
59+
:param frame: frame, on which animal skeleton was found
60+
:param skeleton: skeleton, consisting of multiple joints of an animal
61+
"""
62+
self.check_exp_timer() # checking if experiment is still on
63+
for trial in self._trial_count:
64+
# checking if any trial hit a predefined cap
65+
if self._trial_count[trial] >= self._max_reps:
66+
self.stop_experiment()
67+
68+
if not self.experiment_finished:
69+
for trial in self._trials:
70+
# check for all trials if condition is met
71+
#this passes the skeleton to the trigger, where the feature extraction is done and the extracted features
72+
#are passed to the classifier process
73+
result, response = self._trials[trial]['trigger'](skeleton, target_prob = self._trials[trial]['target_prob'])
74+
plot_triggers_response(frame, response)
75+
#if the trigger is reporting back that the behavior is found: do something
76+
#currently nothing is done, expect counting the occurances
77+
if result:
78+
if self._current_trial is None:
79+
if not self._trial_timers[trial].check_timer():
80+
self._current_trial = trial
81+
self._trial_timers[trial].reset()
82+
self._trial_count[trial] += 1
83+
print(trial, self._trial_count[trial])
84+
else:
85+
if self._current_trial == trial:
86+
self._current_trial = None
87+
self._trial_timers[trial].start()
88+
89+
self._process_experiment.set_trial(self._current_trial)
90+
91+
@property
92+
def _trials(self):
93+
"""
94+
Defining the trials
95+
"""
96+
trials = {'DLStream_test': dict(trigger=self._behaviortrigger.check_skeleton,
97+
target_prob = None,
98+
count=0)}
99+
return trials
100+
101+
def check_exp_timer(self):
102+
"""
103+
Checking the experiment timer
104+
"""
105+
if not self._exp_timer.check_timer():
106+
print("Experiment is finished")
107+
print("Time ran out.")
108+
self.stop_experiment()
109+
110+
def start_experiment(self):
111+
"""
112+
Start the experiment
113+
"""
114+
self._process_experiment.start()
115+
self._process_pool.start()
116+
if not self.experiment_finished:
117+
self._exp_timer.start()
118+
119+
def stop_experiment(self):
120+
"""
121+
Stop the experiment and reset the timer
122+
"""
123+
self.experiment_finished = True
124+
self._process_experiment.end()
125+
self._process_pool.end()
126+
print('Experiment completed!')
127+
self._exp_timer.reset()
128+
129+
def get_trial(self):
130+
"""
131+
Check which trial is going on right now
132+
"""
133+
return self._event
134+
135+
def get_info(self):
136+
""" returns optional info"""
137+
info = self._behaviortrigger.get_last_prob()
138+
return info
139+
140+
141+
26142
""" experimental classification experiment using Simba trained classifiers in a pool"""
27143
class SimbaBehaviorPoolExperiment:
28144
"""
@@ -39,7 +155,8 @@ def __init__(self):
39155
#pass classifier to trigger, so that check_skeleton is the only function that passes skeleton
40156
#initiate in experiment, so that process can be started with start_experiment
41157
self._behaviortrigger = SimbaThresholdBehaviorPoolTrigger(prob_threshold= THRESHOLD,
42-
class_process_pool = self._process_pool)
158+
class_process_pool = self._process_pool,
159+
debug=False)
43160
self._event = None
44161
#is not fully utilized in this experiment but is usefull to keep for further adaptation
45162
self._current_trial = None

experiments/custom/stimulus_process.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -70,9 +70,11 @@ def example_protocol_run(condition_q: mp.Queue):
7070
current_trial = None
7171
#dmod_device = DigitalModDevice('Dev1/PFI0')
7272
while True:
73+
# if no protocol is selected, running default picture (background)
7374
if condition_q.full():
7475
current_trial = condition_q.get()
7576
if current_trial is not None:
77+
print('IM HEEEEERE!')
7678
show_visual_stim_img(type=current_trial, name='DlStream')
7779
#dmod_device.toggle()
7880
else:

experiments/custom/triggers.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66
Licensed under GNU General Public License v3.0
77
"""
88

9-
9+
from utils.poser import transform_2pose
1010
from utils.analysis import angle_between_vectors, calculate_distance, EllipseROI, RectangleROI
1111
from utils.configloader import RESOLUTION, TIME_WINDOW
1212
from collections import deque
@@ -589,7 +589,6 @@ def __init__(self,prob_threshold: float, class_process_pool, debug: bool = False
589589

590590
def fill_time_window(self,skeleton: dict):
591591
"""Transforms skeleton input into flat numpy array of coordinates to pass to feature extraction"""
592-
from utils.poser import transform_2pose
593592
flat_values = transform_2pose(skeleton).flatten()
594593
# this appends the new row to the deque time_window, which will drop the "oldest" entry due to a maximum
595594
# length of time_window_len

requirements.txt

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
PySide2==5.14.1
2-
numba
2+
numba==0.51.1
33
gpiozero==1.5.1
44
pigpio==1.78
55
pyserial==3.5
@@ -9,5 +9,7 @@ opencv-python==3.4.5.20
99
opencv-contrib-python==4.4.0.46
1010
numpy>=1.14.5
1111
pandas==1.1.4
12+
scikit-learn== 0.24.1
1213
scikit-image==0.17.2
1314
scipy==1.4.1
15+
pure-predict==0.0.4

utils/poser.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -479,7 +479,9 @@ def calculate_skeletons(peaks: dict, animals_number: int) -> list:
479479
adaptive to chosen model origin
480480
"""
481481
if MODEL_ORIGIN == 'DLC':
482-
animal_skeletons = calculate_dlstream_skeletons(peaks, animals_number)
482+
animal_skeletons = calculate_skeletons_dlc_live(peaks)
483+
#TODO: alter back to original
484+
#animal_skeletons = calculate_dlstream_skeletons(peaks, animals_number)
483485
if animals_number != 1 and SPLIT_MA:
484486
animal_skeletons = split_flat_skeleton(animal_skeletons)
485487
else:

0 commit comments

Comments
 (0)