Skip to content

Commit 082c982

Browse files
twieckispringcoil
authored andcommitted
Nuts speed up ~1.5x (#1522)
* PERF Make NUTS 1.5x faster by moving E0 calculation outside of buildtree inner loop. * STY Refactor creation of functions. * DOC Adapt doc-string. * STY Pep8. * STY Only return scalar * DOC Typo * MAINT Move theano.tensor to tt. Remove obsolete p0 and q0 args from build_tree.
1 parent 0d5f292 commit 082c982

File tree

1 file changed

+43
-33
lines changed

1 file changed

+43
-33
lines changed

pymc3/step_methods/nuts.py

Lines changed: 43 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -92,18 +92,40 @@ def __init__(self, vars=None, scaling=None, step_scale=0.25, is_cov=False, state
9292
self.m = 1
9393

9494
shared = make_shared_replacements(vars, model)
95-
self.leapfrog1_dE = leapfrog1_dE(
96-
model.logpt, vars, shared, self.potential, profile=profile)
95+
96+
def create_hamiltonian(vars, shared, model):
97+
dlogp = gradient(model.logpt, vars)
98+
(logp, dlogp), q = join_nonshared_inputs(
99+
[model.logpt, dlogp], vars, shared)
100+
logp = CallableTensor(logp)
101+
dlogp = CallableTensor(dlogp)
102+
103+
return Hamiltonian(logp, dlogp, self.potential), q
104+
105+
def create_energy_func(q):
106+
p = tt.dvector('p')
107+
p.tag.test_value = q.tag.test_value
108+
E0 = energy(self.H, q, p)
109+
E0_func = theano.function([q, p], E0)
110+
E0_func.trust_input = True
111+
112+
return E0_func
113+
114+
self.H, q = create_hamiltonian(vars, shared, model)
115+
self.compute_energy = create_energy_func(q)
116+
117+
self.leapfrog1_dE = leapfrog1_dE(self.H, q, profile=profile)
97118

98119
super(NUTS, self).__init__(vars, shared, **kwargs)
99120

100121
def astep(self, q0):
101-
# Hamiltonian(self.logp, self.dlogp, self.potential)
102-
H = self.leapfrog1_dE
122+
leapfrog = self.leapfrog1_dE
103123
Emax = self.Emax
104124
e = self.step_size
105125

106126
p0 = self.potential.random()
127+
E0 = self.compute_energy(q0, p0)
128+
107129
u = uniform()
108130
q = qn = qp = q0
109131
p = pn = pp = p0
@@ -115,10 +137,10 @@ def astep(self, q0):
115137

116138
if v == -1:
117139
qn, pn, _, _, q1, n1, s1, a, na = buildtree(
118-
H, qn, pn, u, v, j, e, Emax, q0, p0)
140+
leapfrog, qn, pn, u, v, j, e, Emax, E0)
119141
else:
120142
_, _, qp, pp, q1, n1, s1, a, na = buildtree(
121-
H, qp, pp, u, v, j, e, Emax, q0, p0)
143+
leapfrog, qp, pp, u, v, j, e, Emax, E0)
122144

123145
if s1 == 1 and bern(min(1, n1 * 1. / n)):
124146
q = q1
@@ -147,24 +169,23 @@ def competence(var):
147169
return Competence.INCOMPATIBLE
148170

149171

150-
def buildtree(H, q, p, u, v, j, e, Emax, q0, p0):
172+
def buildtree(leapfrog1_dE, q, p, u, v, j, e, Emax, E0):
151173
if j == 0:
152-
leapfrog1_dE = H
153-
q1, p1, dE = leapfrog1_dE(q, p, array(v * e), q0, p0)
174+
q1, p1, dE = leapfrog1_dE(q, p, array(v * e), E0)
154175

155176
n1 = int(log(u) + dE <= 0)
156177
s1 = int(log(u) + dE < Emax)
157178
return q1, p1, q1, p1, q1, n1, s1, min(1, exp(-dE)), 1
158179
else:
159180
qn, pn, qp, pp, q1, n1, s1, a1, na1 = buildtree(
160-
H, q, p, u, v, j - 1, e, Emax, q0, p0)
181+
leapfrog1_dE, q, p, u, v, j - 1, e, Emax, E0)
161182
if s1 == 1:
162183
if v == -1:
163184
qn, pn, _, _, q11, n11, s11, a11, na11 = buildtree(
164-
H, qn, pn, u, v, j - 1, e, Emax, q0, p0)
185+
leapfrog1_dE, qn, pn, u, v, j - 1, e, Emax, E0)
165186
else:
166187
_, _, qp, pp, q11, n11, s11, a11, na11 = buildtree(
167-
H, qp, pp, u, v, j - 1, e, Emax, q0, p0)
188+
leapfrog1_dE, qp, pp, u, v, j - 1, e, Emax, E0)
168189

169190
if bern(n11 * 1. / (max(n1 + n11, 1))):
170191
q1 = q11
@@ -179,44 +200,33 @@ def buildtree(H, q, p, u, v, j, e, Emax, q0, p0):
179200
return
180201

181202

182-
def leapfrog1_dE(logp, vars, shared, pot, profile):
203+
def leapfrog1_dE(H, q, profile):
183204
"""Computes a theano function that computes one leapfrog step and the energy difference between the beginning and end of the trajectory.
184205
Parameters
185206
----------
186-
logp : TensorVariable
187-
vars : list of tensor variables
188-
shared : list of shared variables not to compute leapfrog over
189-
pot : quadpotential
190-
porifle : Boolean
207+
H : Hamiltonian
208+
q : theano.tensor
209+
profile : Boolean
191210
192211
Returns
193212
-------
194213
theano function which returns
195-
q_new, p_new, delta_E
214+
q_new, p_new, dE
196215
"""
197-
dlogp = gradient(logp, vars)
198-
(logp, dlogp), q = join_nonshared_inputs([logp, dlogp], vars, shared)
199-
logp = CallableTensor(logp)
200-
dlogp = CallableTensor(dlogp)
201-
202-
H = Hamiltonian(logp, dlogp, pot)
203-
204216
p = tt.dvector('p')
205217
p.tag.test_value = q.tag.test_value
206218

207-
q0 = tt.dvector('q0')
208-
q0.tag.test_value = q.tag.test_value
209-
p0 = tt.dvector('p0')
210-
p0.tag.test_value = p.tag.test_value
211-
212219
e = tt.dscalar('e')
213220
e.tag.test_value = 1
214221

215222
q1, p1 = leapfrog(H, q, p, 1, e)
216223
E = energy(H, q1, p1)
217-
E0 = energy(H, q0, p0)
224+
225+
E0 = tt.dscalar('E0')
226+
E0.tag.test_value = 1
227+
218228
dE = E - E0
219229

220-
f = theano.function([q, p, e, q0, p0], [q1, p1, dE], profile=profile)
230+
f = theano.function([q, p, e, E0], [q1, p1, dE], profile=profile)
221231
f.trust_input = True
222232
return f

0 commit comments

Comments
 (0)