-
Notifications
You must be signed in to change notification settings - Fork 270
Support infer_discrete for Predictive #1086
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
|
|
||
| with substitute(data=data): | ||
| return model(*args, **kwargs) | ||
| return data |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I revise this a bit to use this function in Predictive (and avoid tracing the model 2 times)
| model_trace = prototype_trace | ||
| temperature = 1 | ||
| pred_samples = _sample_posterior( | ||
| config_enumerate(condition(model, samples)), |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'm surprised you're automatically configuring the model for enumeration here. In Pyro we let the user decide which variables are enumerated (though maybe this is too burdensome). That way a guide might sample some discrete latents, and the model might enumerate others. I would expect Predictive to use the guide's samples if provided, and only enumerate sites already marked for enumeration by the user.
I guess if in NumPyro you automatically wrap models with @config_enumerate inside MCMC, then it would also make sense to automatically wrap them here. Still, it would be nice to support SVI with guides that sample some or all discrete latent variables.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
let the user decide which variables are enumerated
We'll aim for this except for algorithms that do not work with discrete latent sites (in this case, it makes sense to enable enumeration by default and raise errors for invalid models). For example, currently, DiscreteHMCGibbs will perform Gibbs update for discrete latent sites that are not marked "enumerated". Then posterior samples will include all latent variables except for those marked "enumerated".
When infer_discrete=True, I assumed that the latent sites that are not available in the posterior_samples or guide are enumerated (those latent sites belong to the samples variable in the above code). So it makes sense to me to config_enumerate them by default (config_enumerate will skip observed sites, including those sites in samples). What do you think? (it seems to not contradict with the usage case in your comment)
fritzo
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM as long as you're ok wrapping models with @config_enumerate.
|
Thanks for reviewing, @fritzo! I thought about it for a while and I think we can merge this and revisit the API later if needed (especially when TraceEnumELBO is available). We can check the model and raise a deprecation warning if |
Follow-up #1043, this PR adds arguments
infer_discreteandtemperaturetoPredictiveto make it easier for users to get samples for discrete latent variables (illustrated in the annotation example).