diff --git a/loss.py b/loss.py index dd7bc59..9237d82 100644 --- a/loss.py +++ b/loss.py @@ -66,7 +66,19 @@ def get_at_indices(tensor, indices): return tf.gather_nd(tensor, tf.stack((counter, indices), -1)) -def batch_hard(dists, pids, margin, batch_precision_at_k=None): +def apply_margin(x, margin): + if isinstance(margin, numbers.Real): + return tf.maximum(x + margin, 0.0) + elif margin == 'soft': + return tf.nn.softplus(x) + elif margin.lower() == 'none': + return x + else: + raise NotImplementedError( + 'The margin {} is not implemented in batch_hard'.format(margin)) + + +def _generic_batchloss(dists, pids, margin, batch_precision_at_k=None, variant='hard'): """Computes the batch-hard loss from arxiv.org/abs/1703.07737. Args: @@ -87,25 +99,83 @@ def batch_hard(dists, pids, margin, batch_precision_at_k=None): positive_mask = tf.logical_xor(same_identity_mask, tf.eye(tf.shape(pids)[0], dtype=tf.bool)) - furthest_positive = tf.reduce_max(dists*tf.cast(positive_mask, tf.float32), axis=1) - closest_negative = tf.map_fn(lambda x: tf.reduce_min(tf.boolean_mask(x[0], x[1])), - (dists, negative_mask), tf.float32) - # Another way of achieving the same, though more hacky: - # closest_negative = tf.reduce_min(dists + 1e5*tf.cast(same_identity_mask, tf.float32), axis=1) - - diff = furthest_positive - closest_negative - if isinstance(margin, numbers.Real): - diff = tf.maximum(diff + margin, 0.0) - elif margin == 'soft': - diff = tf.nn.softplus(diff) - elif margin.lower() == 'none': - pass - else: - raise NotImplementedError( - 'The margin {} is not implemented in batch_hard'.format(margin)) + if variant == 'sample': + # -inf gives that index a probability of zero. + neg_infs = -tf.constant(float('inf'))*tf.ones_like(dists) + # higher logits are more likely to be sampled. + pos_logits = tf.where(positive_mask, dists, neg_infs) + pos_indices = tf.multinomial(pos_logits, num_samples=1)[:,0] + positive = get_at_indices(dists, pos_indices) + + # Same for the negatives, but we need to turn the logits around, + # since we want to sample the smaller distances more likely. + neg_logits = tf.where(negative_mask, -dists, neg_infs) + neg_indices = tf.multinomial(neg_logits, num_samples=1)[:,0] + negative = get_at_indices(dists, neg_indices) + elif variant == 'hard': + # Furthest one is worst positive. + positive = tf.reduce_max(dists*tf.cast(positive_mask, tf.float32), axis=1) + # Closest one is worst negative. + negative = tf.map_fn(lambda x: tf.reduce_min(tf.boolean_mask(x[0], x[1])), + (dists, negative_mask), tf.float32) + # negative = tf.reduce_min(dists + 1e5*tf.cast(same_identity_mask, tf.float32), axis=1) + + losses = apply_margin(positive - negative, margin) + + return return_with_extra_stats(losses, dists, batch_precision_at_k, + same_identity_mask, + positive_mask, negative_mask) + +def batch_hard(dists, pids, margin, batch_precision_at_k=None): + return _generic_batchloss(dists, pids, margin, batch_precision_at_k, variant='hard') + +def batch_sample(dists, pids, margin, batch_precision_at_k=None): + return _generic_batchloss(dists, pids, margin, batch_precision_at_k, variant='sample') + + +def batch_all(dists, pids, margin, batch_precision_at_k=None): + with tf.name_scope("batch_hard"): + same_identity_mask = tf.equal(tf.expand_dims(pids, axis=1), + tf.expand_dims(pids, axis=0)) + negative_mask = tf.logical_not(same_identity_mask) + positive_mask = tf.logical_xor(same_identity_mask, + tf.eye(tf.shape(pids)[0], dtype=tf.bool)) + + # Unfortunately, foldl can only go over one tensor, unlike map_fn, + # so we need to convert and stack around. + packed = tf.stack([dists, + tf.cast(positive_mask, tf.float32), + tf.cast(negative_mask, tf.float32)], axis=1) + + def per_anchor(accum, row): + # `dists_` is a 1D array of distance (row of `dists`) + # `poss_` is a 1D bool array marking positives. + # `negs_` is a 1D bool array marking negatives. + dists_, poss_, negs_ = row[0], row[1], row[2] + + # Now construct a (P,N)-matrix of all-to-all (anchor-pos - anchor-neg). + diff = all_diffs(tf.boolean_mask(dists_, tf.cast(poss_, tf.bool)), + tf.boolean_mask(dists_, tf.cast(negs_, tf.bool))) + + losses = tf.reshape(apply_margin(diff, margin), [-1]) + return tf.concat([accum, losses], axis=0) + + # Some very advanced trickery in order to get the initialization tensor + # to be an empty 1D tensor with a dynamic shape, such that it is + # allowed to grow during the iteration. + init = tf.placeholder_with_default([], shape=[None]) + losses = tf.foldl(per_anchor, packed, init) + + return return_with_extra_stats(losses, dists, batch_precision_at_k, + same_identity_mask, + positive_mask, negative_mask) + + +def return_with_extra_stats(to_return, dists, batch_precision_at_k, + same_identity_mask, positive_mask, negative_mask): if batch_precision_at_k is None: - return diff + return to_return # For monitoring, compute the within-batch top-1 accuracy and the # within-batch precision-at-k, which is somewhat more expressive. @@ -142,9 +212,12 @@ def batch_hard(dists, pids, margin, batch_precision_at_k=None): negative_dists = tf.boolean_mask(dists, negative_mask) positive_dists = tf.boolean_mask(dists, positive_mask) - return diff, top1, prec_at_k, topk_is_same, negative_dists, positive_dists + return to_return, top1, prec_at_k, topk_is_same, negative_dists, positive_dists + LOSS_CHOICES = { 'batch_hard': batch_hard, + 'batch_sample': batch_sample, + 'batch_all': batch_all, } diff --git a/train.py b/train.py index 9294494..eee3d9a 100755 --- a/train.py +++ b/train.py @@ -102,9 +102,13 @@ help='Which metric to use for the distance between embeddings.') parser.add_argument( - '--loss', default='batch_hard', choices=loss.LOSS_CHOICES.keys(), + '--loss', default='batch_hard', choices=loss.LOSS_CHOICES, help='Enable the super-mega-advanced top-secret sampling stabilizer.') +parser.add_argument( + '--loss_ignore_zero', default=False, const=True, nargs='?', type=common.positive_float, + help='Average only over non-zero loss values, called "=/=0" in the paper.') + parser.add_argument( '--learning_rate', default=3e-4, type=common.positive_float, help='The initial value of the learning-rate, before it kicks in.') @@ -141,6 +145,11 @@ ' embeddings, losses and FIDs seen in each batch during training.' ' Everything can be re-constructed and analyzed that way.') +parser.add_argument( + '--optim', default='AdamOptimizer(learning_rate)', + help='Which optimizer to use. This is actual TensorFlow code that will be' + ' eval\'d. Use `learning_rate` for the learning-rate with schedule.') + def sample_k_fids_for_pid(pid, all_fids, all_pids, batch_k): """ Given a PID, select K FIDs of that specific PID. """ @@ -294,16 +303,24 @@ def main(): losses, train_top1, prec_at_k, _, neg_dists, pos_dists = loss.LOSS_CHOICES[args.loss]( dists, pids, args.margin, batch_precision_at_k=args.batch_k-1) - # Count the number of active entries, and compute the total batch loss. - num_active = tf.reduce_sum(tf.cast(tf.greater(losses, 1e-5), tf.float32)) - loss_mean = tf.reduce_mean(losses) + # Count how many entries in the batch are (possibly approximately) non-zero. + if args.loss_ignore_zero is True: + nnz = tf.count_nonzero(losses, dtype=tf.float32) + else: + nnz = tf.reduce_sum(tf.to_float(tf.greater(losses, args.loss_ignore_zero or 1e-5))) + + # Compute the total batch-loss by either averaging all, or averaging non-zeros only. + if args.loss_ignore_zero is False: + loss_mean = tf.reduce_mean(losses) + else: + loss_mean = tf.reduce_sum(losses) / (1e-33 + nnz) # Some logging for tensorboard. tf.summary.histogram('loss_distribution', losses) tf.summary.scalar('loss', loss_mean) tf.summary.scalar('batch_top1', train_top1) tf.summary.scalar('batch_prec_at_{}'.format(args.batch_k-1), prec_at_k) - tf.summary.scalar('active_count', num_active) + tf.summary.scalar('active_count', nnz) tf.summary.histogram('embedding_dists', dists) tf.summary.histogram('embedding_pos_dists', pos_dists) tf.summary.histogram('embedding_neg_dists', neg_dists) @@ -341,9 +358,7 @@ def main(): else: learning_rate = args.learning_rate tf.summary.scalar('learning_rate', learning_rate) - optimizer = tf.train.AdamOptimizer(learning_rate) - # Feel free to try others! - # optimizer = tf.train.AdadeltaOptimizer(learning_rate) + optimizer = eval("tf.train." + args.optim) # Update_ops are used to update batchnorm stats. with tf.control_dependencies(tf.get_collection(tf.GraphKeys.UPDATE_OPS)):