Skip to content

Commit 3a82cf9

Browse files
committed
fix: max_std_sampling type hint changed from CommitteeRegressor to BaseEstimator
1 parent bf60590 commit 3a82cf9

File tree

1 file changed

+5
-3
lines changed

1 file changed

+5
-3
lines changed

modAL/disagreement.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -7,11 +7,13 @@
77
import numpy as np
88
from scipy.stats import entropy
99
from sklearn.exceptions import NotFittedError
10+
from sklearn.base import BaseEstimator
1011

11-
from modAL.models import BaseCommittee, CommitteeRegressor
1212
from modAL.utils.data import modALinput
1313
from modAL.utils.selection import multi_argmax
1414

15+
from modAL.models import BaseCommittee
16+
1517

1618
def vote_entropy(committee: BaseCommittee, X: modALinput, **predict_proba_kwargs) -> np.ndarray:
1719
"""
@@ -161,13 +163,13 @@ def max_disagreement_sampling(committee: BaseCommittee, X: modALinput,
161163
return query_idx, X[query_idx]
162164

163165

164-
def max_std_sampling(regressor: CommitteeRegressor, X: modALinput,
166+
def max_std_sampling(regressor: BaseEstimator, X: modALinput,
165167
n_instances: int = 1, **predict_kwargs) -> Tuple[np.ndarray, modALinput]:
166168
"""
167169
Regressor standard deviation sampling strategy.
168170
169171
Args:
170-
regressor: The CommitteeRegressor for which the labels are to be queried.
172+
regressor: The regressor for which the labels are to be queried.
171173
X: The pool of samples to query from.
172174
n_instances: Number of samples to be queried.
173175
**predict_kwargs: Keyword arguments to be passed to :meth:`predict` of the CommiteeRegressor.

0 commit comments

Comments
 (0)