We read every piece of feedback, and take your input very seriously.
To see all available qualifiers, see our documentation.
There was an error while loading. Please reload this page.
1 parent 51ad7f4 commit bae2cacCopy full SHA for bae2cac
dlclive/pose.py
@@ -78,6 +78,22 @@ def argmax_pose_predict(scmap, offmat, stride):
78
pose.append(np.hstack((pos_f8[::-1], [scmap[maxloc][joint_idx]])))
79
return np.array(pose)
80
81
+def get_top_values(scmap, n_top=5):
82
+ batchsize, ny, nx, num_joints = scmap.shape
83
+ scmap_flat = scmap.reshape(batchsize, nx * ny, num_joints)
84
+ if n_top == 1:
85
+ scmap_top = np.argmax(scmap_flat, axis=1)[None]
86
+ else:
87
+ scmap_top = np.argpartition(scmap_flat, -n_top, axis=1)[:, -n_top:]
88
+ for ix in range(batchsize):
89
+ vals = scmap_flat[ix, scmap_top[ix], np.arange(num_joints)]
90
+ arg = np.argsort(-vals, axis=0)
91
+ scmap_top[ix] = scmap_top[ix, arg, np.arange(num_joints)]
92
+ scmap_top = scmap_top.swapaxes(0, 1)
93
+
94
+ Y, X = np.unravel_index(scmap_top, (ny, nx))
95
+ return Y, X
96
97
98
def multi_pose_predict(scmap, locref, stride, num_outputs):
99
Y, X = get_top_values(scmap[None], num_outputs)
0 commit comments