Skip to content

Commit 0c8559f

Browse files
committed
updating SIMBA featureextraction.py and classifier.py
merging featureextraction.py and classifier.py into same pool
1 parent 2708ef2 commit 0c8559f

File tree

4 files changed

+1287
-9
lines changed

4 files changed

+1287
-9
lines changed

experiments/custom/classifier.py

Lines changed: 259 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,8 @@
1010
import time
1111
import pickle
1212

13-
from utils.configloader import PATH_TO_CLASSIFIER
13+
from utils.configloader import PATH_TO_CLASSIFIER, TIME_WINDOW
14+
from experiments.custom.featureextraction import SimbaFeatureExtractor, SimbaFeatureExtractorStandard14bp, BsoidFeatureExtractor
1415

1516

1617
class Classifier:
@@ -82,6 +83,7 @@ class SiMBAClassifier:
8283
def __init__(self):
8384
self._classifier = self.load_classifier(PATH_TO_CLASSIFIER)
8485
self.last_result = 0.0
86+
self._pure = self._check_pure()
8587

8688
@staticmethod
8789
def load_classifier(path_to_sav):
@@ -92,10 +94,22 @@ def load_classifier(path_to_sav):
9294
file.close()
9395
return classifier
9496

97+
def _check_pure(self):
98+
if 'pure' in str(self._classifier):
99+
return True
100+
else:
101+
return False
102+
95103
def classify(self, features):
96104
"""predicts motif probability from features"""
97-
prediction = self._classifier.predict_proba(features)
98-
probability = prediction.item(1)
105+
if self._pure:
106+
# pure-predict needs a list instead of a numpy array
107+
prediction = self._classifier.predict_proba(list(features))
108+
# pure-predict returns a nested list
109+
probability = prediction[0][1]
110+
else:
111+
prediction = self._classifier.predict_proba(features)
112+
probability = prediction.item(1)
99113
self.last_result = probability
100114
return probability
101115

@@ -306,7 +320,6 @@ def simba_classifier_pool_run(input_q: mp.Queue, output_q: mp.Queue):
306320
else:
307321
pass
308322

309-
310323
def pure_simba_classifier_pool_run(input_q: mp.Queue, output_q: mp.Queue):
311324
classifier = PureSiMBAClassifier() # initialize classifier
312325
while True:
@@ -495,3 +508,245 @@ def __init__(self, pool_size: int):
495508
"""
496509
super().__init__(pool_size)
497510
self._process_pool = super().initiate_pool(bsoid_classifier_pool_run, pool_size)
511+
512+
513+
"""Feature Extraction and Classification in the same pool"""
514+
515+
516+
517+
def example_feat_classifier_pool_run(input_q: mp.Queue, output_q: mp.Queue):
518+
feature_extractor = SimbaFeatureExtractor(TIME_WINDOW)
519+
classifier = Classifier() # initialize classifier
520+
while True:
521+
skel_time_window = None
522+
feature_id = 0
523+
if input_q.full():
524+
skel_time_window, feature_id = input_q.get()
525+
if skel_time_window is not None:
526+
start_time = time.time()
527+
features = feature_extractor.extract_features(skel_time_window)
528+
last_prob = classifier.classify(features)
529+
output_q.put((last_prob, feature_id))
530+
end_time = time.time()
531+
# print("Classification time: {:.2f} msec".format((end_time-start_time)*1000))
532+
else:
533+
pass
534+
535+
536+
def simba_feat_classifier_pool_run(input_q: mp.Queue, output_q: mp.Queue):
537+
feature_extractor = SimbaFeatureExtractorStandard14bp(TIME_WINDOW)
538+
classifier = SiMBAClassifier() # initialize classifier
539+
while True:
540+
skel_time_window = None
541+
feature_id = 0
542+
if input_q.full():
543+
skel_time_window, feature_id = input_q.get()
544+
if skel_time_window is not None:
545+
start_time = time.time()
546+
features = feature_extractor.extract_features(skel_time_window)
547+
last_prob = classifier.classify(features)
548+
output_q.put((last_prob, feature_id))
549+
end_time = time.time()
550+
# print("Classification time: {:.2f} msec".format((end_time-start_time)*1000))
551+
else:
552+
pass
553+
554+
def pure_simba_feat_classifier_pool_run(input_q: mp.Queue, output_q: mp.Queue):
555+
feature_extractor = SimbaFeatureExtractorStandard14bp(TIME_WINDOW)
556+
#feature_extractor = SimbaFeatureExtractor(TIME_WINDOW)
557+
classifier = SiMBAClassifier() # initialize classifier
558+
while True:
559+
skel_time_window = None
560+
feature_id = 0
561+
if input_q.full():
562+
skel_time_window, feature_id = input_q.get()
563+
if skel_time_window is not None:
564+
start_time = time.time()
565+
features = feature_extractor.extract_features(skel_time_window)
566+
end_time = time.time()
567+
print(
568+
"Feature extraction time: {:.2f} msec".format(
569+
(end_time - start_time) * 1000
570+
)
571+
)
572+
573+
last_prob = classifier.classify(features)
574+
output_q.put((last_prob, feature_id))
575+
end_time2 = time.time()
576+
print(
577+
"Classification time: {:.2f} msec".format(
578+
(end_time2 - end_time) * 1000
579+
)
580+
)
581+
else:
582+
pass
583+
584+
585+
def bsoid_feat_classifier_pool_run(input_q: mp.Queue, output_q: mp.Queue):
586+
feature_extractor = BsoidFeatureExtractor(TIME_WINDOW)
587+
classifier = BsoidClassifier() # initialize classifier
588+
while True:
589+
skel_time_window = None
590+
feature_id = 0
591+
if input_q.full():
592+
skel_time_window, feature_id = input_q.get()
593+
if skel_time_window is not None:
594+
start_time = time.time()
595+
features = feature_extractor.extract_features(skel_time_window)
596+
last_prob = classifier.classify(features)
597+
output_q.put((last_prob, feature_id))
598+
end_time = time.time()
599+
# print("Classification time: {:.2f} msec".format((end_time-start_time)*1000))
600+
# print("Feature ID: "+ feature_id)
601+
else:
602+
pass
603+
604+
605+
class FeatureExtractionClassifierProcessPool:
606+
"""
607+
Class to help work with protocol function in multiprocessing
608+
spawns a pool of processes that tackle the frame-by-frame issue.
609+
"""
610+
611+
def __init__(self, pool_size: int):
612+
"""
613+
Setting up the three queues and the process itself
614+
"""
615+
self._running = False
616+
self._pool_size = pool_size
617+
self._process_pool = self.initiate_pool(example_feat_classifier_pool_run, pool_size)
618+
619+
@staticmethod
620+
def initiate_pool(process_func, pool_size: int):
621+
"""creates list of process dictionaries that are used to classify features
622+
:param process_func: function that will be passed to mp.Process object, should contain classification
623+
:param pool_size: number of processes created by function, should be enough to enable constistent feature classification without skipped frames
624+
:"""
625+
process_pool = []
626+
627+
for i in range(pool_size):
628+
input_queue = mp.Queue(1)
629+
output_queue = mp.Queue(1)
630+
classification_process = mp.Process(
631+
target=process_func, args=(input_queue, output_queue)
632+
)
633+
process_pool.append(
634+
dict(
635+
process=classification_process,
636+
input=input_queue,
637+
output=output_queue,
638+
running=False,
639+
)
640+
)
641+
642+
return process_pool
643+
644+
def start(self):
645+
"""
646+
Starting all processes
647+
"""
648+
for process in self._process_pool:
649+
process["process"].start()
650+
651+
def end(self):
652+
"""
653+
Ending all processes
654+
"""
655+
for process in self._process_pool:
656+
process["input"].close()
657+
process["output"].close()
658+
process["process"].terminate()
659+
660+
def get_status(self):
661+
"""
662+
Getting current status of the running protocol
663+
"""
664+
return self._running
665+
666+
def pass_time_window(self, skel_time_window: tuple, debug: bool = False):
667+
"""
668+
Passing the features to the process pool
669+
First checks if processes got their first input yet
670+
Checks which process is already done and then gives new input
671+
breaks for loop if an idle process was found
672+
:param features tuple: feature list from feature extractor and feature_id used to identify processing sequence
673+
:param debug bool: reporting of process + feature id to identify discrepancies in processing sequence
674+
"""
675+
for process in self._process_pool:
676+
if not process["running"]:
677+
if process["input"].empty():
678+
process["input"].put(skel_time_window)
679+
process["running"] = True
680+
if debug:
681+
print(
682+
"First Input",
683+
process["process"].name,
684+
"ID: " + str(skel_time_window[1]),
685+
)
686+
break
687+
688+
elif process["input"].empty() and process["output"].full():
689+
process["input"].put(skel_time_window)
690+
if debug:
691+
print("Input", process["process"].name, "ID: " + str(skel_time_window[1]))
692+
break
693+
694+
def get_result(self, debug: bool = False):
695+
"""
696+
Getting result from the process pool
697+
takes result from first finished process in pool
698+
:param debug bool: reporting of process + feature id to identify discrepancies in processing sequence
699+
700+
"""
701+
result = (None, 0)
702+
for process in self._process_pool:
703+
if process["output"].full():
704+
result = process["output"].get()
705+
if debug:
706+
print("Output", process["process"].name, "ID: " + str(result[1]))
707+
break
708+
return result
709+
710+
711+
class PureFeatSimbaProcessPool(FeatureExtractionClassifierProcessPool):
712+
"""
713+
Class to help work with protocol function in multiprocessing
714+
spawns a pool of processes that tackle the frame-by-frame issue.
715+
"""
716+
717+
def __init__(self, pool_size: int):
718+
"""
719+
Setting up the three queues and the process itself
720+
"""
721+
super().__init__(pool_size)
722+
self._process_pool = super().initiate_pool(
723+
pure_simba_feat_classifier_pool_run, pool_size
724+
)
725+
726+
727+
class FeatSimbaProcessPool(FeatureExtractionClassifierProcessPool):
728+
"""
729+
Class to help work with protocol function in multiprocessing
730+
spawns a pool of processes that tackle the frame-by-frame issue.
731+
"""
732+
733+
def __init__(self, pool_size: int):
734+
"""
735+
Setting up the three queues and the process itself
736+
"""
737+
super().__init__(pool_size)
738+
self._process_pool = super().initiate_pool(simba_feat_classifier_pool_run, pool_size)
739+
740+
741+
class FeatBsoidProcessPool(ClassifierProcessPool):
742+
"""
743+
Class to help work with protocol function in multiprocessing
744+
spawns a pool of processes that tackle the frame-by-frame issue.
745+
"""
746+
747+
def __init__(self, pool_size: int):
748+
"""
749+
Setting up the three queues and the process itself
750+
"""
751+
super().__init__(pool_size)
752+
self._process_pool = super().initiate_pool(bsoid_feat_classifier_pool_run, pool_size)

0 commit comments

Comments
 (0)