@@ -425,7 +425,7 @@ def fit_laplace(
425
425
chains : int = 2 ,
426
426
draws : int = 500 ,
427
427
on_bad_cov : Literal ["warn" , "error" , "ignore" ] = "ignore" ,
428
- transform_samples : bool = False ,
428
+ fit_in_unconstrained_space : bool = False ,
429
429
zero_tol : float = 1e-8 ,
430
430
diag_jitter : float | None = 1e-8 ,
431
431
optimizer_kwargs : dict | None = None ,
@@ -464,8 +464,17 @@ def fit_laplace(
464
464
Variables whose initial values should be jittered. If None, all variables are jittered.
465
465
progressbar : bool, optional
466
466
Whether to display a progress bar during optimization. Defaults to True.
467
- include_transformed: bool, optional
468
- Whether to include transformed variable values in the returned dictionary. Defaults to True.
467
+ fit_in_unconstrained_space: bool, default False
468
+ Whether to fit the Laplace approximation in the unconstrained parameter space. If True, samples will be drawn
469
+ from a mean and covariance matrix computed at a point in the **unconstrained** parameter space. Samples will
470
+ then be transformed back to the original parameter space. This will guarantee that the samples will respect
471
+ the domain of prior distributions (for exmaple, samples from a Beta distribution will be strictly between 0
472
+ and 1).
473
+
474
+ .. warning::
475
+ This argumnet should be considered highly experimental. It has not been verified if this method produces
476
+ valid draws from the posterior. **Use at your own risk**.
477
+
469
478
gradient_backend: str, default "pytensor"
470
479
The backend to use for gradient computations. Must be one of "pytensor" or "jax".
471
480
chains: int, default: 2
@@ -476,15 +485,17 @@ def fit_laplace(
476
485
What to do when ``H_inv`` (inverse Hessian) is not positive semi-definite.
477
486
If 'ignore' or 'warn', the closest positive-semi-definite matrix to ``H_inv`` (in L1 norm) will be returned.
478
487
If 'error', an error will be raised.
479
- transform_samples : bool
480
- Whether to transform the samples back to the original parameter space. Default is True.
481
488
zero_tol: float
482
489
Value below which an element of the Hessian matrix is counted as 0.
483
490
This is used to stabilize the computation of the inverse Hessian matrix. Default is 1e-8.
484
491
diag_jitter: float | None
485
492
A small value added to the diagonal of the inverse Hessian matrix to ensure it is positive semi-definite.
486
493
If None, no jitter is added. Default is 1e-8.
487
- compile_kwargs: optional
494
+ optimizer_kwargs: dict, optional
495
+ Additional keyword arguments to pass to scipy.minimize. See the documentation for scipy.optimize.minimize for
496
+ details. Arguments that are typically passed via ``options`` will be automatically extracted without the need
497
+ to use a nested dictionary.
498
+ compile_kwargs: dict, optional
488
499
Additional keyword arguments to pass to pytensor.function.
489
500
490
501
Returns
@@ -540,7 +551,7 @@ def fit_laplace(
540
551
optimized_point = optimized_point ,
541
552
model = model ,
542
553
on_bad_cov = on_bad_cov ,
543
- transform_samples = transform_samples ,
554
+ transform_samples = fit_in_unconstrained_space ,
544
555
zero_tol = zero_tol ,
545
556
diag_jitter = diag_jitter ,
546
557
compile_kwargs = compile_kwargs ,
@@ -552,7 +563,7 @@ def fit_laplace(
552
563
model = model ,
553
564
chains = chains ,
554
565
draws = draws ,
555
- transform_samples = transform_samples ,
566
+ transform_samples = fit_in_unconstrained_space ,
556
567
progressbar = progressbar ,
557
568
random_seed = random_seed ,
558
569
compile_kwargs = compile_kwargs ,
0 commit comments