Skip to content

Commit 885afaa

Browse files
committed
Improvements to Importance Sampling and InferenceData shape
- Handle different importance sampling methods for reshaping and adjusting log densities. - Modified to return InferenceData with chain dim of size num_paths when
1 parent e4b8996 commit 885afaa

File tree

3 files changed

+53
-34
lines changed

3 files changed

+53
-34
lines changed

pymc_experimental/inference/pathfinder/importance_sampling.py

Lines changed: 31 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,6 @@
88
import pytensor.tensor as pt
99

1010
from pytensor.graph import Apply, Op
11-
from pytensor.tensor.variable import TensorVariable
1211

1312
logger = logging.getLogger(__name__)
1413

@@ -34,12 +33,12 @@ def perform(self, node: Apply, inputs, outputs) -> None:
3433

3534

3635
def importance_sampling(
37-
samples: TensorVariable,
38-
# logP: TensorVariable,
39-
# logQ: TensorVariable,
40-
logiw: TensorVariable,
36+
samples: np.ndarray,
37+
logP: np.ndarray,
38+
logQ: np.ndarray,
4139
num_draws: int,
4240
method: Literal["psis", "psir", "identity", "none"],
41+
logiw: np.ndarray | None = None,
4342
random_seed: int | None = None,
4443
) -> np.ndarray:
4544
"""Pareto Smoothed Importance Resampling (PSIR)
@@ -79,21 +78,36 @@ def importance_sampling(
7978
Zhang, L., Carpenter, B., Gelman, A., & Vehtari, A. (2022). Pathfinder: Parallel quasi-Newton variational inference. Journal of Machine Learning Research, 23(306), 1-49.
8079
"""
8180

82-
if method == "psis":
83-
replace = False
84-
logiw, pareto_k = PSIS()(logiw)
85-
elif method == "psir":
86-
replace = True
87-
logiw, pareto_k = PSIS()(logiw)
88-
elif method == "identity":
89-
replace = False
90-
logiw = logiw
91-
pareto_k = None
92-
elif method == "none":
81+
num_paths, num_pdraws, N = samples.shape
82+
83+
if method == "none":
9384
logger.warning(
9485
"importance sampling is disabled. The samples are returned as is which may include samples from failed paths with non-finite logP or logQ values. It is recommended to use importance_sampling='psis' for better stability."
9586
)
9687
return samples
88+
else:
89+
samples = samples.reshape(-1, N)
90+
logP = logP.ravel()
91+
logQ = logQ.ravel()
92+
93+
# adjust log densities
94+
log_I = np.log(num_paths)
95+
logP -= log_I
96+
logQ -= log_I
97+
logiw = logP - logQ
98+
99+
if method == "psis":
100+
replace = False
101+
logiw, pareto_k = PSIS()(logiw)
102+
elif method == "psir":
103+
replace = True
104+
logiw, pareto_k = PSIS()(logiw)
105+
elif method == "identity":
106+
replace = False
107+
logiw = logiw
108+
pareto_k = None
109+
else:
110+
raise ValueError(f"Invalid importance sampling method: {method}")
97111

98112
# NOTE: Pareto k is normally bad for Pathfinder even when the posterior is close to the NUTS posterior or closer to NUTS than ADVI.
99113
# Pareto k may not be a good diagnostic for Pathfinder.
@@ -121,7 +135,7 @@ def importance_sampling(
121135
"Consider reparametrising the model all together or ensure the input data are correct."
122136
)
123137

124-
logger.warning(f"Pareto k value: {pareto_k:.2f}")
138+
logger.warning(f"Pareto k value: {pareto_k:.2f}")
125139

126140
p = pt.exp(logiw - pt.logsumexp(logiw)).eval()
127141
rng = np.random.default_rng(random_seed)

pymc_experimental/inference/pathfinder/pathfinder.py

Lines changed: 15 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -118,7 +118,13 @@ def convert_flat_trace_to_idata(
118118
postprocessing_backend="cpu",
119119
inference_backend="pymc",
120120
model=None,
121+
importance_sampling: Literal["psis", "psir", "identity", "none"] = "psis",
121122
):
123+
if importance_sampling == "none":
124+
# samples.ndim == 3 in this case, otherwise ndim == 2
125+
num_paths, num_pdraws, N = samples.shape
126+
samples = samples.reshape(-1, N)
127+
122128
model = modelcontext(model)
123129
ip = model.initial_point()
124130
ip_point_map_info = DictToArrayBijection.map(ip).point_map_info
@@ -152,6 +158,10 @@ def convert_flat_trace_to_idata(
152158
)
153159
fn.trust_input = True
154160
result = fn(*list(trace.values()))
161+
162+
if importance_sampling == "none":
163+
result = [res.reshape(num_paths, num_pdraws, *res.shape[2:]) for res in result]
164+
155165
elif inference_backend == "blackjax":
156166
jax_fn = get_jaxified_graph(inputs=model.value_vars, outputs=vars_to_sample)
157167
result = jax.vmap(jax.vmap(jax_fn))(
@@ -731,7 +741,6 @@ def multipath_pathfinder(
731741
**pathfinder_kwargs,
732742
):
733743
*path_seeds, choice_seed = _get_seeds_per_chain(random_seed, num_paths + 1)
734-
N = DictToArrayBijection.map(model.initial_point()).data.shape[0]
735744

736745
single_pathfinder_fn = make_single_pathfinder_fn(
737746
model,
@@ -808,19 +817,11 @@ def multipath_pathfinder(
808817
logP = np.concatenate(logP)
809818
logQ = np.concatenate(logQ)
810819

811-
samples = samples.reshape(-1, N)
812-
logP = logP.ravel()
813-
logQ = logQ.ravel()
814-
815-
# adjust log densities
816-
log_I = np.log(num_paths)
817-
logP -= log_I
818-
logQ -= log_I
819-
logiw = logP - logQ
820-
821820
return _importance_sampling(
822821
samples=samples,
823-
logiw=logiw,
822+
logP=logP,
823+
logQ=logQ,
824+
# logiw=logiw,
824825
num_draws=num_draws,
825826
method=importance_sampling,
826827
random_seed=choice_seed,
@@ -881,7 +882,7 @@ def fit_pathfinder(
881882
epsilon: float
882883
value used to filter out large changes in the direction of the update gradient at each iteration l in L. Iteration l is only accepted if delta_theta[l] * delta_grad[l] > epsilon * L2_norm(delta_grad[l]) for each l in L. (default is 1e-8).
883884
importance_sampling : str, optional
884-
importance sampling method to use. Options are "psis" (default), "psir", "identity", "none. Pareto Smoothed Importance Sampling (psis) is recommended in many cases for more stable results than Pareto Smoothed Importance Resampling (psir). identity applies the log importance weights directly without resampling. none applies no importance sampling weights and returns the samples as is of size num_draws_per_path * num_paths.
885+
importance sampling method to use which applies sampling based on the log importance weights equal to logP - logQ. Options are "psis" (default), "psir", "identity", "none". Pareto Smoothed Importance Sampling (psis) is recommended in many cases for more stable results than Pareto Smoothed Importance Resampling (psir). identity applies the log importance weights directly without resampling. none applies no importance sampling weights and returns the samples as is of size (num_paths, num_draws_per_path, N) where N is the number of model parameters, otherwise sample size is (num_draws, N).
885886
progressbar : bool, optional
886887
Whether to display a progress bar (default is False). Setting this to True will likely increase the computation time.
887888
random_seed : RandomSeed, optional
@@ -974,5 +975,6 @@ def fit_pathfinder(
974975
postprocessing_backend=postprocessing_backend,
975976
inference_backend=inference_backend,
976977
model=model,
978+
importance_sampling=importance_sampling,
977979
)
978980
return idata

tests/test_pathfinder.py

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -45,21 +45,22 @@ def test_pathfinder(inference_backend):
4545
with model:
4646
idata = pmx.fit(
4747
method="pathfinder",
48-
num_paths=20,
48+
num_paths=50,
49+
jitter=10.0,
4950
random_seed=41,
5051
inference_backend=inference_backend,
5152
)
5253

5354
assert idata.posterior["mu"].shape == (1, 1000)
5455
assert idata.posterior["tau"].shape == (1, 1000)
5556
assert idata.posterior["theta"].shape == (1, 1000, 8)
56-
# NOTE: Pathfinder tends to return means around 7 and tau around 0.58. So need to increase atol by a large amount.
5757
if inference_backend == "pymc":
58-
np.testing.assert_allclose(idata.posterior["mu"].mean(), 5.0, atol=2.5)
59-
np.testing.assert_allclose(idata.posterior["tau"].mean(), 4.15, atol=3.8)
58+
np.testing.assert_allclose(idata.posterior["mu"].mean(), 5.0, atol=1.6)
59+
np.testing.assert_allclose(idata.posterior["tau"].mean(), 4.15, atol=1.5)
6060

6161

6262
def test_bfgs_sample():
63+
import pytensor
6364
import pytensor.tensor as pt
6465

6566
from pymc_experimental.inference.pathfinder.pathfinder import (
@@ -73,6 +74,7 @@ def test_bfgs_sample():
7374
L = Lp1 - 1
7475
J = 6
7576
num_samples = 1000
77+
rng = pytensor.shared(np.random.default_rng(42), name="rng")
7678

7779
# mock data
7880
x_data = np.random.randn(Lp1, N)
@@ -90,6 +92,7 @@ def test_bfgs_sample():
9092

9193
# sample
9294
phi, logq = bfgs_sample(
95+
rng=rng,
9396
num_samples=num_samples,
9497
x=x,
9598
g=g,

0 commit comments

Comments
 (0)