Skip to content

Commit ba85587

Browse files
committed
Display summary of results, Improve error handling, General improvements
Changes: - Add rich table summary display for results - Added PathStatus and LBFGSStatus for error handling, status tracking and displaying results - Changed importance_sampling return type to ImportanceSamplingResult - Changed multipath_pathfinder return type to MultiPathfinderResult - Added dataclass containers for results (ImportanceSamplingResult, PathfinderResult, MultiPathfinderResult) - Refactored LBFGS by removing PyTensor Op classes in favor of pure functions - Added timing and configuration tracking - Improve concurrency with better error handling - Improved docstrings and type hints - Simplified logp and gradient computation by combining into single function - Added compile_kwargs parameter for pytensor compilation options
1 parent 885afaa commit ba85587

File tree

4 files changed

+1093
-314
lines changed

4 files changed

+1093
-314
lines changed
Lines changed: 72 additions & 75 deletions
Original file line numberDiff line numberDiff line change
@@ -1,57 +1,47 @@
11
import logging
22
import warnings
33

4+
from dataclasses import dataclass, field
45
from typing import Literal
56

67
import arviz as az
78
import numpy as np
8-
import pytensor.tensor as pt
99

10-
from pytensor.graph import Apply, Op
10+
from numpy.typing import NDArray
11+
from scipy.special import logsumexp
1112

1213
logger = logging.getLogger(__name__)
1314

1415

15-
class PSIS(Op):
16-
__props__ = ()
16+
@dataclass(frozen=True)
17+
class ImportanceSamplingResult:
18+
"""container for importance sampling results"""
1719

18-
def make_node(self, inputs):
19-
logweights = pt.as_tensor(inputs)
20-
psislw = pt.dvector()
21-
pareto_k = pt.dscalar()
22-
return Apply(self, [logweights], [psislw, pareto_k])
23-
24-
def perform(self, node: Apply, inputs, outputs) -> None:
25-
with warnings.catch_warnings():
26-
warnings.filterwarnings(
27-
"ignore", category=RuntimeWarning, message="overflow encountered in exp"
28-
)
29-
logweights = inputs[0]
30-
psislw, pareto_k = az.psislw(logweights)
31-
outputs[0][0] = psislw
32-
outputs[1][0] = pareto_k
20+
samples: NDArray
21+
pareto_k: float | None = None
22+
warnings: list[str] = field(default_factory=list)
23+
method: str = "none"
3324

3425

3526
def importance_sampling(
36-
samples: np.ndarray,
37-
logP: np.ndarray,
38-
logQ: np.ndarray,
27+
samples: NDArray,
28+
logP: NDArray,
29+
logQ: NDArray,
3930
num_draws: int,
40-
method: Literal["psis", "psir", "identity", "none"],
41-
logiw: np.ndarray | None = None,
31+
method: Literal["psis", "psir", "identity", "none"] | None,
4232
random_seed: int | None = None,
43-
) -> np.ndarray:
33+
):
4434
"""Pareto Smoothed Importance Resampling (PSIR)
4535
This implements the Pareto Smooth Importance Resampling (PSIR) method, as described in Algorithm 5 of Zhang et al. (2022). The PSIR follows a similar approach to Algorithm 1 PSIS diagnostic from Yao et al., (2018). However, before computing the the importance ratio r_s, the logP and logQ are adjusted to account for the number multiple estimators (or paths). The process involves resampling from the original sample with replacement, with probabilities proportional to the computed importance weights from PSIS.
4636
4737
Parameters
4838
----------
49-
samples : np.ndarray
50-
samples from proposal distribution
51-
logP : np.ndarray
52-
log probability of target distribution
53-
logQ : np.ndarray
54-
log probability of proposal distribution
39+
samples : NDArray
40+
samples from proposal distribution, shape (L, M, N)
41+
logP : NDArray
42+
log probability values of target distribution, shape (L, M)
43+
logQ : NDArray
44+
log probability values of proposal distribution, shape (L, M)
5545
num_draws : int
5646
number of draws to return where num_draws <= samples.shape[0]
5747
method : str, optional
@@ -60,7 +50,7 @@ def importance_sampling(
6050
6151
Returns
6252
-------
63-
np.ndarray
53+
NDArray
6454
importance sampled draws
6555
6656
Future work!
@@ -78,13 +68,14 @@ def importance_sampling(
7868
Zhang, L., Carpenter, B., Gelman, A., & Vehtari, A. (2022). Pathfinder: Parallel quasi-Newton variational inference. Journal of Machine Learning Research, 23(306), 1-49.
7969
"""
8070

81-
num_paths, num_pdraws, N = samples.shape
71+
warning_msgs = []
72+
num_paths, _, N = samples.shape
8273

8374
if method == "none":
84-
logger.warning(
85-
"importance sampling is disabled. The samples are returned as is which may include samples from failed paths with non-finite logP or logQ values. It is recommended to use importance_sampling='psis' for better stability."
75+
warning_msgs.append(
76+
"Importance sampling is disabled. The samples are returned as is which may include samples from failed paths with non-finite logP or logQ values. It is recommended to use importance_sampling='psis' for better stability."
8677
)
87-
return samples
78+
return ImportanceSamplingResult(samples=samples, warnings=warning_msgs)
8879
else:
8980
samples = samples.reshape(-1, N)
9081
logP = logP.ravel()
@@ -96,47 +87,53 @@ def importance_sampling(
9687
logQ -= log_I
9788
logiw = logP - logQ
9889

99-
if method == "psis":
100-
replace = False
101-
logiw, pareto_k = PSIS()(logiw)
102-
elif method == "psir":
103-
replace = True
104-
logiw, pareto_k = PSIS()(logiw)
105-
elif method == "identity":
106-
replace = False
107-
logiw = logiw
108-
pareto_k = None
109-
else:
110-
raise ValueError(f"Invalid importance sampling method: {method}")
90+
with warnings.catch_warnings():
91+
warnings.filterwarnings(
92+
"ignore", category=RuntimeWarning, message="overflow encountered in exp"
93+
)
94+
if method == "psis":
95+
replace = False
96+
logiw, pareto_k = az.psislw(logiw)
97+
elif method == "psir":
98+
replace = True
99+
logiw, pareto_k = az.psislw(logiw)
100+
elif method == "identity":
101+
replace = False
102+
pareto_k = None
103+
else:
104+
raise ValueError(f"Invalid importance sampling method: {method}")
111105

112106
# NOTE: Pareto k is normally bad for Pathfinder even when the posterior is close to the NUTS posterior or closer to NUTS than ADVI.
113107
# Pareto k may not be a good diagnostic for Pathfinder.
114-
if pareto_k is not None:
115-
pareto_k = pareto_k.eval()
116-
if pareto_k < 0.5:
117-
pass
118-
elif 0.5 <= pareto_k < 0.70:
119-
logger.info(
120-
f"Pareto k value ({pareto_k:.2f}) is between 0.5 and 0.7 which indicates an imperfect approximation however still useful."
121-
)
122-
logger.info("Consider increasing ftol, gtol, maxcor or num_paths.")
123-
elif pareto_k >= 0.7:
124-
logger.info(
125-
f"Pareto k value ({pareto_k:.2f}) exceeds 0.7 which indicates a bad approximation."
126-
)
127-
logger.info(
128-
"Consider increasing ftol, gtol, maxcor, num_paths or reparametrising the model."
129-
)
130-
else:
131-
logger.info(
132-
f"Received an invalid Pareto k value of {pareto_k:.2f} which indicates the model is seriously flawed."
133-
)
134-
logger.info(
135-
"Consider reparametrising the model all together or ensure the input data are correct."
136-
)
108+
# TODO: Find replacement diagnostics for Pathfinder.
137109

138-
logger.warning(f"Pareto k value: {pareto_k:.2f}")
139-
140-
p = pt.exp(logiw - pt.logsumexp(logiw)).eval()
110+
p = np.exp(logiw - logsumexp(logiw))
141111
rng = np.random.default_rng(random_seed)
142-
return rng.choice(samples, size=num_draws, replace=replace, p=p, shuffle=False, axis=0)
112+
113+
try:
114+
resampled = rng.choice(samples, size=num_draws, replace=replace, p=p, shuffle=False, axis=0)
115+
return ImportanceSamplingResult(
116+
samples=resampled, pareto_k=pareto_k, warnings=warning_msgs, method=method
117+
)
118+
except ValueError as e1:
119+
if "Fewer non-zero entries in p than size" in str(e1):
120+
num_nonzero = np.where(np.nonzero(p)[0], 1, 0).sum()
121+
warning_msgs.append(
122+
f"Not enough valid samples: {num_nonzero} available out of {num_draws} requested. Switching to psir importance sampling."
123+
)
124+
try:
125+
resampled = rng.choice(
126+
samples, size=num_draws, replace=True, p=p, shuffle=False, axis=0
127+
)
128+
return ImportanceSamplingResult(
129+
samples=resampled, pareto_k=pareto_k, warnings=warning_msgs, method=method
130+
)
131+
except ValueError as e2:
132+
logger.error(
133+
"Importance sampling failed even with psir importance sampling. "
134+
"This might indicate invalid probability weights or insufficient valid samples."
135+
)
136+
raise ValueError(
137+
"Importance sampling failed with both with and without replacement"
138+
) from e2
139+
raise

0 commit comments

Comments
 (0)