@@ -135,6 +135,27 @@ def find_local_peaks_new(scoremap: np.ndarray, local_reference: np.ndarray, anim
135135 all_peaks [joint ].append ([tuple (coordinates .astype (int )), joint ])
136136 return all_peaks
137137
138+ def filter_pose_by_likelihood (pose , threshold : float = 0.1 ):
139+ """
140+ !!!FOR NOW THIS FUNCTION SETS THEM TO (0,0) due to missing float dtype in coordinate handling. Will be updated later on.
141+ filters pose estimation by likelihood threshold. Estimates below threshold are set to NaN and handled downstream
142+ of this function in calculate skeletons.
143+ :param pose: pose estimation (e.g., from DLC)
144+ :param threshold: likelihood threshold to filter by
145+ :return filtered_pose: pose estimation filtered by likelihood (may contain NaN)
146+ """
147+ #TODO: Update to NaN
148+
149+ filtered_pose = pose .copy ()
150+
151+ for num , bp in enumerate (filtered_pose ):
152+ if bp [2 ] < threshold :
153+ #set new threshold to "2" (number outside of normal range to signify filter
154+ #TODO: This should be np.NaN not 0,0
155+ filtered_pose [num ] = np .array ([0 , 0 , 2 ])
156+
157+ return filtered_pose
158+
138159
139160def calculate_dlstream_skeletons (peaks : dict , animals_number : int ) -> list :
140161 """
@@ -281,7 +302,7 @@ def transform_2skeleton(pose):
281302 for bp in pose :
282303 skeleton [ALL_BODYPARTS [counter ]] = tuple (np .array (bp [0 :2 ], dtype = int ))
283304 counter += 1
284- except KeyError :
305+ except ( KeyError , IndexError ) as e :
285306 skeleton = dict ()
286307 counter = 0
287308 for bp in pose :
@@ -296,6 +317,36 @@ def transform_2pose(skeleton):
296317 return pose
297318
298319
320+ def arrange_flatskeleton (skeleton , n_animals , n_bp_animal , switch_dict ):
321+ """changes sequence of bodypart sets (skeletons) in multi animal tracking with flat skeleton output (multiple animals in single skeleton) by switching position of pairs.
322+ E.g. in pose estimation with different fur colors. Note: When switching muliple animals the new position of the previous switches will be used.
323+ :param skeleton: flat skeleton of pose estimation in style {bp1: (x,y), bp2: (x2,y2) ...}
324+ :param n_animals: number of animals in total represented by skeleton
325+ :param n_bp_animal: number of bodyparts per animal in skeleton
326+ :param switch_dict: dictionary containing position of bodypart set (animal) in flat skeleton as key and bp set to exchange with as value.
327+ e.g. switch_dict = dict(1 = 2, 3 = 4)
328+ :return: skeleton with new order
329+ """
330+ flat_pose = transform_2pose (skeleton )
331+ ra_dict = {}
332+ #slicing the animals out
333+ for num_animal in range (n_animals ):
334+ ra_dict [num_animal ] = flat_pose [num_animal * n_bp_animal :num_animal * n_bp_animal + n_bp_animal ]
335+ #switching positions
336+ for orig_pos , switch_pos in switch_dict .items ():
337+ #extract old
338+ orig = ra_dict [orig_pos ]
339+ switch = ra_dict [switch_pos ]
340+ #set to new position
341+ ra_dict [orig_pos ] = switch
342+ ra_dict [switch_pos ] = orig
343+ #extracting pose
344+ arranged_pose = np .array ([* ra_dict .values ()]).reshape (flat_pose .shape )
345+ #transforming it to skeleton
346+ flat_skeleton = transform_2skeleton (arranged_pose )
347+ return flat_skeleton
348+
349+
299350def calculate_skeletons_dlc_live (pose ) -> list :
300351 """
301352 Creating skeletons from given pose
@@ -314,7 +365,12 @@ def calculate_skeletons(peaks: dict, animals_number: int) -> list:
314365 adaptive to chosen model origin
315366 """
316367 if MODEL_ORIGIN == 'DLC' :
317- animal_skeletons = calculate_dlstream_skeletons (peaks , animals_number )
368+ #TODO: remove alterations from SIMBA tests
369+ #animal_skeletons = calculate_dlstream_skeletons(peaks, animals_number)
370+ #peaks = filter_pose_by_likelihood(peaks, threshold= 0.6)
371+ flat_skeleton = transform_2skeleton (peaks )
372+ animal_skeletons = [arrange_flatskeleton (flat_skeleton ,2 ,7 ,{0 :1 })]
373+ #animal_skeletons = calculate_skeletons_dlc_live(peaks)
318374
319375 elif MODEL_ORIGIN == 'MADLC' :
320376 animal_skeletons = calculate_ma_skeletons (peaks , animals_number )
0 commit comments