Skip to content

Commit cbb449a

Browse files
committed
Resolve mergeconflict
1 parent 67a66a8 commit cbb449a

File tree

1 file changed

+103
-77
lines changed

1 file changed

+103
-77
lines changed

pymc3/step_methods/nuts.py

Lines changed: 103 additions & 77 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,11 @@
11
from .quadpotential import quad_potential
2-
from .arraystep import ArrayStepShared, SamplerHist, Competence
2+
from .arraystep import ArrayStepShared, ArrayStep, SamplerHist, Competence
33
from ..model import modelcontext, Point
44
from ..vartypes import continuous_types
5-
from .hmc import leapfrog, Hamiltonian, energy, bern
5+
from numpy import exp, log, array
6+
from numpy.random import uniform
7+
from .hmc import leapfrog, Hamiltonian, bern, energy
68
from ..tuning import guess_scaling
7-
import numpy as np
8-
import numpy.random as nr
99
import theano
1010
from ..theanof import (make_shared_replacements, join_nonshared_inputs, CallableTensor,
1111
gradient, inputvars)
@@ -26,7 +26,7 @@ class NUTS(ArrayStepShared):
2626
default_blocked = True
2727

2828
def __init__(self, vars=None, scaling=None, step_scale=0.25, is_cov=False, state=None,
29-
max_energy=1000,
29+
Emax=1000,
3030
target_accept=0.8,
3131
gamma=0.05,
3232
k=0.75,
@@ -42,11 +42,10 @@ def __init__(self, vars=None, scaling=None, step_scale=0.25, is_cov=False, state
4242
step_scale : float, default=.25
4343
Size of steps to take, automatically scaled down by 1/n**(1/4)
4444
is_cov : bool, default=False
45-
Treat C as a covariance matrix/vector if True, else treat it as a
46-
precision matrix/vector
45+
Treat C as a covariance matrix/vector if True, else treat it as a precision matrix/vector
4746
state
4847
state to start from
49-
max_energy : float, default 1000
48+
Emax : float, default 1000
5049
maximum energy
5150
target_accept : float (0,1) default .8
5251
target for avg accept probability between final branch and initial position
@@ -69,19 +68,27 @@ def __init__(self, vars=None, scaling=None, step_scale=0.25, is_cov=False, state
6968
scaling = model.test_point
7069

7170
if isinstance(scaling, dict):
72-
scaling = guess_scaling(Point(scaling, model=model), model=model, vars=vars)
73-
self.step_size = step_scale / scaling.shape[0]**0.25
71+
scaling = guess_scaling(
72+
Point(scaling, model=model), model=model, vars=vars)
73+
74+
n = scaling.shape[0]
75+
76+
self.step_size = step_scale / n**(1 / 4.)
77+
7478
self.potential = quad_potential(scaling, is_cov, as_cov=False)
79+
7580
if state is None:
7681
state = SamplerHist()
7782
self.state = state
78-
self.max_energy = max_energy
83+
self.Emax = Emax
84+
7985
self.target_accept = target_accept
8086
self.gamma = gamma
8187
self.t0 = t0
8288
self.k = k
83-
self.h_bar = 0
84-
self.u = np.log(self.step_size * 10)
89+
90+
self.Hbar = 0
91+
self.u = log(self.step_size * 10)
8592
self.m = 1
8693

8794
shared = make_shared_replacements(vars, model)
@@ -90,80 +97,97 @@ def __init__(self, vars=None, scaling=None, step_scale=0.25, is_cov=False, state
9097

9198
super(NUTS, self).__init__(vars, shared, **kwargs)
9299

93-
@staticmethod
94-
def competence(var):
95-
if var.dtype in continuous_types:
96-
return Competence.IDEAL
97-
return Competence.INCOMPATIBLE
100+
def astep(self, q0):
101+
# Hamiltonian(self.logp, self.dlogp, self.potential)
102+
H = self.leapfrog1_dE
103+
Emax = self.Emax
104+
e = self.step_size
98105

99-
def astep(self, initial_position):
100-
log_slice_var = np.log(nr.uniform())
101-
initial_momentum = self.potential.random()
102-
position = back_position = forward_position = initial_position
103-
back_momentum = forward_momentum = initial_momentum
104-
should_continue = True
105-
trials = 1
106-
depth = 0
107-
while should_continue:
108-
direction = nr.choice((-1, 1))
109-
step = np.array(direction * self.step_size)
110-
new_trials = 0
111-
metropolis_acceptance = 0
112-
steps = 0
113-
for _ in range(2 ** depth):
114-
if not should_continue:
115-
break
116-
if direction == 1:
117-
forward_position, forward_momentum, energy_change = self.leapfrog1_dE(
118-
forward_position, forward_momentum, step,
119-
initial_position, initial_momentum)
120-
else:
121-
back_position, back_momentum, energy_change = self.leapfrog1_dE(
122-
back_position, back_momentum, step, initial_position, initial_momentum)
123-
new_trials += int(log_slice_var + energy_change <= 0)
124-
if should_update_position(new_trials, trials):
125-
if direction == 1:
126-
position = forward_position
127-
else:
128-
position = back_position
129-
130-
should_continue = (self._energy_is_bounded(log_slice_var, energy_change) and
131-
no_u_turns(forward_position, forward_momentum,
132-
back_position, back_momentum))
133-
metropolis_acceptance += min(1., np.exp(-energy_change))
134-
steps += 1
135-
trials += new_trials
136-
depth += 1
137-
w = 1. / (self.m + self.t0)
138-
self.h_bar = (1 - w) * self.h_bar + w * (self.target_accept - metropolis_acceptance / steps)
139-
self.step_size = np.exp(self.u - (self.m ** 0.5 / self.gamma) * self.h_bar)
140-
self.m += 1
141-
return position
106+
p0 = self.potential.random()
107+
u = uniform()
108+
q = qn = qp = q0
109+
p = pn = pp = p0
142110

143-
def _energy_is_bounded(self, log_slice_var, energy_change):
144-
return log_slice_var + energy_change < self.max_energy
111+
n, s, j = 1, 1, 0
145112

113+
while s == 1:
114+
v = bern(.5) * 2 - 1
146115

147-
def no_u_turns(forward_position, forward_momentum, back_position, back_momentum):
148-
span = forward_position - back_position
149-
return span.dot(back_momentum) >= 0 and span.dot(forward_momentum) >= 0
116+
if v == -1:
117+
qn, pn, _, _, q1, n1, s1, a, na = buildtree(
118+
H, qn, pn, u, v, j, e, Emax, q0, p0)
119+
else:
120+
_, _, qp, pp, q1, n1, s1, a, na = buildtree(
121+
H, qp, pp, u, v, j, e, Emax, q0, p0)
150122

123+
if s1 == 1 and bern(min(1, n1 * 1. / n)):
124+
q = q1
151125

152-
def should_update_position(new_trials, trials):
153-
return bern(float(new_trials) / max(trials, 1.))
126+
n = n + n1
154127

128+
span = qp - qn
129+
s = s1 * (span.dot(pn) >= 0) * (span.dot(pp) >= 0)
130+
j = j + 1
131+
132+
p = -p
133+
134+
w = 1. / (self.m + self.t0)
135+
self.Hbar = (1 - w) * self.Hbar + w * \
136+
(self.target_accept - a * 1. / na)
137+
138+
self.step_size = exp(self.u - (self.m**.5 / self.gamma) * self.Hbar)
139+
self.m += 1
140+
141+
return q
142+
143+
@staticmethod
144+
def competence(var):
145+
if var.dtype in continuous_types:
146+
return Competence.IDEAL
147+
return Competence.INCOMPATIBLE
155148

156-
def leapfrog1_dE(logp, vars, shared, quad_potential, profile):
157-
"""Computes a theano function that computes one leapfrog step and the energy
158-
difference between the beginning and end of the trajectory.
159149

150+
def buildtree(H, q, p, u, v, j, e, Emax, q0, p0):
151+
if j == 0:
152+
leapfrog1_dE = H
153+
q1, p1, dE = leapfrog1_dE(q, p, array(v * e), q0, p0)
154+
155+
n1 = int(log(u) + dE <= 0)
156+
s1 = int(log(u) + dE < Emax)
157+
return q1, p1, q1, p1, q1, n1, s1, min(1, exp(-dE)), 1
158+
else:
159+
qn, pn, qp, pp, q1, n1, s1, a1, na1 = buildtree(
160+
H, q, p, u, v, j - 1, e, Emax, q0, p0)
161+
if s1 == 1:
162+
if v == -1:
163+
qn, pn, _, _, q11, n11, s11, a11, na11 = buildtree(
164+
H, qn, pn, u, v, j - 1, e, Emax, q0, p0)
165+
else:
166+
_, _, qp, pp, q11, n11, s11, a11, na11 = buildtree(
167+
H, qp, pp, u, v, j - 1, e, Emax, q0, p0)
168+
169+
if bern(n11 * 1. / (max(n1 + n11, 1))):
170+
q1 = q11
171+
172+
a1 = a1 + a11
173+
na1 = na1 + na11
174+
175+
span = qp - qn
176+
s1 = s11 * (span.dot(pn) >= 0) * (span.dot(pp) >= 0)
177+
n1 = n1 + n11
178+
return qn, pn, qp, pp, q1, n1, s1, a1, na1
179+
return
180+
181+
182+
def leapfrog1_dE(logp, vars, shared, pot, profile):
183+
"""Computes a theano function that computes one leapfrog step and the energy difference between the beginning and end of the trajectory.
160184
Parameters
161185
----------
162186
logp : TensorVariable
163187
vars : list of tensor variables
164188
shared : list of shared variables not to compute leapfrog over
165-
quad_potential : quadpotential
166-
profile : Boolean
189+
pot : quadpotential
190+
porifle : Boolean
167191
168192
Returns
169193
-------
@@ -175,7 +199,7 @@ def leapfrog1_dE(logp, vars, shared, quad_potential, profile):
175199
logp = CallableTensor(logp)
176200
dlogp = CallableTensor(dlogp)
177201

178-
hamiltonian = Hamiltonian(logp, dlogp, quad_potential)
202+
H = Hamiltonian(logp, dlogp, pot)
179203

180204
p = tt.dvector('p')
181205
p.tag.test_value = q.tag.test_value
@@ -188,9 +212,11 @@ def leapfrog1_dE(logp, vars, shared, quad_potential, profile):
188212
e = tt.dscalar('e')
189213
e.tag.test_value = 1
190214

191-
q1, p1 = leapfrog(hamiltonian, q, p, 1, e)
192-
energy_change = energy(hamiltonian, q1, p1) - energy(hamiltonian, q0, p0)
215+
q1, p1 = leapfrog(H, q, p, 1, e)
216+
E = energy(H, q1, p1)
217+
E0 = energy(H, q0, p0)
218+
dE = E - E0
193219

194-
f = theano.function([q, p, e, q0, p0], [q1, p1, energy_change], profile=profile)
220+
f = theano.function([q, p, e, q0, p0], [q1, p1, dE], profile=profile)
195221
f.trust_input = True
196222
return f

0 commit comments

Comments
 (0)