|
34 | 34 |
|
35 | 35 | from pymc_extras.inference.laplace_approx.find_map import ( |
36 | 36 | _compute_inverse_hessian, |
37 | | - _make_inital_point, |
| 37 | + _make_initial_point, |
38 | 38 | find_MAP, |
39 | 39 | ) |
40 | 40 | from pymc_extras.inference.laplace_approx.scipy_interface import scipy_optimize_funcs_from_loss |
@@ -228,7 +228,6 @@ def fit_laplace( |
228 | 228 | use_hess: bool | None = None, |
229 | 229 | initvals: dict | None = None, |
230 | 230 | random_seed: int | np.random.Generator | None = None, |
231 | | - return_raw: bool = False, |
232 | 231 | jitter_rvs: list[pt.TensorVariable] | None = None, |
233 | 232 | progressbar: bool = True, |
234 | 233 | include_transformed: bool = True, |
@@ -268,23 +267,13 @@ def fit_laplace( |
268 | 267 | If None, the model's default initial values are used. |
269 | 268 | random_seed : None | int | np.random.Generator, optional |
270 | 269 | Seed for the random number generator or a numpy Generator for reproducibility |
271 | | - return_raw: bool | False, optinal |
272 | | - Whether to also return the full output of `scipy.optimize.minimize` |
273 | 270 | jitter_rvs : list of TensorVariables, optional |
274 | 271 | Variables whose initial values should be jittered. If None, all variables are jittered. |
275 | 272 | progressbar : bool, optional |
276 | 273 | Whether to display a progress bar during optimization. Defaults to True. |
277 | | - fit_in_unconstrained_space: bool, default False |
278 | | - Whether to fit the Laplace approximation in the unconstrained parameter space. If True, samples will be drawn |
279 | | - from a mean and covariance matrix computed at a point in the **unconstrained** parameter space. Samples will |
280 | | - then be transformed back to the original parameter space. This will guarantee that the samples will respect |
281 | | - the domain of prior distributions (for exmaple, samples from a Beta distribution will be strictly between 0 |
282 | | - and 1). |
283 | | -
|
284 | | - .. warning:: |
285 | | - This argument should be considered highly experimental. It has not been verified if this method produces |
286 | | - valid draws from the posterior. **Use at your own risk**. |
287 | | -
|
| 274 | + include_transformed: bool, default True |
| 275 | + Whether to include transformed variables in the output. If True, transformed variables will be included in the |
| 276 | + output InferenceData object. If False, only the original variables will be included. |
288 | 277 | gradient_backend: str, default "pytensor" |
289 | 278 | The backend to use for gradient computations. Must be one of "pytensor" or "jax". |
290 | 279 | chains: int, default: 2 |
@@ -365,7 +354,7 @@ def fit_laplace( |
365 | 354 | # The user didn't use `use_hess` or `use_hessp` (or an optimization method that returns an inverse Hessian), so |
366 | 355 | # we have to go back and compute the Hessian at the MAP point now. |
367 | 356 | frozen_model = freeze_dims_and_data(model) |
368 | | - initial_params = _make_inital_point(frozen_model, initvals, random_seed, jitter_rvs) |
| 357 | + initial_params = _make_initial_point(frozen_model, initvals, random_seed, jitter_rvs) |
369 | 358 |
|
370 | 359 | _, f_hessp = scipy_optimize_funcs_from_loss( |
371 | 360 | loss=-frozen_model.logp(jacobian=False), |
@@ -405,9 +394,9 @@ def fit_laplace( |
405 | 394 | .rename({"temp_chain": "chain", "temp_draw": "draw"}) |
406 | 395 | ) |
407 | 396 |
|
408 | | - new_posterior.update(unstack_laplace_draws(new_posterior, model)).drop_vars( |
409 | | - "laplace_approximation" |
410 | | - ) |
| 397 | + new_posterior.update(unstack_laplace_draws(new_posterior, model)) |
| 398 | + new_posterior = new_posterior.drop_vars("laplace_approximation") |
| 399 | + |
411 | 400 | idata.posterior.update(new_posterior) |
412 | 401 |
|
413 | 402 | return idata |
0 commit comments