77import h5py
88import json
99import numpy as np
10+ from sklearn .metrics import average_precision_score
1011import tensorflow as tf
1112
1213import common
5051 '--batch_size' , default = 256 , type = common .positive_int ,
5152 help = 'Batch size used during evaluation, adapt based on your memory usage.' )
5253
54+ parser .add_argument (
55+ '--use_market_ap' , action = 'store_true' , default = False ,
56+ help = 'When this flag is provided, the average precision is computed exactly'
57+ ' as done by the Market-1501 evaluation script, rather than the '
58+ 'default scikit-learn implementation that gives slightly different'
59+ 'scores.' )
60+
5361
54- def average_precision_score (y_true , y_score ):
62+ def average_precision_score_market (y_true , y_score ):
5563 """ Compute average precision (AP) from prediction scores.
5664
5765 This is a replacement for the scikit-learn version which, while likely more
@@ -75,6 +83,8 @@ def average_precision_score(y_true, y_score):
7583 'got lengths y_true:{} and y_score:{}' .format (
7684 len (y_true ), len (y_score )))
7785
86+ # Mergesort is used since it is a stable sorting algorithm. This is
87+ # important to compute consistent and correct scores.
7888 y_true_sorted = y_true [np .argsort (- y_score , kind = 'mergesort' )]
7989
8090 tp = np .cumsum (y_true_sorted )
@@ -119,6 +129,12 @@ def main():
119129
120130 batch_distances = loss .cdist (batch_embs , gallery_embs , metric = args .metric )
121131
132+ # Check if we should use Market-1501 specific average precision computation.
133+ if args .use_market_ap :
134+ average_precision = average_precision_score_market
135+ else :
136+ average_precision = average_precision_score
137+
122138 # Loop over the query embeddings and compute their APs and the CMC curve.
123139 aps = []
124140 cmc = np .zeros (len (gallery_pids ), dtype = np .int32 )
@@ -153,7 +169,7 @@ def main():
153169 # it won't change anything.
154170 scores = 1 / (1 + distances )
155171 for i in range (len (distances )):
156- ap = average_precision_score (pid_matches [i ], scores [i ])
172+ ap = average_precision (pid_matches [i ], scores [i ])
157173
158174 if np .isnan (ap ):
159175 print ()
0 commit comments