@@ -43,7 +43,7 @@ def fixed_hessian(point, model=None):
43
43
return rval
44
44
45
45
46
- def find_hessian (point , vars = None , model = None ):
46
+ def find_hessian (point , vars = None , model = None , negate_output = True ):
47
47
"""
48
48
Returns Hessian of logp at the point passed.
49
49
@@ -55,11 +55,11 @@ def find_hessian(point, vars=None, model=None):
55
55
Variables for which Hessian is to be calculated.
56
56
"""
57
57
model = modelcontext (model )
58
- H = model .compile_d2logp (vars )
58
+ H = model .compile_d2logp (vars , negate_output = negate_output )
59
59
return H (Point (point , filter_model_vars = True , model = model ))
60
60
61
61
62
- def find_hessian_diag (point , vars = None , model = None ):
62
+ def find_hessian_diag (point , vars = None , model = None , negate_output = True ):
63
63
"""
64
64
Returns Hessian of logp at the point passed.
65
65
@@ -71,14 +71,14 @@ def find_hessian_diag(point, vars=None, model=None):
71
71
Variables for which Hessian is to be calculated.
72
72
"""
73
73
model = modelcontext (model )
74
- H = model .compile_fn (hessian_diag (model .logp (), vars ))
74
+ H = model .compile_fn (hessian_diag (model .logp (), vars , negate_output = negate_output ))
75
75
return H (Point (point , model = model ))
76
76
77
77
78
78
def guess_scaling (point , vars = None , model = None , scaling_bound = 1e-8 ):
79
79
model = modelcontext (model )
80
80
try :
81
- h = find_hessian_diag (point , vars , model = model )
81
+ h = - find_hessian_diag (point , vars , model = model , negate_output = False )
82
82
except NotImplementedError :
83
83
h = fixed_hessian (point , model = model )
84
84
return adjust_scaling (h , scaling_bound )
0 commit comments