Skip to content

Commit 6d75a4f

Browse files
committed
fix: remove "none" as a valid importance sampling method
1 parent 3823b81 commit 6d75a4f

File tree

2 files changed

+5
-6
lines changed

2 files changed

+5
-6
lines changed

pymc_extras/inference/pathfinder/importance_sampling.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@ class ImportanceSamplingResult:
2020
samples: NDArray
2121
pareto_k: float | None = None
2222
warnings: list[str] = field(default_factory=list)
23-
method: str = "none"
23+
method: str = "psis"
2424

2525

2626
def importance_sampling(
@@ -82,7 +82,7 @@ def importance_sampling(
8282
warnings.append(
8383
"Importance sampling is disabled. The samples are returned as is which may include samples from failed paths with non-finite logP or logQ values. It is recommended to use importance_sampling='psis' for better stability."
8484
)
85-
return ImportanceSamplingResult(samples=samples, warnings=warnings)
85+
return ImportanceSamplingResult(samples=samples, warnings=warnings, method=method)
8686
else:
8787
samples = samples.reshape(-1, N)
8888
logP = logP.ravel()

pymc_extras/inference/pathfinder/pathfinder.py

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1258,7 +1258,7 @@ def with_warnings(self, warnings: list[str]) -> Self:
12581258
def with_importance_sampling(
12591259
self,
12601260
num_draws: int,
1261-
method: Literal["psis", "psir", "identity", "none"] | None,
1261+
method: Literal["psis", "psir", "identity"] | None,
12621262
random_seed: int | None = None,
12631263
) -> Self:
12641264
"""perform importance sampling"""
@@ -1622,7 +1622,7 @@ def fit_pathfinder(
16221622
num_elbo_draws: int = 10, # K
16231623
jitter: float = 2.0,
16241624
epsilon: float = 1e-8,
1625-
importance_sampling: Literal["psis", "psir", "identity", "none"] | None = "psis",
1625+
importance_sampling: Literal["psis", "psir", "identity"] | None = "psis",
16261626
progressbar: bool = True,
16271627
concurrent: Literal["thread", "process"] | None = None,
16281628
random_seed: RandomSeed | None = None,
@@ -1670,7 +1670,7 @@ def fit_pathfinder(
16701670
"psir" : Pareto Smoothed Importance Resampling
16711671
Less stable than PSIS.
16721672
"identity" : Applies log importance weights directly without resampling.
1673-
None or "none" : No importance sampling weights. Returns raw samples of size (num_paths, num_draws_per_path, N) where N is number of model parameters. Other methods return samples of size (num_draws, N).
1673+
None : No importance sampling weights. Returns raw samples of size (num_paths, num_draws_per_path, N) where N is number of model parameters. Other methods return samples of size (num_draws, N).
16741674
progressbar : bool, optional
16751675
Whether to display a progress bar (default is True). Setting this to False will likely reduce the computation time.
16761676
random_seed : RandomSeed, optional
@@ -1702,7 +1702,6 @@ def fit_pathfinder(
17021702

17031703
if importance_sampling is not None:
17041704
importance_sampling = importance_sampling.lower()
1705-
importance_sampling = None if importance_sampling == "none" else importance_sampling
17061705

17071706
if importance_sampling not in valid_importance_sampling:
17081707
raise ValueError(f"Invalid importance sampling method: {importance_sampling}")

0 commit comments

Comments
 (0)