Skip to content

Commit c8b3fa7

Browse files
docs: refactor utils.selection and uncertainty
1 parent 3084406 commit c8b3fa7

File tree

2 files changed

+75
-171
lines changed

2 files changed

+75
-171
lines changed

modAL/uncertainty.py

Lines changed: 63 additions & 144 deletions
Original file line numberDiff line numberDiff line change
@@ -1,35 +1,27 @@
11
"""
22
Uncertainty measures and uncertainty based sampling strategies for the active learning models.
33
"""
4+
from typing import Tuple
45

56
import numpy as np
67
from scipy.stats import entropy
78
from sklearn.exceptions import NotFittedError
9+
from sklearn.base import BaseEstimator
810

911
from modAL.utils.selection import multi_argmax
12+
from modAL.utils.data import modALinput
1013

1114

12-
def classifier_uncertainty(classifier, X, **predict_proba_kwargs):
15+
def classifier_uncertainty(classifier: BaseEstimator, X: modALinput, **predict_proba_kwargs) -> np.ndarray:
1316
"""
1417
Classification uncertainty of the classifier for the provided samples.
1518
16-
:param classifier:
17-
The classifier for which the uncertainty is to be measured.
18-
:type classifier:
19-
sklearn classifier object, for instance sklearn.ensemble.RandomForestClassifier
19+
Args:
20+
classifier: The classifier for which the uncertainty is to be measured.
21+
X: The samples for which the uncertainty of classification is to be measured.
22+
**predict_proba_kwargs: Keyword arguments to be passed for the :meth:`predict_proba` of the classifier.
2023
21-
:param X:
22-
The samples for which the uncertainty of classification is to be measured.
23-
:type X:
24-
numpy.ndarray of shape (n_samples, n_features)
25-
26-
:param predict_proba_kwargs:
27-
Keyword arguments to be passed for the predict_proba method of the classifier.
28-
:type predict_proba_kwargs:
29-
keyword arguments
30-
31-
:returns:
32-
- **uncertainty** *(numpy.ndarray of shape (n_samples, ))* --
24+
Returns:
3325
Classifier uncertainty, which is 1 - P(prediction is correct).
3426
"""
3527
# calculate uncertainty for each point provided
@@ -43,36 +35,23 @@ def classifier_uncertainty(classifier, X, **predict_proba_kwargs):
4335
return uncertainty
4436

4537

46-
def classifier_margin(classifier, X, **predict_proba_kwargs):
38+
def classifier_margin(classifier: BaseEstimator, X: modALinput, **predict_proba_kwargs) -> np.ndarray:
4739
"""
48-
Classification margin uncertainty of the classifier for the provided samples.
49-
This uncertainty measure takes the first and second most likely predictions
50-
and takes the difference of their probabilities, which is the margin.
51-
52-
:param classifier:
53-
The classifier for which the uncertainty is to be measured
54-
:type classifier:
55-
sklearn classifier object, for instance sklearn.ensemble.RandomForestClassifier
56-
57-
:param X:
58-
The samples for which the uncertainty of classification is to be measured
59-
:type X:
60-
numpy.ndarray of shape (n_samples, n_features)
61-
62-
:param predict_proba_kwargs:
63-
Keyword arguments to be passed for the predict_proba method of the classifier
64-
:type predict_proba_kwargs:
65-
keyword arguments
66-
67-
:returns:
68-
- **margin** *(numpy.ndarray of shape (n_samples, ))* --
69-
Margin uncertainty, which is the difference of the probabilities of first
70-
and second most likely predictions.
40+
Classification margin uncertainty of the classifier for the provided samples. This uncertainty measure takes the
41+
first and second most likely predictions and takes the difference of their probabilities, which is the margin.
42+
43+
Args:
44+
classifier: The classifier for which the prediction margin is to be measured.
45+
X: The samples for which the prediction margin of classification is to be measured.
46+
**predict_proba_kwargs: Keyword arguments to be passed for the :meth:`predict_proba` of the classifier.
47+
48+
Returns:
49+
Margin uncertainty, which is the difference of the probabilities of first and second most likely predictions.
7150
"""
7251
try:
7352
classwise_uncertainty = classifier.predict_proba(X, **predict_proba_kwargs)
7453
except NotFittedError:
75-
return np.zeros(shape=(len(X), ))
54+
return np.zeros(shape=(X.shape[0], ))
7655

7756
if classwise_uncertainty.shape[1] == 1:
7857
return np.zeros(shape=(classwise_uncertainty.shape[0],))
@@ -83,140 +62,80 @@ def classifier_margin(classifier, X, **predict_proba_kwargs):
8362
return margin
8463

8564

86-
def classifier_entropy(classifier, X, **predict_proba_kwargs):
65+
def classifier_entropy(classifier: BaseEstimator, X: modALinput, **predict_proba_kwargs) -> np.ndarray:
8766
"""
8867
Entropy of predictions of the for the provided samples.
8968
90-
:param classifier:
91-
The classifier for which the prediction entropy is to be measured.
92-
:type classifier:
93-
sklearn classifier object, for instance sklearn.ensemble.RandomForestClassifier
94-
95-
:param X:
96-
The samples for which the prediction entropy is to be measured.
97-
:type X:
98-
numpy.ndarray of shape (n_samples, n_features)
99-
100-
:param predict_proba_kwargs:
101-
Keyword arguments to be passed for the predict_proba method of the classifier.
102-
:type predict_proba_kwargs:
103-
keyword arguments
69+
Args:
70+
classifier: The classifier for which the prediction entropy is to be measured.
71+
X: The samples for which the prediction entropy is to be measured.
72+
**predict_proba_kwargs: Keyword arguments to be passed for the :meth:`predict_proba` of the classifier.
10473
105-
:returns:
106-
- **entr** *(numpy.ndarray of shape (n_samples, ))* --
74+
Returns:
10775
Entropy of the class probabilities.
10876
"""
10977
try:
11078
classwise_uncertainty = classifier.predict_proba(X, **predict_proba_kwargs)
11179
except NotFittedError:
112-
return np.zeros(shape=(len(X), ))
80+
return np.zeros(shape=(X.shape[0], ))
11381

11482
return np.transpose(entropy(np.transpose(classwise_uncertainty)))
11583

11684

117-
def uncertainty_sampling(classifier, X, n_instances=1, **uncertainty_measure_kwargs):
85+
def uncertainty_sampling(classifier: BaseEstimator, X: modALinput,
86+
n_instances: int = 1, **uncertainty_measure_kwargs) -> Tuple[np.ndarray, modALinput]:
11887
"""
11988
Uncertainty sampling query strategy. Selects the least sure instances for labelling.
12089
121-
:param classifier:
122-
The classifier for which the labels are to be queried.
123-
:type classifier:
124-
sklearn classifier object, for instance sklearn.ensemble.RandomForestClassifier
125-
126-
:param X:
127-
The pool of samples to query from.
128-
:type X:
129-
numpy.ndarray of shape (n_samples, n_features)
130-
131-
:param n_instances:
132-
Number of samples to be queried.
133-
:type n_instances:
134-
int
135-
136-
:param uncertainty_measure_kwargs:
137-
Keyword arguments to be passed for the uncertainty measure function.
138-
:type uncertainty_measure_kwargs:
139-
keyword arguments
140-
141-
:returns:
142-
- **query_idx** *(numpy.ndarray of shape (n_instances, ))* --
143-
The indices of the instances from X chosen to be labelled.
144-
- **X[query_idx]** *(numpy.ndarray of shape (n_instances, n_features))* --
145-
The instances from X chosen to be labelled.
90+
Args:
91+
classifier: The classifier for which the labels are to be queried.
92+
X: The pool of samples to query from.
93+
n_instances: Number of samples to be queried.
94+
**uncertainty_measure_kwargs: Keyword arguments to be passed for the uncertainty measure function.
95+
96+
Returns:
97+
The indices of the instances from X chosen to be labelled; the instances from X chosen to be labelled.
14698
"""
14799
uncertainty = classifier_uncertainty(classifier, X, **uncertainty_measure_kwargs)
148100
query_idx = multi_argmax(uncertainty, n_instances=n_instances)
149101

150102
return query_idx, X[query_idx]
151103

152104

153-
def margin_sampling(classifier, X, n_instances=1, **uncertainty_measure_kwargs):
105+
def margin_sampling(classifier: BaseEstimator, X: modALinput,
106+
n_instances: int = 1, **uncertainty_measure_kwargs) -> Tuple[np.ndarray, modALinput]:
154107
"""
155-
Margin sampling query strategy. Selects the instances where the difference between
156-
the first most likely and second most likely classes are the smallest.
157-
158-
:param classifier:
159-
The classifier for which the labels are to be queried.
160-
:type classifier:
161-
sklearn classifier object, for instance sklearn.ensemble.RandomForestClassifier
162-
163-
:param X:
164-
The pool of samples to query from.
165-
:type X:
166-
numpy.ndarray of shape (n_samples, n_features)
167-
168-
:param n_instances:
169-
Number of samples to be queried.
170-
:type n_instances:
171-
int
172-
173-
:param uncertainty_measure_kwargs:
174-
Keyword arguments to be passed for the uncertainty measure function.
175-
:type uncertainty_measure_kwargs:
176-
keyword arguments
177-
178-
:returns:
179-
- **query_idx** *(numpy.ndarray of shape (n_instances, ))* --
180-
The indices of the instances from X chosen to be labelled.
181-
- **X[query_idx]** *(numpy.ndarray of shape (n_instances, n_features))* --
182-
The instances from X chosen to be labelled.
108+
Margin sampling query strategy. Selects the instances where the difference between the first most likely and second
109+
most likely classes are the smallest.
110+
111+
Args:
112+
classifier: The classifier for which the labels are to be queried.
113+
X: The pool of samples to query from.
114+
n_instances: Number of samples to be queried.
115+
**uncertainty_measure_kwargs: Keyword arguments to be passed for the uncertainty measure function.
116+
117+
Returns:
118+
The indices of the instances from X chosen to be labelled; the instances from X chosen to be labelled.
183119
"""
184120
margin = classifier_margin(classifier, X, **uncertainty_measure_kwargs)
185121
query_idx = multi_argmax(-margin, n_instances=n_instances)
186122

187123
return query_idx, X[query_idx]
188124

189125

190-
def entropy_sampling(classifier, X, n_instances=1, **uncertainty_measure_kwargs):
126+
def entropy_sampling(classifier: BaseEstimator, X: modALinput,
127+
n_instances: int = 1, **uncertainty_measure_kwargs) -> Tuple[np.ndarray, modALinput]:
191128
"""
192-
Entropy sampling query strategy. Selects the instances where the class probabilities
193-
have the largest entropy.
194-
195-
:param classifier:
196-
The classifier for which the labels are to be queried.
197-
:type classifier:
198-
sklearn classifier object, for instance sklearn.ensemble.RandomForestClassifier
199-
200-
:param X:
201-
The pool of samples to query from.
202-
:type X:
203-
numpy.ndarray of shape (n_samples, n_features)
204-
205-
:param n_instances:
206-
Number of samples to be queried.
207-
:type n_instances:
208-
int
209-
210-
:param uncertainty_measure_kwargs:
211-
Keyword arguments to be passed for the uncertainty measure function.
212-
:type uncertainty_measure_kwargs:
213-
keyword arguments
214-
215-
:returns:
216-
- **query_idx** *(numpy.ndarray of shape (n_instances, ))* --
217-
The indices of the instances from X chosen to be labelled.
218-
- **X[query_idx]** *(numpy.ndarray of shape (n_instances, n_features))* --
219-
The instances from X chosen to be labelled.
129+
Entropy sampling query strategy. Selects the instances where the class probabilities have the largest entropy.
130+
131+
Args:
132+
classifier: The classifier for which the labels are to be queried.
133+
X: The pool of samples to query from.
134+
n_instances: Number of samples to be queried.
135+
**uncertainty_measure_kwargs: Keyword arguments to be passed for the uncertainty measure function.
136+
137+
Returns:
138+
The indices of the instances from X chosen to be labelled; the instances from X chosen to be labelled.
220139
"""
221140
entropy = classifier_entropy(classifier, X, **uncertainty_measure_kwargs)
222141
query_idx = multi_argmax(entropy, n_instances=n_instances)

modAL/utils/selection.py

Lines changed: 12 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -5,50 +5,35 @@
55
import numpy as np
66

77

8-
def multi_argmax(values, n_instances=1):
8+
def multi_argmax(values: np.ndarray, n_instances: int = 1) -> np.ndarray:
99
"""
1010
Selects the indices of the n_instances highest values.
1111
12-
:param values:
13-
Contains the values to be selected from.
14-
:type values:
15-
numpy.ndarray of shape = (n_samples, 1)
12+
Args:
13+
values: Contains the values to be selected from.
14+
n_instances: Specifies how many indices to return.
1615
17-
:param n_instances:
18-
Specifies how many indices to return.
19-
:type n_instances:
20-
int
21-
22-
:returns:
23-
- **max_idx** *(numpy.ndarray of shape = (n_samples, 1))* --
16+
Returns:
2417
Contains the indices of the n_instances largest values.
25-
2618
"""
27-
assert n_instances <= len(values), 'n_instances must be less or equal than the size of utility'
19+
assert n_instances <= values.shape[0], 'n_instances must be less or equal than the size of utility'
2820

2921
max_idx = np.argpartition(-values, n_instances-1, axis=0)[:n_instances]
3022
return max_idx
3123

3224

33-
def weighted_random(weights, n_instances=1):
25+
def weighted_random(weights: np.ndarray, n_instances: int = 1) -> np.ndarray:
3426
"""
3527
Returns n_instances indices based on the weights.
3628
37-
:param weights:
38-
Contains the weights of the sampling.
39-
:type weights:
40-
numpy.ndarray of shape = (n_samples, 1)
41-
42-
:param n_instances:
43-
Specifies how many indices to return.
44-
:type n_instances:
45-
int
29+
Args:
30+
weights: Contains the weights of the sampling.
31+
n_instances: Specifies how many indices to return.
4632
47-
:returns:
48-
- **random_idx** *(numpy.ndarray of shape = (n_instances, 1))* --
33+
Returns:
4934
n_instances random indices based on the weights.
5035
"""
51-
assert n_instances <= len(weights), 'n_instances must be less or equal than the size of utility'
36+
assert n_instances <= weights.shape[0], 'n_instances must be less or equal than the size of utility'
5237
weight_sum = np.sum(weights)
5338
assert weight_sum > 0, 'the sum of weights must be larger than zero'
5439

0 commit comments

Comments
 (0)