-
Notifications
You must be signed in to change notification settings - Fork 275
Description
Feature Summary
#2139 added the PSIS diagnostic from the "Yes, but Did It Work?: Evaluating Variational Inference" paper Yao et al. (2018). That paper actually proposes two complementary diagnostic metrics for evaluating variational inference that can be used together. The other is the variational simulation-based calibration (VSBC). See section 3 of Yao et al. (2018). I'd like to add this second diagnostic metric.
Why is this needed?
The paper recommends using both diagnostics together because they catch different failure modes. PSIS evaluates the full joint posterior for a fixed dataset, while VSBC evaluates whether marginal point estimates are systematically biased on average across data generated from the model. From section 3 in Yao et al. (2018): "... while the VI posterior may be a poor approximation to the full posterior, point estimates that are derived from it may still have good statistical properties." So PSIS assess "the quality of the entire variational posterior for a particular data set" while VSBC "assess the average bias of a point estimate
produced under correct model specification."
There are a few design questions I'd want input on before submitting a PR.
-
The algorithm for VSBC outlined in Section 3 Algorithm 2 requires us to simulate M >1 data sets,
${ y_j}_{j=1}^M$ , and for each of these data sets, construct a variational approximation to$p(\theta \mid y_j)$ and compute the marginal calibration probability. Obviously this is computationally expensive but we could add an arguement to run these in parallel? -
Should the function accept a pre-configured SVI object or take the model/guide/optimizer/loss separately?
-
Should this live alongside psis_diagnostic in numpyro/infer/importance.py or in a separate module since VSBC doesn't involve importance sampling?