Skip to content

Commit 724289d

Browse files
jessegrabowskiandreacate
authored andcommitted
Allow method="basinhopping" in find_MAP and fit_laplace (pymc-devs#467)
1 parent a3df8c7 commit 724289d

File tree

3 files changed

+668
-16
lines changed

3 files changed

+668
-16
lines changed

pymc_extras/inference/find_map.py

Lines changed: 53 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99
import pytensor
1010
import pytensor.tensor as pt
1111

12-
from better_optimize import minimize
12+
from better_optimize import basinhopping, minimize
1313
from better_optimize.constants import MINIMIZE_MODE_KWARGS, minimize_method
1414
from pymc.blocking import DictToArrayBijection, RaveledVars
1515
from pymc.initial_point import make_initial_point_fn
@@ -335,7 +335,7 @@ def scipy_optimize_funcs_from_loss(
335335

336336

337337
def find_MAP(
338-
method: minimize_method,
338+
method: minimize_method | Literal["basinhopping"],
339339
*,
340340
model: pm.Model | None = None,
341341
use_grad: bool | None = None,
@@ -352,14 +352,17 @@ def find_MAP(
352352
**optimizer_kwargs,
353353
) -> dict[str, np.ndarray] | tuple[dict[str, np.ndarray], OptimizeResult]:
354354
"""
355-
Fit a PyMC model via maximum a posteriori (MAP) estimation using JAX and scipy.minimize.
355+
Fit a PyMC model via maximum a posteriori (MAP) estimation using JAX and scipy.optimize.
356356
357357
Parameters
358358
----------
359359
model : pm.Model
360360
The PyMC model to be fit. If None, the current model context is used.
361361
method : str
362-
The optimization method to use. See scipy.optimize.minimize documentation for details.
362+
The optimization method to use. Valid choices are: Nelder-Mead, Powell, CG, BFGS, L-BFGS-B, TNC, SLSQP,
363+
trust-constr, dogleg, trust-ncg, trust-exact, trust-krylov, and basinhopping.
364+
365+
See scipy.optimize.minimize documentation for details.
363366
use_grad : bool | None, optional
364367
Whether to use gradients in the optimization. Defaults to None, which determines this automatically based on
365368
the ``method``.
@@ -387,7 +390,9 @@ def find_MAP(
387390
compile_kwargs: dict, optional
388391
Additional options to pass to the ``pytensor.function`` function when compiling loss functions.
389392
**optimizer_kwargs
390-
Additional keyword arguments to pass to the ``scipy.optimize.minimize`` function.
393+
Additional keyword arguments to pass to the ``scipy.optimize`` function being used. Unless
394+
``method = "basinhopping"``, ``scipy.optimize.minimize`` will be used. For ``basinhopping``,
395+
``scipy.optimize.basinhopping`` will be used. See the documentation of these functions for details.
391396
392397
Returns
393398
-------
@@ -413,6 +418,18 @@ def find_MAP(
413418
initial_params = DictToArrayBijection.map(
414419
{var_name: value for var_name, value in start_dict.items() if var_name in vars_dict}
415420
)
421+
422+
do_basinhopping = method == "basinhopping"
423+
minimizer_kwargs = optimizer_kwargs.pop("minimizer_kwargs", {})
424+
425+
if do_basinhopping:
426+
# For a nice API, we let the user set method="basinhopping", but if we're doing basinhopping we still need
427+
# another method for the inner optimizer. This will be set in the minimizer_kwargs, but also needs a default
428+
# if one isn't provided.
429+
430+
method = minimizer_kwargs.pop("method", "L-BFGS-B")
431+
minimizer_kwargs["method"] = method
432+
416433
use_grad, use_hess, use_hessp = set_optimizer_function_defaults(
417434
method, use_grad, use_hess, use_hessp
418435
)
@@ -431,17 +448,37 @@ def find_MAP(
431448
args = optimizer_kwargs.pop("args", None)
432449

433450
# better_optimize.minimize will check if f_logp is a fused loss+grad Op, and automatically assign the jac argument
434-
# if so. That is why it is not set here, regardless of user settings.
435-
optimizer_result = minimize(
436-
f=f_logp,
437-
x0=cast(np.ndarray[float], initial_params.data),
438-
args=args,
439-
hess=f_hess,
440-
hessp=f_hessp,
441-
progressbar=progressbar,
442-
method=method,
443-
**optimizer_kwargs,
444-
)
451+
# if so. That is why the jac argument is not passed here in either branch.
452+
453+
if do_basinhopping:
454+
if "args" not in minimizer_kwargs:
455+
minimizer_kwargs["args"] = args
456+
if "hess" not in minimizer_kwargs:
457+
minimizer_kwargs["hess"] = f_hess
458+
if "hessp" not in minimizer_kwargs:
459+
minimizer_kwargs["hessp"] = f_hessp
460+
if "method" not in minimizer_kwargs:
461+
minimizer_kwargs["method"] = method
462+
463+
optimizer_result = basinhopping(
464+
func=f_logp,
465+
x0=cast(np.ndarray[float], initial_params.data),
466+
progressbar=progressbar,
467+
minimizer_kwargs=minimizer_kwargs,
468+
**optimizer_kwargs,
469+
)
470+
471+
else:
472+
optimizer_result = minimize(
473+
f=f_logp,
474+
x0=cast(np.ndarray[float], initial_params.data),
475+
args=args,
476+
hess=f_hess,
477+
hessp=f_hessp,
478+
progressbar=progressbar,
479+
method=method,
480+
**optimizer_kwargs,
481+
)
445482

446483
raveled_optimized = RaveledVars(optimizer_result.x, initial_params.point_map_info)
447484
unobserved_vars = get_default_varnames(model.unobserved_value_vars, include_transformed)

0 commit comments

Comments
 (0)