File tree Expand file tree Collapse file tree 1 file changed +3
-24
lines changed Expand file tree Collapse file tree 1 file changed +3
-24
lines changed Original file line number Diff line number Diff line change 21
21
from .base_hmc import BaseHMC , HMCStepData , DivergenceInfo
22
22
from .integration import IntegrationError
23
23
from pymc3 .backends .report import SamplerWarning , WarningType
24
+ from pymc3 .math import logbern , log1mexp_numpy , logdiffexp_numpy
24
25
from pymc3 .theanof import floatX
25
26
from pymc3 .vartypes import continuous_types
26
27
27
- __all__ = ["NUTS" ]
28
-
29
-
30
- def logbern (log_p ):
31
- if np .isnan (log_p ):
32
- raise FloatingPointError ("log_p can't be nan." )
33
- return np .log (nr .uniform ()) < log_p
34
-
35
28
36
- def log1mexp_numpy (x ):
37
- """Return log(1 - exp(-x)).
38
- This function is numerically more stable than the naive approach.
39
- For details, see
40
- https://cran.r-project.org/web/packages/Rmpfr/vignettes/log1mexp-note.pdf
41
- """
42
- return np .where (
43
- x < 0.683 ,
44
- np .log (- np .expm1 (- x )),
45
- np .log1p (- np .exp (- x )))
46
-
47
-
48
- def logdiffexp_numpy (a , b ):
49
- """log(exp(a) - exp(b))"""
50
- return a + log1mexp_numpy (a - b )
29
+ __all__ = ["NUTS" ]
51
30
52
31
53
32
class NUTS (BaseHMC ):
@@ -422,7 +401,7 @@ def stats(self):
422
401
if self .log_size > 0 :
423
402
# Remove contribution from initial state which is always a perfect
424
403
# accept
425
- log_sum_weight = logdiffexp_numpy (self .log_size , 0. )
404
+ log_sum_weight = logdiffexp_numpy (self .log_size , 0.0 )
426
405
self .mean_tree_accept = np .exp (self .log_accept_sum - log_sum_weight )
427
406
return {
428
407
"depth" : self .depth ,
You can’t perform that action at this time.
0 commit comments