Skip to content

Commit 1c7cd51

Browse files
Disable threshold decoding by default
It's possible that this could cause a model to activate less at the 0.5 threshold because the center wasn't adjusted properly
1 parent 9cc135e commit 1c7cd51

File tree

4 files changed

+20
-7
lines changed

4 files changed

+20
-7
lines changed

precise/params.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
- Interpretation of the network output to a confidence value
1919
"""
2020
from math import floor
21+
from typing import Optional
2122

2223
import attr
2324
import json
@@ -68,7 +69,7 @@ class ListenerParams:
6869
use_delta = attr.ib() # type: bool
6970
vectorizer = attr.ib() # type: int
7071
threshold_config = attr.ib() # type: tuple
71-
threshold_center = attr.ib() # type: float
72+
threshold_center = attr.ib() # type: Optional[float]
7273

7374
@property
7475
def buffer_samples(self):
@@ -140,7 +141,7 @@ class Vectorizer:
140141
pr = ListenerParams(
141142
buffer_t=1.5, window_t=0.1, hop_t=0.05, sample_rate=16000,
142143
sample_depth=2, n_fft=512, n_filt=20, n_mfcc=13, use_delta=False,
143-
threshold_config=((6, 4),), threshold_center=0.2, vectorizer=Vectorizer.mfccs
144+
threshold_config=(), threshold_center=None, vectorizer=Vectorizer.mfccs
144145
)
145146

146147
# Used to fill in old param files without new attributes

precise/pocketsphinx/__init__.py

Whitespace-only changes.

precise/pocketsphinx/scripts/__init__.py

Whitespace-only changes.

precise/threshold_decoder.py

Lines changed: 17 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,13 @@ class ThresholdDecoder:
2727
activations using a series of averages and standard deviations to
2828
calculate a cumulative probability distribution
2929
30+
Args:
31+
mu_stds: tuple of pairs of (mean, standard deviation) that model the positive network output
32+
center: proportion of activations that a threshold of 0.5 indicates. Pass as None to disable decoding
33+
resolution: precision of cumulative sum estimation. Increases memory usage
34+
min_z: Minimum z score to generate in distribution map
35+
max_z: Maximum z score to generate in distribution map
36+
3037
Background:
3138
We could simply take the output of the neural network as the confidence of a given
3239
prediction, but this typically jumps quickly between 0.01 and 0.99 even in cases where
@@ -36,14 +43,17 @@ class ThresholdDecoder:
3643
of 80% means that the network output is greater than roughly 80% of the dataset
3744
"""
3845
def __init__(self, mu_stds: Tuple[Tuple[float, float]], center=0.5, resolution=200, min_z=-4, max_z=4):
39-
self.min_out = int(min(mu + min_z * std for mu, std in mu_stds))
40-
self.max_out = int(max(mu + max_z * std for mu, std in mu_stds))
41-
self.out_range = self.max_out - self.min_out
42-
self.cd = np.cumsum(self._calc_pd(mu_stds, resolution))
46+
self.min_out = self.max_out = self.out_range = 0
47+
self.cd = np.array([])
4348
self.center = center
49+
if center is not None:
50+
self.min_out = int(min([mu + min_z * std for mu, std in mu_stds]))
51+
self.max_out = int(max([mu + max_z * std for mu, std in mu_stds]))
52+
self.out_range = self.max_out - self.min_out
53+
self.cd = np.cumsum(self._calc_pd(mu_stds, resolution))
4454

4555
def decode(self, raw_output: float) -> float:
46-
if raw_output == 1.0 or raw_output == 0.0:
56+
if self.center is None or raw_output == 1.0 or raw_output == 0.0:
4757
return raw_output
4858
if self.out_range == 0:
4959
cp = int(raw_output > self.min_out)
@@ -57,6 +67,8 @@ def decode(self, raw_output: float) -> float:
5767
return 0.5 + 0.5 * (cp - self.center) / (1 - self.center)
5868

5969
def encode(self, threshold: float) -> float:
70+
if self.center is None:
71+
return threshold
6072
threshold = 0.5 * threshold / self.center
6173
if threshold < 0.5:
6274
cp = threshold * self.center * 2

0 commit comments

Comments
 (0)