Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
24 changes: 16 additions & 8 deletions pymc_extras/inference/laplace_approx/find_map.py
Original file line number Diff line number Diff line change
Expand Up @@ -198,6 +198,7 @@ def find_MAP(
include_transformed: bool = True,
gradient_backend: GradientBackend = "pytensor",
compile_kwargs: dict | None = None,
compute_hessian: bool = False,
**optimizer_kwargs,
) -> (
dict[str, np.ndarray]
Expand Down Expand Up @@ -239,6 +240,10 @@ def find_MAP(
Whether to include transformed variable values in the returned dictionary. Defaults to True.
gradient_backend: str, default "pytensor"
Which backend to use to compute gradients. Must be one of "pytensor" or "jax".
compute_hessian: bool
If True, the inverse Hessian matrix at the optimum will be computed and included in the returned
InferenceData object. This is needed for the Laplace approximation, but can be computationally expensive for
high-dimensional problems. Defaults to False.
compile_kwargs: dict, optional
Additional options to pass to the ``pytensor.function`` function when compiling loss functions.
**optimizer_kwargs
Expand Down Expand Up @@ -316,14 +321,17 @@ def find_MAP(
**optimizer_kwargs,
)

H_inv = _compute_inverse_hessian(
optimizer_result=optimizer_result,
optimal_point=None,
f_fused=f_fused,
f_hessp=f_hessp,
use_hess=use_hess,
method=method,
)
if compute_hessian:
H_inv = _compute_inverse_hessian(
optimizer_result=optimizer_result,
optimal_point=None,
f_fused=f_fused,
f_hessp=f_hessp,
use_hess=use_hess,
method=method,
)
else:
H_inv = None

raveled_optimized = RaveledVars(optimizer_result.x, initial_params.point_map_info)
unobserved_vars = get_default_varnames(model.unobserved_value_vars, include_transformed=True)
Expand Down
7 changes: 5 additions & 2 deletions pymc_extras/inference/laplace_approx/idata.py
Original file line number Diff line number Diff line change
Expand Up @@ -136,7 +136,10 @@ def map_results_to_inference_data(


def add_fit_to_inference_data(
idata: az.InferenceData, mu: RaveledVars, H_inv: np.ndarray, model: pm.Model | None = None
idata: az.InferenceData,
mu: RaveledVars,
H_inv: np.ndarray | None,
model: pm.Model | None = None,
) -> az.InferenceData:
"""
Add the mean vector and covariance matrix of the Laplace approximation to an InferenceData object.
Expand All @@ -147,7 +150,7 @@ def add_fit_to_inference_data(
An InferenceData object containing the approximated posterior samples.
mu: RaveledVars
The MAP estimate of the model parameters.
H_inv: np.ndarray
H_inv: np.ndarray, optional
The inverse Hessian matrix of the log-posterior evaluated at the MAP estimate.
model: Model, optional
A PyMC model. If None, the model is taken from the current model context.
Expand Down
1 change: 1 addition & 0 deletions pymc_extras/inference/laplace_approx/laplace.py
Original file line number Diff line number Diff line change
Expand Up @@ -389,6 +389,7 @@ def fit_laplace(
include_transformed=include_transformed,
gradient_backend=gradient_backend,
compile_kwargs=compile_kwargs,
compute_hessian=True,
**optimizer_kwargs,
)

Expand Down
8 changes: 6 additions & 2 deletions tests/inference/laplace_approx/test_find_map.py
Original file line number Diff line number Diff line change
Expand Up @@ -133,8 +133,8 @@ def compute_z(x):
],
)
@pytest.mark.parametrize(
"backend, gradient_backend, include_transformed",
[("jax", "jax", True), ("jax", "pytensor", False)],
"backend, gradient_backend, include_transformed, compute_hessian",
[("jax", "jax", True, True), ("jax", "pytensor", False, False)],
ids=str,
)
def test_find_MAP(
Expand All @@ -145,6 +145,7 @@ def test_find_MAP(
backend,
gradient_backend: GradientBackend,
include_transformed,
compute_hessian,
rng,
):
pytest.importorskip("jax")
Expand All @@ -164,6 +165,7 @@ def test_find_MAP(
include_transformed=include_transformed,
compile_kwargs={"mode": backend.upper()},
maxiter=5,
compute_hessian=compute_hessian,
)

assert hasattr(idata, "posterior")
Expand All @@ -184,6 +186,8 @@ def test_find_MAP(
else:
assert not hasattr(idata, "unconstrained_posterior")

assert ("covariance_matrix" in idata.fit) == compute_hessian


def test_find_map_outside_model_context():
"""
Expand Down
Loading