@@ -92,18 +92,40 @@ def __init__(self, vars=None, scaling=None, step_scale=0.25, is_cov=False, state
92
92
self .m = 1
93
93
94
94
shared = make_shared_replacements (vars , model )
95
- self .leapfrog1_dE = leapfrog1_dE (
96
- model .logpt , vars , shared , self .potential , profile = profile )
95
+
96
+ def create_hamiltonian (vars , shared , model ):
97
+ dlogp = gradient (model .logpt , vars )
98
+ (logp , dlogp ), q = join_nonshared_inputs (
99
+ [model .logpt , dlogp ], vars , shared )
100
+ logp = CallableTensor (logp )
101
+ dlogp = CallableTensor (dlogp )
102
+
103
+ return Hamiltonian (logp , dlogp , self .potential ), q
104
+
105
+ def create_energy_func (q ):
106
+ p = tt .dvector ('p' )
107
+ p .tag .test_value = q .tag .test_value
108
+ E0 = energy (self .H , q , p )
109
+ E0_func = theano .function ([q , p ], E0 )
110
+ E0_func .trust_input = True
111
+
112
+ return E0_func
113
+
114
+ self .H , q = create_hamiltonian (vars , shared , model )
115
+ self .compute_energy = create_energy_func (q )
116
+
117
+ self .leapfrog1_dE = leapfrog1_dE (self .H , q , profile = profile )
97
118
98
119
super (NUTS , self ).__init__ (vars , shared , ** kwargs )
99
120
100
121
def astep (self , q0 ):
101
- # Hamiltonian(self.logp, self.dlogp, self.potential)
102
- H = self .leapfrog1_dE
122
+ leapfrog = self .leapfrog1_dE
103
123
Emax = self .Emax
104
124
e = self .step_size
105
125
106
126
p0 = self .potential .random ()
127
+ E0 = self .compute_energy (q0 , p0 )
128
+
107
129
u = uniform ()
108
130
q = qn = qp = q0
109
131
p = pn = pp = p0
@@ -115,10 +137,10 @@ def astep(self, q0):
115
137
116
138
if v == - 1 :
117
139
qn , pn , _ , _ , q1 , n1 , s1 , a , na = buildtree (
118
- H , qn , pn , u , v , j , e , Emax , q0 , p0 )
140
+ leapfrog , qn , pn , u , v , j , e , Emax , E0 )
119
141
else :
120
142
_ , _ , qp , pp , q1 , n1 , s1 , a , na = buildtree (
121
- H , qp , pp , u , v , j , e , Emax , q0 , p0 )
143
+ leapfrog , qp , pp , u , v , j , e , Emax , E0 )
122
144
123
145
if s1 == 1 and bern (min (1 , n1 * 1. / n )):
124
146
q = q1
@@ -147,24 +169,23 @@ def competence(var):
147
169
return Competence .INCOMPATIBLE
148
170
149
171
150
- def buildtree (H , q , p , u , v , j , e , Emax , q0 , p0 ):
172
+ def buildtree (leapfrog1_dE , q , p , u , v , j , e , Emax , E0 ):
151
173
if j == 0 :
152
- leapfrog1_dE = H
153
- q1 , p1 , dE = leapfrog1_dE (q , p , array (v * e ), q0 , p0 )
174
+ q1 , p1 , dE = leapfrog1_dE (q , p , array (v * e ), E0 )
154
175
155
176
n1 = int (log (u ) + dE <= 0 )
156
177
s1 = int (log (u ) + dE < Emax )
157
178
return q1 , p1 , q1 , p1 , q1 , n1 , s1 , min (1 , exp (- dE )), 1
158
179
else :
159
180
qn , pn , qp , pp , q1 , n1 , s1 , a1 , na1 = buildtree (
160
- H , q , p , u , v , j - 1 , e , Emax , q0 , p0 )
181
+ leapfrog1_dE , q , p , u , v , j - 1 , e , Emax , E0 )
161
182
if s1 == 1 :
162
183
if v == - 1 :
163
184
qn , pn , _ , _ , q11 , n11 , s11 , a11 , na11 = buildtree (
164
- H , qn , pn , u , v , j - 1 , e , Emax , q0 , p0 )
185
+ leapfrog1_dE , qn , pn , u , v , j - 1 , e , Emax , E0 )
165
186
else :
166
187
_ , _ , qp , pp , q11 , n11 , s11 , a11 , na11 = buildtree (
167
- H , qp , pp , u , v , j - 1 , e , Emax , q0 , p0 )
188
+ leapfrog1_dE , qp , pp , u , v , j - 1 , e , Emax , E0 )
168
189
169
190
if bern (n11 * 1. / (max (n1 + n11 , 1 ))):
170
191
q1 = q11
@@ -179,44 +200,33 @@ def buildtree(H, q, p, u, v, j, e, Emax, q0, p0):
179
200
return
180
201
181
202
182
- def leapfrog1_dE (logp , vars , shared , pot , profile ):
203
+ def leapfrog1_dE (H , q , profile ):
183
204
"""Computes a theano function that computes one leapfrog step and the energy difference between the beginning and end of the trajectory.
184
205
Parameters
185
206
----------
186
- logp : TensorVariable
187
- vars : list of tensor variables
188
- shared : list of shared variables not to compute leapfrog over
189
- pot : quadpotential
190
- porifle : Boolean
207
+ H : Hamiltonian
208
+ q : theano.tensor
209
+ profile : Boolean
191
210
192
211
Returns
193
212
-------
194
213
theano function which returns
195
- q_new, p_new, delta_E
214
+ q_new, p_new, dE
196
215
"""
197
- dlogp = gradient (logp , vars )
198
- (logp , dlogp ), q = join_nonshared_inputs ([logp , dlogp ], vars , shared )
199
- logp = CallableTensor (logp )
200
- dlogp = CallableTensor (dlogp )
201
-
202
- H = Hamiltonian (logp , dlogp , pot )
203
-
204
216
p = tt .dvector ('p' )
205
217
p .tag .test_value = q .tag .test_value
206
218
207
- q0 = tt .dvector ('q0' )
208
- q0 .tag .test_value = q .tag .test_value
209
- p0 = tt .dvector ('p0' )
210
- p0 .tag .test_value = p .tag .test_value
211
-
212
219
e = tt .dscalar ('e' )
213
220
e .tag .test_value = 1
214
221
215
222
q1 , p1 = leapfrog (H , q , p , 1 , e )
216
223
E = energy (H , q1 , p1 )
217
- E0 = energy (H , q0 , p0 )
224
+
225
+ E0 = tt .dscalar ('E0' )
226
+ E0 .tag .test_value = 1
227
+
218
228
dE = E - E0
219
229
220
- f = theano .function ([q , p , e , q0 , p0 ], [q1 , p1 , dE ], profile = profile )
230
+ f = theano .function ([q , p , e , E0 ], [q1 , p1 , dE ], profile = profile )
221
231
f .trust_input = True
222
232
return f
0 commit comments