12
12
import numpy as np
13
13
from functools import wraps
14
14
15
- __all__ = ['Model' , 'compilef' , 'gradient' , 'hessian' , 'withmodel ' , 'Point' ]
15
+ __all__ = ['Model' , 'compilef' , 'gradient' , 'hessian' , 'modelcontext ' , 'Point' ]
16
16
17
17
18
18
@@ -38,36 +38,10 @@ def get_context(cls):
38
38
except IndexError :
39
39
raise TypeError ("No context on context stack" )
40
40
41
- def withcontext (contexttype , argname ):
42
- """
43
- Returns a decorator for wrapping functions so they look for an argument in a specific argument slot.
44
- If not found, the decorated function searches the for a context and inserts it in that slot.
45
-
46
- Parameters
47
- ----------
48
- contexttype : type
49
- The type of context to search for
50
- argname : string
51
- The name of the argument slot where the context should go
52
-
53
- Returns
54
- -------
55
- decorator function
56
-
57
- """
58
- def decorator (fn ):
59
- n = list (fn .func_code .co_varnames ).index (argname )
60
-
61
- @wraps (fn )
62
- def nfn (* args , ** kwargs ):
63
- if not (len (args ) > n and isinstance (args [n ], contexttype )):
64
- context = contexttype .get_context ()
65
- args = args [:n ] + (context ,) + args [n :]
66
- return fn (* args ,** kwargs )
67
-
68
- return nfn
69
- return decorator
70
-
41
+ def modelcontext (model ):
42
+ if model is None :
43
+ return Model .get_context ()
44
+ return model
71
45
72
46
class Model (Context ):
73
47
"""
@@ -142,20 +116,22 @@ def TransformedVar(model, name, dist, trans):
142
116
def AddPotential (model , potential ):
143
117
model .factors .append (potential )
144
118
145
- withmodel = withcontext (Model , 'model' )
146
119
147
- @withmodel
148
- def Point (model , * args ,** kwargs ):
120
+ def Point (* args ,** kwargs ):
149
121
"""
150
122
Build a point. Uses same args as dict() does.
151
123
Filters out variables not in the model. All keys are strings.
152
124
153
125
Parameters
154
126
----------
155
- model : Model (in context)
156
127
*args, **kwargs
157
128
arguments to build a dict
158
129
"""
130
+ if 'model' in kwargs :
131
+ model = kwargs ['model' ]
132
+ del kwargs ['model' ]
133
+ else :
134
+ model = Model .get_context ()
159
135
160
136
d = dict (* args , ** kwargs )
161
137
varnames = map (str , model .vars )
0 commit comments