26
26
)
27
27
from ..config import BaseField , BoolField , NumberField
28
28
from .metric import FullDatasetEvaluationMetric
29
- from ..utils import UnsupportedPackage
29
+
30
+
31
+ def _auc (x , y ):
32
+ if x .shape [0 ] < 2 :
33
+ raise ValueError ('At least 2 points are needed to compute'
34
+ ' area under curve, but x.shape = {}' .format (x .shape ))
35
+ direction = 1
36
+ dx = np .diff (x )
37
+ if np .any (dx < 0 ):
38
+ if np .all (dx <= 0 ):
39
+ direction = - 1
40
+ else :
41
+ raise ValueError ("x is neither increasing nor decreasing "
42
+ ": {}." .format (x ))
43
+ area = direction * np .trapz (y , x )
44
+ return area
45
+
46
+
47
+ def _binary_clf_curve (y_true , y_score ):
48
+ pos_label = 1.
49
+
50
+ # make y_true a boolean vector
51
+ y_true = (y_true == pos_label )
52
+
53
+ # sort scores and corresponding truth values
54
+ desc_score_indices = np .argsort (y_score , kind = "mergesort" )[::- 1 ]
55
+ y_score = y_score [desc_score_indices ]
56
+ y_true = y_true [desc_score_indices ]
57
+ weight = 1.
58
+
59
+ distinct_value_indices = np .where (np .diff (y_score ))[0 ]
60
+ threshold_idxs = np .r_ [distinct_value_indices , y_true .size - 1 ]
61
+
62
+ # accumulate the true positives with decreasing threshold
63
+ tps = np .cumsum ((y_true * weight ), axis = None , dtype = np .float64 )[threshold_idxs ]
64
+ fps = 1 + threshold_idxs - tps
65
+ return fps , tps , y_score [threshold_idxs ]
66
+
67
+
68
+ def _precision_recall_curve (y_true , probas_pred ):
69
+
70
+ fps , tps , thresholds = _binary_clf_curve (y_true , probas_pred ,)
71
+
72
+ precision = tps / (tps + fps )
73
+ precision [np .isnan (precision )] = 0
74
+ recall = tps / tps [- 1 ]
75
+
76
+ # stop when full recall attained
77
+ # and reverse the outputs so recall is decreasing
78
+ last_ind = tps .searchsorted (tps [- 1 ])
79
+ sl = slice (last_ind , None , - 1 )
80
+ return np .r_ [precision [sl ], 1 ], np .r_ [recall [sl ], 0 ], thresholds [sl ]
81
+
30
82
31
83
try :
32
84
from sklearn .metrics import auc , precision_recall_curve
33
85
except ImportError as import_error :
34
- auc = UnsupportedPackage ("sklearn.metrics.auc" , import_error .msg )
35
- precision_recall_curve = UnsupportedPackage ("sklearn.metrics.precision_recall_curve" , import_error .msg )
86
+ auc = _auc
87
+ precision_recall_curve = _precision_recall_curve
88
+
36
89
37
90
PairDesc = namedtuple ('PairDesc' , 'image1 image2 same' )
38
91
92
+
39
93
def _average_binary_score (binary_metric , y_true , y_score ):
40
94
def binary_target (y ):
41
95
return not (len (np .unique (y )) > 2 ) or (y .ndim >= 2 and len (y [0 ]) > 1 )
@@ -287,6 +341,7 @@ def get_subset(container, subset_bounds):
287
341
288
342
return subset
289
343
344
+
290
345
class FaceRecognitionTAFAPairMetric (FullDatasetEvaluationMetric ):
291
346
__provider__ = 'face_recognition_tafa_pair_metric'
292
347
@@ -337,6 +392,7 @@ def evaluate(self, annotations, predictions):
337
392
338
393
return [(tp + tn ) / (tp + fp + tn + fn )]
339
394
395
+
340
396
class NormalizedEmbeddingAccuracy (FullDatasetEvaluationMetric ):
341
397
"""
342
398
Accuracy score calculated with normalized embedding dot products
@@ -407,6 +463,7 @@ def evaluate(self, annotations, predictions):
407
463
return 0
408
464
return tp / (tp + fp )
409
465
466
+
410
467
def regroup_pairs (annotations , predictions ):
411
468
image_indexes = {}
412
469
@@ -424,10 +481,12 @@ def regroup_pairs(annotations, predictions):
424
481
425
482
return pairs
426
483
484
+
427
485
def extract_embeddings (annotation , prediction , query ):
428
486
embeddings = [pred .embedding for pred , ann in zip (prediction , annotation ) if ann .query == query ]
429
487
return np .stack (embeddings ) if embeddings else embeddings
430
488
489
+
431
490
def get_gallery_query_pids (annotation ):
432
491
gallery_pids = np .asarray ([ann .person_id for ann in annotation if not ann .query ])
433
492
query_pids = np .asarray ([ann .person_id for ann in annotation if ann .query ])
@@ -574,10 +633,6 @@ def get_embedding_distances(annotation, prediction, train=False):
574
633
575
634
576
635
def binary_average_precision (y_true , y_score , interpolated_auc = True ):
577
- if isinstance (auc , UnsupportedPackage ):
578
- auc .raise_error ("reid metric" )
579
- if isinstance (precision_recall_curve , UnsupportedPackage ):
580
- precision_recall_curve .raise_error ("reid metric" )
581
636
def _average_precision (y_true_ , y_score_ ):
582
637
precision , recall , _ = precision_recall_curve (y_true_ , y_score_ )
583
638
if not interpolated_auc :
0 commit comments