Skip to content

Commit 7486d69

Browse files
authored
AC: added alternatives sklearn if it is not available (#1854)
1 parent f06bf05 commit 7486d69

File tree

1 file changed

+62
-7
lines changed
  • tools/accuracy_checker/accuracy_checker/metrics

1 file changed

+62
-7
lines changed

tools/accuracy_checker/accuracy_checker/metrics/reid.py

Lines changed: 62 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -26,16 +26,70 @@
2626
)
2727
from ..config import BaseField, BoolField, NumberField
2828
from .metric import FullDatasetEvaluationMetric
29-
from ..utils import UnsupportedPackage
29+
30+
31+
def _auc(x, y):
32+
if x.shape[0] < 2:
33+
raise ValueError('At least 2 points are needed to compute'
34+
' area under curve, but x.shape = {}'.format(x.shape))
35+
direction = 1
36+
dx = np.diff(x)
37+
if np.any(dx < 0):
38+
if np.all(dx <= 0):
39+
direction = -1
40+
else:
41+
raise ValueError("x is neither increasing nor decreasing "
42+
": {}.".format(x))
43+
area = direction * np.trapz(y, x)
44+
return area
45+
46+
47+
def _binary_clf_curve(y_true, y_score):
48+
pos_label = 1.
49+
50+
# make y_true a boolean vector
51+
y_true = (y_true == pos_label)
52+
53+
# sort scores and corresponding truth values
54+
desc_score_indices = np.argsort(y_score, kind="mergesort")[::-1]
55+
y_score = y_score[desc_score_indices]
56+
y_true = y_true[desc_score_indices]
57+
weight = 1.
58+
59+
distinct_value_indices = np.where(np.diff(y_score))[0]
60+
threshold_idxs = np.r_[distinct_value_indices, y_true.size - 1]
61+
62+
# accumulate the true positives with decreasing threshold
63+
tps = np.cumsum((y_true * weight), axis=None, dtype=np.float64)[threshold_idxs]
64+
fps = 1 + threshold_idxs - tps
65+
return fps, tps, y_score[threshold_idxs]
66+
67+
68+
def _precision_recall_curve(y_true, probas_pred):
69+
70+
fps, tps, thresholds = _binary_clf_curve(y_true, probas_pred,)
71+
72+
precision = tps / (tps + fps)
73+
precision[np.isnan(precision)] = 0
74+
recall = tps / tps[-1]
75+
76+
# stop when full recall attained
77+
# and reverse the outputs so recall is decreasing
78+
last_ind = tps.searchsorted(tps[-1])
79+
sl = slice(last_ind, None, -1)
80+
return np.r_[precision[sl], 1], np.r_[recall[sl], 0], thresholds[sl]
81+
3082

3183
try:
3284
from sklearn.metrics import auc, precision_recall_curve
3385
except ImportError as import_error:
34-
auc = UnsupportedPackage("sklearn.metrics.auc", import_error.msg)
35-
precision_recall_curve = UnsupportedPackage("sklearn.metrics.precision_recall_curve", import_error.msg)
86+
auc = _auc
87+
precision_recall_curve = _precision_recall_curve
88+
3689

3790
PairDesc = namedtuple('PairDesc', 'image1 image2 same')
3891

92+
3993
def _average_binary_score(binary_metric, y_true, y_score):
4094
def binary_target(y):
4195
return not (len(np.unique(y)) > 2) or (y.ndim >= 2 and len(y[0]) > 1)
@@ -287,6 +341,7 @@ def get_subset(container, subset_bounds):
287341

288342
return subset
289343

344+
290345
class FaceRecognitionTAFAPairMetric(FullDatasetEvaluationMetric):
291346
__provider__ = 'face_recognition_tafa_pair_metric'
292347

@@ -337,6 +392,7 @@ def evaluate(self, annotations, predictions):
337392

338393
return [(tp+tn) / (tp+fp+tn+fn)]
339394

395+
340396
class NormalizedEmbeddingAccuracy(FullDatasetEvaluationMetric):
341397
"""
342398
Accuracy score calculated with normalized embedding dot products
@@ -407,6 +463,7 @@ def evaluate(self, annotations, predictions):
407463
return 0
408464
return tp/(tp+fp)
409465

466+
410467
def regroup_pairs(annotations, predictions):
411468
image_indexes = {}
412469

@@ -424,10 +481,12 @@ def regroup_pairs(annotations, predictions):
424481

425482
return pairs
426483

484+
427485
def extract_embeddings(annotation, prediction, query):
428486
embeddings = [pred.embedding for pred, ann in zip(prediction, annotation) if ann.query == query]
429487
return np.stack(embeddings) if embeddings else embeddings
430488

489+
431490
def get_gallery_query_pids(annotation):
432491
gallery_pids = np.asarray([ann.person_id for ann in annotation if not ann.query])
433492
query_pids = np.asarray([ann.person_id for ann in annotation if ann.query])
@@ -574,10 +633,6 @@ def get_embedding_distances(annotation, prediction, train=False):
574633

575634

576635
def binary_average_precision(y_true, y_score, interpolated_auc=True):
577-
if isinstance(auc, UnsupportedPackage):
578-
auc.raise_error("reid metric")
579-
if isinstance(precision_recall_curve, UnsupportedPackage):
580-
precision_recall_curve.raise_error("reid metric")
581636
def _average_precision(y_true_, y_score_):
582637
precision, recall, _ = precision_recall_curve(y_true_, y_score_)
583638
if not interpolated_auc:

0 commit comments

Comments
 (0)