55import numpy as np
66
77
8+ def shuffled_argmax (values : np .ndarray , n_instances : int = 1 ) -> np .ndarray :
9+ """
10+ Shuffles the values and sorts them afterwards. This can be used to break
11+ the tie when the highest utility score is not unique. The shuffle randomizes
12+ order, which is preserved by the mergesort algorithm.
13+
14+ Args:
15+ values:
16+ n_instances:
17+
18+ Args:
19+ values: Contains the values to be selected from.
20+ n_instances: Specifies how many indices to return.
21+
22+ Returns:
23+ The indices of the n_instances largest values.
24+ """
25+
26+ # shuffling indices and corresponding values
27+ shuffled_idx = np .random .permutation (len (values ))
28+ shuffled_values = values [shuffled_idx ]
29+
30+ # getting the n_instances best instance
31+ # since mergesort is used, the shuffled order is preserved
32+ sorted_query_idx = np .argsort (shuffled_values , kind = 'mergesort' )[:n_instances ]
33+
34+ # inverting the shuffle
35+ query_idx = shuffled_idx [sorted_query_idx ]
36+ return query_idx
37+
38+
839def multi_argmax (values : np .ndarray , n_instances : int = 1 ) -> np .ndarray :
940 """
1041 Selects the indices of the n_instances highest values.
@@ -14,7 +45,7 @@ def multi_argmax(values: np.ndarray, n_instances: int = 1) -> np.ndarray:
1445 n_instances: Specifies how many indices to return.
1546
1647 Returns:
17- Contains the indices of the n_instances largest values.
48+ The indices of the n_instances largest values.
1849 """
1950 assert n_instances <= values .shape [0 ], 'n_instances must be less or equal than the size of utility'
2051
0 commit comments