Skip to content

Commit 08ad93b

Browse files
authored
Add implementation of precision-recall gain AUC (#21370)
* Remove dead code in the confusion metrics * Add PRGAIN enum for AUC metric types * Add implementation of precision-recall gain AUC Based on https://research-information.bris.ac.uk/files/72164009/5867_precision_recall_gain_curves_pr_analysis_done_right.pdf
1 parent 0c0ec1a commit 08ad93b

File tree

3 files changed

+118
-21
lines changed

3 files changed

+118
-21
lines changed

keras/src/metrics/confusion_metrics.py

Lines changed: 41 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -1346,25 +1346,6 @@ def update_state(self, y_true, y_pred, sample_weight=None):
13461346
if not self._built:
13471347
self._build(y_pred.shape)
13481348

1349-
if self.multi_label or (self.label_weights is not None):
1350-
# y_true should have shape (number of examples, number of labels).
1351-
shapes = [(y_true, ("N", "L"))]
1352-
if self.multi_label:
1353-
# TP, TN, FP, and FN should all have shape
1354-
# (number of thresholds, number of labels).
1355-
shapes.extend(
1356-
[
1357-
(self.true_positives, ("T", "L")),
1358-
(self.true_negatives, ("T", "L")),
1359-
(self.false_positives, ("T", "L")),
1360-
(self.false_negatives, ("T", "L")),
1361-
]
1362-
)
1363-
if self.label_weights is not None:
1364-
# label_weights should be of length equal to the number of
1365-
# labels.
1366-
shapes.append((self.label_weights, ("L",)))
1367-
13681349
# Only forward label_weights to update_confusion_matrix_variables when
13691350
# multi_label is False. Otherwise the averaging of individual label AUCs
13701351
# is handled in AUC.result
@@ -1500,13 +1481,53 @@ def result(self):
15001481
)
15011482
x = fp_rate
15021483
y = recall
1503-
else: # curve == 'PR'.
1484+
elif self.curve == metrics_utils.AUCCurve.PR: # curve == 'PR'.
15041485
precision = ops.divide_no_nan(
15051486
self.true_positives,
15061487
ops.add(self.true_positives, self.false_positives),
15071488
)
15081489
x = recall
15091490
y = precision
1491+
else: # curve == 'PRGAIN'.
1492+
# Due to the hyperbolic transform, this formula is less robust than
1493+
# ROC and PR values. In particular
1494+
# 1) Both measures diverge when there are no negative values;
1495+
# 2) Both measures diverge when there are no true positives;
1496+
# 3) Recall gain becomes negative when the recall is lower than the
1497+
# label average (i.e. when more negative exampless are
1498+
# classified positive than real positives).
1499+
#
1500+
# We ignore case 1 as it is easily understood that metrics would be
1501+
# badly defined then. For case 2 we set recall_gain to 0 and
1502+
# precision_gain to 1. For case 3 we set recall_gain to 0. These
1503+
# fixes will result in an overstimation of the AUCfor estimators
1504+
# that are anti-correlated with the label (at some threshold).
1505+
1506+
# The scaling factor $\frac{P}{N}$ that is used to for mboth gain
1507+
# values.
1508+
scaling_factor = ops.divide_no_nan(
1509+
ops.add(self.true_positives, self.false_negatives),
1510+
ops.add(self.true_negatives, self.false_positives),
1511+
)
1512+
1513+
recall_gain = 1.0 - scaling_factor * ops.divide_no_nan(
1514+
self.false_negatives, self.true_positives
1515+
)
1516+
precision_gain = 1.0 - scaling_factor * ops.divide_no_nan(
1517+
self.false_positives, self.true_positives
1518+
)
1519+
# Handle case 2.
1520+
recall_gain = ops.where(
1521+
ops.equal(self.true_positives, 0.0), 0.0, recall_gain
1522+
)
1523+
precision_gain = ops.where(
1524+
ops.equal(self.true_positives, 0.0), 1.0, precision_gain
1525+
)
1526+
# Handle case 3.
1527+
recall_gain = ops.maximum(recall_gain, 0.0)
1528+
1529+
x = recall_gain
1530+
y = precision_gain
15101531

15111532
# Find the rectangle heights based on `summation_method`.
15121533
if (

keras/src/metrics/confusion_metrics_test.py

Lines changed: 73 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1396,6 +1396,79 @@ def test_weighted_pr_interpolation_negative_weights(self):
13961396
# produce all zeros.
13971397
self.assertAllClose(result, 0.0, 1e-3)
13981398

1399+
def test_weighted_prgain_majoring(self):
1400+
auc_obj = metrics.AUC(
1401+
num_thresholds=self.num_thresholds,
1402+
curve="PRGAIN",
1403+
summation_method="majoring",
1404+
)
1405+
result = auc_obj(
1406+
self.y_true, self.y_pred, sample_weight=self.sample_weight
1407+
)
1408+
1409+
# tp = [7, 4, 0], fp = [3, 0, 0], fn = [0, 3, 7], tn = [0, 3, 3]
1410+
# scaling_facor (P/N) = 7/3
1411+
# recall_gain = 1 - 7/3 [0/7, 3/4, 7/0] = [1, -3/4, -inf] -> [1, 0, 0]
1412+
# precision_gain = 1 - 7/3 [3/7, 0/4, 0/0] = [0, 1, NaN] -> [0, 1, 1]
1413+
# heights = [max(0, 1), max(1, 1)] = [1, 1]
1414+
# widths = [(1 - 0), (0 - 0)] = [1, 0]
1415+
expected_result = 1 * 1 + 0 * 1
1416+
self.assertAllClose(result, expected_result, 1e-3)
1417+
1418+
def test_weighted_prgain_minoring(self):
1419+
auc_obj = metrics.AUC(
1420+
num_thresholds=self.num_thresholds,
1421+
curve="PRGAIN",
1422+
summation_method="minoring",
1423+
)
1424+
result = auc_obj(
1425+
self.y_true, self.y_pred, sample_weight=self.sample_weight
1426+
)
1427+
1428+
# tp = [7, 4, 0], fp = [3, 0, 0], fn = [0, 3, 7], tn = [0, 3, 3]
1429+
# scaling_facor (P/N) = 7/3
1430+
# recall_gain = 1 - 7/3 [0/7, 3/4, 7/0] = [1, -3/4, -inf] -> [1, 0, 0]
1431+
# precision_gain = 1 - 7/3 [3/7, 0/4, 0/0] = [0, 1, NaN] -> [0, 1, 1]
1432+
# heights = [min(0, 1), min(1, 1)] = [0, 1]
1433+
# widths = [(1 - 0), (0 - 0)] = [1, 0]
1434+
expected_result = 1 * 0 + 0 * 1
1435+
self.assertAllClose(result, expected_result, 1e-3)
1436+
1437+
def test_weighted_prgain_interpolation(self):
1438+
auc_obj = metrics.AUC(
1439+
num_thresholds=self.num_thresholds, curve="PRGAIN"
1440+
)
1441+
result = auc_obj(
1442+
self.y_true, self.y_pred, sample_weight=self.sample_weight
1443+
)
1444+
1445+
# tp = [7, 4, 0], fp = [3, 0, 0], fn = [0, 3, 7], tn = [0, 3, 3]
1446+
# scaling_facor (P/N) = 7/3
1447+
# recall_gain = 1 - 7/3 [0/7, 3/4, 7/0] = [1, -3/4, -inf] -> [1, 0, 0]
1448+
# precision_gain = 1 - 7/3 [3/7, 0/4, 0/0] = [0, 1, NaN] -> [0, 1, 1]
1449+
# heights = [(0+1)/2, (1+1)/2] = [0.5, 1]
1450+
# widths = [(1 - 0), (0 - 0)] = [1, 0]
1451+
expected_result = 1 * 0.5 + 0 * 1
1452+
self.assertAllClose(result, expected_result, 1e-3)
1453+
1454+
def test_prgain_interpolation(self):
1455+
auc_obj = metrics.AUC(
1456+
num_thresholds=self.num_thresholds, curve="PRGAIN"
1457+
)
1458+
1459+
y_true = np.array([0, 0, 0, 1, 0, 1, 0, 1, 1, 1])
1460+
y_pred = np.array([0.1, 0.2, 0.3, 0.3, 0.4, 0.4, 0.6, 0.6, 0.8, 0.9])
1461+
result = auc_obj(y_true, y_pred)
1462+
1463+
# tp = [5, 3, 0], fp = [5, 1, 0], fn = [0, 2, 5], tn = [0, 4, 4]
1464+
# scaling_facor (P/N) = 5/5 = 1
1465+
# recall_gain = 1 - [0/5, 2/3, 5/0] = [1, 1/3, -inf] -> [1, 1/3, 0]
1466+
# precision_gain = 1 - [5/5, 1/3, 0/0] = [1, 1/3, NaN] -> [0, 2/3, 1]
1467+
# heights = [(0+2/3)/2, (2/3+1)/2] = [0.333333, 0.833333]
1468+
# widths = [(1 - 1/3), (1/3 - 0)] = [0.666666, 0.333333]
1469+
expected_result = 0.666666 * 0.333333 + 0.333333 * 0.833333
1470+
self.assertAllClose(result, expected_result, 1e-3)
1471+
13991472
def test_invalid_num_thresholds(self):
14001473
with self.assertRaisesRegex(
14011474
ValueError, "Argument `num_thresholds` must be an integer > 1"

keras/src/metrics/metrics_utils.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -43,17 +43,20 @@ class AUCCurve(Enum):
4343

4444
ROC = "ROC"
4545
PR = "PR"
46+
PRGAIN = "PRGAIN"
4647

4748
@staticmethod
4849
def from_str(key):
4950
if key in ("pr", "PR"):
5051
return AUCCurve.PR
5152
elif key in ("roc", "ROC"):
5253
return AUCCurve.ROC
54+
elif key in ("prgain", "PRGAIN"):
55+
return AUCCurve.PRGAIN
5356
else:
5457
raise ValueError(
5558
f'Invalid AUC curve value: "{key}". '
56-
'Expected values are ["PR", "ROC"]'
59+
'Expected values are ["PR", "ROC", "PRGAIN"]'
5760
)
5861

5962

0 commit comments

Comments
 (0)