Skip to content

Commit 30090ed

Browse files
set up skeleton for find_mode
1 parent 009b5ac commit 30090ed

File tree

1 file changed

+27
-0
lines changed

1 file changed

+27
-0
lines changed

pymc_extras/inference/laplace.py

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,8 @@
3939
from pymc.model.transform.conditioning import remove_value_transforms
4040
from pymc.model.transform.optimization import freeze_dims_and_data
4141
from pymc.util import get_default_varnames
42+
from pytensor.tensor import TensorVariable
43+
from pytensor.tensor.optimize import minimize
4244
from scipy import stats
4345

4446
from pymc_extras.inference.find_map import (
@@ -415,6 +417,31 @@ def sample_laplace_posterior(
415417
return idata
416418

417419

420+
def find_mode(
421+
inputs: list[TensorVariable],
422+
x: TensorVariable | None = None,
423+
model: pm.Model | None = None,
424+
method: minimize_method = "BFGS",
425+
optimizer_kwargs: dict | None = None,
426+
): # Unsure of the return type, I'd assume it would be a list of pt tensors of some kind
427+
model = pm.modelcontext(model)
428+
if x is None:
429+
raise NotImplementedError("Currently assumes user specifies the Gaussian latent field x")
430+
431+
# Minimise negative log likelihood
432+
loss_x = -model.logp()
433+
# TODO Need to think about how to get inputs (i.e. a collection of all the input variables) to go along with the specific
434+
# variable x, i.e. f(x, *args). I assume I can't assume that the inputs arg will be ordered to have x first. May need to sort it somehow
435+
loss = pytensor.function(inputs, loss_x)
436+
437+
grad = pytensor.gradient.grad(loss, inputs)
438+
hess = pytensor.gradient.jacobian(grad, inputs)[0]
439+
440+
# Need to play around with scipy.optimize.minimize with pytensor a little so I can figure out if it's "x" or "inputs" that goes here
441+
res = minimize(loss, x, method, grad, hess, optimizer_kwargs)
442+
return res.x, res.hess
443+
444+
418445
def fit_laplace(
419446
optimize_method: minimize_method | Literal["basinhopping"] = "BFGS",
420447
*,

0 commit comments

Comments
 (0)