Skip to content

Commit 255f3ae

Browse files
refactor: improve acquisition and disagreement modules
1 parent c8b3fa7 commit 255f3ae

File tree

2 files changed

+139
-311
lines changed

2 files changed

+139
-311
lines changed

modAL/acquisition.py

Lines changed: 59 additions & 136 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,16 @@
11
"""
22
Acquisition functions for Bayesian optimization.
33
"""
4+
from typing import Tuple
45

56
import numpy as np
6-
77
from scipy.stats import norm
88
from scipy.special import ndtr
99
from sklearn.exceptions import NotFittedError
10+
11+
from modAL.models import BayesianOptimizer
1012
from modAL.utils.selection import multi_argmax
13+
from modAL.utils.data import modALinput
1114

1215

1316
def PI(mean, std, max_val, tradeoff):
@@ -30,209 +33,129 @@ def UCB(mean, std, beta):
3033
"""
3134

3235

33-
def optimizer_PI(optimizer, X, tradeoff=0):
36+
def optimizer_PI(optimizer: BayesianOptimizer, X: modALinput, tradeoff: float = 0) -> np.ndarray:
3437
"""
3538
Probability of improvement acquisition function for Bayesian optimization.
3639
37-
:param optimizer:
38-
The BayesianEstimator object for which the utility is to be calculated.
39-
:type optimizer:
40-
modAL.models.BayesianEstimator object
41-
42-
:param X:
43-
The samples for which the probability of improvement is to be calculated.
44-
:type X:
45-
numpy.ndarray of shape (n_samples, n_features)
46-
47-
:param tradeoff:
48-
Value controlling the tradeoff parameter.
49-
:type tradeoff:
50-
float
40+
Args:
41+
optimizer: The :class:`~modAL.models.BayesianOptimizer` object for which the utility is to be calculated.
42+
X: The samples for which the probability of improvement is to be calculated.
43+
tradeoff: Value controlling the tradeoff parameter.
5144
52-
:returns:
53-
- **pi** *(numpy.ndarray of shape (n_samples, ))* --
45+
Returns:
5446
Probability of improvement utility score.
5547
"""
5648
try:
5749
mean, std = optimizer.predict(X, return_std=True)
5850
std = std.reshape(-1, 1)
5951
except NotFittedError:
60-
mean, std = np.zeros(shape=(len(X), 1)), np.ones(shape=(len(X), 1))
52+
mean, std = np.zeros(shape=(X.shape[0], 1)), np.ones(shape=(X.shape[0], 1))
6153

6254
return PI(mean, std, optimizer.y_max, tradeoff)
6355

6456

65-
def optimizer_EI(optimizer, X, tradeoff=0):
57+
def optimizer_EI(optimizer: BayesianOptimizer, X: modALinput, tradeoff: float = 0) -> np.ndarray:
6658
"""
6759
Expected improvement acquisition function for Bayesian optimization.
6860
69-
:param optimizer:
70-
The BayesianEstimator object for which the utility is to be calculated.
71-
:type optimizer:
72-
modAL.models.BayesianEstimator object
61+
Args:
62+
optimizer: The :class:`~modAL.models.BayesianOptimizer` object for which the utility is to be calculated.
63+
X: The samples for which the expected improvement is to be calculated.
64+
tradeoff: Value controlling the tradeoff parameter.
7365
74-
:param X:
75-
The samples for which the expected improvement is to be calculated.
76-
:type X:
77-
numpy.ndarray of shape (n_samples, n_features)
78-
79-
:param tradeoff:
80-
Value controlling the tradeoff parameter.
81-
:type tradeoff:
82-
float
83-
84-
:returns:
85-
- **ei** *(numpy.ndarray of shape (n_samples, ))* --
66+
Returns:
8667
Expected improvement utility score.
8768
"""
8869
try:
8970
mean, std = optimizer.predict(X, return_std=True)
9071
std = std.reshape(-1, 1)
9172
except NotFittedError:
92-
mean, std = np.zeros(shape=(len(X), 1)), np.ones(shape=(len(X), 1))
73+
mean, std = np.zeros(shape=(X.shape[0], 1)), np.ones(shape=(X.shape[0], 1))
9374

9475
return EI(mean, std, optimizer.y_max, tradeoff)
9576

9677

97-
def optimizer_UCB(optimizer, X, beta=1):
78+
def optimizer_UCB(optimizer: BayesianOptimizer, X: modALinput, beta: float = 1) -> np.ndarray:
9879
"""
9980
Upper confidence bound acquisition function for Bayesian optimization.
10081
101-
:param optimizer:
102-
The BayesianEstimator object for which the utility is to be calculated.
103-
:type optimizer:
104-
modAL.models.BayesianEstimator object
105-
106-
:param X:
107-
The samples for which the upper confidence bound is to be calculated.
108-
:type X:
109-
numpy.ndarray of shape (n_samples, n_features)
82+
Args:
83+
optimizer: The :class:`~modAL.models.BayesianOptimizer` object for which the utility is to be calculated.
84+
X: The samples for which the upper confidence bound is to be calculated.
85+
beta: Value controlling the beta parameter.
11086
111-
:param beta:
112-
Value controlling the beta parameter.
113-
:type beta:
114-
float
115-
116-
:returns:
117-
- **ucb** *(numpy.ndarray of shape (n_samples, ))* --
87+
Returns:
11888
Upper confidence bound utility score.
11989
"""
12090
try:
12191
mean, std = optimizer.predict(X, return_std=True)
12292
std = std.reshape(-1, 1)
12393
except NotFittedError:
124-
mean, std = np.zeros(shape=(len(X), 1)), np.ones(shape=(len(X), 1))
94+
mean, std = np.zeros(shape=(X.shape[0], 1)), np.ones(shape=(X.shape[0], 1))
12595

12696
return UCB(mean, std, beta)
12797

98+
12899
"""
129100
--------------------------------------------
130101
Query strategies using acquisition functions
131102
--------------------------------------------
132103
"""
133104

134-
def max_PI(optimizer, X, tradeoff=0, n_instances=1):
105+
106+
def max_PI(optimizer: BayesianOptimizer, X: modALinput, tradeoff: float = 0,
107+
n_instances: int = 1) -> Tuple[np.ndarray, modALinput]:
135108
"""
136109
Maximum PI query strategy. Selects the instance with highest probability of improvement.
137110
138-
:param optimizer:
139-
The BayesianEstimator object for which the utility is to be calculated.
140-
:type optimizer:
141-
modAL.models.BayesianEstimator object
142-
143-
:param X:
144-
The samples for which the probability of improvement is to be calculated.
145-
:type X:
146-
numpy.ndarray of shape (n_samples, n_features)
147-
148-
:param tradeoff:
149-
Value controlling the tradeoff parameter.
150-
:type tradeoff:
151-
float
152-
153-
:param n_instances:
154-
Number of samples to be queried.
155-
:type n_instances:
156-
int
157-
158-
:returns:
159-
- **query_idx** *(numpy.ndarray of shape (n_instances, ))* --
160-
The indices of the instances from X chosen to be labelled.
161-
- **X[query_idx]** *(numpy.ndarray of shape (n_instances, n_features))* --
162-
The instances from X chosen to be labelled.
111+
Args:
112+
optimizer: The :class:`~modAL.models.BayesianOptimizer` object for which the utility is to be calculated.
113+
X: The samples for which the probability of improvement is to be calculated.
114+
tradeoff: Value controlling the tradeoff parameter.
115+
n_instances: Number of samples to be queried.
116+
117+
Returns:
118+
The indices of the instances from X chosen to be labelled; the instances from X chosen to be labelled.
163119
"""
164120
pi = optimizer_PI(optimizer, X, tradeoff=tradeoff)
165121
query_idx = multi_argmax(pi, n_instances=n_instances)
166122

167123
return query_idx, X[query_idx]
168124

169125

170-
def max_EI(optimizer, X, tradeoff=0, n_instances=1):
126+
def max_EI(optimizer: BayesianOptimizer, X: modALinput, tradeoff: float = 0,
127+
n_instances: int = 1) -> Tuple[np.ndarray, modALinput]:
171128
"""
172129
Maximum EI query strategy. Selects the instance with highest expected improvement.
173130
174-
:param optimizer:
175-
The BayesianEstimator object for which the utility is to be calculated.
176-
:type optimizer:
177-
modAL.models.BayesianEstimator object
178-
179-
:param X:
180-
The samples for which the expected improvement is to be calculated.
181-
:type X:
182-
numpy.ndarray of shape (n_samples, n_features)
183-
184-
:param tradeoff:
185-
Value controlling the tradeoff parameter.
186-
:type tradeoff:
187-
float
188-
189-
:param n_instances:
190-
Number of samples to be queried.
191-
:type n_instances:
192-
int
193-
194-
:returns:
195-
- **query_idx** *(numpy.ndarray of shape (n_instances, )*) --
196-
The indices of the instances from X chosen to be labelled.
197-
- **X[query_idx]** *(numpy.ndarray of shape (n_instances, n_features)*) --
198-
The instances from X chosen to be labelled.
131+
Args:
132+
optimizer: The :class:`~modAL.models.BayesianOptimizer` object for which the utility is to be calculated.
133+
X: The samples for which the expected improvement is to be calculated.
134+
tradeoff: Value controlling the tradeoff parameter.
135+
n_instances: Number of samples to be queried.
136+
137+
Returns:
138+
The indices of the instances from X chosen to be labelled; the instances from X chosen to be labelled.
199139
"""
200140
ei = optimizer_EI(optimizer, X, tradeoff=tradeoff)
201141
query_idx = multi_argmax(ei, n_instances=n_instances)
202142

203143
return query_idx, X[query_idx]
204144

205145

206-
def max_UCB(optimizer, X, beta=1, n_instances=1):
146+
def max_UCB(optimizer: BayesianOptimizer, X: modALinput, beta: float = 1,
147+
n_instances: int = 1) -> Tuple[np.ndarray, modALinput]:
207148
"""
208-
Maximum UCB query strategy. Selects the instance with highest upper confidence
209-
bound.
210-
211-
:param optimizer:
212-
The BayesianEstimator object for which the utility is to be calculated.
213-
:type optimizer:
214-
modAL.models.BayesianEstimator object
215-
216-
:param X:
217-
The samples for which the probability of improvement is to be calculated.
218-
:type X:
219-
numpy.ndarray of shape (n_samples, n_features)
220-
221-
:param beta:
222-
Value controlling the beta parameter.
223-
:type beta:
224-
float
225-
226-
:param n_instances:
227-
Number of samples to be queried.
228-
:type n_instances:
229-
int
230-
231-
:returns:
232-
- **query_idx** *(numpy.ndarray of shape (n_instances, ))* --
233-
The indices of the instances from X chosen to be labelled.
234-
- **X[query_idx]** *(numpy.ndarray of shape (n_instances, n_features))* --
235-
The instances from X chosen to be labelled.
149+
Maximum UCB query strategy. Selects the instance with highest upper confidence bound.
150+
151+
Args:
152+
optimizer: The :class:`~modAL.models.BayesianOptimizer` object for which the utility is to be calculated.
153+
X: The samples for which the maximum upper confidence bound is to be calculated.
154+
beta: Value controlling the beta parameter.
155+
n_instances: Number of samples to be queried.
156+
157+
Returns:
158+
The indices of the instances from X chosen to be labelled; the instances from X chosen to be labelled.
236159
"""
237160
ucb = optimizer_UCB(optimizer, X, beta=beta)
238161
query_idx = multi_argmax(ucb, n_instances=n_instances)

0 commit comments

Comments
 (0)