Skip to content

Conversation

@fehiepsi
Copy link
Member

@fehiepsi fehiepsi commented Jul 8, 2021

Follow-up #1043, this PR adds arguments infer_discrete and temperature to Predictive to make it easier for users to get samples for discrete latent variables (illustrated in the annotation example).

@fehiepsi fehiepsi requested a review from eb8680 July 8, 2021 12:12
@fehiepsi fehiepsi added this to the 0.7 milestone Jul 8, 2021

with substitute(data=data):
return model(*args, **kwargs)
return data
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I revise this a bit to use this function in Predictive (and avoid tracing the model 2 times)

model_trace = prototype_trace
temperature = 1
pred_samples = _sample_posterior(
config_enumerate(condition(model, samples)),
Copy link
Member

@fritzo fritzo Jul 9, 2021

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm surprised you're automatically configuring the model for enumeration here. In Pyro we let the user decide which variables are enumerated (though maybe this is too burdensome). That way a guide might sample some discrete latents, and the model might enumerate others. I would expect Predictive to use the guide's samples if provided, and only enumerate sites already marked for enumeration by the user.

I guess if in NumPyro you automatically wrap models with @config_enumerate inside MCMC, then it would also make sense to automatically wrap them here. Still, it would be nice to support SVI with guides that sample some or all discrete latent variables.

Copy link
Member Author

@fehiepsi fehiepsi Jul 10, 2021

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

let the user decide which variables are enumerated

We'll aim for this except for algorithms that do not work with discrete latent sites (in this case, it makes sense to enable enumeration by default and raise errors for invalid models). For example, currently, DiscreteHMCGibbs will perform Gibbs update for discrete latent sites that are not marked "enumerated". Then posterior samples will include all latent variables except for those marked "enumerated".

When infer_discrete=True, I assumed that the latent sites that are not available in the posterior_samples or guide are enumerated (those latent sites belong to the samples variable in the above code). So it makes sense to me to config_enumerate them by default (config_enumerate will skip observed sites, including those sites in samples). What do you think? (it seems to not contradict with the usage case in your comment)

Copy link
Member

@fritzo fritzo left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM as long as you're ok wrapping models with @config_enumerate.

@fehiepsi
Copy link
Member Author

Thanks for reviewing, @fritzo! I thought about it for a while and I think we can merge this and revisit the API later if needed (especially when TraceEnumELBO is available). We can check the model and raise a deprecation warning if config_enuemerate is required in the future.

@fehiepsi fehiepsi merged commit 003424b into pyro-ppl:master Jul 10, 2021
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants