Skip to content

Commit 2cd7fcb

Browse files
Rename compute_covariance to compute_hessian
1 parent 3fdd101 commit 2cd7fcb

File tree

3 files changed

+8
-8
lines changed

3 files changed

+8
-8
lines changed

pymc_extras/inference/laplace_approx/find_map.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -198,7 +198,7 @@ def find_MAP(
198198
include_transformed: bool = True,
199199
gradient_backend: GradientBackend = "pytensor",
200200
compile_kwargs: dict | None = None,
201-
compute_covariance: bool = False,
201+
compute_hessian: bool = False,
202202
**optimizer_kwargs,
203203
) -> (
204204
dict[str, np.ndarray]
@@ -240,7 +240,7 @@ def find_MAP(
240240
Whether to include transformed variable values in the returned dictionary. Defaults to True.
241241
gradient_backend: str, default "pytensor"
242242
Which backend to use to compute gradients. Must be one of "pytensor" or "jax".
243-
compute_covariance: bool
243+
compute_hessian: bool
244244
If True, the inverse Hessian matrix at the optimum will be computed and included in the returned
245245
InferenceData object. This is needed for the Laplace approximation, but can be computationally expensive for
246246
high-dimensional problems. Defaults to False.
@@ -321,7 +321,7 @@ def find_MAP(
321321
**optimizer_kwargs,
322322
)
323323

324-
if compute_covariance:
324+
if compute_hessian:
325325
H_inv = _compute_inverse_hessian(
326326
optimizer_result=optimizer_result,
327327
optimal_point=None,

pymc_extras/inference/laplace_approx/laplace.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -389,7 +389,7 @@ def fit_laplace(
389389
include_transformed=include_transformed,
390390
gradient_backend=gradient_backend,
391391
compile_kwargs=compile_kwargs,
392-
compute_covariance=True,
392+
compute_hessian=True,
393393
**optimizer_kwargs,
394394
)
395395

tests/inference/laplace_approx/test_find_map.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -133,7 +133,7 @@ def compute_z(x):
133133
],
134134
)
135135
@pytest.mark.parametrize(
136-
"backend, gradient_backend, include_transformed, compute_covariance",
136+
"backend, gradient_backend, include_transformed, compute_hessian",
137137
[("jax", "jax", True, True), ("jax", "pytensor", False, False)],
138138
ids=str,
139139
)
@@ -145,7 +145,7 @@ def test_find_MAP(
145145
backend,
146146
gradient_backend: GradientBackend,
147147
include_transformed,
148-
compute_covariance,
148+
compute_hessian,
149149
rng,
150150
):
151151
pytest.importorskip("jax")
@@ -165,7 +165,7 @@ def test_find_MAP(
165165
include_transformed=include_transformed,
166166
compile_kwargs={"mode": backend.upper()},
167167
maxiter=5,
168-
compute_covariance=compute_covariance,
168+
compute_hessian=compute_hessian,
169169
)
170170

171171
assert hasattr(idata, "posterior")
@@ -186,7 +186,7 @@ def test_find_MAP(
186186
else:
187187
assert not hasattr(idata, "unconstrained_posterior")
188188

189-
assert ("covariance_matrix" in idata.fit) == compute_covariance
189+
assert ("covariance_matrix" in idata.fit) == compute_hessian
190190

191191

192192
def test_find_map_outside_model_context():

0 commit comments

Comments
 (0)