Skip to content

Commit 7a4d47c

Browse files
committed
move function to pymc3.math (2)
1 parent 0ebedb0 commit 7a4d47c

File tree

1 file changed

+3
-24
lines changed

1 file changed

+3
-24
lines changed

pymc3/step_methods/hmc/nuts.py

Lines changed: 3 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -21,33 +21,12 @@
2121
from .base_hmc import BaseHMC, HMCStepData, DivergenceInfo
2222
from .integration import IntegrationError
2323
from pymc3.backends.report import SamplerWarning, WarningType
24+
from pymc3.math import logbern, log1mexp_numpy, logdiffexp_numpy
2425
from pymc3.theanof import floatX
2526
from pymc3.vartypes import continuous_types
2627

27-
__all__ = ["NUTS"]
28-
29-
30-
def logbern(log_p):
31-
if np.isnan(log_p):
32-
raise FloatingPointError("log_p can't be nan.")
33-
return np.log(nr.uniform()) < log_p
34-
3528

36-
def log1mexp_numpy(x):
37-
"""Return log(1 - exp(-x)).
38-
This function is numerically more stable than the naive approach.
39-
For details, see
40-
https://cran.r-project.org/web/packages/Rmpfr/vignettes/log1mexp-note.pdf
41-
"""
42-
return np.where(
43-
x < 0.683,
44-
np.log(-np.expm1(-x)),
45-
np.log1p(-np.exp(-x)))
46-
47-
48-
def logdiffexp_numpy(a, b):
49-
"""log(exp(a) - exp(b))"""
50-
return a + log1mexp_numpy(a - b)
29+
__all__ = ["NUTS"]
5130

5231

5332
class NUTS(BaseHMC):
@@ -422,7 +401,7 @@ def stats(self):
422401
if self.log_size > 0:
423402
# Remove contribution from initial state which is always a perfect
424403
# accept
425-
log_sum_weight = logdiffexp_numpy(self.log_size, 0.)
404+
log_sum_weight = logdiffexp_numpy(self.log_size, 0.0)
426405
self.mean_tree_accept = np.exp(self.log_accept_sum - log_sum_weight)
427406
return {
428407
"depth": self.depth,

0 commit comments

Comments
 (0)