Skip to content

Commit 315ba4c

Browse files
committed
Fixes the "wrong" computation of the AP score.
Given that sklearn changed the way AP scores are computed this implements a custom version. This implementation follows the official Market-1501 computation of the AP.
1 parent b61ea61 commit 315ba4c

File tree

2 files changed

+40
-1
lines changed

2 files changed

+40
-1
lines changed

README.md

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -273,6 +273,9 @@ The evaluation code in this repository simply uses the scikit-learn code, and th
273273
Unfortunately, almost no paper mentions which code-base they used and how they computed `mAP` scores, so comparison is difficult.
274274
Other frameworks have [the same problem](https://github.com/Cysu/open-reid/issues/50), but we expect many not to be aware of this.
275275

276+
To make the evaluating code independent of the sklearn version we have implemented our own version of the average precision computation.
277+
This now follows the official Market1501 code and results in values directly comparable.
278+
276279
# Independent re-implementations
277280

278281
These are the independent re-implementations of our paper that we are aware of,

evaluate.py

Lines changed: 37 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,6 @@
77
import h5py
88
import json
99
import numpy as np
10-
from sklearn.metrics import average_precision_score
1110
import tensorflow as tf
1211

1312
import common
@@ -52,6 +51,43 @@
5251
help='Batch size used during evaluation, adapt based on your memory usage.')
5352

5453

54+
def average_precision_score(y_true, y_score):
55+
""" Compute average precision (AP) from prediction scores.
56+
57+
This is a replacement for the scikit-learn version which, while likely more
58+
correct does not follow the same protocol as used in the default Market-1501
59+
evaluation that first introduced this score to the person ReID field.
60+
61+
Args:
62+
y_true (array): The binary labels for all data points.
63+
y_score (array): The predicted scores for each samples for all data
64+
points.
65+
66+
Raises:
67+
ValueError if the length of the labels and scores do not match.
68+
69+
Returns:
70+
A float representing the average precision given the predictions.
71+
"""
72+
73+
if len(y_true) != len(y_score):
74+
raise ValueError('The length of the labels and predictions must match '
75+
'got lengths y_true:{} and y_score:{}'.format(
76+
len(y_true), len(y_score)))
77+
78+
y_true_sorted = y_true[np.argsort(-y_score, kind='mergesort')]
79+
80+
tp = np.cumsum(y_true_sorted)
81+
total_true = np.sum(y_true_sorted)
82+
recall = tp / total_true
83+
recall = np.insert(recall, 0, 0.)
84+
precision = tp / np.arange(1, len(tp) + 1)
85+
precision = np.insert(precision, 0, 1.)
86+
ap = np.sum(np.diff(recall) * ((precision[1:] + precision[:-1]) / 2))
87+
88+
return ap
89+
90+
5591
def main():
5692
# Verify that parameters are set correctly.
5793
args = parser.parse_args()

0 commit comments

Comments
 (0)