Skip to content

Commit 3fd5db4

Browse files
committed
simplifying model parameters to be simple defaults
1 parent 1a2503f commit 3fd5db4

File tree

10 files changed

+39
-60
lines changed

10 files changed

+39
-60
lines changed

pymc/model.py

Lines changed: 11 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@
1212
import numpy as np
1313
from functools import wraps
1414

15-
__all__ = ['Model', 'compilef', 'gradient', 'hessian', 'withmodel', 'Point']
15+
__all__ = ['Model', 'compilef', 'gradient', 'hessian', 'modelcontext', 'Point']
1616

1717

1818

@@ -38,36 +38,10 @@ def get_context(cls):
3838
except IndexError:
3939
raise TypeError("No context on context stack")
4040

41-
def withcontext(contexttype, argname):
42-
"""
43-
Returns a decorator for wrapping functions so they look for an argument in a specific argument slot.
44-
If not found, the decorated function searches the for a context and inserts it in that slot.
45-
46-
Parameters
47-
----------
48-
contexttype : type
49-
The type of context to search for
50-
argname : string
51-
The name of the argument slot where the context should go
52-
53-
Returns
54-
-------
55-
decorator function
56-
57-
"""
58-
def decorator(fn):
59-
n = list(fn.func_code.co_varnames).index(argname)
60-
61-
@wraps(fn)
62-
def nfn(*args, **kwargs):
63-
if not (len(args) > n and isinstance(args[n], contexttype)):
64-
context = contexttype.get_context()
65-
args = args[:n] + (context,) + args[n:]
66-
return fn(*args,**kwargs)
67-
68-
return nfn
69-
return decorator
70-
41+
def modelcontext(model):
42+
if model is None:
43+
return Model.get_context()
44+
return model
7145

7246
class Model(Context):
7347
"""
@@ -142,20 +116,22 @@ def TransformedVar(model, name, dist, trans):
142116
def AddPotential(model, potential):
143117
model.factors.append(potential)
144118

145-
withmodel = withcontext(Model, 'model')
146119

147-
@withmodel
148-
def Point(model, *args,**kwargs):
120+
def Point(*args,**kwargs):
149121
"""
150122
Build a point. Uses same args as dict() does.
151123
Filters out variables not in the model. All keys are strings.
152124
153125
Parameters
154126
----------
155-
model : Model (in context)
156127
*args, **kwargs
157128
arguments to build a dict
158129
"""
130+
if 'model' in kwargs :
131+
model = kwargs['model']
132+
del kwargs['model']
133+
else:
134+
model = Model.get_context()
159135

160136
d = dict(*args, **kwargs)
161137
varnames = map(str, model.vars)

pymc/sample.py

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -8,8 +8,7 @@
88

99
__all__ = ['sample', 'psample']
1010

11-
@withmodel
12-
def sample(model, draws, step, start = None, trace = None, track_progress = True):
11+
def sample(draws, step, start = None, trace = None, track_progress = True, model = None):
1312
"""
1413
Draw a number of samples using the given step method.
1514
Multiple step methods supported via compound step method
@@ -36,6 +35,7 @@ def sample(model, draws, step, start = None, trace = None, track_progress = True
3635
>>> an example
3736
3837
"""
38+
model = modelcontext(model)
3939
draws = int(draws)
4040
if start is None:
4141
start = trace[-1]
@@ -66,11 +66,12 @@ def argsample(args):
6666
""" defined at top level so it can be pickled"""
6767
return sample(*args)
6868

69-
@withmodel
70-
def psample(model, draws, step, start, mtrace = None, threads = None, track = None):
69+
def psample(draws, step, start, mtrace = None, threads = None, track = None, model = None):
7170
"""draw a number of samples using the given step method. Multiple step methods supported via compound step method
7271
returns the amount of time taken"""
7372

73+
model = modelcontext(model)
74+
7475
if not threads:
7576
threads = max(mp.cpu_count() - 2, 1)
7677

pymc/step_methods/gibbs.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -19,8 +19,8 @@ class ElemwiseCategoricalStep(ArrayStep):
1919
2020
"""
2121
#TODO: It would be great to come up with a way to make ElemwiseCategoricalStep more general (handling more complex elementwise variables)
22-
@withmodel
23-
def __init__(self, model, var, values):
22+
def __init__(self, var, values, model = None):
23+
model = modelcontext(model)
2424
self.sh = ones(var.dshape, var.dtype)
2525
self.values = values
2626
self.var = var

pymc/step_methods/hmc.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -20,12 +20,10 @@ def unif(step_size, elow = .85, ehigh = 1.15):
2020

2121

2222
class HamiltonianMC(ArrayStep):
23-
@withmodel
24-
def __init__(self, model, vars, C, step_scale = .25, path_length = 2., is_cov = False, step_rand = unif, state = None):
23+
def __init__(self, vars, C, step_scale = .25, path_length = 2., is_cov = False, step_rand = unif, state = None, model = None):
2524
"""
2625
Parameters
2726
----------
28-
model : Model
2927
vars : list of theano variables
3028
C : array_like, ndim = {1,2}
3129
Scaling for momentum distribution. 1d arrays interpreted matrix diagonal.
@@ -37,7 +35,11 @@ def __init__(self, model, vars, C, step_scale = .25, path_length = 2., is_cov =
3735
Treat C as a covariance matrix/vector if True, else treat it as a precision matrix/vector
3836
step_rand : function float -> float, default=unif
3937
A function which takes the step size and returns an new one used to randomize the step size at each iteration.
38+
state
39+
State object
40+
model : Model
4041
"""
42+
model = modelcontext(model)
4143
n = C.shape[0]
4244

4345
self.step_size = step_scale / n**(1/4.)

pymc/step_methods/metropolis.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -14,8 +14,8 @@
1414

1515
# TODO Implement tuning for Metropolis step
1616
class Metropolis(ArrayStep):
17-
@withmodel
18-
def __init__(self, model, vars, C, scaling=.25, is_cov = False):
17+
def __init__(self, vars, C, scaling=.25, is_cov = False, model = None):
18+
model = modelcontext(model)
1919

2020
self.potential = quad_potential(C, not is_cov)
2121
self.scaling = scaling

pymc/tuning/scaling.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -9,8 +9,7 @@
99

1010
__all__ = ['approx_hess', 'find_hessian', 'trace_cov']
1111

12-
@withmodel
13-
def approx_hess(model, point, vars=None):
12+
def approx_hess(point, vars=None, model = None):
1413
"""
1514
Returns an approximation of the Hessian at the current chain location.
1615
@@ -21,6 +20,7 @@ def approx_hess(model, point, vars=None):
2120
vars : list
2221
Variables for which Hessian is to be calculated.
2322
"""
23+
model = modelcontext(model)
2424
if vars is None :
2525
vars = model.cont_vars
2626

@@ -40,8 +40,7 @@ def grad_logp(point):
4040
'''
4141
return -nd.Jacobian(grad_logp)(bij.map(point))
4242

43-
@withmodel
44-
def find_hessian(model, point, vars = None):
43+
def find_hessian(point, vars = None, model = None):
4544
"""
4645
Returns Hessian of logp at the point passed.
4746
@@ -52,6 +51,7 @@ def find_hessian(model, point, vars = None):
5251
vars : list
5352
Variables for which Hessian is to be calculated.
5453
"""
54+
model = modelcontext(model)
5555
H = model.d2logpc(vars)
5656
return H(Point(model, point))
5757

pymc/tuning/starting.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -11,26 +11,26 @@
1111
__all__ = ['find_MAP', 'scipyminimize']
1212

1313

14-
@withmodel
15-
def find_MAP(model, start = None, vars=None, fmin = optimize.fmin_bfgs, return_raw = False, disp = False, *args, **kwargs):
14+
def find_MAP(start = None, vars=None, fmin = optimize.fmin_bfgs, return_raw = False, disp = False, model = None, model = None, *args, **kwargs):
1615
"""
1716
Sets state to the local maximum a posteriori point given a model.
1817
Current default of fmin_Hessian does not deal well with optimizing close
1918
to sharp edges, especially if they are the minimum.
2019
2120
Parameters
2221
----------
23-
model : Model (optional if in `with` context)
2422
start : dict of parameter values (Defaults to model.test_point)
2523
vars : list
2624
List of variables to set to MAP point (Defaults to all continuous).
2725
fmin : function
2826
Optimization algorithm (Defaults to `scipy.optimize.fmin_l_bfgs_b`).
2927
return_raw : Bool
3028
Whether to return extra value returned by fmin (Defaults to False)
29+
model : Model (optional if in `with` context)
3130
*args, **kwargs
3231
Extra args passed to fmin
3332
"""
33+
model = modelcontext(model)
3434
if start is None:
3535
start = model.test_point
3636

tests/models.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
def simple_init():
66
start, model, moments = simple_model()
77

8-
step = Metropolis(model, model.vars, np.diag([1.]))
8+
step = Metropolis(model.vars, np.diag([1.]), model = model)
99
return model, start, step, moments
1010

1111

tests/test_step.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -11,8 +11,8 @@ def check_stat(name, trace, var, stat, value, bound):
1111
def test_step_continuous():
1212
start, model, (mu, C) = mv_simple()
1313

14-
hmc = pm.HamiltonianMC(model, model.vars, C, is_cov = True)
15-
mh = pm.Metropolis(model, model.vars , C, is_cov = True, scaling = 2)
14+
hmc = pm.HamiltonianMC(model.vars, C, is_cov = True, model = model)
15+
mh = pm.Metropolis(model.vars , C, is_cov = True, scaling = 2, model = model)
1616
compound = pm.CompoundStep([hmc, mh])
1717

1818
steps = [mh, hmc, compound]
@@ -23,10 +23,10 @@ def test_step_continuous():
2323

2424
for st in steps:
2525
np.random.seed(1)
26-
h = sample(model, 8000, st, start)
26+
h = sample(8000, st, start, model = model)
2727
for (var, stat, val, bound) in check:
2828
np.random.seed(1)
29-
h = sample(model, 8000, st, start)
29+
h = sample(8000, st, start, model = model)
3030

3131
yield check_stat,repr(st), h, var, stat, val, bound
3232

tests/test_trace.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@
1313
def check_trace(model, trace, n, step, start):
1414
#try using a trace object a few times
1515
for i in range(2):
16-
trace = sample(model, n, step, start, trace)
16+
trace = sample(n, step, start, trace, model = model)
1717

1818
for (var, val) in start.iteritems():
1919

@@ -45,7 +45,7 @@ def test_multitrace():
4545
def check_multi_trace(model, trace, n, step, start):
4646

4747
for i in range(2):
48-
trace = psample(model, n, step, start, trace)
48+
trace = psample(n, step, start, trace, model = model)
4949

5050

5151
for (var, val) in start.iteritems():

0 commit comments

Comments
 (0)