Skip to content

Commit 862627e

Browse files
committed
Improve pathfinder error handling and type hints
- Add proper type hints throughout pathfinder module - Improve error handling in concurrent execution paths - Better handling of when all paths are fail by displaying results before Assertion - Changed Australian English spelling to US - Update compile_pymc usage to handle deprecation warning - Add tests for concurrent execution and seed reproducibility - Clean up imports and remove redundant code - Improve docstrings and error messages
1 parent baad3d9 commit 862627e

File tree

4 files changed

+249
-139
lines changed

4 files changed

+249
-139
lines changed

pymc_extras/inference/pathfinder/importance_sampling.py

Lines changed: 13 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
import logging
2-
import warnings
2+
import warnings as _warnings
33

44
from dataclasses import dataclass, field
55
from typing import Literal
@@ -30,7 +30,7 @@ def importance_sampling(
3030
num_draws: int,
3131
method: Literal["psis", "psir", "identity", "none"] | None,
3232
random_seed: int | None = None,
33-
):
33+
) -> ImportanceSamplingResult:
3434
"""Pareto Smoothed Importance Resampling (PSIR)
3535
This implements the Pareto Smooth Importance Resampling (PSIR) method, as described in Algorithm 5 of Zhang et al. (2022). The PSIR follows a similar approach to Algorithm 1 PSIS diagnostic from Yao et al., (2018). However, before computing the the importance ratio r_s, the logP and logQ are adjusted to account for the number multiple estimators (or paths). The process involves resampling from the original sample with replacement, with probabilities proportional to the computed importance weights from PSIS.
3636
@@ -50,8 +50,8 @@ def importance_sampling(
5050
5151
Returns
5252
-------
53-
NDArray
54-
importance sampled draws
53+
ImportanceSamplingResult
54+
importance sampled draws and other info based on the specified method
5555
5656
Future work!
5757
----------
@@ -68,14 +68,14 @@ def importance_sampling(
6868
Zhang, L., Carpenter, B., Gelman, A., & Vehtari, A. (2022). Pathfinder: Parallel quasi-Newton variational inference. Journal of Machine Learning Research, 23(306), 1-49.
6969
"""
7070

71-
warning_msgs = []
71+
warnings = []
7272
num_paths, _, N = samples.shape
7373

7474
if method == "none":
75-
warning_msgs.append(
75+
warnings.append(
7676
"Importance sampling is disabled. The samples are returned as is which may include samples from failed paths with non-finite logP or logQ values. It is recommended to use importance_sampling='psis' for better stability."
7777
)
78-
return ImportanceSamplingResult(samples=samples, warnings=warning_msgs)
78+
return ImportanceSamplingResult(samples=samples, warnings=warnings)
7979
else:
8080
samples = samples.reshape(-1, N)
8181
logP = logP.ravel()
@@ -87,8 +87,8 @@ def importance_sampling(
8787
logQ -= log_I
8888
logiw = logP - logQ
8989

90-
with warnings.catch_warnings():
91-
warnings.filterwarnings(
90+
with _warnings.catch_warnings():
91+
_warnings.filterwarnings(
9292
"ignore", category=RuntimeWarning, message="overflow encountered in exp"
9393
)
9494
if method == "psis":
@@ -113,27 +113,27 @@ def importance_sampling(
113113
try:
114114
resampled = rng.choice(samples, size=num_draws, replace=replace, p=p, shuffle=False, axis=0)
115115
return ImportanceSamplingResult(
116-
samples=resampled, pareto_k=pareto_k, warnings=warning_msgs, method=method
116+
samples=resampled, pareto_k=pareto_k, warnings=warnings, method=method
117117
)
118118
except ValueError as e1:
119119
if "Fewer non-zero entries in p than size" in str(e1):
120120
num_nonzero = np.where(np.nonzero(p)[0], 1, 0).sum()
121-
warning_msgs.append(
121+
warnings.append(
122122
f"Not enough valid samples: {num_nonzero} available out of {num_draws} requested. Switching to psir importance sampling."
123123
)
124124
try:
125125
resampled = rng.choice(
126126
samples, size=num_draws, replace=True, p=p, shuffle=False, axis=0
127127
)
128128
return ImportanceSamplingResult(
129-
samples=resampled, pareto_k=pareto_k, warnings=warning_msgs, method=method
129+
samples=resampled, pareto_k=pareto_k, warnings=warnings, method=method
130130
)
131131
except ValueError as e2:
132132
logger.error(
133133
"Importance sampling failed even with psir importance sampling. "
134134
"This might indicate invalid probability weights or insufficient valid samples."
135135
)
136136
raise ValueError(
137-
"Importance sampling failed with both with and without replacement"
137+
"Importance sampling failed for both with and without replacement"
138138
) from e2
139139
raise

pymc_extras/inference/pathfinder/lbfgs.py

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -106,7 +106,7 @@ def __init__(self, message=None):
106106

107107

108108
class LBFGS:
109-
"""L-BFGS optimiser wrapper around scipy's implementation.
109+
"""L-BFGS optimizer wrapper around scipy's implementation.
110110
111111
Parameters
112112
----------
@@ -124,16 +124,18 @@ class LBFGS:
124124
maximum number of line search steps, defaults to 1000
125125
"""
126126

127-
def __init__(self, value_grad_fn, maxcor, maxiter=1000, ftol=1e-5, gtol=1e-8, maxls=1000):
127+
def __init__(
128+
self, value_grad_fn, maxcor, maxiter=1000, ftol=1e-5, gtol=1e-8, maxls=1000
129+
) -> None:
128130
self.value_grad_fn = value_grad_fn
129131
self.maxcor = maxcor
130132
self.maxiter = maxiter
131133
self.ftol = ftol
132134
self.gtol = gtol
133135
self.maxls = maxls
134136

135-
def minimise(self, x0):
136-
"""minimises objective function starting from initial position.
137+
def minimize(self, x0) -> tuple[NDArray, NDArray, int, LBFGSStatus]:
138+
"""minimizes objective function starting from initial position.
137139
138140
Parameters
139141
----------

0 commit comments

Comments
 (0)