From a845475e8cab861d72d4999ab89ee9cca6699111 Mon Sep 17 00:00:00 2001 From: jessegrabowski Date: Tue, 16 Sep 2025 09:26:20 -0500 Subject: [PATCH 1/3] Allow skipping covariance computation in `find_MAP` --- .../inference/laplace_approx/find_map.py | 24 ++++++++++++------- pymc_extras/inference/laplace_approx/idata.py | 7 ++++-- .../inference/laplace_approx/laplace.py | 1 + .../inference/laplace_approx/test_find_map.py | 8 +++++-- 4 files changed, 28 insertions(+), 12 deletions(-) diff --git a/pymc_extras/inference/laplace_approx/find_map.py b/pymc_extras/inference/laplace_approx/find_map.py index b717e0ba9..fa4bab888 100644 --- a/pymc_extras/inference/laplace_approx/find_map.py +++ b/pymc_extras/inference/laplace_approx/find_map.py @@ -198,6 +198,7 @@ def find_MAP( include_transformed: bool = True, gradient_backend: GradientBackend = "pytensor", compile_kwargs: dict | None = None, + compute_covariance: bool = True, **optimizer_kwargs, ) -> ( dict[str, np.ndarray] @@ -239,6 +240,10 @@ def find_MAP( Whether to include transformed variable values in the returned dictionary. Defaults to True. gradient_backend: str, default "pytensor" Which backend to use to compute gradients. Must be one of "pytensor" or "jax". + compute_covariance: bool + If True, the inverse Hessian matrix at the optimum will be computed and included in the returned + InferenceData object. This is needed for the Laplace approximation, but can be computationally expensive for + high-dimensional problems. Defaults to True. compile_kwargs: dict, optional Additional options to pass to the ``pytensor.function`` function when compiling loss functions. **optimizer_kwargs @@ -316,14 +321,17 @@ def find_MAP( **optimizer_kwargs, ) - H_inv = _compute_inverse_hessian( - optimizer_result=optimizer_result, - optimal_point=None, - f_fused=f_fused, - f_hessp=f_hessp, - use_hess=use_hess, - method=method, - ) + if compute_covariance: + H_inv = _compute_inverse_hessian( + optimizer_result=optimizer_result, + optimal_point=None, + f_fused=f_fused, + f_hessp=f_hessp, + use_hess=use_hess, + method=method, + ) + else: + H_inv = None raveled_optimized = RaveledVars(optimizer_result.x, initial_params.point_map_info) unobserved_vars = get_default_varnames(model.unobserved_value_vars, include_transformed=True) diff --git a/pymc_extras/inference/laplace_approx/idata.py b/pymc_extras/inference/laplace_approx/idata.py index 0d81d64bb..610b28b00 100644 --- a/pymc_extras/inference/laplace_approx/idata.py +++ b/pymc_extras/inference/laplace_approx/idata.py @@ -136,7 +136,10 @@ def map_results_to_inference_data( def add_fit_to_inference_data( - idata: az.InferenceData, mu: RaveledVars, H_inv: np.ndarray, model: pm.Model | None = None + idata: az.InferenceData, + mu: RaveledVars, + H_inv: np.ndarray | None, + model: pm.Model | None = None, ) -> az.InferenceData: """ Add the mean vector and covariance matrix of the Laplace approximation to an InferenceData object. @@ -147,7 +150,7 @@ def add_fit_to_inference_data( An InferenceData object containing the approximated posterior samples. mu: RaveledVars The MAP estimate of the model parameters. - H_inv: np.ndarray + H_inv: np.ndarray, optional The inverse Hessian matrix of the log-posterior evaluated at the MAP estimate. model: Model, optional A PyMC model. If None, the model is taken from the current model context. diff --git a/pymc_extras/inference/laplace_approx/laplace.py b/pymc_extras/inference/laplace_approx/laplace.py index 85e4cddaa..6c2fd84c8 100644 --- a/pymc_extras/inference/laplace_approx/laplace.py +++ b/pymc_extras/inference/laplace_approx/laplace.py @@ -389,6 +389,7 @@ def fit_laplace( include_transformed=include_transformed, gradient_backend=gradient_backend, compile_kwargs=compile_kwargs, + compute_covariance=True, **optimizer_kwargs, ) diff --git a/tests/inference/laplace_approx/test_find_map.py b/tests/inference/laplace_approx/test_find_map.py index f1406ca62..08f1ff501 100644 --- a/tests/inference/laplace_approx/test_find_map.py +++ b/tests/inference/laplace_approx/test_find_map.py @@ -133,8 +133,8 @@ def compute_z(x): ], ) @pytest.mark.parametrize( - "backend, gradient_backend, include_transformed", - [("jax", "jax", True), ("jax", "pytensor", False)], + "backend, gradient_backend, include_transformed, compute_covariance", + [("jax", "jax", True, True), ("jax", "pytensor", False, False)], ids=str, ) def test_find_MAP( @@ -145,6 +145,7 @@ def test_find_MAP( backend, gradient_backend: GradientBackend, include_transformed, + compute_covariance, rng, ): pytest.importorskip("jax") @@ -164,6 +165,7 @@ def test_find_MAP( include_transformed=include_transformed, compile_kwargs={"mode": backend.upper()}, maxiter=5, + compute_covariance=compute_covariance, ) assert hasattr(idata, "posterior") @@ -184,6 +186,8 @@ def test_find_MAP( else: assert not hasattr(idata, "unconstrained_posterior") + assert ("covariance_matrix" in idata.fit) == compute_covariance + def test_find_map_outside_model_context(): """ From 3fdd10126af6c66ad3557dc2aede8ad9ea7c1cc6 Mon Sep 17 00:00:00 2001 From: jessegrabowski Date: Tue, 16 Sep 2025 09:27:44 -0500 Subject: [PATCH 2/3] Set `compute_covariance` to False by default --- pymc_extras/inference/laplace_approx/find_map.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/pymc_extras/inference/laplace_approx/find_map.py b/pymc_extras/inference/laplace_approx/find_map.py index fa4bab888..376ebe8e8 100644 --- a/pymc_extras/inference/laplace_approx/find_map.py +++ b/pymc_extras/inference/laplace_approx/find_map.py @@ -198,7 +198,7 @@ def find_MAP( include_transformed: bool = True, gradient_backend: GradientBackend = "pytensor", compile_kwargs: dict | None = None, - compute_covariance: bool = True, + compute_covariance: bool = False, **optimizer_kwargs, ) -> ( dict[str, np.ndarray] @@ -243,7 +243,7 @@ def find_MAP( compute_covariance: bool If True, the inverse Hessian matrix at the optimum will be computed and included in the returned InferenceData object. This is needed for the Laplace approximation, but can be computationally expensive for - high-dimensional problems. Defaults to True. + high-dimensional problems. Defaults to False. compile_kwargs: dict, optional Additional options to pass to the ``pytensor.function`` function when compiling loss functions. **optimizer_kwargs From 2cd7fcb52faa30a26521978f9dc676cd838e9a6c Mon Sep 17 00:00:00 2001 From: jessegrabowski Date: Tue, 16 Sep 2025 21:53:31 -0500 Subject: [PATCH 3/3] Rename `compute_covariance` to `compute_hessian` --- pymc_extras/inference/laplace_approx/find_map.py | 6 +++--- pymc_extras/inference/laplace_approx/laplace.py | 2 +- tests/inference/laplace_approx/test_find_map.py | 8 ++++---- 3 files changed, 8 insertions(+), 8 deletions(-) diff --git a/pymc_extras/inference/laplace_approx/find_map.py b/pymc_extras/inference/laplace_approx/find_map.py index 376ebe8e8..73ad741df 100644 --- a/pymc_extras/inference/laplace_approx/find_map.py +++ b/pymc_extras/inference/laplace_approx/find_map.py @@ -198,7 +198,7 @@ def find_MAP( include_transformed: bool = True, gradient_backend: GradientBackend = "pytensor", compile_kwargs: dict | None = None, - compute_covariance: bool = False, + compute_hessian: bool = False, **optimizer_kwargs, ) -> ( dict[str, np.ndarray] @@ -240,7 +240,7 @@ def find_MAP( Whether to include transformed variable values in the returned dictionary. Defaults to True. gradient_backend: str, default "pytensor" Which backend to use to compute gradients. Must be one of "pytensor" or "jax". - compute_covariance: bool + compute_hessian: bool If True, the inverse Hessian matrix at the optimum will be computed and included in the returned InferenceData object. This is needed for the Laplace approximation, but can be computationally expensive for high-dimensional problems. Defaults to False. @@ -321,7 +321,7 @@ def find_MAP( **optimizer_kwargs, ) - if compute_covariance: + if compute_hessian: H_inv = _compute_inverse_hessian( optimizer_result=optimizer_result, optimal_point=None, diff --git a/pymc_extras/inference/laplace_approx/laplace.py b/pymc_extras/inference/laplace_approx/laplace.py index 6c2fd84c8..ab5203587 100644 --- a/pymc_extras/inference/laplace_approx/laplace.py +++ b/pymc_extras/inference/laplace_approx/laplace.py @@ -389,7 +389,7 @@ def fit_laplace( include_transformed=include_transformed, gradient_backend=gradient_backend, compile_kwargs=compile_kwargs, - compute_covariance=True, + compute_hessian=True, **optimizer_kwargs, ) diff --git a/tests/inference/laplace_approx/test_find_map.py b/tests/inference/laplace_approx/test_find_map.py index 08f1ff501..3529f972d 100644 --- a/tests/inference/laplace_approx/test_find_map.py +++ b/tests/inference/laplace_approx/test_find_map.py @@ -133,7 +133,7 @@ def compute_z(x): ], ) @pytest.mark.parametrize( - "backend, gradient_backend, include_transformed, compute_covariance", + "backend, gradient_backend, include_transformed, compute_hessian", [("jax", "jax", True, True), ("jax", "pytensor", False, False)], ids=str, ) @@ -145,7 +145,7 @@ def test_find_MAP( backend, gradient_backend: GradientBackend, include_transformed, - compute_covariance, + compute_hessian, rng, ): pytest.importorskip("jax") @@ -165,7 +165,7 @@ def test_find_MAP( include_transformed=include_transformed, compile_kwargs={"mode": backend.upper()}, maxiter=5, - compute_covariance=compute_covariance, + compute_hessian=compute_hessian, ) assert hasattr(idata, "posterior") @@ -186,7 +186,7 @@ def test_find_MAP( else: assert not hasattr(idata, "unconstrained_posterior") - assert ("covariance_matrix" in idata.fit) == compute_covariance + assert ("covariance_matrix" in idata.fit) == compute_hessian def test_find_map_outside_model_context():