Skip to content

Commit 4c2529d

Browse files
Reconcile the two laplace approximation functions
1 parent 2d21403 commit 4c2529d

File tree

5 files changed

+642
-528
lines changed

5 files changed

+642
-528
lines changed

pymc_experimental/inference/find_map.py

Lines changed: 22 additions & 255 deletions
Original file line numberDiff line numberDiff line change
@@ -1,38 +1,41 @@
11
import logging
22

33
from collections.abc import Callable
4-
from typing import Literal, cast
4+
from typing import cast
55

6-
import arviz as az
76
import jax
87
import numpy as np
98
import pymc as pm
109
import pytensor
1110
import pytensor.tensor as pt
12-
import xarray as xr
1311

14-
from arviz import dict_to_dataset
1512
from better_optimize import minimize
16-
from better_optimize.constants import minimize_method
17-
from pymc.backends.arviz import (
18-
coords_and_dims_for_inferencedata,
19-
find_constants,
20-
find_observations,
21-
)
13+
from better_optimize.constants import MINIMIZE_MODE_KWARGS, minimize_method
2214
from pymc.blocking import DictToArrayBijection, RaveledVars
2315
from pymc.initial_point import make_initial_point_fn
24-
from pymc.model.transform.conditioning import remove_value_transforms
2516
from pymc.model.transform.optimization import freeze_dims_and_data
2617
from pymc.pytensorf import join_nonshared_inputs
2718
from pymc.util import get_default_varnames
2819
from pytensor.compile import Function
2920
from pytensor.tensor import TensorVariable
30-
from scipy import stats
3121
from scipy.optimize import OptimizeResult
3222

3323
_log = logging.getLogger(__name__)
3424

3525

26+
def set_optimizer_function_defaults(method, use_grad, use_hess, use_hessp):
27+
method_info = MINIMIZE_MODE_KWARGS[method].copy()
28+
29+
use_grad = use_grad if use_grad is not None else method_info["uses_grad"]
30+
use_hess = use_hess if use_hess is not None else method_info["uses_hess"]
31+
use_hessp = use_hessp if use_hessp is not None else method_info["uses_hessp"]
32+
33+
if use_hess and use_hessp:
34+
use_hess = False
35+
36+
return use_grad, use_hess, use_hessp
37+
38+
3639
def get_nearest_psd(A: np.ndarray) -> np.ndarray:
3740
"""
3841
Compute the nearest positive semi-definite matrix to a given matrix.
@@ -60,7 +63,9 @@ def get_nearest_psd(A: np.ndarray) -> np.ndarray:
6063

6164
def _unconstrained_vector_to_constrained_rvs(model):
6265
constrained_rvs, unconstrained_vector = join_nonshared_inputs(
63-
model.initial_point(), inputs=model.value_vars, outputs=model.unobserved_value_vars
66+
model.initial_point(),
67+
inputs=model.value_vars,
68+
outputs=get_default_varnames(model.unobserved_value_vars, include_transformed=False),
6469
)
6570

6671
unconstrained_vector.name = "unconstrained_vector"
@@ -289,247 +294,6 @@ def scipy_optimize_funcs_from_loss(
289294
return f_loss, f_hess, f_hessp
290295

291296

292-
def fit_mvn_to_MAP(
293-
optimized_point: dict[str, np.ndarray],
294-
model: pm.Model,
295-
on_bad_cov: Literal["warn", "error", "ignore"] = "ignore",
296-
transform_samples: bool = True,
297-
use_jax_gradients: bool = False,
298-
zero_tol: float = 1e-8,
299-
diag_jitter: float | None = 1e-8,
300-
compile_kwargs: dict | None = None,
301-
) -> tuple[RaveledVars, np.ndarray]:
302-
"""
303-
Create a multivariate normal distribution using the inverse of the negative Hessian matrix of the log-posterior
304-
evaluated at the MAP estimate. This is the basis of the Laplace approximation.
305-
306-
Parameters
307-
----------
308-
optimized_point : dict[str, np.ndarray]
309-
Local maximum a posteriori (MAP) point returned from pymc.find_MAP or jax_tools.fit_map
310-
model : Model
311-
A PyMC model
312-
on_bad_cov : str, one of 'ignore', 'warn', or 'error', default: 'ignore'
313-
What to do when ``H_inv`` (inverse Hessian) is not positive semi-definite.
314-
If 'ignore' or 'warn', the closest positive-semi-definite matrix to ``H_inv`` (in L1 norm) will be returned.
315-
If 'error', an error will be raised.
316-
transform_samples : bool
317-
Whether to transform the samples back to the original parameter space. Default is True.
318-
zero_tol: float
319-
Value below which an element of the Hessian matrix is counted as 0.
320-
This is used to stabilize the computation of the inverse Hessian matrix. Default is 1e-8.
321-
diag_jitter: float | None
322-
A small value added to the diagonal of the inverse Hessian matrix to ensure it is positive semi-definite.
323-
If None, no jitter is added. Default is 1e-8.
324-
325-
Returns
326-
-------
327-
map_estimate: RaveledVars
328-
The MAP estimate of the model parameters, raveled into a 1D array.
329-
330-
inverse_hessian: np.ndarray
331-
The inverse Hessian matrix of the log-posterior evaluated at the MAP estimate.
332-
"""
333-
compile_kwargs = {} if compile_kwargs is None else compile_kwargs
334-
frozen_model = freeze_dims_and_data(model)
335-
336-
if not transform_samples:
337-
untransformed_model = remove_value_transforms(frozen_model)
338-
logp = untransformed_model.logp(jacobian=False)
339-
variables = untransformed_model.continuous_value_vars
340-
else:
341-
logp = frozen_model.logp(jacobian=True)
342-
variables = frozen_model.continuous_value_vars
343-
344-
variable_names = {var.name for var in variables}
345-
optimized_free_params = {k: v for k, v in optimized_point.items() if k in variable_names}
346-
mu = DictToArrayBijection.map(optimized_free_params)
347-
348-
_, f_hess, _ = scipy_optimize_funcs_from_loss(
349-
loss=-logp,
350-
inputs=variables,
351-
initial_point_dict=frozen_model.initial_point(),
352-
use_grad=True,
353-
use_hess=True,
354-
use_hessp=False,
355-
use_jax_gradients=use_jax_gradients,
356-
compile_kwargs=compile_kwargs,
357-
)
358-
359-
H = -f_hess(mu.data)
360-
H_inv = np.linalg.pinv(np.where(np.abs(H) < zero_tol, 0, -H))
361-
362-
def stabilize(x, jitter):
363-
return x + np.eye(x.shape[0]) * jitter
364-
365-
H_inv = H_inv if diag_jitter is None else stabilize(H_inv, diag_jitter)
366-
367-
try:
368-
np.linalg.cholesky(H_inv)
369-
except np.linalg.LinAlgError:
370-
if on_bad_cov == "error":
371-
raise np.linalg.LinAlgError(
372-
"Inverse Hessian not positive-semi definite at the provided point"
373-
)
374-
H_inv = get_nearest_psd(H_inv)
375-
if on_bad_cov == "warn":
376-
_log.warning(
377-
"Inverse Hessian is not positive semi-definite at the provided point, using the closest PSD "
378-
"matrix in L1-norm instead"
379-
)
380-
381-
return mu, H_inv
382-
383-
384-
def laplace(
385-
mu: RaveledVars,
386-
H_inv: np.ndarray,
387-
model: pm.Model,
388-
chains: int = 2,
389-
draws: int = 500,
390-
transform_samples: bool = True,
391-
progressbar: bool = True,
392-
**compile_kwargs,
393-
) -> az.InferenceData:
394-
"""
395-
396-
Parameters
397-
----------
398-
mu
399-
H_inv
400-
model : Model
401-
A PyMC model
402-
chains : int
403-
The number of sampling chains running in parallel. Default is 2.
404-
draws : int
405-
The number of samples to draw from the approximated posterior. Default is 500.
406-
transform_samples : bool
407-
Whether to transform the samples back to the original parameter space. Default is True.
408-
409-
Returns
410-
-------
411-
idata: az.InferenceData
412-
An InferenceData object containing the approximated posterior samples.
413-
"""
414-
posterior_dist = stats.multivariate_normal(mean=mu.data, cov=H_inv, allow_singular=True)
415-
posterior_draws = posterior_dist.rvs(size=(chains, draws))
416-
417-
if transform_samples:
418-
constrained_rvs, unconstrained_vector = _unconstrained_vector_to_constrained_rvs(model)
419-
batched_values = pt.tensor(
420-
"batched_values",
421-
shape=(chains, draws, *unconstrained_vector.type.shape),
422-
dtype=unconstrained_vector.type.dtype,
423-
)
424-
batched_rvs = pytensor.graph.vectorize_graph(
425-
constrained_rvs, replace={unconstrained_vector: batched_values}
426-
)
427-
428-
f_constrain = pm.compile_pymc(
429-
inputs=[batched_values], outputs=batched_rvs, **compile_kwargs
430-
)
431-
posterior_draws = f_constrain(posterior_draws)
432-
433-
else:
434-
info = mu.point_map_info
435-
flat_shapes = [np.prod(shape).astype(int) for _, shape, _ in info]
436-
slices = [
437-
slice(sum(flat_shapes[:i]), sum(flat_shapes[: i + 1])) for i in range(len(flat_shapes))
438-
]
439-
440-
posterior_draws = [
441-
posterior_draws[..., idx].reshape((chains, draws, *shape)).astype(dtype)
442-
for idx, (name, shape, dtype) in zip(slices, info)
443-
]
444-
445-
def make_rv_coords(name):
446-
coords = {"chain": range(chains), "draw": range(draws)}
447-
extra_dims = model.named_vars_to_dims.get(name)
448-
if extra_dims is None:
449-
return coords
450-
return coords | {dim: list(model.coords[dim]) for dim in extra_dims}
451-
452-
def make_rv_dims(name):
453-
dims = ["chain", "draw"]
454-
extra_dims = model.named_vars_to_dims.get(name)
455-
if extra_dims is None:
456-
return dims
457-
return dims + list(extra_dims)
458-
459-
idata = {
460-
name: xr.DataArray(
461-
data=draws.squeeze(),
462-
coords=make_rv_coords(name),
463-
dims=make_rv_dims(name),
464-
name=name,
465-
)
466-
for (name, _, _), draws in zip(mu.point_map_info, posterior_draws)
467-
}
468-
469-
coords, dims = coords_and_dims_for_inferencedata(model)
470-
idata = az.convert_to_inference_data(idata, coords=coords, dims=dims)
471-
472-
if model.deterministics:
473-
idata.posterior = pm.compute_deterministics(
474-
idata.posterior,
475-
model=model,
476-
merge_dataset=True,
477-
progressbar=progressbar,
478-
compile_kwargs=compile_kwargs,
479-
)
480-
481-
observed_data = dict_to_dataset(
482-
find_observations(model),
483-
library=pm,
484-
coords=coords,
485-
dims=dims,
486-
default_dims=[],
487-
)
488-
489-
constant_data = dict_to_dataset(
490-
find_constants(model),
491-
library=pm,
492-
coords=coords,
493-
dims=dims,
494-
default_dims=[],
495-
)
496-
497-
idata.add_groups(
498-
{"observed_data": observed_data, "constant_data": constant_data},
499-
coords=coords,
500-
dims=dims,
501-
)
502-
503-
return idata
504-
505-
506-
def fit_laplace(
507-
optimized_point: dict[str, np.ndarray],
508-
model: pm.Model,
509-
chains: int = 2,
510-
draws: int = 500,
511-
on_bad_cov: Literal["warn", "error", "ignore"] = "ignore",
512-
transform_samples: bool = True,
513-
zero_tol: float = 1e-8,
514-
diag_jitter: float | None = 1e-8,
515-
progressbar: bool = True,
516-
compile_kwargs: dict | None = None,
517-
) -> az.InferenceData:
518-
compile_kwargs = {} if compile_kwargs is None else compile_kwargs
519-
520-
mu, H_inv = fit_mvn_to_MAP(
521-
optimized_point=optimized_point,
522-
model=model,
523-
on_bad_cov=on_bad_cov,
524-
transform_samples=transform_samples,
525-
zero_tol=zero_tol,
526-
diag_jitter=diag_jitter,
527-
compile_kwargs=compile_kwargs,
528-
)
529-
530-
return laplace(mu, H_inv, model, chains, draws, transform_samples, progressbar)
531-
532-
533297
def find_MAP(
534298
method: minimize_method,
535299
*,
@@ -605,9 +369,12 @@ def find_MAP(
605369
initial_params = DictToArrayBijection.map(
606370
{var_name: value for var_name, value in start_dict.items() if var_name in vars_dict}
607371
)
372+
use_grad, use_hess, use_hessp = set_optimizer_function_defaults(
373+
method, use_grad, use_hess, use_hessp
374+
)
608375

609376
f_logp, f_hess, f_hessp = scipy_optimize_funcs_from_loss(
610-
loss=-frozen_model.logp(),
377+
loss=-frozen_model.logp(jacobian=False),
611378
inputs=frozen_model.continuous_value_vars + frozen_model.discrete_value_vars,
612379
initial_point_dict=start_dict,
613380
use_grad=use_grad,

pymc_experimental/inference/fit.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,6 @@ def fit(method, **kwargs):
3939
return fit_pathfinder(**kwargs)
4040

4141
if method == "laplace":
42-
from pymc_experimental.inference.laplace import laplace
42+
from pymc_experimental.inference.laplace import fit_laplace
4343

44-
return laplace(**kwargs)
44+
return fit_laplace(**kwargs)

0 commit comments

Comments
 (0)