|
39 | 39 | from pymc.model.transform.conditioning import remove_value_transforms |
40 | 40 | from pymc.model.transform.optimization import freeze_dims_and_data |
41 | 41 | from pymc.util import get_default_varnames |
| 42 | +from pytensor.tensor import TensorVariable |
| 43 | +from pytensor.tensor.optimize import minimize |
42 | 44 | from scipy import stats |
43 | 45 |
|
44 | 46 | from pymc_extras.inference.find_map import ( |
@@ -415,6 +417,31 @@ def sample_laplace_posterior( |
415 | 417 | return idata |
416 | 418 |
|
417 | 419 |
|
| 420 | +def find_mode( |
| 421 | + inputs: list[TensorVariable], |
| 422 | + x: TensorVariable | None = None, |
| 423 | + model: pm.Model | None = None, |
| 424 | + method: minimize_method = "BFGS", |
| 425 | + optimizer_kwargs: dict | None = None, |
| 426 | +): # Unsure of the return type, I'd assume it would be a list of pt tensors of some kind |
| 427 | + model = pm.modelcontext(model) |
| 428 | + if x is None: |
| 429 | + raise NotImplementedError("Currently assumes user specifies the Gaussian latent field x") |
| 430 | + |
| 431 | + # Minimise negative log likelihood |
| 432 | + loss_x = -model.logp() |
| 433 | + # TODO Need to think about how to get inputs (i.e. a collection of all the input variables) to go along with the specific |
| 434 | + # variable x, i.e. f(x, *args). I assume I can't assume that the inputs arg will be ordered to have x first. May need to sort it somehow |
| 435 | + loss = pytensor.function(inputs, loss_x) |
| 436 | + |
| 437 | + grad = pytensor.gradient.grad(loss, inputs) |
| 438 | + hess = pytensor.gradient.jacobian(grad, inputs)[0] |
| 439 | + |
| 440 | + # Need to play around with scipy.optimize.minimize with pytensor a little so I can figure out if it's "x" or "inputs" that goes here |
| 441 | + res = minimize(loss, x, method, grad, hess, optimizer_kwargs) |
| 442 | + return res.x, res.hess |
| 443 | + |
| 444 | + |
418 | 445 | def fit_laplace( |
419 | 446 | optimize_method: minimize_method | Literal["basinhopping"] = "BFGS", |
420 | 447 | *, |
|
0 commit comments