@@ -44,9 +44,9 @@ def cdist(a, b, metric='euclidean'):
4444 """
4545 with tf .name_scope ("cdist" ):
4646 diffs = all_diffs (a , b )
47- if metric == 'euclidean ' :
47+ if metric == 'sqeuclidean ' :
4848 return tf .reduce_sum (tf .square (diffs ), axis = - 1 )
49- elif metric == 'sqeuclidean ' :
49+ elif metric == 'euclidean ' :
5050 return tf .sqrt (tf .reduce_sum (tf .square (diffs ), axis = - 1 ) + 1e-12 )
5151 elif metric == 'cityblock' :
5252 return tf .reduce_sum (tf .abs (diffs ), axis = - 1 )
@@ -82,10 +82,10 @@ def batch_hard(dists, pids, margin, batch_precision_at_k=None):
8282 """
8383 with tf .name_scope ("batch_hard" ):
8484 same_identity_mask = tf .equal (tf .expand_dims (pids , axis = 1 ),
85- tf .expand_dims (pids , axis = 0 ))
85+ tf .expand_dims (pids , axis = 0 ))
8686 negative_mask = tf .logical_not (same_identity_mask )
8787 positive_mask = tf .logical_xor (same_identity_mask ,
88- tf .eye (tf .shape (pids )[0 ], dtype = tf .bool ))
88+ tf .eye (tf .shape (pids )[0 ], dtype = tf .bool ))
8989
9090 furthest_positive = tf .reduce_max (dists * tf .cast (positive_mask , tf .float32 ), axis = 1 )
9191 closest_negative = tf .map_fn (lambda x : tf .reduce_min (tf .boolean_mask (x [0 ], x [1 ])),
0 commit comments