diff --git a/pymc_extras/inference/laplace_approx/find_map.py b/pymc_extras/inference/laplace_approx/find_map.py index b717e0ba9..73ad741df 100644 --- a/pymc_extras/inference/laplace_approx/find_map.py +++ b/pymc_extras/inference/laplace_approx/find_map.py @@ -198,6 +198,7 @@ def find_MAP( include_transformed: bool = True, gradient_backend: GradientBackend = "pytensor", compile_kwargs: dict | None = None, + compute_hessian: bool = False, **optimizer_kwargs, ) -> ( dict[str, np.ndarray] @@ -239,6 +240,10 @@ def find_MAP( Whether to include transformed variable values in the returned dictionary. Defaults to True. gradient_backend: str, default "pytensor" Which backend to use to compute gradients. Must be one of "pytensor" or "jax". + compute_hessian: bool + If True, the inverse Hessian matrix at the optimum will be computed and included in the returned + InferenceData object. This is needed for the Laplace approximation, but can be computationally expensive for + high-dimensional problems. Defaults to False. compile_kwargs: dict, optional Additional options to pass to the ``pytensor.function`` function when compiling loss functions. **optimizer_kwargs @@ -316,14 +321,17 @@ def find_MAP( **optimizer_kwargs, ) - H_inv = _compute_inverse_hessian( - optimizer_result=optimizer_result, - optimal_point=None, - f_fused=f_fused, - f_hessp=f_hessp, - use_hess=use_hess, - method=method, - ) + if compute_hessian: + H_inv = _compute_inverse_hessian( + optimizer_result=optimizer_result, + optimal_point=None, + f_fused=f_fused, + f_hessp=f_hessp, + use_hess=use_hess, + method=method, + ) + else: + H_inv = None raveled_optimized = RaveledVars(optimizer_result.x, initial_params.point_map_info) unobserved_vars = get_default_varnames(model.unobserved_value_vars, include_transformed=True) diff --git a/pymc_extras/inference/laplace_approx/idata.py b/pymc_extras/inference/laplace_approx/idata.py index 0d81d64bb..610b28b00 100644 --- a/pymc_extras/inference/laplace_approx/idata.py +++ b/pymc_extras/inference/laplace_approx/idata.py @@ -136,7 +136,10 @@ def map_results_to_inference_data( def add_fit_to_inference_data( - idata: az.InferenceData, mu: RaveledVars, H_inv: np.ndarray, model: pm.Model | None = None + idata: az.InferenceData, + mu: RaveledVars, + H_inv: np.ndarray | None, + model: pm.Model | None = None, ) -> az.InferenceData: """ Add the mean vector and covariance matrix of the Laplace approximation to an InferenceData object. @@ -147,7 +150,7 @@ def add_fit_to_inference_data( An InferenceData object containing the approximated posterior samples. mu: RaveledVars The MAP estimate of the model parameters. - H_inv: np.ndarray + H_inv: np.ndarray, optional The inverse Hessian matrix of the log-posterior evaluated at the MAP estimate. model: Model, optional A PyMC model. If None, the model is taken from the current model context. diff --git a/pymc_extras/inference/laplace_approx/laplace.py b/pymc_extras/inference/laplace_approx/laplace.py index 85e4cddaa..ab5203587 100644 --- a/pymc_extras/inference/laplace_approx/laplace.py +++ b/pymc_extras/inference/laplace_approx/laplace.py @@ -389,6 +389,7 @@ def fit_laplace( include_transformed=include_transformed, gradient_backend=gradient_backend, compile_kwargs=compile_kwargs, + compute_hessian=True, **optimizer_kwargs, ) diff --git a/tests/inference/laplace_approx/test_find_map.py b/tests/inference/laplace_approx/test_find_map.py index f1406ca62..3529f972d 100644 --- a/tests/inference/laplace_approx/test_find_map.py +++ b/tests/inference/laplace_approx/test_find_map.py @@ -133,8 +133,8 @@ def compute_z(x): ], ) @pytest.mark.parametrize( - "backend, gradient_backend, include_transformed", - [("jax", "jax", True), ("jax", "pytensor", False)], + "backend, gradient_backend, include_transformed, compute_hessian", + [("jax", "jax", True, True), ("jax", "pytensor", False, False)], ids=str, ) def test_find_MAP( @@ -145,6 +145,7 @@ def test_find_MAP( backend, gradient_backend: GradientBackend, include_transformed, + compute_hessian, rng, ): pytest.importorskip("jax") @@ -164,6 +165,7 @@ def test_find_MAP( include_transformed=include_transformed, compile_kwargs={"mode": backend.upper()}, maxiter=5, + compute_hessian=compute_hessian, ) assert hasattr(idata, "posterior") @@ -184,6 +186,8 @@ def test_find_MAP( else: assert not hasattr(idata, "unconstrained_posterior") + assert ("covariance_matrix" in idata.fit) == compute_hessian + def test_find_map_outside_model_context(): """