1
1
from .quadpotential import quad_potential
2
- from .arraystep import ArrayStepShared , ArrayStep , SamplerHist , Competence
2
+ from .arraystep import ArrayStepShared , SamplerHist , Competence
3
3
from ..model import modelcontext , Point
4
4
from ..vartypes import continuous_types
5
- from numpy import exp , log , array
6
- from numpy .random import uniform
7
- from .hmc import leapfrog , Hamiltonian , bern , energy
5
+ from .hmc import leapfrog , Hamiltonian , energy , bern
8
6
from ..tuning import guess_scaling
7
+ import numpy as np
8
+ import numpy .random as nr
9
9
import theano
10
10
from ..theanof import (make_shared_replacements , join_nonshared_inputs , CallableTensor ,
11
11
gradient , inputvars )
@@ -26,7 +26,7 @@ class NUTS(ArrayStepShared):
26
26
default_blocked = True
27
27
28
28
def __init__ (self , vars = None , scaling = None , step_scale = 0.25 , is_cov = False , state = None ,
29
- Emax = 1000 ,
29
+ max_energy = 1000 ,
30
30
target_accept = 0.8 ,
31
31
gamma = 0.05 ,
32
32
k = 0.75 ,
@@ -42,10 +42,11 @@ def __init__(self, vars=None, scaling=None, step_scale=0.25, is_cov=False, state
42
42
step_scale : float, default=.25
43
43
Size of steps to take, automatically scaled down by 1/n**(1/4)
44
44
is_cov : bool, default=False
45
- Treat C as a covariance matrix/vector if True, else treat it as a precision matrix/vector
45
+ Treat C as a covariance matrix/vector if True, else treat it as a
46
+ precision matrix/vector
46
47
state
47
48
state to start from
48
- Emax : float, default 1000
49
+ max_energy : float, default 1000
49
50
maximum energy
50
51
target_accept : float (0,1) default .8
51
52
target for avg accept probability between final branch and initial position
@@ -68,27 +69,19 @@ def __init__(self, vars=None, scaling=None, step_scale=0.25, is_cov=False, state
68
69
scaling = model .test_point
69
70
70
71
if isinstance (scaling , dict ):
71
- scaling = guess_scaling (
72
- Point (scaling , model = model ), model = model , vars = vars )
73
-
74
- n = scaling .shape [0 ]
75
-
76
- self .step_size = step_scale / n ** (1 / 4. )
77
-
72
+ scaling = guess_scaling (Point (scaling , model = model ), model = model , vars = vars )
73
+ self .step_size = step_scale / scaling .shape [0 ]** 0.25
78
74
self .potential = quad_potential (scaling , is_cov , as_cov = False )
79
-
80
75
if state is None :
81
76
state = SamplerHist ()
82
77
self .state = state
83
- self .Emax = Emax
84
-
78
+ self .max_energy = max_energy
85
79
self .target_accept = target_accept
86
80
self .gamma = gamma
87
81
self .t0 = t0
88
82
self .k = k
89
-
90
- self .Hbar = 0
91
- self .u = log (self .step_size * 10 )
83
+ self .h_bar = 0
84
+ self .u = np .log (self .step_size * 10 )
92
85
self .m = 1
93
86
94
87
shared = make_shared_replacements (vars , model )
@@ -97,97 +90,80 @@ def __init__(self, vars=None, scaling=None, step_scale=0.25, is_cov=False, state
97
90
98
91
super (NUTS , self ).__init__ (vars , shared , ** kwargs )
99
92
100
- def astep (self , q0 ):
101
- # Hamiltonian(self.logp, self.dlogp, self.potential)
102
- H = self .leapfrog1_dE
103
- Emax = self .Emax
104
- e = self .step_size
105
-
106
- p0 = self .potential .random ()
107
- u = uniform ()
108
- q = qn = qp = q0
109
- p = pn = pp = p0
110
-
111
- n , s , j = 1 , 1 , 0
112
-
113
- while s == 1 :
114
- v = bern (.5 ) * 2 - 1
115
-
116
- if v == - 1 :
117
- qn , pn , _ , _ , q1 , n1 , s1 , a , na = buildtree (
118
- H , qn , pn , u , v , j , e , Emax , q0 , p0 )
119
- else :
120
- _ , _ , qp , pp , q1 , n1 , s1 , a , na = buildtree (
121
- H , qp , pp , u , v , j , e , Emax , q0 , p0 )
122
-
123
- if s1 == 1 and bern (min (1 , n1 * 1. / n )):
124
- q = q1
93
+ @staticmethod
94
+ def competence (var ):
95
+ if var .dtype in continuous_types :
96
+ return Competence .IDEAL
97
+ return Competence .INCOMPATIBLE
125
98
126
- n = n + n1
99
+ def astep (self , initial_position ):
100
+ log_slice_var = np .log (nr .uniform ())
101
+ initial_momentum = self .potential .random ()
102
+ position = back_position = forward_position = initial_position
103
+ back_momentum = forward_momentum = initial_momentum
104
+ should_continue = True
105
+ trials = 1
106
+ depth = 0
107
+ while should_continue :
108
+ direction = nr .choice ((- 1 , 1 ))
109
+ step = np .array (direction * self .step_size )
110
+ new_trials = 0
111
+ metropolis_acceptance = 0
112
+ steps = 0
113
+ for _ in range (2 ** depth ):
114
+ if not should_continue :
115
+ break
116
+ if direction == 1 :
117
+ forward_position , forward_momentum , energy_change = self .leapfrog1_dE (
118
+ forward_position , forward_momentum , step ,
119
+ initial_position , initial_momentum )
120
+ else :
121
+ back_position , back_momentum , energy_change = self .leapfrog1_dE (
122
+ back_position , back_momentum , step , initial_position , initial_momentum )
123
+ new_trials += int (log_slice_var + energy_change <= 0 )
124
+ if should_update_position (new_trials , trials ):
125
+ if direction == 1 :
126
+ position = forward_position
127
+ else :
128
+ position = back_position
129
+
130
+ should_continue = (self ._energy_is_bounded (log_slice_var , energy_change ) and
131
+ no_u_turns (forward_position , forward_momentum ,
132
+ back_position , back_momentum ))
133
+ metropolis_acceptance += min (1. , np .exp (- energy_change ))
134
+ steps += 1
135
+ trials += new_trials
136
+ depth += 1
137
+ w = 1. / (self .m + self .t0 )
138
+ self .h_bar = (1 - w ) * self .h_bar + w * (self .target_accept - metropolis_acceptance / steps )
139
+ self .step_size = np .exp (self .u - (self .m ** 0.5 / self .gamma ) * self .h_bar )
140
+ self .m += 1
141
+ return position
127
142
128
- span = qp - qn
129
- s = s1 * (span .dot (pn ) >= 0 ) * (span .dot (pp ) >= 0 )
130
- j = j + 1
143
+ def _energy_is_bounded (self , log_slice_var , energy_change ):
144
+ return log_slice_var + energy_change < self .max_energy
131
145
132
- p = - p
133
146
134
- w = 1. / ( self . m + self . t0 )
135
- self . Hbar = ( 1 - w ) * self . Hbar + w * \
136
- ( self . target_accept - a * 1. / na )
147
+ def no_u_turns ( forward_position , forward_momentum , back_position , back_momentum ):
148
+ span = forward_position - back_position
149
+ return span . dot ( back_momentum ) >= 0 and span . dot ( forward_momentum ) >= 0
137
150
138
- self .step_size = exp (self .u - (self .m ** .5 / self .gamma ) * self .Hbar )
139
- self .m += 1
140
151
141
- return q
152
+ def should_update_position (new_trials , trials ):
153
+ return bern (float (new_trials ) / max (trials , 1. ))
142
154
143
- @staticmethod
144
- def competence (var ):
145
- if var .dtype in continuous_types :
146
- return Competence .IDEAL
147
- return Competence .INCOMPATIBLE
148
155
156
+ def leapfrog1_dE (logp , vars , shared , quad_potential , profile ):
157
+ """Computes a theano function that computes one leapfrog step and the energy
158
+ difference between the beginning and end of the trajectory.
149
159
150
- def buildtree (H , q , p , u , v , j , e , Emax , q0 , p0 ):
151
- if j == 0 :
152
- leapfrog1_dE = H
153
- q1 , p1 , dE = leapfrog1_dE (q , p , array (v * e ), q0 , p0 )
154
-
155
- n1 = int (log (u ) + dE <= 0 )
156
- s1 = int (log (u ) + dE < Emax )
157
- return q1 , p1 , q1 , p1 , q1 , n1 , s1 , min (1 , exp (- dE )), 1
158
- else :
159
- qn , pn , qp , pp , q1 , n1 , s1 , a1 , na1 = buildtree (
160
- H , q , p , u , v , j - 1 , e , Emax , q0 , p0 )
161
- if s1 == 1 :
162
- if v == - 1 :
163
- qn , pn , _ , _ , q11 , n11 , s11 , a11 , na11 = buildtree (
164
- H , qn , pn , u , v , j - 1 , e , Emax , q0 , p0 )
165
- else :
166
- _ , _ , qp , pp , q11 , n11 , s11 , a11 , na11 = buildtree (
167
- H , qp , pp , u , v , j - 1 , e , Emax , q0 , p0 )
168
-
169
- if bern (n11 * 1. / (max (n1 + n11 , 1 ))):
170
- q1 = q11
171
-
172
- a1 = a1 + a11
173
- na1 = na1 + na11
174
-
175
- span = qp - qn
176
- s1 = s11 * (span .dot (pn ) >= 0 ) * (span .dot (pp ) >= 0 )
177
- n1 = n1 + n11
178
- return qn , pn , qp , pp , q1 , n1 , s1 , a1 , na1
179
- return
180
-
181
-
182
- def leapfrog1_dE (logp , vars , shared , pot , profile ):
183
- """Computes a theano function that computes one leapfrog step and the energy difference between the beginning and end of the trajectory.
184
160
Parameters
185
161
----------
186
162
logp : TensorVariable
187
163
vars : list of tensor variables
188
164
shared : list of shared variables not to compute leapfrog over
189
- pot : quadpotential
190
- porifle : Boolean
165
+ quad_potential : quadpotential
166
+ profile : Boolean
191
167
192
168
Returns
193
169
-------
@@ -199,7 +175,7 @@ def leapfrog1_dE(logp, vars, shared, pot, profile):
199
175
logp = CallableTensor (logp )
200
176
dlogp = CallableTensor (dlogp )
201
177
202
- H = Hamiltonian (logp , dlogp , pot )
178
+ hamiltonian = Hamiltonian (logp , dlogp , quad_potential )
203
179
204
180
p = tt .dvector ('p' )
205
181
p .tag .test_value = q .tag .test_value
@@ -212,11 +188,9 @@ def leapfrog1_dE(logp, vars, shared, pot, profile):
212
188
e = tt .dscalar ('e' )
213
189
e .tag .test_value = 1
214
190
215
- q1 , p1 = leapfrog (H , q , p , 1 , e )
216
- E = energy (H , q1 , p1 )
217
- E0 = energy (H , q0 , p0 )
218
- dE = E - E0
191
+ q1 , p1 = leapfrog (hamiltonian , q , p , 1 , e )
192
+ energy_change = energy (hamiltonian , q1 , p1 ) - energy (hamiltonian , q0 , p0 )
219
193
220
- f = theano .function ([q , p , e , q0 , p0 ], [q1 , p1 , dE ], profile = profile )
194
+ f = theano .function ([q , p , e , q0 , p0 ], [q1 , p1 , energy_change ], profile = profile )
221
195
f .trust_input = True
222
196
return f
0 commit comments