Skip to content

Commit e94a646

Browse files
committed
added draft of filter function for pose estimation + rearranging of animals in flattened multi animal skeleton
1 parent b81b5c5 commit e94a646

File tree

1 file changed

+58
-2
lines changed

1 file changed

+58
-2
lines changed

utils/poser.py

Lines changed: 58 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -135,6 +135,27 @@ def find_local_peaks_new(scoremap: np.ndarray, local_reference: np.ndarray, anim
135135
all_peaks[joint].append([tuple(coordinates.astype(int)), joint])
136136
return all_peaks
137137

138+
def filter_pose_by_likelihood(pose, threshold: float = 0.1):
139+
"""
140+
!!!FOR NOW THIS FUNCTION SETS THEM TO (0,0) due to missing float dtype in coordinate handling. Will be updated later on.
141+
filters pose estimation by likelihood threshold. Estimates below threshold are set to NaN and handled downstream
142+
of this function in calculate skeletons.
143+
:param pose: pose estimation (e.g., from DLC)
144+
:param threshold: likelihood threshold to filter by
145+
:return filtered_pose: pose estimation filtered by likelihood (may contain NaN)
146+
"""
147+
#TODO: Update to NaN
148+
149+
filtered_pose = pose.copy()
150+
151+
for num, bp in enumerate(filtered_pose):
152+
if bp[2] < threshold:
153+
#set new threshold to "2" (number outside of normal range to signify filter
154+
#TODO: This should be np.NaN not 0,0
155+
filtered_pose[num] = np.array([0, 0, 2])
156+
157+
return filtered_pose
158+
138159

139160
def calculate_dlstream_skeletons(peaks: dict, animals_number: int) -> list:
140161
"""
@@ -281,7 +302,7 @@ def transform_2skeleton(pose):
281302
for bp in pose:
282303
skeleton[ALL_BODYPARTS[counter]] = tuple(np.array(bp[0:2], dtype=int))
283304
counter += 1
284-
except KeyError:
305+
except (KeyError, IndexError) as e:
285306
skeleton = dict()
286307
counter = 0
287308
for bp in pose:
@@ -296,6 +317,36 @@ def transform_2pose(skeleton):
296317
return pose
297318

298319

320+
def arrange_flatskeleton(skeleton, n_animals, n_bp_animal, switch_dict):
321+
"""changes sequence of bodypart sets (skeletons) in multi animal tracking with flat skeleton output (multiple animals in single skeleton) by switching position of pairs.
322+
E.g. in pose estimation with different fur colors. Note: When switching muliple animals the new position of the previous switches will be used.
323+
:param skeleton: flat skeleton of pose estimation in style {bp1: (x,y), bp2: (x2,y2) ...}
324+
:param n_animals: number of animals in total represented by skeleton
325+
:param n_bp_animal: number of bodyparts per animal in skeleton
326+
:param switch_dict: dictionary containing position of bodypart set (animal) in flat skeleton as key and bp set to exchange with as value.
327+
e.g. switch_dict = dict(1 = 2, 3 = 4)
328+
:return: skeleton with new order
329+
"""
330+
flat_pose = transform_2pose(skeleton)
331+
ra_dict = {}
332+
#slicing the animals out
333+
for num_animal in range(n_animals):
334+
ra_dict[num_animal] = flat_pose[num_animal*n_bp_animal:num_animal*n_bp_animal+n_bp_animal]
335+
#switching positions
336+
for orig_pos, switch_pos in switch_dict.items():
337+
#extract old
338+
orig = ra_dict[orig_pos]
339+
switch = ra_dict[switch_pos]
340+
#set to new position
341+
ra_dict[orig_pos] = switch
342+
ra_dict[switch_pos] = orig
343+
#extracting pose
344+
arranged_pose = np.array([*ra_dict.values()]).reshape(flat_pose.shape)
345+
#transforming it to skeleton
346+
flat_skeleton = transform_2skeleton(arranged_pose)
347+
return flat_skeleton
348+
349+
299350
def calculate_skeletons_dlc_live(pose) -> list:
300351
"""
301352
Creating skeletons from given pose
@@ -314,7 +365,12 @@ def calculate_skeletons(peaks: dict, animals_number: int) -> list:
314365
adaptive to chosen model origin
315366
"""
316367
if MODEL_ORIGIN == 'DLC':
317-
animal_skeletons = calculate_dlstream_skeletons(peaks, animals_number)
368+
#TODO: remove alterations from SIMBA tests
369+
#animal_skeletons = calculate_dlstream_skeletons(peaks, animals_number)
370+
#peaks = filter_pose_by_likelihood(peaks, threshold= 0.6)
371+
flat_skeleton = transform_2skeleton(peaks)
372+
animal_skeletons = [arrange_flatskeleton(flat_skeleton,2,7,{0:1})]
373+
#animal_skeletons = calculate_skeletons_dlc_live(peaks)
318374

319375
elif MODEL_ORIGIN == 'MADLC':
320376
animal_skeletons = calculate_ma_skeletons(peaks, animals_number)

0 commit comments

Comments
 (0)