1111
1212from modAL .utils .data import modALinput
1313from modAL .utils .selection import multi_argmax
14+ from .models .base import BaseCommittee
1415
1516
16- def vote_entropy (committee , X : modALinput , ** predict_proba_kwargs ) -> np .ndarray :
17+ def vote_entropy (committee : BaseCommittee , X : modALinput , ** predict_proba_kwargs ) -> np .ndarray :
1718 """
1819 Calculates the vote entropy for the Committee. First it computes the predictions of X for each learner in the
1920 Committee, then calculates the probability distribution of the votes. The entropy of this distribution is the vote
@@ -47,7 +48,7 @@ def vote_entropy(committee, X: modALinput, **predict_proba_kwargs) -> np.ndarray
4748 return entr
4849
4950
50- def consensus_entropy (committee , X : modALinput , ** predict_proba_kwargs ) -> np .ndarray :
51+ def consensus_entropy (committee : BaseCommittee , X : modALinput , ** predict_proba_kwargs ) -> np .ndarray :
5152 """
5253 Calculates the consensus entropy for the Committee. First it computes the class probabilties of X for each learner
5354 in the Committee, then calculates the consensus probability distribution by averaging the individual class
@@ -71,7 +72,7 @@ def consensus_entropy(committee, X: modALinput, **predict_proba_kwargs) -> np.nd
7172 return entr
7273
7374
74- def KL_max_disagreement (committee , X : modALinput , ** predict_proba_kwargs ) -> np .ndarray :
75+ def KL_max_disagreement (committee : BaseCommittee , X : modALinput , ** predict_proba_kwargs ) -> np .ndarray :
7576 """
7677 Calculates the max disagreement for the Committee. First it computes the class probabilties of X for each learner in
7778 the Committee, then calculates the consensus probability distribution by averaging the individual class
@@ -101,7 +102,7 @@ def KL_max_disagreement(committee, X: modALinput, **predict_proba_kwargs) -> np.
101102 return np .max (learner_KL_div , axis = 1 )
102103
103104
104- def vote_entropy_sampling (committee , X : modALinput ,
105+ def vote_entropy_sampling (committee : BaseCommittee , X : modALinput ,
105106 n_instances : int = 1 ,** disagreement_measure_kwargs ) -> Tuple [np .ndarray , modALinput ]:
106107 """
107108 Vote entropy sampling strategy.
@@ -121,7 +122,7 @@ def vote_entropy_sampling(committee, X: modALinput,
121122 return query_idx , X [query_idx ]
122123
123124
124- def consensus_entropy_sampling (committee , X : modALinput ,
125+ def consensus_entropy_sampling (committee : BaseCommittee , X : modALinput ,
125126 n_instances : int = 1 ,** disagreement_measure_kwargs ) -> Tuple [np .ndarray , modALinput ]:
126127 """
127128 Consensus entropy sampling strategy.
@@ -141,7 +142,7 @@ def consensus_entropy_sampling(committee, X: modALinput,
141142 return query_idx , X [query_idx ]
142143
143144
144- def max_disagreement_sampling (committee , X : modALinput ,
145+ def max_disagreement_sampling (committee : BaseCommittee , X : modALinput ,
145146 n_instances : int = 1 ,** disagreement_measure_kwargs ) -> Tuple [np .ndarray , modALinput ]:
146147 """
147148 Maximum disagreement sampling strategy.
0 commit comments