Skip to content

Commit 4cb7060

Browse files
author
Max Keller
committed
Remove mc_dropout_multi
1 parent eb63dc0 commit 4cb7060

File tree

1 file changed

+0
-34
lines changed

1 file changed

+0
-34
lines changed

modAL/dropout.py

Lines changed: 0 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -20,40 +20,6 @@ def default_logits_adaptor(input_tensor: torch.tensor, samples: modALinput):
2020
return input_tensor
2121

2222

23-
def mc_dropout_multi(classifier: BaseEstimator, X: modALinput, query_strategies: list = ["bald", "mean_st", "max_entropy", "max_var"],
24-
n_instances: int = 1, random_tie_break: bool = False, dropout_layer_indexes: list = [],
25-
num_cycles: int = 50, sample_per_forward_pass: int = 1000,
26-
logits_adaptor: Callable[[
27-
torch.tensor, modALinput], torch.tensor] = default_logits_adaptor,
28-
**mc_dropout_kwargs) -> np.ndarray:
29-
"""
30-
Multi metric dropout query strategy. Returns the specified metrics for given input data.
31-
Selection of query strategies are:
32-
- bald: BALD query strategy
33-
- mean_st: Mean Standard deviation
34-
- max_entropy: maximum entropy
35-
- max_var: maximum variation
36-
By default all query strategies are selected
37-
38-
Function returns dictionary of metrics with their name as key.
39-
The indices of the n-best samples (n_instances) is not used in this function.
40-
"""
41-
predictions = get_predictions(
42-
classifier, X, dropout_layer_indexes, num_cycles, sample_per_forward_pass, logits_adaptor)
43-
44-
metrics_dict = {}
45-
if "bald" in query_strategies:
46-
metrics_dict["bald"] = _bald_divergence(predictions)
47-
if "mean_st" in query_strategies:
48-
metrics_dict["mean_st"] = _mean_standard_deviation(predictions)
49-
if "max_entropy" in query_strategies:
50-
metrics_dict["max_entropy"] = _entropy(predictions)
51-
if "max_var" in query_strategies:
52-
metrics_dict["max_var"] = _variation_ratios(predictions)
53-
54-
return None, metrics_dict
55-
56-
5723
def mc_dropout_bald(classifier: BaseEstimator, X: modALinput, n_instances: int = 1,
5824
random_tie_break: bool = False, dropout_layer_indexes: list = [],
5925
num_cycles: int = 50, sample_per_forward_pass: int = 1000,

0 commit comments

Comments
 (0)