Skip to content

Commit 2a50057

Browse files
committed
fix: type annotations added for modAL.disagreement
1 parent 6df2fc1 commit 2a50057

File tree

1 file changed

+7
-6
lines changed

1 file changed

+7
-6
lines changed

modAL/disagreement.py

Lines changed: 7 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -11,9 +11,10 @@
1111

1212
from modAL.utils.data import modALinput
1313
from modAL.utils.selection import multi_argmax
14+
from .models.base import BaseCommittee
1415

1516

16-
def vote_entropy(committee, X: modALinput, **predict_proba_kwargs) -> np.ndarray:
17+
def vote_entropy(committee: BaseCommittee, X: modALinput, **predict_proba_kwargs) -> np.ndarray:
1718
"""
1819
Calculates the vote entropy for the Committee. First it computes the predictions of X for each learner in the
1920
Committee, then calculates the probability distribution of the votes. The entropy of this distribution is the vote
@@ -47,7 +48,7 @@ def vote_entropy(committee, X: modALinput, **predict_proba_kwargs) -> np.ndarray
4748
return entr
4849

4950

50-
def consensus_entropy(committee, X: modALinput, **predict_proba_kwargs) -> np.ndarray:
51+
def consensus_entropy(committee: BaseCommittee, X: modALinput, **predict_proba_kwargs) -> np.ndarray:
5152
"""
5253
Calculates the consensus entropy for the Committee. First it computes the class probabilties of X for each learner
5354
in the Committee, then calculates the consensus probability distribution by averaging the individual class
@@ -71,7 +72,7 @@ def consensus_entropy(committee, X: modALinput, **predict_proba_kwargs) -> np.nd
7172
return entr
7273

7374

74-
def KL_max_disagreement(committee, X: modALinput, **predict_proba_kwargs) -> np.ndarray:
75+
def KL_max_disagreement(committee: BaseCommittee, X: modALinput, **predict_proba_kwargs) -> np.ndarray:
7576
"""
7677
Calculates the max disagreement for the Committee. First it computes the class probabilties of X for each learner in
7778
the Committee, then calculates the consensus probability distribution by averaging the individual class
@@ -101,7 +102,7 @@ def KL_max_disagreement(committee, X: modALinput, **predict_proba_kwargs) -> np.
101102
return np.max(learner_KL_div, axis=1)
102103

103104

104-
def vote_entropy_sampling(committee, X: modALinput,
105+
def vote_entropy_sampling(committee: BaseCommittee, X: modALinput,
105106
n_instances: int = 1,**disagreement_measure_kwargs) -> Tuple[np.ndarray, modALinput]:
106107
"""
107108
Vote entropy sampling strategy.
@@ -121,7 +122,7 @@ def vote_entropy_sampling(committee, X: modALinput,
121122
return query_idx, X[query_idx]
122123

123124

124-
def consensus_entropy_sampling(committee, X: modALinput,
125+
def consensus_entropy_sampling(committee: BaseCommittee, X: modALinput,
125126
n_instances: int = 1,**disagreement_measure_kwargs) -> Tuple[np.ndarray, modALinput]:
126127
"""
127128
Consensus entropy sampling strategy.
@@ -141,7 +142,7 @@ def consensus_entropy_sampling(committee, X: modALinput,
141142
return query_idx, X[query_idx]
142143

143144

144-
def max_disagreement_sampling(committee, X: modALinput,
145+
def max_disagreement_sampling(committee: BaseCommittee, X: modALinput,
145146
n_instances: int = 1,**disagreement_measure_kwargs) -> Tuple[np.ndarray, modALinput]:
146147
"""
147148
Maximum disagreement sampling strategy.

0 commit comments

Comments
 (0)