Skip to content

Commit bae2cac

Browse files
authored
closes #26
#26 Thanks @ehsainit
1 parent 51ad7f4 commit bae2cac

File tree

1 file changed

+16
-0
lines changed

1 file changed

+16
-0
lines changed

dlclive/pose.py

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -78,6 +78,22 @@ def argmax_pose_predict(scmap, offmat, stride):
7878
pose.append(np.hstack((pos_f8[::-1], [scmap[maxloc][joint_idx]])))
7979
return np.array(pose)
8080

81+
def get_top_values(scmap, n_top=5):
82+
batchsize, ny, nx, num_joints = scmap.shape
83+
scmap_flat = scmap.reshape(batchsize, nx * ny, num_joints)
84+
if n_top == 1:
85+
scmap_top = np.argmax(scmap_flat, axis=1)[None]
86+
else:
87+
scmap_top = np.argpartition(scmap_flat, -n_top, axis=1)[:, -n_top:]
88+
for ix in range(batchsize):
89+
vals = scmap_flat[ix, scmap_top[ix], np.arange(num_joints)]
90+
arg = np.argsort(-vals, axis=0)
91+
scmap_top[ix] = scmap_top[ix, arg, np.arange(num_joints)]
92+
scmap_top = scmap_top.swapaxes(0, 1)
93+
94+
Y, X = np.unravel_index(scmap_top, (ny, nx))
95+
return Y, X
96+
8197

8298
def multi_pose_predict(scmap, locref, stride, num_outputs):
8399
Y, X = get_top_values(scmap[None], num_outputs)

0 commit comments

Comments
 (0)