Skip to content

Commit 23b4970

Browse files
moved notebook testing code into find_mode
1 parent 4b92331 commit 23b4970

File tree

1 file changed

+28
-14
lines changed

1 file changed

+28
-14
lines changed

pymc_extras/inference/laplace.py

Lines changed: 28 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -419,28 +419,42 @@ def sample_laplace_posterior(
419419

420420
def find_mode(
421421
inputs: list[TensorVariable],
422+
params: dict, # TODO Would be nice to automatically map this to inputs somehow: {k.name: ... for k in inputs}
423+
x0: TensorVariable
424+
| None = None, # TODO This isn't a TensorVariable, not sure what the general datatype for numeric arraylikes is
422425
x: TensorVariable | None = None,
423426
model: pm.Model | None = None,
424427
method: minimize_method = "BFGS",
428+
jac: bool = True,
429+
hess: bool = False,
425430
optimizer_kwargs: dict | None = None,
426-
): # Unsure of the return type, I'd assume it would be a list of pt tensors of some kind
431+
): # TODO Output type is list of same type as x0
427432
model = pm.modelcontext(model)
428433
if x is None:
429-
raise NotImplementedError("Currently assumes user specifies the Gaussian latent field x")
434+
raise UserWarning(
435+
"Latent Gaussian field x unspecified. Assuming it is the first entry in inputs. Specify which input to obtain the mode over using the input x."
436+
)
437+
x = inputs[0]
438+
439+
if x0 is None:
440+
# Should return a random numpy array of the same shape as x0 - not sure how to get the shape of x0
441+
raise NotImplementedError
430442

431443
# Minimise negative log likelihood
432-
# TODO: NLL already computs jac by default. Need to check how to access
433-
loss_x = -model.logp()
434-
# TODO Need to think about how to get inputs (i.e. a collection of all the input variables) to go along with the specific
435-
# variable x, i.e. f(x, *args). I assume I can't assume that the inputs arg will be ordered to have x first. May need to sort it somehow
436-
loss = pytensor.function(inputs, loss_x)
437-
438-
grad = pytensor.gradient.grad(loss, inputs)
439-
hess = pytensor.gradient.jacobian(grad, inputs)[0]
440-
441-
# Need to play around with scipy.optimize.minimize with pytensor a little so I can figure out if it's "x" or "inputs" that goes here
442-
res = minimize(loss, x, method, grad, hess, optimizer_kwargs)
443-
return res.x, res.hess
444+
nll = -model.logp()
445+
soln, _ = minimize(
446+
objective=nll, x=x, method=method, jac=jac, hess=hess, optimizer_kwargs=optimizer_kwargs
447+
)
448+
449+
get_mode = pytensor.function(inputs, soln)
450+
mode = get_mode(x0, **params)
451+
452+
# Calculate the value of the Hessian at the mode
453+
# TODO check if we can't pull this out of the soln graph when jac or hess=True
454+
hess_x = pytensor.gradient.hessian(nll, x)
455+
hess = pytensor.function(inputs, hess_x)
456+
457+
return mode, hess(mode, **params)
444458

445459

446460
def fit_laplace(

0 commit comments

Comments
 (0)