You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
Refactor Pathfinder VI: Default to PSIS, Add Concurrency, and Improved Computational Performance
- Significantly computational efficiency by combining 3 computational graphs into 1 larger compile. Removed non-shared inputs and used with for significant performance gains.
- Set default importance sampling method to 'psis' for more stable posterior results, avoiding local peaks seen with 'psir'.
- Introduce concurrency options ('thread' and 'process') for multithreading and multiprocessing. Defaults to No concurrency as there haven't been any/or much reduction to the compute time.
- Adjusted default from 8 to 4 and from 1.0 to 2.0 and maxcor to max(3*log(N), 5). This default setting lessens computational time and and the degree by which the posterior variance is being underestimated.
number of draws to return where num_draws <= samples.shape[0]
58
+
method : str, optional
59
+
importance sampling method to use. Options are "psis" (default), "psir", "identity", "none. Pareto Smoothed Importance Sampling (psis) is recommended in many cases for more stable results than Pareto Smoothed Importance Resampling (psir). identity applies the log importance weights directly without resampling. none applies no importance sampling weights and returns the samples as is of size num_draws_per_path * num_paths.
55
60
random_seed : int | None
56
61
57
62
Returns
@@ -74,30 +79,50 @@ def psir(
74
79
Zhang, L., Carpenter, B., Gelman, A., & Vehtari, A. (2022). Pathfinder: Parallel quasi-Newton variational inference. Journal of Machine Learning Research, 23(306), 1-49.
75
80
"""
76
81
77
-
psislw, pareto_k=PSIS()(logiw)
78
-
pareto_k=pareto_k.eval()
79
-
ifpareto_k<0.5:
80
-
pass
81
-
elif0.5<=pareto_k<0.70:
82
-
logger.warning(
83
-
f"Pareto k value ({pareto_k:.2f}) is between 0.5 and 0.7 which indicates an imperfect approximation however still useful."
84
-
)
85
-
logger.info("Consider increasing ftol, gtol, maxcor or num_paths.")
86
-
elifpareto_k>=0.7:
82
+
ifmethod=="psis":
83
+
replace=False
84
+
logiw, pareto_k=PSIS()(logiw)
85
+
elifmethod=="psir":
86
+
replace=True
87
+
logiw, pareto_k=PSIS()(logiw)
88
+
elifmethod=="identity":
89
+
replace=False
90
+
logiw=logiw
91
+
pareto_k=None
92
+
elifmethod=="none":
87
93
logger.warning(
88
-
f"Pareto k value ({pareto_k:.2f}) exceeds 0.7 which indicates a bad approximation."
89
-
)
90
-
logger.info(
91
-
"Consider increasing ftol, gtol, maxcor, num_paths or reparametrising the model."
92
-
)
93
-
else:
94
-
logger.warning(
95
-
f"Received an invalid Pareto k value of {pareto_k:.2f} which indicates the model is seriously flawed."
96
-
)
97
-
logger.info(
98
-
"Consider reparametrising the model all together or ensure the input data are correct."
94
+
"importance sampling is disabled. The samples are returned as is which may include samples from failed paths with non-finite logP or logQ values. It is recommended to use importance_sampling='psis' for better stability."
99
95
)
96
+
returnsamples
97
+
98
+
# NOTE: Pareto k is normally bad for Pathfinder even when the posterior is close to the NUTS posterior or closer to NUTS than ADVI.
99
+
# Pareto k may not be a good diagnostic for Pathfinder.
100
+
ifpareto_kisnotNone:
101
+
pareto_k=pareto_k.eval()
102
+
ifpareto_k<0.5:
103
+
pass
104
+
elif0.5<=pareto_k<0.70:
105
+
logger.info(
106
+
f"Pareto k value ({pareto_k:.2f}) is between 0.5 and 0.7 which indicates an imperfect approximation however still useful."
107
+
)
108
+
logger.info("Consider increasing ftol, gtol, maxcor or num_paths.")
109
+
elifpareto_k>=0.7:
110
+
logger.info(
111
+
f"Pareto k value ({pareto_k:.2f}) exceeds 0.7 which indicates a bad approximation."
112
+
)
113
+
logger.info(
114
+
"Consider increasing ftol, gtol, maxcor, num_paths or reparametrising the model."
115
+
)
116
+
else:
117
+
logger.info(
118
+
f"Received an invalid Pareto k value of {pareto_k:.2f} which indicates the model is seriously flawed."
119
+
)
120
+
logger.info(
121
+
"Consider reparametrising the model all together or ensure the input data are correct."
0 commit comments