Skip to content

Commit 8ddd6ff

Browse files
committed
added deterministic traces
1 parent 9ef5018 commit 8ddd6ff

File tree

8 files changed

+86
-28
lines changed

8 files changed

+86
-28
lines changed

examples/stochastic_volatility.ipynb

Lines changed: 31 additions & 4 deletions
Large diffs are not rendered by default.

pymc/model.py

Lines changed: 25 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@
1212
import numpy as np
1313
from functools import wraps
1414

15-
__all__ = ['Model', 'compilef', 'gradient', 'hessian', 'withmodel']
15+
__all__ = ['Model', 'compilef', 'gradient', 'hessian', 'withmodel', 'Point']
1616

1717

1818

@@ -109,13 +109,33 @@ def Var(model, name, dist):
109109

110110
def TransformedVar(model, name, dist, trans):
111111
tvar = model.Var(trans.name + '_' + name, trans.apply(dist))
112-
return trans.backward(tvar), tvar
112+
113+
return named(trans.backward(tvar),name), tvar
113114

114115
def AddPotential(model, potential):
115116
model.factors.append(potential)
116117

117118
withmodel = withcontext(Model, 'model')
118119

120+
@withmodel
121+
def Point(model, *args,**kwargs):
122+
"""
123+
Build a point. Uses same args as dict() does.
124+
Filters out variables not in the model. All keys are strings.
125+
126+
Parameters
127+
----------
128+
model : Model (in context)
129+
*args, **kwargs
130+
arguments to build a dict
131+
"""
132+
133+
d = dict(*args, **kwargs)
134+
varnames = map(str, model.vars)
135+
return dict((str(k),np.array(v))
136+
for (k,v) in d.iteritems()
137+
if str(k) in varnames)
138+
119139

120140
def compilef(outs, mode = None):
121141
return PointFunc(
@@ -125,6 +145,9 @@ def compilef(outs, mode = None):
125145
mode = mode)
126146
)
127147

148+
def named(var, name):
149+
var.name = name
150+
return var
128151

129152
def as_iterargs(data):
130153
if isinstance(data, tuple):

pymc/point.py

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,5 @@
11
import numpy as np
22

3-
def Point(d) :
4-
d = dict(d)
5-
return dict((str(k),np.array(v)) for (k,v) in d.iteritems())
6-
73
class PointFunc(object):
84
def __init__(self, f):
95
self.f = f

pymc/sample.py

Lines changed: 20 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -4,11 +4,12 @@
44
from time import time
55
from core import *
66
import step_methods
7+
from progressbar import ProgressBar
78

89
__all__ = ['sample', 'psample']
910

1011
@withmodel
11-
def sample(model, draws, step, start = None, trace = None, vars = None):
12+
def sample(model, draws, step, start = None, trace = None, progress_bar = True):
1213
"""
1314
Draw a number of samples using the given step method.
1415
Multiple step methods supported via compound step method
@@ -25,33 +26,40 @@ def sample(model, draws, step, start = None, trace = None, vars = None):
2526
Starting point in parameter space (Defaults to trace.point(-1))
2627
trace : NpTrace
2728
A trace of past values (defaults to None)
28-
state :
29-
The current state of the sampler (defaults to None)
29+
track : list of vars
30+
The variables to follow
3031
3132
Examples
3233
--------
3334
3435
>>> an example
3536
3637
"""
38+
draws = int(draws)
3739
if start is None:
3840
start = trace[-1]
3941
point = Point(start)
4042

41-
if vars is None:
42-
vars = model.vars
43+
if not hasattr(trace, 'record'):
44+
if trace is None:
45+
trace = model.vars
46+
trace = NpTrace(list(trace))
4347

44-
if trace is None:
45-
trace = NpTrace(vars)
4648

4749
try:
4850
step = step_methods.CompoundStep(step)
4951
except TypeError:
5052
pass
5153

52-
for _ in xrange(int(draws)):
54+
55+
56+
progress = ProgressBar(draws)
57+
58+
for i in xrange(draws):
5359
point = step.step(point)
5460
trace = trace.record(point)
61+
if progress_bar:
62+
progress.animate(i)
5563

5664
return trace
5765

@@ -60,7 +68,7 @@ def argsample(args):
6068
return sample(*args)
6169

6270
@withmodel
63-
def psample(model, draws, step, start, mtrace = None, threads = None, vars = None):
71+
def psample(model, draws, step, start, mtrace = None, threads = None, track = None):
6472
"""draw a number of samples using the given step method. Multiple step methods supported via compound step method
6573
returns the amount of time taken"""
6674

@@ -70,11 +78,11 @@ def psample(model, draws, step, start, mtrace = None, threads = None, vars = Non
7078
if isinstance(start, dict) :
7179
start = threads * [start]
7280

73-
if vars is None:
74-
vars = model.vars
81+
if track is None:
82+
track = model.vars
7583

7684
if not mtrace:
77-
mtrace = MultiTrace(threads, vars)
85+
mtrace = MultiTrace(threads, track)
7886

7987
p = mp.Pool(threads)
8088

pymc/trace.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,7 @@ def __getitem__(self, key):
3636
return self.samples[str(key)].value
3737

3838
def point(self, index):
39-
return Point((k, v.value[index]) for (k,v) in self.samples.iteritems())
39+
return dict((k, v.value[index]) for (k,v) in self.samples.iteritems())
4040

4141
class ListArray(object):
4242
def __init__(self):

pymc/tuning/__init__.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,2 +1,2 @@
1-
from starting import find_MAP
2-
from scaling import approx_hess
1+
from starting import *
2+
from scaling import *

pymc/tuning/scaling.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77
import numpy as np
88
from ..core import *
99

10-
__all__ = ['approx_hess']
10+
__all__ = ['approx_hess', 'find_hessian', 'trace_cov']
1111

1212
@withmodel
1313
def approx_hess(model, start, vars=None):
@@ -24,7 +24,7 @@ def approx_hess(model, start, vars=None):
2424
if vars is None :
2525
vars = model.cont_vars
2626

27-
start = Point(start)
27+
start = Point(model, start)
2828

2929
bij = DictToArrayBijection(ArrayOrdering(vars), start)
3030
dlogp = bij.mapf(model.dlogpc(vars))
@@ -40,6 +40,10 @@ def grad_logp(point):
4040
'''
4141
return -nd.Jacobian(grad_logp)(bij.map(start))
4242

43+
@withmodel
44+
def find_hessian(model, point, vars = None):
45+
H = model.d2logpc(vars)
46+
return H(Point(model, point))
4347

4448
def trace_cov(trace, vars = None):
4549
"""

pymc/tuning/starting.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -37,7 +37,7 @@ def find_MAP(model, start = None, vars=None, fmin = optimize.fmin_bfgs, return_r
3737
if vars is None:
3838
vars = model.cont_vars
3939

40-
start = Point(start)
40+
start = Point(model, start)
4141
bij = DictToArrayBijection(ArrayOrdering(vars), start)
4242

4343
logp = bij.mapf(model.logpc)

0 commit comments

Comments
 (0)