@@ -27,6 +27,13 @@ class ThresholdDecoder:
2727 activations using a series of averages and standard deviations to
2828 calculate a cumulative probability distribution
2929
30+ Args:
31+ mu_stds: tuple of pairs of (mean, standard deviation) that model the positive network output
32+ center: proportion of activations that a threshold of 0.5 indicates. Pass as None to disable decoding
33+ resolution: precision of cumulative sum estimation. Increases memory usage
34+ min_z: Minimum z score to generate in distribution map
35+ max_z: Maximum z score to generate in distribution map
36+
3037 Background:
3138 We could simply take the output of the neural network as the confidence of a given
3239 prediction, but this typically jumps quickly between 0.01 and 0.99 even in cases where
@@ -36,14 +43,17 @@ class ThresholdDecoder:
3643 of 80% means that the network output is greater than roughly 80% of the dataset
3744 """
3845 def __init__ (self , mu_stds : Tuple [Tuple [float , float ]], center = 0.5 , resolution = 200 , min_z = - 4 , max_z = 4 ):
39- self .min_out = int (min (mu + min_z * std for mu , std in mu_stds ))
40- self .max_out = int (max (mu + max_z * std for mu , std in mu_stds ))
41- self .out_range = self .max_out - self .min_out
42- self .cd = np .cumsum (self ._calc_pd (mu_stds , resolution ))
46+ self .min_out = self .max_out = self .out_range = 0
47+ self .cd = np .array ([])
4348 self .center = center
49+ if center is not None :
50+ self .min_out = int (min ([mu + min_z * std for mu , std in mu_stds ]))
51+ self .max_out = int (max ([mu + max_z * std for mu , std in mu_stds ]))
52+ self .out_range = self .max_out - self .min_out
53+ self .cd = np .cumsum (self ._calc_pd (mu_stds , resolution ))
4454
4555 def decode (self , raw_output : float ) -> float :
46- if raw_output == 1.0 or raw_output == 0.0 :
56+ if self . center is None or raw_output == 1.0 or raw_output == 0.0 :
4757 return raw_output
4858 if self .out_range == 0 :
4959 cp = int (raw_output > self .min_out )
@@ -57,6 +67,8 @@ def decode(self, raw_output: float) -> float:
5767 return 0.5 + 0.5 * (cp - self .center ) / (1 - self .center )
5868
5969 def encode (self , threshold : float ) -> float :
70+ if self .center is None :
71+ return threshold
6072 threshold = 0.5 * threshold / self .center
6173 if threshold < 0.5 :
6274 cp = threshold * self .center * 2
0 commit comments