1010import time
1111import pickle
1212
13- from utils .configloader import PATH_TO_CLASSIFIER
13+ from utils .configloader import PATH_TO_CLASSIFIER , TIME_WINDOW
14+ from experiments .custom .featureextraction import SimbaFeatureExtractor , SimbaFeatureExtractorStandard14bp , BsoidFeatureExtractor
1415
1516
1617class Classifier :
@@ -82,6 +83,7 @@ class SiMBAClassifier:
8283 def __init__ (self ):
8384 self ._classifier = self .load_classifier (PATH_TO_CLASSIFIER )
8485 self .last_result = 0.0
86+ self ._pure = self ._check_pure ()
8587
8688 @staticmethod
8789 def load_classifier (path_to_sav ):
@@ -92,10 +94,22 @@ def load_classifier(path_to_sav):
9294 file .close ()
9395 return classifier
9496
97+ def _check_pure (self ):
98+ if 'pure' in str (self ._classifier ):
99+ return True
100+ else :
101+ return False
102+
95103 def classify (self , features ):
96104 """predicts motif probability from features"""
97- prediction = self ._classifier .predict_proba (features )
98- probability = prediction .item (1 )
105+ if self ._pure :
106+ # pure-predict needs a list instead of a numpy array
107+ prediction = self ._classifier .predict_proba (list (features ))
108+ # pure-predict returns a nested list
109+ probability = prediction [0 ][1 ]
110+ else :
111+ prediction = self ._classifier .predict_proba (features )
112+ probability = prediction .item (1 )
99113 self .last_result = probability
100114 return probability
101115
@@ -306,7 +320,6 @@ def simba_classifier_pool_run(input_q: mp.Queue, output_q: mp.Queue):
306320 else :
307321 pass
308322
309-
310323def pure_simba_classifier_pool_run (input_q : mp .Queue , output_q : mp .Queue ):
311324 classifier = PureSiMBAClassifier () # initialize classifier
312325 while True :
@@ -495,3 +508,245 @@ def __init__(self, pool_size: int):
495508 """
496509 super ().__init__ (pool_size )
497510 self ._process_pool = super ().initiate_pool (bsoid_classifier_pool_run , pool_size )
511+
512+
513+ """Feature Extraction and Classification in the same pool"""
514+
515+
516+
517+ def example_feat_classifier_pool_run (input_q : mp .Queue , output_q : mp .Queue ):
518+ feature_extractor = SimbaFeatureExtractor (TIME_WINDOW )
519+ classifier = Classifier () # initialize classifier
520+ while True :
521+ skel_time_window = None
522+ feature_id = 0
523+ if input_q .full ():
524+ skel_time_window , feature_id = input_q .get ()
525+ if skel_time_window is not None :
526+ start_time = time .time ()
527+ features = feature_extractor .extract_features (skel_time_window )
528+ last_prob = classifier .classify (features )
529+ output_q .put ((last_prob , feature_id ))
530+ end_time = time .time ()
531+ # print("Classification time: {:.2f} msec".format((end_time-start_time)*1000))
532+ else :
533+ pass
534+
535+
536+ def simba_feat_classifier_pool_run (input_q : mp .Queue , output_q : mp .Queue ):
537+ feature_extractor = SimbaFeatureExtractorStandard14bp (TIME_WINDOW )
538+ classifier = SiMBAClassifier () # initialize classifier
539+ while True :
540+ skel_time_window = None
541+ feature_id = 0
542+ if input_q .full ():
543+ skel_time_window , feature_id = input_q .get ()
544+ if skel_time_window is not None :
545+ start_time = time .time ()
546+ features = feature_extractor .extract_features (skel_time_window )
547+ last_prob = classifier .classify (features )
548+ output_q .put ((last_prob , feature_id ))
549+ end_time = time .time ()
550+ # print("Classification time: {:.2f} msec".format((end_time-start_time)*1000))
551+ else :
552+ pass
553+
554+ def pure_simba_feat_classifier_pool_run (input_q : mp .Queue , output_q : mp .Queue ):
555+ feature_extractor = SimbaFeatureExtractorStandard14bp (TIME_WINDOW )
556+ #feature_extractor = SimbaFeatureExtractor(TIME_WINDOW)
557+ classifier = SiMBAClassifier () # initialize classifier
558+ while True :
559+ skel_time_window = None
560+ feature_id = 0
561+ if input_q .full ():
562+ skel_time_window , feature_id = input_q .get ()
563+ if skel_time_window is not None :
564+ start_time = time .time ()
565+ features = feature_extractor .extract_features (skel_time_window )
566+ end_time = time .time ()
567+ print (
568+ "Feature extraction time: {:.2f} msec" .format (
569+ (end_time - start_time ) * 1000
570+ )
571+ )
572+
573+ last_prob = classifier .classify (features )
574+ output_q .put ((last_prob , feature_id ))
575+ end_time2 = time .time ()
576+ print (
577+ "Classification time: {:.2f} msec" .format (
578+ (end_time2 - end_time ) * 1000
579+ )
580+ )
581+ else :
582+ pass
583+
584+
585+ def bsoid_feat_classifier_pool_run (input_q : mp .Queue , output_q : mp .Queue ):
586+ feature_extractor = BsoidFeatureExtractor (TIME_WINDOW )
587+ classifier = BsoidClassifier () # initialize classifier
588+ while True :
589+ skel_time_window = None
590+ feature_id = 0
591+ if input_q .full ():
592+ skel_time_window , feature_id = input_q .get ()
593+ if skel_time_window is not None :
594+ start_time = time .time ()
595+ features = feature_extractor .extract_features (skel_time_window )
596+ last_prob = classifier .classify (features )
597+ output_q .put ((last_prob , feature_id ))
598+ end_time = time .time ()
599+ # print("Classification time: {:.2f} msec".format((end_time-start_time)*1000))
600+ # print("Feature ID: "+ feature_id)
601+ else :
602+ pass
603+
604+
605+ class FeatureExtractionClassifierProcessPool :
606+ """
607+ Class to help work with protocol function in multiprocessing
608+ spawns a pool of processes that tackle the frame-by-frame issue.
609+ """
610+
611+ def __init__ (self , pool_size : int ):
612+ """
613+ Setting up the three queues and the process itself
614+ """
615+ self ._running = False
616+ self ._pool_size = pool_size
617+ self ._process_pool = self .initiate_pool (example_feat_classifier_pool_run , pool_size )
618+
619+ @staticmethod
620+ def initiate_pool (process_func , pool_size : int ):
621+ """creates list of process dictionaries that are used to classify features
622+ :param process_func: function that will be passed to mp.Process object, should contain classification
623+ :param pool_size: number of processes created by function, should be enough to enable constistent feature classification without skipped frames
624+ :"""
625+ process_pool = []
626+
627+ for i in range (pool_size ):
628+ input_queue = mp .Queue (1 )
629+ output_queue = mp .Queue (1 )
630+ classification_process = mp .Process (
631+ target = process_func , args = (input_queue , output_queue )
632+ )
633+ process_pool .append (
634+ dict (
635+ process = classification_process ,
636+ input = input_queue ,
637+ output = output_queue ,
638+ running = False ,
639+ )
640+ )
641+
642+ return process_pool
643+
644+ def start (self ):
645+ """
646+ Starting all processes
647+ """
648+ for process in self ._process_pool :
649+ process ["process" ].start ()
650+
651+ def end (self ):
652+ """
653+ Ending all processes
654+ """
655+ for process in self ._process_pool :
656+ process ["input" ].close ()
657+ process ["output" ].close ()
658+ process ["process" ].terminate ()
659+
660+ def get_status (self ):
661+ """
662+ Getting current status of the running protocol
663+ """
664+ return self ._running
665+
666+ def pass_time_window (self , skel_time_window : tuple , debug : bool = False ):
667+ """
668+ Passing the features to the process pool
669+ First checks if processes got their first input yet
670+ Checks which process is already done and then gives new input
671+ breaks for loop if an idle process was found
672+ :param features tuple: feature list from feature extractor and feature_id used to identify processing sequence
673+ :param debug bool: reporting of process + feature id to identify discrepancies in processing sequence
674+ """
675+ for process in self ._process_pool :
676+ if not process ["running" ]:
677+ if process ["input" ].empty ():
678+ process ["input" ].put (skel_time_window )
679+ process ["running" ] = True
680+ if debug :
681+ print (
682+ "First Input" ,
683+ process ["process" ].name ,
684+ "ID: " + str (skel_time_window [1 ]),
685+ )
686+ break
687+
688+ elif process ["input" ].empty () and process ["output" ].full ():
689+ process ["input" ].put (skel_time_window )
690+ if debug :
691+ print ("Input" , process ["process" ].name , "ID: " + str (skel_time_window [1 ]))
692+ break
693+
694+ def get_result (self , debug : bool = False ):
695+ """
696+ Getting result from the process pool
697+ takes result from first finished process in pool
698+ :param debug bool: reporting of process + feature id to identify discrepancies in processing sequence
699+
700+ """
701+ result = (None , 0 )
702+ for process in self ._process_pool :
703+ if process ["output" ].full ():
704+ result = process ["output" ].get ()
705+ if debug :
706+ print ("Output" , process ["process" ].name , "ID: " + str (result [1 ]))
707+ break
708+ return result
709+
710+
711+ class PureFeatSimbaProcessPool (FeatureExtractionClassifierProcessPool ):
712+ """
713+ Class to help work with protocol function in multiprocessing
714+ spawns a pool of processes that tackle the frame-by-frame issue.
715+ """
716+
717+ def __init__ (self , pool_size : int ):
718+ """
719+ Setting up the three queues and the process itself
720+ """
721+ super ().__init__ (pool_size )
722+ self ._process_pool = super ().initiate_pool (
723+ pure_simba_feat_classifier_pool_run , pool_size
724+ )
725+
726+
727+ class FeatSimbaProcessPool (FeatureExtractionClassifierProcessPool ):
728+ """
729+ Class to help work with protocol function in multiprocessing
730+ spawns a pool of processes that tackle the frame-by-frame issue.
731+ """
732+
733+ def __init__ (self , pool_size : int ):
734+ """
735+ Setting up the three queues and the process itself
736+ """
737+ super ().__init__ (pool_size )
738+ self ._process_pool = super ().initiate_pool (simba_feat_classifier_pool_run , pool_size )
739+
740+
741+ class FeatBsoidProcessPool (ClassifierProcessPool ):
742+ """
743+ Class to help work with protocol function in multiprocessing
744+ spawns a pool of processes that tackle the frame-by-frame issue.
745+ """
746+
747+ def __init__ (self , pool_size : int ):
748+ """
749+ Setting up the three queues and the process itself
750+ """
751+ super ().__init__ (pool_size )
752+ self ._process_pool = super ().initiate_pool (bsoid_feat_classifier_pool_run , pool_size )
0 commit comments