Skip to content

Commit 644058d

Browse files
author
Christopher Fonnesbeck
committed
Added tuning function for step functions
1 parent b15743d commit 644058d

File tree

5 files changed

+105
-51
lines changed

5 files changed

+105
-51
lines changed

pymc/sample.py

Lines changed: 65 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -8,41 +8,42 @@
88

99
__all__ = ['sample', 'psample']
1010

11-
def sample(draws, step, start = None, trace = None, track_progress = True, model = None):
11+
def sample(draws, step, start=None, trace=None, track_progress=True,
12+
tune_interval=100, model=None):
1213
"""
13-
Draw a number of samples using the given step method.
14-
Multiple step methods supported via compound step method
14+
Draw a number of samples using the given step method.
15+
Multiple step methods supported via compound step method
1516
returns the amount of time taken.
16-
17+
1718
Parameters
1819
----------
19-
20+
2021
model : Model (optional if in `with` context)
21-
draws : int
22+
draws : int
2223
The number of samples to draw
2324
step : function
2425
A step function
25-
start : dict
26+
start : dict
2627
Starting point in parameter space (Defaults to trace.point(-1))
2728
trace : NpTrace
2829
A trace of past values (defaults to None)
2930
track : list of vars
3031
The variables to follow
31-
32+
3233
Examples
3334
--------
34-
35+
3536
>>> an example
36-
37+
3738
"""
3839
model = modelcontext(model)
3940
draws = int(draws)
40-
if start is None:
41+
if start is None:
4142
start = trace[-1]
4243
point = Point(start, model = model)
4344

4445
if not hasattr(trace, 'record'):
45-
if trace is None:
46+
if trace is None:
4647
trace = model.vars
4748
trace = NpTrace(list(trace))
4849

@@ -57,16 +58,63 @@ def sample(draws, step, start = None, trace = None, track_progress = True, model
5758
for i in xrange(draws):
5859
point = step.step(point)
5960
trace = trace.record(point)
61+
if i and not (i % tune_interval) and step.tune:
62+
step = tune(step, tune_interval)
6063
if track_progress:
6164
progress.update(i)
6265

6366
return trace
6467

68+
69+
def tune(step, tune_interval):
70+
"""
71+
Tunes the scaling parameter for the proposal distribution
72+
according to the acceptance rate over the last tune_interval:
73+
74+
Rate Variance adaptation
75+
---- -------------------
76+
<0.001 x 0.1
77+
<0.05 x 0.5
78+
<0.2 x 0.9
79+
>0.5 x 1.1
80+
>0.75 x 2
81+
>0.95 x 10
82+
83+
"""
84+
85+
# Calculate acceptance rate
86+
acc_rate = step.accepted / float(tune_interval)
87+
88+
# Switch statement
89+
if acc_rate<0.001:
90+
# reduce by 90 percent
91+
step.scaling *= 0.1
92+
elif acc_rate<0.05:
93+
# reduce by 50 percent
94+
step.scaling *= 0.5
95+
elif acc_rate<0.2:
96+
# reduce by ten percent
97+
step.scaling *= 0.9
98+
elif acc_rate>0.95:
99+
# increase by factor of ten
100+
step.scaling *= 10.0
101+
elif acc_rate>0.75:
102+
# increase by double
103+
step.scaling *= 2.0
104+
elif acc_rate>0.5:
105+
# increase by ten percent
106+
step.scaling *= 1.1
107+
108+
# Re-initialize rejection count
109+
step.accepted = 0
110+
111+
return step
112+
65113
def argsample(args):
66114
""" defined at top level so it can be pickled"""
67115
return sample(*args)
68-
69-
def psample(draws, step, start, mtrace = None, track = None, model = None, threads = None):
116+
117+
def psample(draws, step, start, mtrace=None, track=None, model=None, threads=None):
70118
"""draw a number of samples using the given step method. Multiple step methods supported via compound step method
71119
returns the amount of time taken"""
72120

@@ -78,7 +126,7 @@ def psample(draws, step, start, mtrace = None, track = None, model = None, threa
78126
if isinstance(start, dict) :
79127
start = threads * [start]
80128

81-
if track is None:
129+
if track is None:
82130
track = model.vars
83131

84132
if not mtrace:
@@ -87,7 +135,7 @@ def psample(draws, step, start, mtrace = None, track = None, model = None, threa
87135
p = mp.Pool(threads)
88136

89137
argset = zip([draws]*threads, [step]*threads, start, mtrace.traces, [False]*threads, [model] *threads)
90-
138+
91139
traces = p.map(argsample, argset)
92-
140+
93141
return MultiTrace(traces)

pymc/step_methods/arraystep.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -7,10 +7,11 @@
77

88
# TODO Add docstrings to ArrayStep
99
class ArrayStep(object):
10-
def __init__(self, vars, fs, allvars = False):
10+
def __init__(self, vars, fs, allvars = False, tune=False):
1111
self.ordering = ArrayOrdering(vars)
1212
self.fs = fs
1313
self.allvars = allvars
14+
self.tune = tune
1415

1516
def step(self, point):
1617
bij = DictToArrayBijection(self.ordering, point)
@@ -28,10 +29,10 @@ def metrop_select(mr, q, q0):
2829
# Compare acceptance ratio to uniform random number
2930
if isfinite(mr) and log(uniform()) < mr:
3031
# Accept proposed value
31-
return q
32+
return q, True
3233
else:
3334
# Reject proposed value
34-
return q0
35+
return q0, False
3536

3637

3738
class SamplerHist(object):

pymc/step_methods/compound.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -2,12 +2,13 @@
22
Created on Mar 7, 2011
33
44
@author: johnsalvatier
5-
'''
6-
from ..quickclass import *
5+
'''
6+
from ..quickclass import *
77

88
@quickclass(object)
99
def CompoundStep(methods):
10-
methods = list(methods)
10+
methods = list(methods)
11+
tune = False
1112
def step(point):
1213
for method in methods:
1314
point = method.step(point)

pymc/step_methods/hmc.py

Lines changed: 25 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -2,16 +2,16 @@
22
Created on Mar 7, 2011
33
44
@author: johnsalvatier
5-
'''
5+
'''
66
from numpy import floor
77
from quadpotential import *
88
from arraystep import *
9-
from ..core import *
9+
from ..core import *
1010
import numpy as np
1111

1212
__all__ = ['HamiltonianMC']
1313

14-
#TODO:
14+
#TODO:
1515
#add constraint handling via page 37 of Radford's http://www.cs.utoronto.ca/~radford/ham-mcmc.abstract.html
1616

1717
def unif(step_size, elow = .85, ehigh = 1.15):
@@ -30,20 +30,20 @@ def __init__(self, vars, C, step_scale = .25, path_length = 2., is_cov = False,
3030
step_scale : float, default=.25
3131
Size of steps to take, automatically scaled down by 1/n**(1/4) (defaults to .25)
3232
path_length : float, default=2
33-
total length to travel
33+
total length to travel
3434
is_cov : bool, default=False
3535
Treat C as a covariance matrix/vector if True, else treat it as a precision matrix/vector
36-
step_rand : function float -> float, default=unif
37-
A function which takes the step size and returns an new one used to randomize the step size at each iteration.
38-
state
36+
step_rand : function float -> float, default=unif
37+
A function which takes the step size and returns an new one used to randomize the step size at each iteration.
38+
state
3939
State object
4040
model : Model
4141
"""
4242
model = modelcontext(model)
4343
n = C.shape[0]
44-
44+
4545
self.step_size = step_scale / n**(1/4.)
46-
46+
4747
self.potential = quad_potential(C, is_cov)
4848
self.path_length = path_length
4949
self.step_rand = step_rand
@@ -52,37 +52,37 @@ def __init__(self, vars, C, step_scale = .25, path_length = 2., is_cov = False,
5252
state = SamplerHist()
5353
self.state = state
5454

55-
ArrayStep.__init__(self,
55+
ArrayStep.__init__(self,
5656
vars, [model.logpc, model.dlogpc(vars)]
5757
)
5858

5959
def astep(self, q0, logp, dlogp):
60-
61-
60+
61+
6262
#randomize step size
63-
e = self.step_rand(self.step_size)
63+
e = self.step_rand(self.step_size)
6464
nstep = int(floor(self.path_length / self.step_size))
65-
66-
q = q0
65+
66+
q = q0
6767
p = p0 = self.potential.random()
68-
68+
6969
#use the leapfrog method
7070
p = p - (e/2) * -dlogp(q) # half momentum update
71-
72-
for i in range(nstep):
71+
72+
for i in range(nstep):
7373
#alternate full variable and momentum updates
7474
q = q + e * self.potential.velocity(p)
7575
if i != nstep - 1:
7676
p = p - e * -dlogp(q)
77-
77+
7878
p = p - (e/2) * -dlogp(q) # do a half step momentum update to finish off
79-
80-
p = -p
81-
79+
80+
p = -p
81+
8282
# - H(q*, p*) + H(q, p) = -H(q, p) + H(q0, p0) = -(- logp(q) + K(p)) + (-logp(q0) + K(p0))
8383
mr = (-logp(q0)) + self.potential.energy(p0) - ((-logp(q)) + self.potential.energy(p))
8484

8585
self.state.metrops.append(mr)
86-
87-
return metrop_select(mr, q, q0)
88-
86+
87+
return metrop_select(mr, q, q0)[0]
88+

pymc/step_methods/metropolis.py

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -51,13 +51,17 @@ def __init__(self, vars, S, proposal_dist=quad_potential, scaling=1.,
5151
# quadpotential does not require n
5252
self.proposal_dist = proposal_dist(S)
5353
self.scaling = scaling
54-
self.tune = tune
55-
super(Metropolis,self).__init__(vars, [model.logpc])
54+
self.accepted = 0
55+
super(Metropolis,self).__init__(vars, [model.logpc], tune=tune)
5656

5757
def astep(self, q0, logp):
5858

5959
delta = self.proposal_dist() * self.scaling
6060

6161
q = q0 + delta
6262

63-
return metrop_select(logp(q) - logp(q0), q, q0)
63+
q_new, accepted = metrop_select(logp(q) - logp(q0), q, q0)
64+
65+
self.accepted += int(accepted)
66+
67+
return q_new

0 commit comments

Comments
 (0)