Skip to content

Commit a81079b

Browse files
Improve docstring for fit_laplace
1 parent f2504e9 commit a81079b

File tree

1 file changed

+19
-8
lines changed

1 file changed

+19
-8
lines changed

pymc_experimental/inference/laplace.py

Lines changed: 19 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -425,7 +425,7 @@ def fit_laplace(
425425
chains: int = 2,
426426
draws: int = 500,
427427
on_bad_cov: Literal["warn", "error", "ignore"] = "ignore",
428-
transform_samples: bool = False,
428+
fit_in_unconstrained_space: bool = False,
429429
zero_tol: float = 1e-8,
430430
diag_jitter: float | None = 1e-8,
431431
optimizer_kwargs: dict | None = None,
@@ -464,8 +464,17 @@ def fit_laplace(
464464
Variables whose initial values should be jittered. If None, all variables are jittered.
465465
progressbar : bool, optional
466466
Whether to display a progress bar during optimization. Defaults to True.
467-
include_transformed: bool, optional
468-
Whether to include transformed variable values in the returned dictionary. Defaults to True.
467+
fit_in_unconstrained_space: bool, default False
468+
Whether to fit the Laplace approximation in the unconstrained parameter space. If True, samples will be drawn
469+
from a mean and covariance matrix computed at a point in the **unconstrained** parameter space. Samples will
470+
then be transformed back to the original parameter space. This will guarantee that the samples will respect
471+
the domain of prior distributions (for exmaple, samples from a Beta distribution will be strictly between 0
472+
and 1).
473+
474+
.. warning::
475+
This argumnet should be considered highly experimental. It has not been verified if this method produces
476+
valid draws from the posterior. **Use at your own risk**.
477+
469478
gradient_backend: str, default "pytensor"
470479
The backend to use for gradient computations. Must be one of "pytensor" or "jax".
471480
chains: int, default: 2
@@ -476,15 +485,17 @@ def fit_laplace(
476485
What to do when ``H_inv`` (inverse Hessian) is not positive semi-definite.
477486
If 'ignore' or 'warn', the closest positive-semi-definite matrix to ``H_inv`` (in L1 norm) will be returned.
478487
If 'error', an error will be raised.
479-
transform_samples : bool
480-
Whether to transform the samples back to the original parameter space. Default is True.
481488
zero_tol: float
482489
Value below which an element of the Hessian matrix is counted as 0.
483490
This is used to stabilize the computation of the inverse Hessian matrix. Default is 1e-8.
484491
diag_jitter: float | None
485492
A small value added to the diagonal of the inverse Hessian matrix to ensure it is positive semi-definite.
486493
If None, no jitter is added. Default is 1e-8.
487-
compile_kwargs: optional
494+
optimizer_kwargs: dict, optional
495+
Additional keyword arguments to pass to scipy.minimize. See the documentation for scipy.optimize.minimize for
496+
details. Arguments that are typically passed via ``options`` will be automatically extracted without the need
497+
to use a nested dictionary.
498+
compile_kwargs: dict, optional
488499
Additional keyword arguments to pass to pytensor.function.
489500
490501
Returns
@@ -540,7 +551,7 @@ def fit_laplace(
540551
optimized_point=optimized_point,
541552
model=model,
542553
on_bad_cov=on_bad_cov,
543-
transform_samples=transform_samples,
554+
transform_samples=fit_in_unconstrained_space,
544555
zero_tol=zero_tol,
545556
diag_jitter=diag_jitter,
546557
compile_kwargs=compile_kwargs,
@@ -552,7 +563,7 @@ def fit_laplace(
552563
model=model,
553564
chains=chains,
554565
draws=draws,
555-
transform_samples=transform_samples,
566+
transform_samples=fit_in_unconstrained_space,
556567
progressbar=progressbar,
557568
random_seed=random_seed,
558569
compile_kwargs=compile_kwargs,

0 commit comments

Comments
 (0)