@@ -121,11 +121,13 @@ def get_conditional_gaussian_approximation(
121
121
122
122
# f = log(p(y | x, params))
123
123
f_x = model .logp ()
124
- jac = pytensor .gradient .grad (f_x , x )
125
- hess = pytensor .gradient .jacobian (jac .flatten (), x )
124
+ # jac = pytensor.gradient.grad(f_x, x)
125
+ # hess = pytensor.gradient.jacobian(jac.flatten(), x)
126
126
127
127
# log(p(x | y, params)) only including terms that depend on x for the minimization step (logdet(Q) ignored as it is a constant wrt x)
128
- log_x_posterior = f_x - 0.5 * (x - mu ).T @ Q @ (x - mu )
128
+ log_x_posterior = f_x - 0.5 * (x - mu ).T @ Q @ (
129
+ x - mu
130
+ ) # TODO could be f + x.logp - IS X.LOGP DUPLICATED IN F?
129
131
130
132
# Maximize log(p(x | y, params)) wrt x to find mode x0
131
133
x0 , _ = minimize (
@@ -138,11 +140,13 @@ def get_conditional_gaussian_approximation(
138
140
)
139
141
140
142
# require f'(x0) and f''(x0) for Laplace approx
141
- jac = pytensor .graph .replace .graph_replace (jac , {x : x0 })
143
+ # jac = pytensor.graph.replace.graph_replace(jac, {x: x0})
144
+ jac = pytensor .gradient .grad (f_x , x )
145
+ hess = pytensor .gradient .jacobian (jac .flatten (), x )
142
146
hess = pytensor .graph .replace .graph_replace (hess , {x : x0 })
143
147
144
148
# Full log(p(x | y, params)) using the Laplace approximation (up to a constant)
145
- _ , logdetQ = pt .nlinalg .slogdet (Q )
149
+ # _, logdetQ = pt.nlinalg.slogdet(Q)
146
150
# conditional_gaussian_approx = (
147
151
# -0.5 * x.T @ (-hess + Q) @ x + x.T @ (Q @ mu + jac - hess @ x0) + 0.5 * logdetQ
148
152
# )
@@ -152,7 +156,11 @@ def get_conditional_gaussian_approximation(
152
156
153
157
# Currently x is passed both as the query point for f(x, args) = logp(x | y, params) AND as an initial guess for x0. This may cause issues if the query point is
154
158
# far from the mode x0 or in a neighbourhood which results in poor convergence.
155
- return pytensor .function (args , [x0 , pm .MvNormal (mu = x0 , tau = tau )])
159
+ return (
160
+ x0 ,
161
+ pm .MvNormal (f"{ x .name } _laplace_approx" , mu = x0 , tau = tau ),
162
+ tau ,
163
+ ) # pytensor.function(args, [x0, pm.MvNormal(mu=x0, tau=tau)])
156
164
157
165
158
166
def laplace_draws_to_inferencedata (
0 commit comments