Skip to content

Commit 47fd12b

Browse files
committed
add: functions for calculating utility measures directly from classification probabilities
1 parent ee48499 commit 47fd12b

File tree

2 files changed

+66
-0
lines changed

2 files changed

+66
-0
lines changed

modAL/uncertainty.py

Lines changed: 48 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,54 @@
1212
from modAL.utils.data import modALinput
1313

1414

15+
def _proba_uncertainty(proba: np.ndarray) -> np.ndarray:
16+
"""
17+
Calculates the uncertainty of the prediction probabilities.
18+
19+
Args:
20+
proba: Prediction probabilities.
21+
22+
Returns:
23+
Uncertainty of the prediction probabilities.
24+
"""
25+
26+
return 1 - np.max(proba, axis=1)
27+
28+
29+
def _proba_margin(proba: np.ndarray) -> np.ndarray:
30+
"""
31+
Calculates the margin of the prediction probabilities.
32+
33+
Args:
34+
proba: Prediction probabilities.
35+
36+
Returns:
37+
Margin of the prediction probabilities.
38+
"""
39+
40+
if proba.shape[1] == 1:
41+
return np.zeros(shape=len(proba))
42+
43+
part = np.partition(-proba, 1, axis=1)
44+
margin = - part[:, 0] + part[:, 1]
45+
46+
return margin
47+
48+
49+
def _proba_entropy(proba: np.ndarray) -> np.ndarray:
50+
"""
51+
Calculates the entropy of the prediction probabilities.
52+
53+
Args:
54+
proba: Prediction probabilities.
55+
56+
Returns:
57+
Uncertainty of the prediction probabilities.
58+
"""
59+
60+
return np.transpose(entropy(np.transpose(proba)))
61+
62+
1563
def classifier_uncertainty(classifier: BaseEstimator, X: modALinput, **predict_proba_kwargs) -> np.ndarray:
1664
"""
1765
Classification uncertainty of the classifier for the provided samples.

tests/core_tests.py

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -411,6 +411,12 @@ def test_classifier_uncertainty(self):
411411
test_cases = (Test(p * np.ones(shape=(k, l)), (1 - p) * np.ones(shape=(k, )))
412412
for k in range(1, 100) for l in range(1, 10) for p in np.linspace(0, 1, 11))
413413
for case in test_cases:
414+
# testing _proba_uncertainty
415+
np.testing.assert_almost_equal(
416+
modAL.uncertainty._proba_uncertainty(case.input),
417+
case.output
418+
)
419+
414420
# fitted estimator
415421
fitted_estimator = mock.MockEstimator(predict_proba_return=case.input)
416422
np.testing.assert_almost_equal(
@@ -432,6 +438,12 @@ def test_classifier_margin(self):
432438
p * np.ones(shape=(l, ))*int(k!=1))
433439
for k in range(1, 10) for l in range(1, 100) for p in np.linspace(0, 1, 11))
434440
for case in chain(test_cases_1, test_cases_2):
441+
# _proba_margin
442+
np.testing.assert_almost_equal(
443+
modAL.uncertainty._proba_margin(case.input),
444+
case.output
445+
)
446+
435447
# fitted estimator
436448
fitted_estimator = mock.MockEstimator(predict_proba_return=case.input)
437449
np.testing.assert_almost_equal(
@@ -453,6 +465,12 @@ def test_classifier_entropy(self):
453465
for sample_idx in range(n_samples):
454466
proba[sample_idx, np.random.choice(range(n_classes))] = 1.0
455467

468+
# _proba_entropy
469+
np.testing.assert_almost_equal(
470+
modAL.uncertainty._proba_entropy(proba),
471+
np.zeros(shape=(n_samples,))
472+
)
473+
456474
# fitted estimator
457475
fitted_estimator = mock.MockEstimator(predict_proba_return=proba)
458476
np.testing.assert_equal(

0 commit comments

Comments
 (0)