1010from sklearn .base import BaseEstimator
1111
1212from modAL .utils .data import modALinput
13- from modAL .utils .selection import multi_argmax
13+ from modAL .utils .selection import multi_argmax , shuffled_argmax
1414from modAL .models .base import BaseCommittee
1515
1616
@@ -103,80 +103,116 @@ def KL_max_disagreement(committee: BaseCommittee, X: modALinput, **predict_proba
103103
104104
105105def vote_entropy_sampling (committee : BaseCommittee , X : modALinput ,
106- n_instances : int = 1 ,** disagreement_measure_kwargs ) -> Tuple [np .ndarray , modALinput ]:
106+ n_instances : int = 1 , random_tie_break = False ,
107+ ** disagreement_measure_kwargs ) -> Tuple [np .ndarray , modALinput ]:
107108 """
108109 Vote entropy sampling strategy.
109110
110111 Args:
111112 committee: The committee for which the labels are to be queried.
112113 X: The pool of samples to query from.
113114 n_instances: Number of samples to be queried.
114- **disagreement_measure_kwargs: Keyword arguments to be passed for the disagreement measure function.
115+ random_tie_break: If True, shuffles utility scores to randomize the order. This
116+ can be used to break the tie when the highest utility score is not unique.
117+ **disagreement_measure_kwargs: Keyword arguments to be passed for the disagreement
118+ measure function.
115119
116120 Returns:
117- The indices of the instances from X chosen to be labelled; the instances from X chosen to be labelled.
121+ The indices of the instances from X chosen to be labelled;
122+ the instances from X chosen to be labelled.
118123 """
119124 disagreement = vote_entropy (committee , X , ** disagreement_measure_kwargs )
120- query_idx = multi_argmax (disagreement , n_instances = n_instances )
125+
126+ if not random_tie_break :
127+ query_idx = multi_argmax (disagreement , n_instances = n_instances )
128+ else :
129+ query_idx = shuffled_argmax (disagreement , n_instances = n_instances )
121130
122131 return query_idx , X [query_idx ]
123132
124133
125134def consensus_entropy_sampling (committee : BaseCommittee , X : modALinput ,
126- n_instances : int = 1 ,** disagreement_measure_kwargs ) -> Tuple [np .ndarray , modALinput ]:
135+ n_instances : int = 1 , random_tie_break = False ,
136+ ** disagreement_measure_kwargs ) -> Tuple [np .ndarray , modALinput ]:
127137 """
128138 Consensus entropy sampling strategy.
129139
130140 Args:
131141 committee: The committee for which the labels are to be queried.
132142 X: The pool of samples to query from.
133143 n_instances: Number of samples to be queried.
134- **disagreement_measure_kwargs: Keyword arguments to be passed for the disagreement measure function.
144+ random_tie_break: If True, shuffles utility scores to randomize the order. This
145+ can be used to break the tie when the highest utility score is not unique.
146+ **disagreement_measure_kwargs: Keyword arguments to be passed for the disagreement
147+ measure function.
135148
136149 Returns:
137- The indices of the instances from X chosen to be labelled; the instances from X chosen to be labelled.
150+ The indices of the instances from X chosen to be labelled;
151+ the instances from X chosen to be labelled.
138152 """
139153 disagreement = consensus_entropy (committee , X , ** disagreement_measure_kwargs )
140- query_idx = multi_argmax (disagreement , n_instances = n_instances )
154+
155+ if not random_tie_break :
156+ query_idx = multi_argmax (disagreement , n_instances = n_instances )
157+ else :
158+ query_idx = shuffled_argmax (disagreement , n_instances = n_instances )
141159
142160 return query_idx , X [query_idx ]
143161
144162
145163def max_disagreement_sampling (committee : BaseCommittee , X : modALinput ,
146- n_instances : int = 1 ,** disagreement_measure_kwargs ) -> Tuple [np .ndarray , modALinput ]:
164+ n_instances : int = 1 , random_tie_break = False ,
165+ ** disagreement_measure_kwargs ) -> Tuple [np .ndarray , modALinput ]:
147166 """
148167 Maximum disagreement sampling strategy.
149168
150169 Args:
151170 committee: The committee for which the labels are to be queried.
152171 X: The pool of samples to query from.
153172 n_instances: Number of samples to be queried.
154- **disagreement_measure_kwargs: Keyword arguments to be passed for the disagreement measure function.
173+ random_tie_break: If True, shuffles utility scores to randomize the order. This
174+ can be used to break the tie when the highest utility score is not unique.
175+ **disagreement_measure_kwargs: Keyword arguments to be passed for the disagreement
176+ measure function.
155177
156178 Returns:
157- The indices of the instances from X chosen to be labelled; the instances from X chosen to be labelled.
179+ The indices of the instances from X chosen to be labelled;
180+ the instances from X chosen to be labelled.
158181 """
159182 disagreement = KL_max_disagreement (committee , X , ** disagreement_measure_kwargs )
160- query_idx = multi_argmax (disagreement , n_instances = n_instances )
183+
184+ if not random_tie_break :
185+ query_idx = multi_argmax (disagreement , n_instances = n_instances )
186+ else :
187+ query_idx = shuffled_argmax (disagreement , n_instances = n_instances )
161188
162189 return query_idx , X [query_idx ]
163190
164191
165192def max_std_sampling (regressor : BaseEstimator , X : modALinput ,
166- n_instances : int = 1 , ** predict_kwargs ) -> Tuple [np .ndarray , modALinput ]:
193+ n_instances : int = 1 , random_tie_break = False ,
194+ ** predict_kwargs ) -> Tuple [np .ndarray , modALinput ]:
167195 """
168196 Regressor standard deviation sampling strategy.
169197
170198 Args:
171199 regressor: The regressor for which the labels are to be queried.
172200 X: The pool of samples to query from.
173201 n_instances: Number of samples to be queried.
202+ random_tie_break: If True, shuffles utility scores to randomize the order. This
203+ can be used to break the tie when the highest utility score is not unique.
174204 **predict_kwargs: Keyword arguments to be passed to :meth:`predict` of the CommiteeRegressor.
175205
176206 Returns:
177- The indices of the instances from X chosen to be labelled; the instances from X chosen to be labelled.
207+ The indices of the instances from X chosen to be labelled;
208+ the instances from X chosen to be labelled.
178209 """
179210 _ , std = regressor .predict (X , return_std = True , ** predict_kwargs )
180211 std = std .reshape (X .shape [0 ], )
181- query_idx = multi_argmax (std , n_instances = n_instances )
212+
213+ if not random_tie_break :
214+ query_idx = multi_argmax (std , n_instances = n_instances )
215+ else :
216+ query_idx = shuffled_argmax (std , n_instances = n_instances )
217+
182218 return query_idx , X [query_idx ]
0 commit comments