Skip to content

Commit e212bbd

Browse files
committed
initial commit for simba implementation
1 parent 629e201 commit e212bbd

File tree

10 files changed

+790
-13
lines changed

10 files changed

+790
-13
lines changed

DeepLabStream.py

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@
1919

2020
from utils.configloader import RESOLUTION, FRAMERATE, OUT_DIR, MODEL, MULTI_CAM, STACK_FRAMES, \
2121
ANIMALS_NUMBER, STREAMS, STREAMING_SOURCE
22-
from utils.poser import load_deeplabcut, get_pose, find_local_peaks_new, calculate_skeletons
22+
from utils.poser import load_deeplabcut, get_pose, find_local_peaks_new, calculate_skeletons, transform_2skeleton
2323
from utils.plotter import plot_bodyparts, plot_metadata_frame
2424

2525

@@ -277,7 +277,9 @@ def get_pose_mp(input_q, output_q):
277277
if input_q.full():
278278
index, frame = input_q.get()
279279
scmap, locref, pose = get_pose(frame, config, sess, inputs, outputs)
280-
peaks = find_local_peaks_new(scmap, locref, ANIMALS_NUMBER, config)
280+
#TODO: REmove alterations to original
281+
#peaks = find_local_peaks_new(scmap, locref, ANIMALS_NUMBER, config)
282+
peaks = pose
281283
output_q.put((index, peaks))
282284

283285
@staticmethod
@@ -366,7 +368,9 @@ def get_analysed_frames(self) -> tuple:
366368

367369
# Getting the analysed data
368370
analysed_index, peaks = self._multiprocessing[camera]['output'].get()
369-
skeletons = calculate_skeletons(peaks, ANIMALS_NUMBER)
371+
#TODO: REMOVE IF USELESS
372+
skeletons = [transform_2skeleton(peaks)]
373+
#skeletons = calculate_skeletons(peaks, ANIMALS_NUMBER)
370374
print('', end='\r', flush=True) # this is the line you should not remove
371375
analysed_frame, depth_map, input_time = self.get_stored_frames(camera)
372376
analysis_time = time.time() - input_time

experiments/custom/classifier.py

Lines changed: 311 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,311 @@
1+
"""
2+
DeepLabStream
3+
© J.Schweihoff, M. Loshakov
4+
University Bonn Medical Faculty, Germany
5+
https://github.com/SchwarzNeuroconLab/DeepLabStream
6+
Licensed under GNU General Public License v3.0
7+
"""
8+
9+
import multiprocessing as mp
10+
import time
11+
12+
from utils.configloader import PATH_TO_CLASSIFIER
13+
14+
15+
class Classifier:
16+
"""Empty base class for classification trigger. Loads pretrained classifier, extracts features from skeleton sequence
17+
and passes it to the classifier. Returns found motif and result if used as trigger."""
18+
19+
def __init__(self,win_len: int = 1):
20+
self._classifier = self.load_classifier(PATH_TO_CLASSIFIER)
21+
self._win_len = win_len
22+
self.last_result = None
23+
24+
@staticmethod
25+
def load_classifier(path_to_sav):
26+
"""Load saved classifier"""
27+
import pickle
28+
file = open(path_to_sav,'rb')
29+
classifier = pickle.load(file)
30+
file.close()
31+
return classifier
32+
33+
def classify(self,features):
34+
"""predicts motif from features"""
35+
prediction = self._classifier.predict(features)
36+
self.last_result = prediction
37+
return prediction
38+
39+
def get_last_result(self,skeleton_window: list):
40+
"""Returns predicted last prediction"""
41+
return self.last_result
42+
43+
def get_win_len(self):
44+
return self._win_len
45+
46+
47+
class SiMBAClassifier():
48+
"""SiMBA base class for simple behavior classification trigger. Loads pretrained classifier, gets passed features
49+
from FeatureExtractor. Returns probability of prediction that can be incorporated into triggers."""
50+
51+
def __init__(self):
52+
self._classifier = self.load_classifier(PATH_TO_CLASSIFIER)
53+
self.last_result = 0.0
54+
55+
@staticmethod
56+
def load_classifier(path_to_sav):
57+
"""Load saved classifier"""
58+
import pickle
59+
file = open(path_to_sav,'rb')
60+
classifier = pickle.load(file)
61+
file.close()
62+
return classifier
63+
64+
def classify(self,features):
65+
"""predicts motif probability from features"""
66+
prediction = self._classifier.predict_proba(features)
67+
probability = prediction.item(1)
68+
self.last_result = probability
69+
return probability
70+
71+
def get_last_result(self,skeleton_window: list):
72+
"""Returns predicted last prediction"""
73+
return self.last_result
74+
75+
76+
def example_classifier_run(input_classification_q: mp.Queue,output_classification_q: mp.Queue):
77+
classifier = Classifier() # initialize classifier
78+
while True:
79+
features = None
80+
if input_classification_q.full():
81+
features = input_classification_q.get()
82+
if features is not None:
83+
last_prob = classifier.classify(features)
84+
output_classification_q.put(last_prob)
85+
else:
86+
pass
87+
88+
89+
def simba_classifier_run(input_q: mp.Queue,output_q: mp.Queue):
90+
classifier = SiMBAClassifier() # initialize classifier
91+
while True:
92+
features = None
93+
if input_q.full():
94+
features = input_q.get()
95+
if features is not None:
96+
start_time = time.time()
97+
last_prob = classifier.classify(features)
98+
output_q.put((last_prob))
99+
end_time = time.time()
100+
# print("Classification time: {:.2f} msec".format((end_time-start_time)*1000))
101+
else:
102+
pass
103+
104+
105+
class ClassifierProcess:
106+
"""
107+
Class to help work with protocol function in multiprocessing
108+
Modified from stimulus_process.py
109+
"""
110+
111+
def __init__(self):
112+
"""
113+
Setting up the three queues and the process itself
114+
"""
115+
self.input_queue = mp.Queue(1)
116+
self.output_queue = mp.Queue(1)
117+
self._classification_process = None
118+
self._running = False
119+
self._classification_process = mp.Process(target=example_classifier_run,args=(self.input_queue,
120+
self.output_queue))
121+
122+
def start(self):
123+
"""
124+
Starting the process
125+
"""
126+
self._classification_process.start()
127+
128+
def end(self):
129+
"""
130+
Ending the process
131+
"""
132+
self.input_queue.close()
133+
self.output_queue.close()
134+
self._classification_process.terminate()
135+
136+
def get_status(self):
137+
"""
138+
Getting current status of the running protocol
139+
"""
140+
return self._running
141+
142+
def pass_features(self,features):
143+
"""
144+
Passing the features to the process
145+
"""
146+
if self.input_queue.empty():
147+
self.input_queue.put(features)
148+
self._running = True
149+
150+
def get_result(self):
151+
"""
152+
Getting result from the process
153+
"""
154+
if self.output_queue.full():
155+
self._running = False
156+
return self.output_queue.get()
157+
158+
159+
class SimbaClassifier_Process(ClassifierProcess):
160+
161+
def __init__(self):
162+
super().__init__()
163+
self.input_queue = mp.Queue(1)
164+
self.output_queue = mp.Queue(1)
165+
self._classification_process = mp.Process(target=simba_classifier_run,args=(self.input_queue,
166+
self.output_queue))
167+
168+
169+
"""Processing pool for classification"""
170+
171+
172+
def example_classifier_pool_run(input_q: mp.Queue,output_q: mp.Queue):
173+
classifier = Classifier() # initialize classifier
174+
while True:
175+
features = None
176+
feature_id = 0
177+
if input_q.full():
178+
features,feature_id = input_q.get()
179+
if features is not None:
180+
start_time = time.time()
181+
last_prob = classifier.classify(features)
182+
output_q.put((last_prob,feature_id))
183+
end_time = time.time()
184+
# print("Classification time: {:.2f} msec".format((end_time-start_time)*1000))
185+
else:
186+
pass
187+
188+
189+
def simba_classifier_pool_run(input_q: mp.Queue,output_q: mp.Queue):
190+
classifier = SiMBAClassifier() # initialize classifier
191+
while True:
192+
features = None
193+
feature_id = 0
194+
if input_q.full():
195+
features,feature_id = input_q.get()
196+
if features is not None:
197+
start_time = time.time()
198+
last_prob = classifier.classify(features)
199+
output_q.put((last_prob,feature_id))
200+
end_time = time.time()
201+
# print("Classification time: {:.2f} msec".format((end_time-start_time)*1000))
202+
else:
203+
pass
204+
205+
206+
class ClassifierProcessPool:
207+
"""
208+
Class to help work with protocol function in multiprocessing
209+
spawns a pool of processes that tackle the frame-by-frame issue.
210+
"""
211+
212+
def __init__(self,pool_size: int):
213+
"""
214+
Setting up the three queues and the process itself
215+
"""
216+
self._running = False
217+
self._pool_size = pool_size
218+
self._process_pool = self.initiate_pool(example_classifier_pool_run,pool_size)
219+
220+
@staticmethod
221+
def initiate_pool(process_func,pool_size: int):
222+
"""creates list of process dictionaries that are used to classify features
223+
:param process_func: function that will be passed to mp.Process object, should contain classification
224+
:param pool_size: number of processes created by function, should be enough to enable constistent feature classification without skipped frames
225+
:"""
226+
process_pool = []
227+
228+
for i in range(pool_size):
229+
input_queue = mp.Queue(1)
230+
output_queue = mp.Queue(1)
231+
classification_process = mp.Process(target=process_func,args=(input_queue,output_queue))
232+
process_pool.append(
233+
dict(process=classification_process,input=input_queue,output=output_queue,running=False))
234+
235+
return process_pool
236+
237+
def start(self):
238+
"""
239+
Starting all processes
240+
"""
241+
for process in self._process_pool:
242+
process['process'].start()
243+
244+
def end(self):
245+
"""
246+
Ending all processes
247+
"""
248+
for process in self._process_pool:
249+
process['input'].close()
250+
process['output'].close()
251+
process['process'].terminate()
252+
253+
def get_status(self):
254+
"""
255+
Getting current status of the running protocol
256+
"""
257+
return self._running
258+
259+
def pass_features(self,features: tuple,debug: bool = False):
260+
"""
261+
Passing the features to the process pool
262+
First checks if processes got their first input yet
263+
Checks which process is already done and then gives new input
264+
breaks for loop if an idle process was found
265+
:param features tuple: feature list from feature extractor and feature_id used to identify processing sequence
266+
:param debug bool: reporting of process + feature id to identify discrepancies in processing sequence
267+
"""
268+
for process in self._process_pool:
269+
if not process['running']:
270+
if process['input'].empty():
271+
process['input'].put(features)
272+
process['running'] = True
273+
if debug:
274+
print('First Input',process['process'].name,'ID: ' + str(features[1]))
275+
break
276+
277+
elif process['input'].empty() and process['output'].full():
278+
process['input'].put(features)
279+
if debug:
280+
print('Input',process['process'].name,'ID: ' + str(features[1]))
281+
break
282+
283+
def get_result(self,debug: bool = False):
284+
"""
285+
Getting result from the process pool
286+
takes result from first finished process in pool
287+
:param debug bool: reporting of process + feature id to identify discrepancies in processing sequence
288+
289+
"""
290+
result = (None,0)
291+
for process in self._process_pool:
292+
if process['output'].full():
293+
result = process['output'].get()
294+
if debug:
295+
print('Output',process['process'].name,'ID: ' + str(result[1]))
296+
break
297+
return result
298+
299+
300+
class SimbaProcessPool(ClassifierProcessPool):
301+
"""
302+
Class to help work with protocol function in multiprocessing
303+
spawns a pool of processes that tackle the frame-by-frame issue.
304+
"""
305+
306+
def __init__(self,pool_size: int):
307+
"""
308+
Setting up the three queues and the process itself
309+
"""
310+
super().__init__(pool_size)
311+
self._process_pool = super().initiate_pool(simba_classifier_pool_run,pool_size)

0 commit comments

Comments
 (0)