Skip to content

Commit f63d5ce

Browse files
committed
fix: importance sampling handling causing error when chosen method is "none" or None
- Moved importance sampling logic from `multipath_pathfinder` to `fit_pathfinder` to fix error method is "none" or None - Update docstrings to clarify importance sampling method behavior - Use match statement for method selection in importance_sampling
1 parent ec46270 commit f63d5ce

File tree

2 files changed

+54
-31
lines changed

2 files changed

+54
-31
lines changed

pymc_extras/inference/pathfinder/importance_sampling.py

Lines changed: 21 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,7 @@ def importance_sampling(
2828
logP: NDArray,
2929
logQ: NDArray,
3030
num_draws: int,
31-
method: Literal["psis", "psir", "identity", "none"] | None,
31+
method: Literal["psis", "psir", "identity"] | None,
3232
random_seed: int | None = None,
3333
) -> ImportanceSamplingResult:
3434
"""Pareto Smoothed Importance Resampling (PSIR)
@@ -44,8 +44,15 @@ def importance_sampling(
4444
log probability values of proposal distribution, shape (L, M)
4545
num_draws : int
4646
number of draws to return where num_draws <= samples.shape[0]
47-
method : str, optional
48-
importance sampling method to use. Options are "psis" (default), "psir", "identity", "none. Pareto Smoothed Importance Sampling (psis) is recommended in many cases for more stable results than Pareto Smoothed Importance Resampling (psir). identity applies the log importance weights directly without resampling. none applies no importance sampling weights and returns the samples as is of size num_draws_per_path * num_paths.
47+
method : str, None, optional
48+
Method to apply sampling based on log importance weights (logP - logQ).
49+
Options are:
50+
"psis" : Pareto Smoothed Importance Sampling (default)
51+
Recommended for more stable results.
52+
"psir" : Pareto Smoothed Importance Resampling
53+
Less stable than PSIS.
54+
"identity" : Applies log importance weights directly without resampling.
55+
None : No importance sampling weights. Returns raw samples of size (num_paths, num_draws_per_path, N) where N is number of model parameters. Other methods return samples of size (num_draws, N).
4956
random_seed : int | None
5057
5158
Returns
@@ -71,7 +78,7 @@ def importance_sampling(
7178
warnings = []
7279
num_paths, _, N = samples.shape
7380

74-
if method == "none":
81+
if method is None:
7582
warnings.append(
7683
"Importance sampling is disabled. The samples are returned as is which may include samples from failed paths with non-finite logP or logQ values. It is recommended to use importance_sampling='psis' for better stability."
7784
)
@@ -91,17 +98,16 @@ def importance_sampling(
9198
_warnings.filterwarnings(
9299
"ignore", category=RuntimeWarning, message="overflow encountered in exp"
93100
)
94-
if method == "psis":
95-
replace = False
96-
logiw, pareto_k = az.psislw(logiw)
97-
elif method == "psir":
98-
replace = True
99-
logiw, pareto_k = az.psislw(logiw)
100-
elif method == "identity":
101-
replace = False
102-
pareto_k = None
103-
else:
104-
raise ValueError(f"Invalid importance sampling method: {method}")
101+
match method:
102+
case "psis":
103+
replace = False
104+
logiw, pareto_k = az.psislw(logiw)
105+
case "psir":
106+
replace = True
107+
logiw, pareto_k = az.psislw(logiw)
108+
case "identity":
109+
replace = False
110+
pareto_k = None
105111

106112
# NOTE: Pareto k is normally bad for Pathfinder even when the posterior is close to the NUTS posterior or closer to NUTS than ADVI.
107113
# Pareto k may not be a good diagnostic for Pathfinder.

pymc_extras/inference/pathfinder/pathfinder.py

Lines changed: 33 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -156,7 +156,7 @@ def convert_flat_trace_to_idata(
156156
postprocessing_backend: Literal["cpu", "gpu"] = "cpu",
157157
inference_backend: Literal["pymc", "blackjax"] = "pymc",
158158
model: Model | None = None,
159-
importance_sampling: Literal["psis", "psir", "identity", "none"] = "psis",
159+
importance_sampling: Literal["psis", "psir", "identity"] | None = "psis",
160160
) -> az.InferenceData:
161161
"""convert flattened samples to arviz InferenceData format.
162162
@@ -181,7 +181,7 @@ def convert_flat_trace_to_idata(
181181
arviz inference data object
182182
"""
183183

184-
if importance_sampling == "none":
184+
if importance_sampling is None:
185185
# samples.ndim == 3 in this case, otherwise ndim == 2
186186
num_paths, num_pdraws, N = samples.shape
187187
samples = samples.reshape(-1, N)
@@ -220,7 +220,7 @@ def convert_flat_trace_to_idata(
220220
fn.trust_input = True
221221
result = fn(*list(trace.values()))
222222

223-
if importance_sampling == "none":
223+
if importance_sampling is None:
224224
result = [res.reshape(num_paths, num_pdraws, *res.shape[2:]) for res in result]
225225

226226
elif inference_backend == "blackjax":
@@ -1189,7 +1189,7 @@ class MultiPathfinderResult:
11891189
elbo_argmax: NDArray | None = None
11901190
lbfgs_status: Counter = field(default_factory=Counter)
11911191
path_status: Counter = field(default_factory=Counter)
1192-
importance_sampling: str = "none"
1192+
importance_sampling: str | None = "psis"
11931193
warnings: list[str] = field(default_factory=list)
11941194
pareto_k: float | None = None
11951195

@@ -1424,7 +1424,7 @@ def multipath_pathfinder(
14241424
num_elbo_draws: int,
14251425
jitter: float,
14261426
epsilon: float,
1427-
importance_sampling: Literal["psis", "psir", "identity", "none"] | None,
1427+
importance_sampling: Literal["psis", "psir", "identity"] | None,
14281428
progressbar: bool,
14291429
concurrent: Literal["thread", "process"] | None,
14301430
random_seed: RandomSeed,
@@ -1460,8 +1460,14 @@ def multipath_pathfinder(
14601460
Amount of jitter to apply to initial points (default is 2.0). Note that Pathfinder may be highly sensitive to the jitter value. It is recommended to increase num_paths when increasing the jitter value.
14611461
epsilon: float
14621462
value used to filter out large changes in the direction of the update gradient at each iteration l in L. Iteration l is only accepted if delta_theta[l] * delta_grad[l] > epsilon * L2_norm(delta_grad[l]) for each l in L. (default is 1e-8).
1463-
importance_sampling : str, optional
1464-
importance sampling method to use which applies sampling based on the log importance weights equal to logP - logQ. Options are "psis" (default), "psir", "identity", "none". Pareto Smoothed Importance Sampling (psis) is recommended in many cases for more stable results than Pareto Smoothed Importance Resampling (psir). identity applies the log importance weights directly without resampling. none applies no importance sampling weights and returns the samples as is of size (num_paths, num_draws_per_path, N) where N is the number of model parameters, otherwise sample size is (num_draws, N).
1463+
importance_sampling : str, None, optional
1464+
Method to apply sampling based on log importance weights (logP - logQ).
1465+
"psis" : Pareto Smoothed Importance Sampling (default)
1466+
Recommended for more stable results.
1467+
"psir" : Pareto Smoothed Importance Resampling
1468+
Less stable than PSIS.
1469+
"identity" : Applies log importance weights directly without resampling.
1470+
None : No importance sampling weights. Returns raw samples of size (num_paths, num_draws_per_path, N) where N is number of model parameters. Other methods return samples of size (num_draws, N).
14651471
progressbar : bool, optional
14661472
Whether to display a progress bar (default is False). Setting this to True will likely increase the computation time.
14671473
random_seed : RandomSeed, optional
@@ -1483,12 +1489,6 @@ def multipath_pathfinder(
14831489
The result containing samples and other information from the Multi-Path Pathfinder algorithm.
14841490
"""
14851491

1486-
valid_importance_sampling = ["psis", "psir", "identity", "none", None]
1487-
if importance_sampling is None:
1488-
importance_sampling = "none"
1489-
if importance_sampling.lower() not in valid_importance_sampling:
1490-
raise ValueError(f"Invalid importance sampling method: {importance_sampling}")
1491-
14921492
*path_seeds, choice_seed = _get_seeds_per_chain(random_seed, num_paths + 1)
14931493

14941494
pathfinder_config = PathfinderConfig(
@@ -1622,7 +1622,7 @@ def fit_pathfinder(
16221622
num_elbo_draws: int = 10, # K
16231623
jitter: float = 2.0,
16241624
epsilon: float = 1e-8,
1625-
importance_sampling: Literal["psis", "psir", "identity", "none"] = "psis",
1625+
importance_sampling: Literal["psis", "psir", "identity", "none"] | None = "psis",
16261626
progressbar: bool = True,
16271627
concurrent: Literal["thread", "process"] | None = None,
16281628
random_seed: RandomSeed | None = None,
@@ -1662,8 +1662,15 @@ def fit_pathfinder(
16621662
Amount of jitter to apply to initial points (default is 2.0). Note that Pathfinder may be highly sensitive to the jitter value. It is recommended to increase num_paths when increasing the jitter value.
16631663
epsilon: float
16641664
value used to filter out large changes in the direction of the update gradient at each iteration l in L. Iteration l is only accepted if delta_theta[l] * delta_grad[l] > epsilon * L2_norm(delta_grad[l]) for each l in L. (default is 1e-8).
1665-
importance_sampling : str, optional
1666-
importance sampling method to use which applies sampling based on the log importance weights equal to logP - logQ. Options are "psis" (default), "psir", "identity", "none". Pareto Smoothed Importance Sampling (psis) is recommended in many cases for more stable results than Pareto Smoothed Importance Resampling (psir). identity applies the log importance weights directly without resampling. none applies no importance sampling weights and returns the samples as is of size (num_paths, num_draws_per_path, N) where N is the number of model parameters, otherwise sample size is (num_draws, N).
1665+
importance_sampling : str, None, optional
1666+
Method to apply sampling based on log importance weights (logP - logQ).
1667+
Options are:
1668+
"psis" : Pareto Smoothed Importance Sampling (default)
1669+
Recommended for more stable results.
1670+
"psir" : Pareto Smoothed Importance Resampling
1671+
Less stable than PSIS.
1672+
"identity" : Applies log importance weights directly without resampling.
1673+
None or "none" : No importance sampling weights. Returns raw samples of size (num_paths, num_draws_per_path, N) where N is number of model parameters. Other methods return samples of size (num_draws, N).
16671674
progressbar : bool, optional
16681675
Whether to display a progress bar (default is True). Setting this to False will likely reduce the computation time.
16691676
random_seed : RandomSeed, optional
@@ -1690,6 +1697,16 @@ def fit_pathfinder(
16901697
"""
16911698

16921699
model = modelcontext(model)
1700+
1701+
valid_importance_sampling = {"psis", "psir", "identity", None}
1702+
1703+
if importance_sampling is not None:
1704+
importance_sampling = importance_sampling.lower()
1705+
importance_sampling = None if importance_sampling == "none" else importance_sampling
1706+
1707+
if importance_sampling not in valid_importance_sampling:
1708+
raise ValueError(f"Invalid importance sampling method: {importance_sampling}")
1709+
16931710
N = DictToArrayBijection.map(model.initial_point()).data.shape[0]
16941711

16951712
if maxcor is None:

0 commit comments

Comments
 (0)