Skip to content

Commit da1e0da

Browse files
author
Anna Grebneva
authored
Fixed calculaton of roc_auc_score metric (#3081)
1 parent 6e843e5 commit da1e0da

File tree

2 files changed

+18
-15
lines changed

2 files changed

+18
-15
lines changed

tools/accuracy_checker/openvino/tools/accuracy_checker/metrics/README.md

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -48,7 +48,8 @@ Supported representations: `ClassificationAnnotation`, `TextClassificationAnnota
4848
* `pixel_level`- evaluate metric on pixel level for anomaly segmentation (Optional, default False) .
4949
* `metthews_correlation_coef` - [Matthews correlation coefficient (MCC)](https://en.wikipedia.org/wiki/Matthews_correlation_coefficient) for binary classification. Metric is calculated as a percentage. Direction of metric's growth is higher-better. Supported representations: `ClassificationAnnotation`, `TextClassificationAnnotation`, `ClassificationPrediction`.
5050
* `roc_auc_score` - [ROC AUC score](https://en.wikipedia.org/wiki/Receiver_operating_characteristic) for binary classification. Metric is calculated as a percentage. Direction of metric's growth is higher-better. Supported representations: `ClassificationAnnotation`, `TextClassificationAnnotation`, `ClassificationPrediction` `ArgMaxClassificationPrediction`, `AnomalySegmentationAnnotation`, `AnomalySegmentationPrediction`.
51-
* `pixel_level`- evaluate metric on pixel level for anomaly segmentation (Optional, default False)
51+
* `pixel_level`- evaluate metric on pixel level for anomaly segmentation (Optional, default False).
52+
* `calculate_hot_label` - calculate one hot label for annotation and prediction before metric evaluation calculation for anomaly segmentation (Optional, default False).
5253
* `acer_score` - metric for the classification tasks. Can be obtained from the following formula: `ACER = (APCER + BPCER)/2 = ((fp / (tn + fp)) + (fn / (fn + tp)))/2`. For more details about metrics see the section 9.3: <https://arxiv.org/abs/2007.12342>. Metric is calculated as a percentage. Direction of metric's growth is higher-worse. Supported representations: `ClassificationAnnotation`, `TextClassificationAnnotation`, `ClassificationPrediction`.
5354
* `clip_accuracy` - classification video-level accuracy metric. Metric is calculated as a percentage. Direction of metric's growth is higher-better. Supported representations: `ClassificationAnnotation`, `ClassificationPrediction`.
5455
* `map` - mean average precision. Metric is calculated as a percentage. Direction of metric's growth is higher-better. Supported representations: `DetectionAnnotation`, `DetectionPrediction`.

tools/accuracy_checker/openvino/tools/accuracy_checker/metrics/classification.py

Lines changed: 16 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -551,13 +551,18 @@ def parameters(cls):
551551
params.update({
552552
'pixel_level': BoolField(
553553
optional=True, default=False,
554-
description='calculate metic on pixel level, for anomaly segmentation only')
554+
description='calculate metric on pixel level, for anomaly segmentation only'),
555+
'calculate_hot_label': BoolField(
556+
optional=True, default=False,
557+
description='calculate one hot label for annotation and prediction before metric evaluation '
558+
'for anomaly segmentation')
555559
})
556560
return params
557561

558562
def configure(self):
559563
self.reset()
560564
self.pixel_level = self.get_value_from_config('pixel_level')
565+
self.calculate_hot_label = self.get_value_from_config('calculate_hot_label')
561566

562567
def update(self, annotation, prediction):
563568
if (
@@ -581,19 +586,12 @@ def update(self, annotation, prediction):
581586
def one_hot_labels(targets, results):
582587
max_v = int(max(np.max(targets) + 1, np.max(results) + 1))
583588
gt_bin = np.zeros((len(targets), max_v))
584-
pred_bin = np.zeros((len(targets), max_v))
589+
pred_bin = np.zeros((len(results), max_v))
585590
np.put_along_axis(gt_bin, np.expand_dims(np.array(targets).astype(int), 1), 1, axis=1)
586591
np.put_along_axis(pred_bin, np.expand_dims(np.array(results).astype(int), 1), 1, axis=1)
587592

588593
return gt_bin, pred_bin
589594

590-
def roc(self, y_true, y_score):
591-
per_class_area = []
592-
for i in range(y_true.shape[-1]):
593-
per_class_area.append(self.roc_curve_area(y_true[:, i], y_score[:, i]))
594-
average_area = self.roc_curve_area(y_true.ravel(), y_score.ravel())
595-
return average_area, per_class_area
596-
597595
def roc_curve_area(self, gt, pred):
598596
desc_score_indices = np.argsort(pred, kind="mergesort")[::-1]
599597
y_score = pred[desc_score_indices]
@@ -602,10 +600,14 @@ def roc_curve_area(self, gt, pred):
602600
threshold_idxs = np.r_[distinct_value_indices, y_true.size - 1]
603601
tps = np.cumsum(y_true)[threshold_idxs]
604602
fps = 1 + threshold_idxs - tps
603+
604+
tps = np.r_[0, tps]
605+
fps = np.r_[0, fps]
606+
605607
if max(fps) > 0:
606-
fps /= fps[-1]
608+
fps = fps / fps[-1]
607609
if max(tps) > 0:
608-
tps /= tps[-1]
610+
tps = tps / tps[-1]
609611
area = self.roc_auc_score(fps, tps)
610612
return area
611613

@@ -620,7 +622,7 @@ def roc_auc_score(fpr, tpr):
620622

621623
def evaluate(self, annotations, predictions):
622624
all_results = self.results if np.isscalar(self.results[-1]) else np.concatenate(self.results)
623-
all_targets = self.targets if np.isscalar(self.results[-1]) else np.concatenate(self.results)
625+
all_targets = self.targets if np.isscalar(self.targets[-1]) else np.concatenate(self.targets)
624626
roc_auc = self.auc_score(all_targets, all_results)
625627
return roc_auc
626628

@@ -629,8 +631,8 @@ def reset(self):
629631
self.results = []
630632

631633
def auc_score(self, targets, results):
632-
gt, dt = self.one_hot_labels(targets, results)
633-
avg_area, _ = self.roc(gt, dt)
634+
(gt, dt) = self.one_hot_labels(targets, results) if self.calculate_hot_label else (targets, results)
635+
avg_area = self.roc_curve_area(np.array(gt).ravel(), np.array(dt).ravel())
634636
return avg_area
635637

636638

0 commit comments

Comments
 (0)