@@ -20,40 +20,6 @@ def default_logits_adaptor(input_tensor: torch.tensor, samples: modALinput):
2020 return input_tensor
2121
2222
23- def mc_dropout_multi (classifier : BaseEstimator , X : modALinput , query_strategies : list = ["bald" , "mean_st" , "max_entropy" , "max_var" ],
24- n_instances : int = 1 , random_tie_break : bool = False , dropout_layer_indexes : list = [],
25- num_cycles : int = 50 , sample_per_forward_pass : int = 1000 ,
26- logits_adaptor : Callable [[
27- torch .tensor , modALinput ], torch .tensor ] = default_logits_adaptor ,
28- ** mc_dropout_kwargs ) -> np .ndarray :
29- """
30- Multi metric dropout query strategy. Returns the specified metrics for given input data.
31- Selection of query strategies are:
32- - bald: BALD query strategy
33- - mean_st: Mean Standard deviation
34- - max_entropy: maximum entropy
35- - max_var: maximum variation
36- By default all query strategies are selected
37-
38- Function returns dictionary of metrics with their name as key.
39- The indices of the n-best samples (n_instances) is not used in this function.
40- """
41- predictions = get_predictions (
42- classifier , X , dropout_layer_indexes , num_cycles , sample_per_forward_pass , logits_adaptor )
43-
44- metrics_dict = {}
45- if "bald" in query_strategies :
46- metrics_dict ["bald" ] = _bald_divergence (predictions )
47- if "mean_st" in query_strategies :
48- metrics_dict ["mean_st" ] = _mean_standard_deviation (predictions )
49- if "max_entropy" in query_strategies :
50- metrics_dict ["max_entropy" ] = _entropy (predictions )
51- if "max_var" in query_strategies :
52- metrics_dict ["max_var" ] = _variation_ratios (predictions )
53-
54- return None , metrics_dict
55-
56-
5723def mc_dropout_bald (classifier : BaseEstimator , X : modALinput , n_instances : int = 1 ,
5824 random_tie_break : bool = False , dropout_layer_indexes : list = [],
5925 num_cycles : int = 50 , sample_per_forward_pass : int = 1000 ,
0 commit comments