|
1 | 1 | import logging |
2 | 2 |
|
3 | 3 | from collections.abc import Callable |
4 | | -from typing import Literal, cast |
| 4 | +from typing import cast |
5 | 5 |
|
6 | | -import arviz as az |
7 | 6 | import jax |
8 | 7 | import numpy as np |
9 | 8 | import pymc as pm |
10 | 9 | import pytensor |
11 | 10 | import pytensor.tensor as pt |
12 | | -import xarray as xr |
13 | 11 |
|
14 | | -from arviz import dict_to_dataset |
15 | 12 | from better_optimize import minimize |
16 | | -from better_optimize.constants import minimize_method |
17 | | -from pymc.backends.arviz import ( |
18 | | - coords_and_dims_for_inferencedata, |
19 | | - find_constants, |
20 | | - find_observations, |
21 | | -) |
| 13 | +from better_optimize.constants import MINIMIZE_MODE_KWARGS, minimize_method |
22 | 14 | from pymc.blocking import DictToArrayBijection, RaveledVars |
23 | 15 | from pymc.initial_point import make_initial_point_fn |
24 | | -from pymc.model.transform.conditioning import remove_value_transforms |
25 | 16 | from pymc.model.transform.optimization import freeze_dims_and_data |
26 | 17 | from pymc.pytensorf import join_nonshared_inputs |
27 | 18 | from pymc.util import get_default_varnames |
28 | 19 | from pytensor.compile import Function |
29 | 20 | from pytensor.tensor import TensorVariable |
30 | | -from scipy import stats |
31 | 21 | from scipy.optimize import OptimizeResult |
32 | 22 |
|
33 | 23 | _log = logging.getLogger(__name__) |
34 | 24 |
|
35 | 25 |
|
| 26 | +def set_optimizer_function_defaults(method, use_grad, use_hess, use_hessp): |
| 27 | + method_info = MINIMIZE_MODE_KWARGS[method].copy() |
| 28 | + |
| 29 | + use_grad = use_grad if use_grad is not None else method_info["uses_grad"] |
| 30 | + use_hess = use_hess if use_hess is not None else method_info["uses_hess"] |
| 31 | + use_hessp = use_hessp if use_hessp is not None else method_info["uses_hessp"] |
| 32 | + |
| 33 | + if use_hess and use_hessp: |
| 34 | + use_hess = False |
| 35 | + |
| 36 | + return use_grad, use_hess, use_hessp |
| 37 | + |
| 38 | + |
36 | 39 | def get_nearest_psd(A: np.ndarray) -> np.ndarray: |
37 | 40 | """ |
38 | 41 | Compute the nearest positive semi-definite matrix to a given matrix. |
@@ -60,7 +63,9 @@ def get_nearest_psd(A: np.ndarray) -> np.ndarray: |
60 | 63 |
|
61 | 64 | def _unconstrained_vector_to_constrained_rvs(model): |
62 | 65 | constrained_rvs, unconstrained_vector = join_nonshared_inputs( |
63 | | - model.initial_point(), inputs=model.value_vars, outputs=model.unobserved_value_vars |
| 66 | + model.initial_point(), |
| 67 | + inputs=model.value_vars, |
| 68 | + outputs=get_default_varnames(model.unobserved_value_vars, include_transformed=False), |
64 | 69 | ) |
65 | 70 |
|
66 | 71 | unconstrained_vector.name = "unconstrained_vector" |
@@ -289,247 +294,6 @@ def scipy_optimize_funcs_from_loss( |
289 | 294 | return f_loss, f_hess, f_hessp |
290 | 295 |
|
291 | 296 |
|
292 | | -def fit_mvn_to_MAP( |
293 | | - optimized_point: dict[str, np.ndarray], |
294 | | - model: pm.Model, |
295 | | - on_bad_cov: Literal["warn", "error", "ignore"] = "ignore", |
296 | | - transform_samples: bool = True, |
297 | | - use_jax_gradients: bool = False, |
298 | | - zero_tol: float = 1e-8, |
299 | | - diag_jitter: float | None = 1e-8, |
300 | | - compile_kwargs: dict | None = None, |
301 | | -) -> tuple[RaveledVars, np.ndarray]: |
302 | | - """ |
303 | | - Create a multivariate normal distribution using the inverse of the negative Hessian matrix of the log-posterior |
304 | | - evaluated at the MAP estimate. This is the basis of the Laplace approximation. |
305 | | -
|
306 | | - Parameters |
307 | | - ---------- |
308 | | - optimized_point : dict[str, np.ndarray] |
309 | | - Local maximum a posteriori (MAP) point returned from pymc.find_MAP or jax_tools.fit_map |
310 | | - model : Model |
311 | | - A PyMC model |
312 | | - on_bad_cov : str, one of 'ignore', 'warn', or 'error', default: 'ignore' |
313 | | - What to do when ``H_inv`` (inverse Hessian) is not positive semi-definite. |
314 | | - If 'ignore' or 'warn', the closest positive-semi-definite matrix to ``H_inv`` (in L1 norm) will be returned. |
315 | | - If 'error', an error will be raised. |
316 | | - transform_samples : bool |
317 | | - Whether to transform the samples back to the original parameter space. Default is True. |
318 | | - zero_tol: float |
319 | | - Value below which an element of the Hessian matrix is counted as 0. |
320 | | - This is used to stabilize the computation of the inverse Hessian matrix. Default is 1e-8. |
321 | | - diag_jitter: float | None |
322 | | - A small value added to the diagonal of the inverse Hessian matrix to ensure it is positive semi-definite. |
323 | | - If None, no jitter is added. Default is 1e-8. |
324 | | -
|
325 | | - Returns |
326 | | - ------- |
327 | | - map_estimate: RaveledVars |
328 | | - The MAP estimate of the model parameters, raveled into a 1D array. |
329 | | -
|
330 | | - inverse_hessian: np.ndarray |
331 | | - The inverse Hessian matrix of the log-posterior evaluated at the MAP estimate. |
332 | | - """ |
333 | | - compile_kwargs = {} if compile_kwargs is None else compile_kwargs |
334 | | - frozen_model = freeze_dims_and_data(model) |
335 | | - |
336 | | - if not transform_samples: |
337 | | - untransformed_model = remove_value_transforms(frozen_model) |
338 | | - logp = untransformed_model.logp(jacobian=False) |
339 | | - variables = untransformed_model.continuous_value_vars |
340 | | - else: |
341 | | - logp = frozen_model.logp(jacobian=True) |
342 | | - variables = frozen_model.continuous_value_vars |
343 | | - |
344 | | - variable_names = {var.name for var in variables} |
345 | | - optimized_free_params = {k: v for k, v in optimized_point.items() if k in variable_names} |
346 | | - mu = DictToArrayBijection.map(optimized_free_params) |
347 | | - |
348 | | - _, f_hess, _ = scipy_optimize_funcs_from_loss( |
349 | | - loss=-logp, |
350 | | - inputs=variables, |
351 | | - initial_point_dict=frozen_model.initial_point(), |
352 | | - use_grad=True, |
353 | | - use_hess=True, |
354 | | - use_hessp=False, |
355 | | - use_jax_gradients=use_jax_gradients, |
356 | | - compile_kwargs=compile_kwargs, |
357 | | - ) |
358 | | - |
359 | | - H = -f_hess(mu.data) |
360 | | - H_inv = np.linalg.pinv(np.where(np.abs(H) < zero_tol, 0, -H)) |
361 | | - |
362 | | - def stabilize(x, jitter): |
363 | | - return x + np.eye(x.shape[0]) * jitter |
364 | | - |
365 | | - H_inv = H_inv if diag_jitter is None else stabilize(H_inv, diag_jitter) |
366 | | - |
367 | | - try: |
368 | | - np.linalg.cholesky(H_inv) |
369 | | - except np.linalg.LinAlgError: |
370 | | - if on_bad_cov == "error": |
371 | | - raise np.linalg.LinAlgError( |
372 | | - "Inverse Hessian not positive-semi definite at the provided point" |
373 | | - ) |
374 | | - H_inv = get_nearest_psd(H_inv) |
375 | | - if on_bad_cov == "warn": |
376 | | - _log.warning( |
377 | | - "Inverse Hessian is not positive semi-definite at the provided point, using the closest PSD " |
378 | | - "matrix in L1-norm instead" |
379 | | - ) |
380 | | - |
381 | | - return mu, H_inv |
382 | | - |
383 | | - |
384 | | -def laplace( |
385 | | - mu: RaveledVars, |
386 | | - H_inv: np.ndarray, |
387 | | - model: pm.Model, |
388 | | - chains: int = 2, |
389 | | - draws: int = 500, |
390 | | - transform_samples: bool = True, |
391 | | - progressbar: bool = True, |
392 | | - **compile_kwargs, |
393 | | -) -> az.InferenceData: |
394 | | - """ |
395 | | -
|
396 | | - Parameters |
397 | | - ---------- |
398 | | - mu |
399 | | - H_inv |
400 | | - model : Model |
401 | | - A PyMC model |
402 | | - chains : int |
403 | | - The number of sampling chains running in parallel. Default is 2. |
404 | | - draws : int |
405 | | - The number of samples to draw from the approximated posterior. Default is 500. |
406 | | - transform_samples : bool |
407 | | - Whether to transform the samples back to the original parameter space. Default is True. |
408 | | -
|
409 | | - Returns |
410 | | - ------- |
411 | | - idata: az.InferenceData |
412 | | - An InferenceData object containing the approximated posterior samples. |
413 | | - """ |
414 | | - posterior_dist = stats.multivariate_normal(mean=mu.data, cov=H_inv, allow_singular=True) |
415 | | - posterior_draws = posterior_dist.rvs(size=(chains, draws)) |
416 | | - |
417 | | - if transform_samples: |
418 | | - constrained_rvs, unconstrained_vector = _unconstrained_vector_to_constrained_rvs(model) |
419 | | - batched_values = pt.tensor( |
420 | | - "batched_values", |
421 | | - shape=(chains, draws, *unconstrained_vector.type.shape), |
422 | | - dtype=unconstrained_vector.type.dtype, |
423 | | - ) |
424 | | - batched_rvs = pytensor.graph.vectorize_graph( |
425 | | - constrained_rvs, replace={unconstrained_vector: batched_values} |
426 | | - ) |
427 | | - |
428 | | - f_constrain = pm.compile_pymc( |
429 | | - inputs=[batched_values], outputs=batched_rvs, **compile_kwargs |
430 | | - ) |
431 | | - posterior_draws = f_constrain(posterior_draws) |
432 | | - |
433 | | - else: |
434 | | - info = mu.point_map_info |
435 | | - flat_shapes = [np.prod(shape).astype(int) for _, shape, _ in info] |
436 | | - slices = [ |
437 | | - slice(sum(flat_shapes[:i]), sum(flat_shapes[: i + 1])) for i in range(len(flat_shapes)) |
438 | | - ] |
439 | | - |
440 | | - posterior_draws = [ |
441 | | - posterior_draws[..., idx].reshape((chains, draws, *shape)).astype(dtype) |
442 | | - for idx, (name, shape, dtype) in zip(slices, info) |
443 | | - ] |
444 | | - |
445 | | - def make_rv_coords(name): |
446 | | - coords = {"chain": range(chains), "draw": range(draws)} |
447 | | - extra_dims = model.named_vars_to_dims.get(name) |
448 | | - if extra_dims is None: |
449 | | - return coords |
450 | | - return coords | {dim: list(model.coords[dim]) for dim in extra_dims} |
451 | | - |
452 | | - def make_rv_dims(name): |
453 | | - dims = ["chain", "draw"] |
454 | | - extra_dims = model.named_vars_to_dims.get(name) |
455 | | - if extra_dims is None: |
456 | | - return dims |
457 | | - return dims + list(extra_dims) |
458 | | - |
459 | | - idata = { |
460 | | - name: xr.DataArray( |
461 | | - data=draws.squeeze(), |
462 | | - coords=make_rv_coords(name), |
463 | | - dims=make_rv_dims(name), |
464 | | - name=name, |
465 | | - ) |
466 | | - for (name, _, _), draws in zip(mu.point_map_info, posterior_draws) |
467 | | - } |
468 | | - |
469 | | - coords, dims = coords_and_dims_for_inferencedata(model) |
470 | | - idata = az.convert_to_inference_data(idata, coords=coords, dims=dims) |
471 | | - |
472 | | - if model.deterministics: |
473 | | - idata.posterior = pm.compute_deterministics( |
474 | | - idata.posterior, |
475 | | - model=model, |
476 | | - merge_dataset=True, |
477 | | - progressbar=progressbar, |
478 | | - compile_kwargs=compile_kwargs, |
479 | | - ) |
480 | | - |
481 | | - observed_data = dict_to_dataset( |
482 | | - find_observations(model), |
483 | | - library=pm, |
484 | | - coords=coords, |
485 | | - dims=dims, |
486 | | - default_dims=[], |
487 | | - ) |
488 | | - |
489 | | - constant_data = dict_to_dataset( |
490 | | - find_constants(model), |
491 | | - library=pm, |
492 | | - coords=coords, |
493 | | - dims=dims, |
494 | | - default_dims=[], |
495 | | - ) |
496 | | - |
497 | | - idata.add_groups( |
498 | | - {"observed_data": observed_data, "constant_data": constant_data}, |
499 | | - coords=coords, |
500 | | - dims=dims, |
501 | | - ) |
502 | | - |
503 | | - return idata |
504 | | - |
505 | | - |
506 | | -def fit_laplace( |
507 | | - optimized_point: dict[str, np.ndarray], |
508 | | - model: pm.Model, |
509 | | - chains: int = 2, |
510 | | - draws: int = 500, |
511 | | - on_bad_cov: Literal["warn", "error", "ignore"] = "ignore", |
512 | | - transform_samples: bool = True, |
513 | | - zero_tol: float = 1e-8, |
514 | | - diag_jitter: float | None = 1e-8, |
515 | | - progressbar: bool = True, |
516 | | - compile_kwargs: dict | None = None, |
517 | | -) -> az.InferenceData: |
518 | | - compile_kwargs = {} if compile_kwargs is None else compile_kwargs |
519 | | - |
520 | | - mu, H_inv = fit_mvn_to_MAP( |
521 | | - optimized_point=optimized_point, |
522 | | - model=model, |
523 | | - on_bad_cov=on_bad_cov, |
524 | | - transform_samples=transform_samples, |
525 | | - zero_tol=zero_tol, |
526 | | - diag_jitter=diag_jitter, |
527 | | - compile_kwargs=compile_kwargs, |
528 | | - ) |
529 | | - |
530 | | - return laplace(mu, H_inv, model, chains, draws, transform_samples, progressbar) |
531 | | - |
532 | | - |
533 | 297 | def find_MAP( |
534 | 298 | method: minimize_method, |
535 | 299 | *, |
@@ -605,9 +369,12 @@ def find_MAP( |
605 | 369 | initial_params = DictToArrayBijection.map( |
606 | 370 | {var_name: value for var_name, value in start_dict.items() if var_name in vars_dict} |
607 | 371 | ) |
| 372 | + use_grad, use_hess, use_hessp = set_optimizer_function_defaults( |
| 373 | + method, use_grad, use_hess, use_hessp |
| 374 | + ) |
608 | 375 |
|
609 | 376 | f_logp, f_hess, f_hessp = scipy_optimize_funcs_from_loss( |
610 | | - loss=-frozen_model.logp(), |
| 377 | + loss=-frozen_model.logp(jacobian=False), |
611 | 378 | inputs=frozen_model.continuous_value_vars + frozen_model.discrete_value_vars, |
612 | 379 | initial_point_dict=start_dict, |
613 | 380 | use_grad=use_grad, |
|
0 commit comments