1818from scipy .ndimage .filters import maximum_filter
1919
2020from utils .analysis import calculate_distance
21- from utils .configloader import MODEL_ORIGIN , MODEL_NAME , MODEL_PATH , ALL_BODYPARTS , FLATTEN_MA , HANDLE_MISSING
21+ from utils .configloader import MODEL_ORIGIN , MODEL_NAME , MODEL_PATH , ALL_BODYPARTS , FLATTEN_MA , SPLIT_MA ,\
22+ HANDLE_MISSING , ANIMALS_NUMBER
2223
2324# suppressing unnecessary warnings
2425import warnings
6162 from utils .configloader import MODEL_PATH
6263
6364
65+ class SkeletonError (Exception ):
66+ """Custom expection to be raised when issues with the skeleton is not received"""
67+
68+
6469def load_deeplabcut ():
6570 """
6671 Loads TensorFlow with predefined in config DeepLabCut model
@@ -292,11 +297,13 @@ def load_dpk():
292297def load_dlc_live ():
293298 return DLCLive (MODEL_PATH )
294299
300+
295301def load_sleap ():
296302 model = load_model (MODEL_PATH )
297303 model .inference_model
298304 return model .inference_model
299305
306+
300307def flatten_maDLC_skeletons (skeletons ):
301308 """Flattens maDLC multi skeletons into one skeleton to simulate dlc output
302309 where animals are not identical e.g. for animals with different fur colors (SIMBA)"""
@@ -308,6 +315,22 @@ def flatten_maDLC_skeletons(skeletons):
308315 return [flat_skeletons ]
309316
310317
318+ def split_flat_skeleton (skeletons ):
319+ """Splits flat multi skeletons (e.g. from flatten_maDLCskeleton) into seperate skeleton to simulate output
320+ where animals are identity tracked (e.g. SLEAP)"""
321+ flat_skeletons = skeletons [0 ]
322+ split_skeletons = []
323+ bp_per_animal , remainder = divmod (len (flat_skeletons ), ANIMALS_NUMBER )
324+ if remainder > 0 :
325+ raise SkeletonError (f'The number of body parts ({ len (flat_skeletons )} ) cannot be split equally into { ANIMALS_NUMBER } animals.' )
326+ else :
327+ for animal in range (ANIMALS_NUMBER ):
328+ single_skeleton = list (flat_skeletons .keys ())[bp_per_animal * animal :bp_per_animal * animal + bp_per_animal ]
329+ split_skeletons .append ({x : flat_skeletons [x ] for x in flat_skeletons if x in single_skeleton })
330+
331+ return split_skeletons
332+
333+
311334def transform_2skeleton (pose ):
312335 """
313336 Transforms pose estimation into DLStream style "skeleton" posture.
@@ -346,6 +369,9 @@ def handle_missing_bp(animal_skeletons: list):
346369
347370 :param: animal_skeletons: list of skeletons returned by calculate skeleton
348371 :return animal_skeleton with handled missing values"""
372+
373+ #TODO: Handle missing instance in multiple animal approach to keep identity safe!
374+
349375 for skeleton in animal_skeletons :
350376 for bodypart , coordinates in skeleton .items ():
351377 np_coords = np .array ((coordinates ))
@@ -391,21 +417,28 @@ def calculate_skeletons(peaks: dict, animals_number: int) -> list:
391417 """
392418 if MODEL_ORIGIN == 'DLC' :
393419 animal_skeletons = calculate_dlstream_skeletons (peaks , animals_number )
420+ if SPLIT_MA :
421+ animal_skeletons = split_flat_skeleton (animal_skeletons )
394422
395423 elif MODEL_ORIGIN == 'MADLC' :
396424 animal_skeletons = calculate_ma_skeletons (peaks , animals_number )
397425 if FLATTEN_MA :
398426 animal_skeletons = flatten_maDLC_skeletons (animal_skeletons )
399427
400428 elif MODEL_ORIGIN == 'DLC-LIVE' or MODEL_ORIGIN == 'DEEPPOSEKIT' :
401- if animals_number != 1 :
402- raise ValueError ('Multiple animals are currently not supported by DLC-LIVE.'
403- ' If you are using differently colored animals, please refer to the bodyparts directly.' )
404429 animal_skeletons = calculate_skeletons_dlc_live (peaks )
430+ if animals_number != 1 and not SPLIT_MA :
431+ raise SkeletonError ('Multiple animals are currently not supported by DLC-LIVE.'
432+ ' If you are using differently colored animals, please refer to the bodyparts directly (as a flattened skeleton) or use SPLIT_MA in the advanced settings.' )
433+ if SPLIT_MA :
434+ animal_skeletons = split_flat_skeleton (animal_skeletons )
435+
405436 elif MODEL_ORIGIN == 'SLEAP' :
406437 animal_skeletons = calculate_sleap_skeletons (peaks , animals_number )
407438 if FLATTEN_MA :
408439 animal_skeletons = flatten_maDLC_skeletons (animal_skeletons )
440+ elif SPLIT_MA :
441+ animal_skeletons = split_flat_skeleton (animal_skeletons )
409442
410443 animal_skeletons = handle_missing_bp (animal_skeletons )
411444
0 commit comments