|
1 | 1 | import logging
|
2 | 2 |
|
3 | 3 | from collections.abc import Callable
|
4 |
| -from typing import Literal, cast |
| 4 | +from typing import cast |
5 | 5 |
|
6 |
| -import arviz as az |
7 | 6 | import jax
|
8 | 7 | import numpy as np
|
9 | 8 | import pymc as pm
|
10 | 9 | import pytensor
|
11 | 10 | import pytensor.tensor as pt
|
12 |
| -import xarray as xr |
13 | 11 |
|
14 |
| -from arviz import dict_to_dataset |
15 | 12 | from better_optimize import minimize
|
16 |
| -from better_optimize.constants import minimize_method |
17 |
| -from pymc.backends.arviz import ( |
18 |
| - coords_and_dims_for_inferencedata, |
19 |
| - find_constants, |
20 |
| - find_observations, |
21 |
| -) |
| 13 | +from better_optimize.constants import MINIMIZE_MODE_KWARGS, minimize_method |
22 | 14 | from pymc.blocking import DictToArrayBijection, RaveledVars
|
23 | 15 | from pymc.initial_point import make_initial_point_fn
|
24 |
| -from pymc.model.transform.conditioning import remove_value_transforms |
25 | 16 | from pymc.model.transform.optimization import freeze_dims_and_data
|
26 | 17 | from pymc.pytensorf import join_nonshared_inputs
|
27 | 18 | from pymc.util import get_default_varnames
|
28 | 19 | from pytensor.compile import Function
|
29 | 20 | from pytensor.tensor import TensorVariable
|
30 |
| -from scipy import stats |
31 | 21 | from scipy.optimize import OptimizeResult
|
32 | 22 |
|
33 | 23 | _log = logging.getLogger(__name__)
|
34 | 24 |
|
35 | 25 |
|
| 26 | +def set_optimizer_function_defaults(method, use_grad, use_hess, use_hessp): |
| 27 | + method_info = MINIMIZE_MODE_KWARGS[method].copy() |
| 28 | + |
| 29 | + use_grad = use_grad if use_grad is not None else method_info["uses_grad"] |
| 30 | + use_hess = use_hess if use_hess is not None else method_info["uses_hess"] |
| 31 | + use_hessp = use_hessp if use_hessp is not None else method_info["uses_hessp"] |
| 32 | + |
| 33 | + if use_hess and use_hessp: |
| 34 | + use_hess = False |
| 35 | + |
| 36 | + return use_grad, use_hess, use_hessp |
| 37 | + |
| 38 | + |
36 | 39 | def get_nearest_psd(A: np.ndarray) -> np.ndarray:
|
37 | 40 | """
|
38 | 41 | Compute the nearest positive semi-definite matrix to a given matrix.
|
@@ -60,7 +63,9 @@ def get_nearest_psd(A: np.ndarray) -> np.ndarray:
|
60 | 63 |
|
61 | 64 | def _unconstrained_vector_to_constrained_rvs(model):
|
62 | 65 | constrained_rvs, unconstrained_vector = join_nonshared_inputs(
|
63 |
| - model.initial_point(), inputs=model.value_vars, outputs=model.unobserved_value_vars |
| 66 | + model.initial_point(), |
| 67 | + inputs=model.value_vars, |
| 68 | + outputs=get_default_varnames(model.unobserved_value_vars, include_transformed=False), |
64 | 69 | )
|
65 | 70 |
|
66 | 71 | unconstrained_vector.name = "unconstrained_vector"
|
@@ -289,247 +294,6 @@ def scipy_optimize_funcs_from_loss(
|
289 | 294 | return f_loss, f_hess, f_hessp
|
290 | 295 |
|
291 | 296 |
|
292 |
| -def fit_mvn_to_MAP( |
293 |
| - optimized_point: dict[str, np.ndarray], |
294 |
| - model: pm.Model, |
295 |
| - on_bad_cov: Literal["warn", "error", "ignore"] = "ignore", |
296 |
| - transform_samples: bool = True, |
297 |
| - use_jax_gradients: bool = False, |
298 |
| - zero_tol: float = 1e-8, |
299 |
| - diag_jitter: float | None = 1e-8, |
300 |
| - compile_kwargs: dict | None = None, |
301 |
| -) -> tuple[RaveledVars, np.ndarray]: |
302 |
| - """ |
303 |
| - Create a multivariate normal distribution using the inverse of the negative Hessian matrix of the log-posterior |
304 |
| - evaluated at the MAP estimate. This is the basis of the Laplace approximation. |
305 |
| -
|
306 |
| - Parameters |
307 |
| - ---------- |
308 |
| - optimized_point : dict[str, np.ndarray] |
309 |
| - Local maximum a posteriori (MAP) point returned from pymc.find_MAP or jax_tools.fit_map |
310 |
| - model : Model |
311 |
| - A PyMC model |
312 |
| - on_bad_cov : str, one of 'ignore', 'warn', or 'error', default: 'ignore' |
313 |
| - What to do when ``H_inv`` (inverse Hessian) is not positive semi-definite. |
314 |
| - If 'ignore' or 'warn', the closest positive-semi-definite matrix to ``H_inv`` (in L1 norm) will be returned. |
315 |
| - If 'error', an error will be raised. |
316 |
| - transform_samples : bool |
317 |
| - Whether to transform the samples back to the original parameter space. Default is True. |
318 |
| - zero_tol: float |
319 |
| - Value below which an element of the Hessian matrix is counted as 0. |
320 |
| - This is used to stabilize the computation of the inverse Hessian matrix. Default is 1e-8. |
321 |
| - diag_jitter: float | None |
322 |
| - A small value added to the diagonal of the inverse Hessian matrix to ensure it is positive semi-definite. |
323 |
| - If None, no jitter is added. Default is 1e-8. |
324 |
| -
|
325 |
| - Returns |
326 |
| - ------- |
327 |
| - map_estimate: RaveledVars |
328 |
| - The MAP estimate of the model parameters, raveled into a 1D array. |
329 |
| -
|
330 |
| - inverse_hessian: np.ndarray |
331 |
| - The inverse Hessian matrix of the log-posterior evaluated at the MAP estimate. |
332 |
| - """ |
333 |
| - compile_kwargs = {} if compile_kwargs is None else compile_kwargs |
334 |
| - frozen_model = freeze_dims_and_data(model) |
335 |
| - |
336 |
| - if not transform_samples: |
337 |
| - untransformed_model = remove_value_transforms(frozen_model) |
338 |
| - logp = untransformed_model.logp(jacobian=False) |
339 |
| - variables = untransformed_model.continuous_value_vars |
340 |
| - else: |
341 |
| - logp = frozen_model.logp(jacobian=True) |
342 |
| - variables = frozen_model.continuous_value_vars |
343 |
| - |
344 |
| - variable_names = {var.name for var in variables} |
345 |
| - optimized_free_params = {k: v for k, v in optimized_point.items() if k in variable_names} |
346 |
| - mu = DictToArrayBijection.map(optimized_free_params) |
347 |
| - |
348 |
| - _, f_hess, _ = scipy_optimize_funcs_from_loss( |
349 |
| - loss=-logp, |
350 |
| - inputs=variables, |
351 |
| - initial_point_dict=frozen_model.initial_point(), |
352 |
| - use_grad=True, |
353 |
| - use_hess=True, |
354 |
| - use_hessp=False, |
355 |
| - use_jax_gradients=use_jax_gradients, |
356 |
| - compile_kwargs=compile_kwargs, |
357 |
| - ) |
358 |
| - |
359 |
| - H = -f_hess(mu.data) |
360 |
| - H_inv = np.linalg.pinv(np.where(np.abs(H) < zero_tol, 0, -H)) |
361 |
| - |
362 |
| - def stabilize(x, jitter): |
363 |
| - return x + np.eye(x.shape[0]) * jitter |
364 |
| - |
365 |
| - H_inv = H_inv if diag_jitter is None else stabilize(H_inv, diag_jitter) |
366 |
| - |
367 |
| - try: |
368 |
| - np.linalg.cholesky(H_inv) |
369 |
| - except np.linalg.LinAlgError: |
370 |
| - if on_bad_cov == "error": |
371 |
| - raise np.linalg.LinAlgError( |
372 |
| - "Inverse Hessian not positive-semi definite at the provided point" |
373 |
| - ) |
374 |
| - H_inv = get_nearest_psd(H_inv) |
375 |
| - if on_bad_cov == "warn": |
376 |
| - _log.warning( |
377 |
| - "Inverse Hessian is not positive semi-definite at the provided point, using the closest PSD " |
378 |
| - "matrix in L1-norm instead" |
379 |
| - ) |
380 |
| - |
381 |
| - return mu, H_inv |
382 |
| - |
383 |
| - |
384 |
| -def laplace( |
385 |
| - mu: RaveledVars, |
386 |
| - H_inv: np.ndarray, |
387 |
| - model: pm.Model, |
388 |
| - chains: int = 2, |
389 |
| - draws: int = 500, |
390 |
| - transform_samples: bool = True, |
391 |
| - progressbar: bool = True, |
392 |
| - **compile_kwargs, |
393 |
| -) -> az.InferenceData: |
394 |
| - """ |
395 |
| -
|
396 |
| - Parameters |
397 |
| - ---------- |
398 |
| - mu |
399 |
| - H_inv |
400 |
| - model : Model |
401 |
| - A PyMC model |
402 |
| - chains : int |
403 |
| - The number of sampling chains running in parallel. Default is 2. |
404 |
| - draws : int |
405 |
| - The number of samples to draw from the approximated posterior. Default is 500. |
406 |
| - transform_samples : bool |
407 |
| - Whether to transform the samples back to the original parameter space. Default is True. |
408 |
| -
|
409 |
| - Returns |
410 |
| - ------- |
411 |
| - idata: az.InferenceData |
412 |
| - An InferenceData object containing the approximated posterior samples. |
413 |
| - """ |
414 |
| - posterior_dist = stats.multivariate_normal(mean=mu.data, cov=H_inv, allow_singular=True) |
415 |
| - posterior_draws = posterior_dist.rvs(size=(chains, draws)) |
416 |
| - |
417 |
| - if transform_samples: |
418 |
| - constrained_rvs, unconstrained_vector = _unconstrained_vector_to_constrained_rvs(model) |
419 |
| - batched_values = pt.tensor( |
420 |
| - "batched_values", |
421 |
| - shape=(chains, draws, *unconstrained_vector.type.shape), |
422 |
| - dtype=unconstrained_vector.type.dtype, |
423 |
| - ) |
424 |
| - batched_rvs = pytensor.graph.vectorize_graph( |
425 |
| - constrained_rvs, replace={unconstrained_vector: batched_values} |
426 |
| - ) |
427 |
| - |
428 |
| - f_constrain = pm.compile_pymc( |
429 |
| - inputs=[batched_values], outputs=batched_rvs, **compile_kwargs |
430 |
| - ) |
431 |
| - posterior_draws = f_constrain(posterior_draws) |
432 |
| - |
433 |
| - else: |
434 |
| - info = mu.point_map_info |
435 |
| - flat_shapes = [np.prod(shape).astype(int) for _, shape, _ in info] |
436 |
| - slices = [ |
437 |
| - slice(sum(flat_shapes[:i]), sum(flat_shapes[: i + 1])) for i in range(len(flat_shapes)) |
438 |
| - ] |
439 |
| - |
440 |
| - posterior_draws = [ |
441 |
| - posterior_draws[..., idx].reshape((chains, draws, *shape)).astype(dtype) |
442 |
| - for idx, (name, shape, dtype) in zip(slices, info) |
443 |
| - ] |
444 |
| - |
445 |
| - def make_rv_coords(name): |
446 |
| - coords = {"chain": range(chains), "draw": range(draws)} |
447 |
| - extra_dims = model.named_vars_to_dims.get(name) |
448 |
| - if extra_dims is None: |
449 |
| - return coords |
450 |
| - return coords | {dim: list(model.coords[dim]) for dim in extra_dims} |
451 |
| - |
452 |
| - def make_rv_dims(name): |
453 |
| - dims = ["chain", "draw"] |
454 |
| - extra_dims = model.named_vars_to_dims.get(name) |
455 |
| - if extra_dims is None: |
456 |
| - return dims |
457 |
| - return dims + list(extra_dims) |
458 |
| - |
459 |
| - idata = { |
460 |
| - name: xr.DataArray( |
461 |
| - data=draws.squeeze(), |
462 |
| - coords=make_rv_coords(name), |
463 |
| - dims=make_rv_dims(name), |
464 |
| - name=name, |
465 |
| - ) |
466 |
| - for (name, _, _), draws in zip(mu.point_map_info, posterior_draws) |
467 |
| - } |
468 |
| - |
469 |
| - coords, dims = coords_and_dims_for_inferencedata(model) |
470 |
| - idata = az.convert_to_inference_data(idata, coords=coords, dims=dims) |
471 |
| - |
472 |
| - if model.deterministics: |
473 |
| - idata.posterior = pm.compute_deterministics( |
474 |
| - idata.posterior, |
475 |
| - model=model, |
476 |
| - merge_dataset=True, |
477 |
| - progressbar=progressbar, |
478 |
| - compile_kwargs=compile_kwargs, |
479 |
| - ) |
480 |
| - |
481 |
| - observed_data = dict_to_dataset( |
482 |
| - find_observations(model), |
483 |
| - library=pm, |
484 |
| - coords=coords, |
485 |
| - dims=dims, |
486 |
| - default_dims=[], |
487 |
| - ) |
488 |
| - |
489 |
| - constant_data = dict_to_dataset( |
490 |
| - find_constants(model), |
491 |
| - library=pm, |
492 |
| - coords=coords, |
493 |
| - dims=dims, |
494 |
| - default_dims=[], |
495 |
| - ) |
496 |
| - |
497 |
| - idata.add_groups( |
498 |
| - {"observed_data": observed_data, "constant_data": constant_data}, |
499 |
| - coords=coords, |
500 |
| - dims=dims, |
501 |
| - ) |
502 |
| - |
503 |
| - return idata |
504 |
| - |
505 |
| - |
506 |
| -def fit_laplace( |
507 |
| - optimized_point: dict[str, np.ndarray], |
508 |
| - model: pm.Model, |
509 |
| - chains: int = 2, |
510 |
| - draws: int = 500, |
511 |
| - on_bad_cov: Literal["warn", "error", "ignore"] = "ignore", |
512 |
| - transform_samples: bool = True, |
513 |
| - zero_tol: float = 1e-8, |
514 |
| - diag_jitter: float | None = 1e-8, |
515 |
| - progressbar: bool = True, |
516 |
| - compile_kwargs: dict | None = None, |
517 |
| -) -> az.InferenceData: |
518 |
| - compile_kwargs = {} if compile_kwargs is None else compile_kwargs |
519 |
| - |
520 |
| - mu, H_inv = fit_mvn_to_MAP( |
521 |
| - optimized_point=optimized_point, |
522 |
| - model=model, |
523 |
| - on_bad_cov=on_bad_cov, |
524 |
| - transform_samples=transform_samples, |
525 |
| - zero_tol=zero_tol, |
526 |
| - diag_jitter=diag_jitter, |
527 |
| - compile_kwargs=compile_kwargs, |
528 |
| - ) |
529 |
| - |
530 |
| - return laplace(mu, H_inv, model, chains, draws, transform_samples, progressbar) |
531 |
| - |
532 |
| - |
533 | 297 | def find_MAP(
|
534 | 298 | method: minimize_method,
|
535 | 299 | *,
|
@@ -605,9 +369,12 @@ def find_MAP(
|
605 | 369 | initial_params = DictToArrayBijection.map(
|
606 | 370 | {var_name: value for var_name, value in start_dict.items() if var_name in vars_dict}
|
607 | 371 | )
|
| 372 | + use_grad, use_hess, use_hessp = set_optimizer_function_defaults( |
| 373 | + method, use_grad, use_hess, use_hessp |
| 374 | + ) |
608 | 375 |
|
609 | 376 | f_logp, f_hess, f_hessp = scipy_optimize_funcs_from_loss(
|
610 |
| - loss=-frozen_model.logp(), |
| 377 | + loss=-frozen_model.logp(jacobian=False), |
611 | 378 | inputs=frozen_model.continuous_value_vars + frozen_model.discrete_value_vars,
|
612 | 379 | initial_point_dict=start_dict,
|
613 | 380 | use_grad=use_grad,
|
|
0 commit comments