16
16
17
17
from numpy import exp , log , sqrt
18
18
19
- from pymc .aesaraf import hessian_diag , inputvars
19
+ from pymc .aesaraf import hessian_diag
20
20
from pymc .blocking import DictToArrayBijection
21
21
from pymc .model import Point , modelcontext
22
22
from pymc .util import get_var_name
23
23
24
24
__all__ = ["find_hessian" , "trace_cov" , "guess_scaling" ]
25
25
26
26
27
- def fixed_hessian (point , vars = None , model = None ):
27
+ def fixed_hessian (point , model = None ):
28
28
"""
29
29
Returns a fixed Hessian for any chain location.
30
30
@@ -37,10 +37,6 @@ def fixed_hessian(point, vars=None, model=None):
37
37
"""
38
38
39
39
model = modelcontext (model )
40
- if vars is None :
41
- vars = model .cont_vars
42
- vars = inputvars (vars )
43
-
44
40
point = Point (point , model = model )
45
41
46
42
rval = np .ones (DictToArrayBijection .map (point ).size ) / 10
@@ -84,7 +80,7 @@ def guess_scaling(point, vars=None, model=None, scaling_bound=1e-8):
84
80
try :
85
81
h = find_hessian_diag (point , vars , model = model )
86
82
except NotImplementedError :
87
- h = fixed_hessian (point , vars , model = model )
83
+ h = fixed_hessian (point , model = model )
88
84
return adjust_scaling (h , scaling_bound )
89
85
90
86
0 commit comments