Skip to content

Commit 30b0c50

Browse files
ColCarrollspringcoil
authored andcommitted
Implement iterative NUTS algorithm (#1381)
1 parent 85b0843 commit 30b0c50

File tree

2 files changed

+84
-112
lines changed

2 files changed

+84
-112
lines changed

pymc3/step_methods/nuts.py

Lines changed: 77 additions & 103 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,11 @@
11
from .quadpotential import quad_potential
2-
from .arraystep import ArrayStepShared, ArrayStep, SamplerHist, Competence
2+
from .arraystep import ArrayStepShared, SamplerHist, Competence
33
from ..model import modelcontext, Point
44
from ..vartypes import continuous_types
5-
from numpy import exp, log, array
6-
from numpy.random import uniform
7-
from .hmc import leapfrog, Hamiltonian, bern, energy
5+
from .hmc import leapfrog, Hamiltonian, energy, bern
86
from ..tuning import guess_scaling
7+
import numpy as np
8+
import numpy.random as nr
99
import theano
1010
from ..theanof import (make_shared_replacements, join_nonshared_inputs, CallableTensor,
1111
gradient, inputvars)
@@ -26,7 +26,7 @@ class NUTS(ArrayStepShared):
2626
default_blocked = True
2727

2828
def __init__(self, vars=None, scaling=None, step_scale=0.25, is_cov=False, state=None,
29-
Emax=1000,
29+
max_energy=1000,
3030
target_accept=0.8,
3131
gamma=0.05,
3232
k=0.75,
@@ -42,10 +42,11 @@ def __init__(self, vars=None, scaling=None, step_scale=0.25, is_cov=False, state
4242
step_scale : float, default=.25
4343
Size of steps to take, automatically scaled down by 1/n**(1/4)
4444
is_cov : bool, default=False
45-
Treat C as a covariance matrix/vector if True, else treat it as a precision matrix/vector
45+
Treat C as a covariance matrix/vector if True, else treat it as a
46+
precision matrix/vector
4647
state
4748
state to start from
48-
Emax : float, default 1000
49+
max_energy : float, default 1000
4950
maximum energy
5051
target_accept : float (0,1) default .8
5152
target for avg accept probability between final branch and initial position
@@ -68,27 +69,19 @@ def __init__(self, vars=None, scaling=None, step_scale=0.25, is_cov=False, state
6869
scaling = model.test_point
6970

7071
if isinstance(scaling, dict):
71-
scaling = guess_scaling(
72-
Point(scaling, model=model), model=model, vars=vars)
73-
74-
n = scaling.shape[0]
75-
76-
self.step_size = step_scale / n**(1 / 4.)
77-
72+
scaling = guess_scaling(Point(scaling, model=model), model=model, vars=vars)
73+
self.step_size = step_scale / scaling.shape[0]**0.25
7874
self.potential = quad_potential(scaling, is_cov, as_cov=False)
79-
8075
if state is None:
8176
state = SamplerHist()
8277
self.state = state
83-
self.Emax = Emax
84-
78+
self.max_energy = max_energy
8579
self.target_accept = target_accept
8680
self.gamma = gamma
8781
self.t0 = t0
8882
self.k = k
89-
90-
self.Hbar = 0
91-
self.u = log(self.step_size * 10)
83+
self.h_bar = 0
84+
self.u = np.log(self.step_size * 10)
9285
self.m = 1
9386

9487
shared = make_shared_replacements(vars, model)
@@ -97,97 +90,80 @@ def __init__(self, vars=None, scaling=None, step_scale=0.25, is_cov=False, state
9790

9891
super(NUTS, self).__init__(vars, shared, **kwargs)
9992

100-
def astep(self, q0):
101-
# Hamiltonian(self.logp, self.dlogp, self.potential)
102-
H = self.leapfrog1_dE
103-
Emax = self.Emax
104-
e = self.step_size
105-
106-
p0 = self.potential.random()
107-
u = uniform()
108-
q = qn = qp = q0
109-
p = pn = pp = p0
110-
111-
n, s, j = 1, 1, 0
112-
113-
while s == 1:
114-
v = bern(.5) * 2 - 1
115-
116-
if v == -1:
117-
qn, pn, _, _, q1, n1, s1, a, na = buildtree(
118-
H, qn, pn, u, v, j, e, Emax, q0, p0)
119-
else:
120-
_, _, qp, pp, q1, n1, s1, a, na = buildtree(
121-
H, qp, pp, u, v, j, e, Emax, q0, p0)
122-
123-
if s1 == 1 and bern(min(1, n1 * 1. / n)):
124-
q = q1
93+
@staticmethod
94+
def competence(var):
95+
if var.dtype in continuous_types:
96+
return Competence.IDEAL
97+
return Competence.INCOMPATIBLE
12598

126-
n = n + n1
99+
def astep(self, initial_position):
100+
log_slice_var = np.log(nr.uniform())
101+
initial_momentum = self.potential.random()
102+
position = back_position = forward_position = initial_position
103+
back_momentum = forward_momentum = initial_momentum
104+
should_continue = True
105+
trials = 1
106+
depth = 0
107+
while should_continue:
108+
direction = nr.choice((-1, 1))
109+
step = np.array(direction * self.step_size)
110+
new_trials = 0
111+
metropolis_acceptance = 0
112+
steps = 0
113+
for _ in range(2 ** depth):
114+
if not should_continue:
115+
break
116+
if direction == 1:
117+
forward_position, forward_momentum, energy_change = self.leapfrog1_dE(
118+
forward_position, forward_momentum, step,
119+
initial_position, initial_momentum)
120+
else:
121+
back_position, back_momentum, energy_change = self.leapfrog1_dE(
122+
back_position, back_momentum, step, initial_position, initial_momentum)
123+
new_trials += int(log_slice_var + energy_change <= 0)
124+
if should_update_position(new_trials, trials):
125+
if direction == 1:
126+
position = forward_position
127+
else:
128+
position = back_position
129+
130+
should_continue = (self._energy_is_bounded(log_slice_var, energy_change) and
131+
no_u_turns(forward_position, forward_momentum,
132+
back_position, back_momentum))
133+
metropolis_acceptance += min(1., np.exp(-energy_change))
134+
steps += 1
135+
trials += new_trials
136+
depth += 1
137+
w = 1. / (self.m + self.t0)
138+
self.h_bar = (1 - w) * self.h_bar + w * (self.target_accept - metropolis_acceptance / steps)
139+
self.step_size = np.exp(self.u - (self.m ** 0.5 / self.gamma) * self.h_bar)
140+
self.m += 1
141+
return position
127142

128-
span = qp - qn
129-
s = s1 * (span.dot(pn) >= 0) * (span.dot(pp) >= 0)
130-
j = j + 1
143+
def _energy_is_bounded(self, log_slice_var, energy_change):
144+
return log_slice_var + energy_change < self.max_energy
131145

132-
p = -p
133146

134-
w = 1. / (self.m + self.t0)
135-
self.Hbar = (1 - w) * self.Hbar + w * \
136-
(self.target_accept - a * 1. / na)
147+
def no_u_turns(forward_position, forward_momentum, back_position, back_momentum):
148+
span = forward_position - back_position
149+
return span.dot(back_momentum) >= 0 and span.dot(forward_momentum) >= 0
137150

138-
self.step_size = exp(self.u - (self.m**.5 / self.gamma) * self.Hbar)
139-
self.m += 1
140151

141-
return q
152+
def should_update_position(new_trials, trials):
153+
return bern(float(new_trials) / max(trials, 1.))
142154

143-
@staticmethod
144-
def competence(var):
145-
if var.dtype in continuous_types:
146-
return Competence.IDEAL
147-
return Competence.INCOMPATIBLE
148155

156+
def leapfrog1_dE(logp, vars, shared, quad_potential, profile):
157+
"""Computes a theano function that computes one leapfrog step and the energy
158+
difference between the beginning and end of the trajectory.
149159
150-
def buildtree(H, q, p, u, v, j, e, Emax, q0, p0):
151-
if j == 0:
152-
leapfrog1_dE = H
153-
q1, p1, dE = leapfrog1_dE(q, p, array(v * e), q0, p0)
154-
155-
n1 = int(log(u) + dE <= 0)
156-
s1 = int(log(u) + dE < Emax)
157-
return q1, p1, q1, p1, q1, n1, s1, min(1, exp(-dE)), 1
158-
else:
159-
qn, pn, qp, pp, q1, n1, s1, a1, na1 = buildtree(
160-
H, q, p, u, v, j - 1, e, Emax, q0, p0)
161-
if s1 == 1:
162-
if v == -1:
163-
qn, pn, _, _, q11, n11, s11, a11, na11 = buildtree(
164-
H, qn, pn, u, v, j - 1, e, Emax, q0, p0)
165-
else:
166-
_, _, qp, pp, q11, n11, s11, a11, na11 = buildtree(
167-
H, qp, pp, u, v, j - 1, e, Emax, q0, p0)
168-
169-
if bern(n11 * 1. / (max(n1 + n11, 1))):
170-
q1 = q11
171-
172-
a1 = a1 + a11
173-
na1 = na1 + na11
174-
175-
span = qp - qn
176-
s1 = s11 * (span.dot(pn) >= 0) * (span.dot(pp) >= 0)
177-
n1 = n1 + n11
178-
return qn, pn, qp, pp, q1, n1, s1, a1, na1
179-
return
180-
181-
182-
def leapfrog1_dE(logp, vars, shared, pot, profile):
183-
"""Computes a theano function that computes one leapfrog step and the energy difference between the beginning and end of the trajectory.
184160
Parameters
185161
----------
186162
logp : TensorVariable
187163
vars : list of tensor variables
188164
shared : list of shared variables not to compute leapfrog over
189-
pot : quadpotential
190-
porifle : Boolean
165+
quad_potential : quadpotential
166+
profile : Boolean
191167
192168
Returns
193169
-------
@@ -199,7 +175,7 @@ def leapfrog1_dE(logp, vars, shared, pot, profile):
199175
logp = CallableTensor(logp)
200176
dlogp = CallableTensor(dlogp)
201177

202-
H = Hamiltonian(logp, dlogp, pot)
178+
hamiltonian = Hamiltonian(logp, dlogp, quad_potential)
203179

204180
p = tt.dvector('p')
205181
p.tag.test_value = q.tag.test_value
@@ -212,11 +188,9 @@ def leapfrog1_dE(logp, vars, shared, pot, profile):
212188
e = tt.dscalar('e')
213189
e.tag.test_value = 1
214190

215-
q1, p1 = leapfrog(H, q, p, 1, e)
216-
E = energy(H, q1, p1)
217-
E0 = energy(H, q0, p0)
218-
dE = E - E0
191+
q1, p1 = leapfrog(hamiltonian, q, p, 1, e)
192+
energy_change = energy(hamiltonian, q1, p1) - energy(hamiltonian, q0, p0)
219193

220-
f = theano.function([q, p, e, q0, p0], [q1, p1, dE], profile=profile)
194+
f = theano.function([q, p, e, q0, p0], [q1, p1, energy_change], profile=profile)
221195
f.trust_input = True
222196
return f

pymc3/tests/test_sampling.py

Lines changed: 7 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -19,8 +19,6 @@
1919
except:
2020
test_parallel = False
2121

22-
RSEED = 20090425
23-
2422

2523
def test_sample():
2624
model, start, step, _ = simple_init()
@@ -38,13 +36,13 @@ def test_iter_sample():
3836
assert i == len(trace) - 1, "Trace does not have correct length."
3937

4038

41-
def test_parallel_start():
42-
model, _, _, _ = simple_init()
43-
with model:
44-
tr = sample(5, njobs=2, start=[{'x': [10, 10]}, {
45-
'x': [-10, -10]}], random_seed=RSEED)
46-
assert tr.get_values('x', chains=0)[0][0] > 0
47-
assert tr.get_values('x', chains=1)[0][0] < 0
39+
class TestParallelStart(SeededTest):
40+
def test_parallel_start(self):
41+
model, _, _, _ = simple_init()
42+
with model:
43+
tr = sample(5, njobs=2, start=[{'x': [10, 10]}, {'x': [-10, -10]}])
44+
self.assertGreater(tr.get_values('x', chains=0)[0][0], 0)
45+
self.assertLess(tr.get_values('x', chains=1)[0][0], 0)
4846

4947

5048
def test_soft_update_all_present():

0 commit comments

Comments
 (0)