Skip to content

Commit e4b8996

Browse files
committed
Refactor Pathfinder VI: Default to PSIS, Add Concurrency, and Improved Computational Performance
- Significantly computational efficiency by combining 3 computational graphs into 1 larger compile. Removed non-shared inputs and used with for significant performance gains. - Set default importance sampling method to 'psis' for more stable posterior results, avoiding local peaks seen with 'psir'. - Introduce concurrency options ('thread' and 'process') for multithreading and multiprocessing. Defaults to No concurrency as there haven't been any/or much reduction to the compute time. - Adjusted default from 8 to 4 and from 1.0 to 2.0 and maxcor to max(3*log(N), 5). This default setting lessens computational time and and the degree by which the posterior variance is being underestimated.
1 parent 2815c4f commit e4b8996

File tree

3 files changed

+258
-173
lines changed

3 files changed

+258
-173
lines changed

pymc_experimental/inference/pathfinder/importance_sampling.py

Lines changed: 50 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,8 @@
11
import logging
22
import warnings
33

4+
from typing import Literal
5+
46
import arviz as az
57
import numpy as np
68
import pytensor.tensor as pt
@@ -31,12 +33,13 @@ def perform(self, node: Apply, inputs, outputs) -> None:
3133
outputs[1][0] = pareto_k
3234

3335

34-
def psir(
36+
def importance_sampling(
3537
samples: TensorVariable,
3638
# logP: TensorVariable,
3739
# logQ: TensorVariable,
3840
logiw: TensorVariable,
39-
num_draws: int = 1000,
41+
num_draws: int,
42+
method: Literal["psis", "psir", "identity", "none"],
4043
random_seed: int | None = None,
4144
) -> np.ndarray:
4245
"""Pareto Smoothed Importance Resampling (PSIR)
@@ -52,6 +55,8 @@ def psir(
5255
log probability of proposal distribution
5356
num_draws : int
5457
number of draws to return where num_draws <= samples.shape[0]
58+
method : str, optional
59+
importance sampling method to use. Options are "psis" (default), "psir", "identity", "none. Pareto Smoothed Importance Sampling (psis) is recommended in many cases for more stable results than Pareto Smoothed Importance Resampling (psir). identity applies the log importance weights directly without resampling. none applies no importance sampling weights and returns the samples as is of size num_draws_per_path * num_paths.
5560
random_seed : int | None
5661
5762
Returns
@@ -74,30 +79,50 @@ def psir(
7479
Zhang, L., Carpenter, B., Gelman, A., & Vehtari, A. (2022). Pathfinder: Parallel quasi-Newton variational inference. Journal of Machine Learning Research, 23(306), 1-49.
7580
"""
7681

77-
psislw, pareto_k = PSIS()(logiw)
78-
pareto_k = pareto_k.eval()
79-
if pareto_k < 0.5:
80-
pass
81-
elif 0.5 <= pareto_k < 0.70:
82-
logger.warning(
83-
f"Pareto k value ({pareto_k:.2f}) is between 0.5 and 0.7 which indicates an imperfect approximation however still useful."
84-
)
85-
logger.info("Consider increasing ftol, gtol, maxcor or num_paths.")
86-
elif pareto_k >= 0.7:
82+
if method == "psis":
83+
replace = False
84+
logiw, pareto_k = PSIS()(logiw)
85+
elif method == "psir":
86+
replace = True
87+
logiw, pareto_k = PSIS()(logiw)
88+
elif method == "identity":
89+
replace = False
90+
logiw = logiw
91+
pareto_k = None
92+
elif method == "none":
8793
logger.warning(
88-
f"Pareto k value ({pareto_k:.2f}) exceeds 0.7 which indicates a bad approximation."
89-
)
90-
logger.info(
91-
"Consider increasing ftol, gtol, maxcor, num_paths or reparametrising the model."
92-
)
93-
else:
94-
logger.warning(
95-
f"Received an invalid Pareto k value of {pareto_k:.2f} which indicates the model is seriously flawed."
96-
)
97-
logger.info(
98-
"Consider reparametrising the model all together or ensure the input data are correct."
94+
"importance sampling is disabled. The samples are returned as is which may include samples from failed paths with non-finite logP or logQ values. It is recommended to use importance_sampling='psis' for better stability."
9995
)
96+
return samples
97+
98+
# NOTE: Pareto k is normally bad for Pathfinder even when the posterior is close to the NUTS posterior or closer to NUTS than ADVI.
99+
# Pareto k may not be a good diagnostic for Pathfinder.
100+
if pareto_k is not None:
101+
pareto_k = pareto_k.eval()
102+
if pareto_k < 0.5:
103+
pass
104+
elif 0.5 <= pareto_k < 0.70:
105+
logger.info(
106+
f"Pareto k value ({pareto_k:.2f}) is between 0.5 and 0.7 which indicates an imperfect approximation however still useful."
107+
)
108+
logger.info("Consider increasing ftol, gtol, maxcor or num_paths.")
109+
elif pareto_k >= 0.7:
110+
logger.info(
111+
f"Pareto k value ({pareto_k:.2f}) exceeds 0.7 which indicates a bad approximation."
112+
)
113+
logger.info(
114+
"Consider increasing ftol, gtol, maxcor, num_paths or reparametrising the model."
115+
)
116+
else:
117+
logger.info(
118+
f"Received an invalid Pareto k value of {pareto_k:.2f} which indicates the model is seriously flawed."
119+
)
120+
logger.info(
121+
"Consider reparametrising the model all together or ensure the input data are correct."
122+
)
123+
124+
logger.warning(f"Pareto k value: {pareto_k:.2f}")
100125

101-
p = pt.exp(psislw - pt.logsumexp(psislw)).eval()
126+
p = pt.exp(logiw - pt.logsumexp(logiw)).eval()
102127
rng = np.random.default_rng(random_seed)
103-
return rng.choice(samples, size=num_draws, replace=True, p=p, shuffle=False, axis=0)
128+
return rng.choice(samples, size=num_draws, replace=replace, p=p, shuffle=False, axis=0)

pymc_experimental/inference/pathfinder/lbfgs.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -69,6 +69,8 @@ def __init__(self, message=None):
6969

7070

7171
class LBFGSOp(Op):
72+
__props__ = ("fn", "grad_fn", "maxcor", "maxiter", "ftol", "gtol", "maxls")
73+
7274
def __init__(self, fn, grad_fn, maxcor, maxiter=1000, ftol=1e-5, gtol=1e-8, maxls=1000):
7375
self.fn = fn
7476
self.grad_fn = grad_fn

0 commit comments

Comments
 (0)