Skip to content

Commit 61ad80d

Browse files
committed
add: shuffled_argmax added
1 parent 8f856fd commit 61ad80d

File tree

1 file changed

+32
-1
lines changed

1 file changed

+32
-1
lines changed

modAL/utils/selection.py

Lines changed: 32 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,37 @@
55
import numpy as np
66

77

8+
def shuffled_argmax(values: np.ndarray, n_instances: int = 1) -> np.ndarray:
9+
"""
10+
Shuffles the values and sorts them afterwards. This can be used to break
11+
the tie when the highest utility score is not unique. The shuffle randomizes
12+
order, which is preserved by the mergesort algorithm.
13+
14+
Args:
15+
values:
16+
n_instances:
17+
18+
Args:
19+
values: Contains the values to be selected from.
20+
n_instances: Specifies how many indices to return.
21+
22+
Returns:
23+
The indices of the n_instances largest values.
24+
"""
25+
26+
# shuffling indices and corresponding values
27+
shuffled_idx = np.random.permutation(len(values))
28+
shuffled_values = values[shuffled_idx]
29+
30+
# getting the n_instances best instance
31+
# since mergesort is used, the shuffled order is preserved
32+
sorted_query_idx = np.argsort(shuffled_values, kind='mergesort')[:n_instances]
33+
34+
# inverting the shuffle
35+
query_idx = shuffled_idx[sorted_query_idx]
36+
return query_idx
37+
38+
839
def multi_argmax(values: np.ndarray, n_instances: int = 1) -> np.ndarray:
940
"""
1041
Selects the indices of the n_instances highest values.
@@ -14,7 +45,7 @@ def multi_argmax(values: np.ndarray, n_instances: int = 1) -> np.ndarray:
1445
n_instances: Specifies how many indices to return.
1546
1647
Returns:
17-
Contains the indices of the n_instances largest values.
48+
The indices of the n_instances largest values.
1849
"""
1950
assert n_instances <= values.shape[0], 'n_instances must be less or equal than the size of utility'
2051

0 commit comments

Comments
 (0)