-
Notifications
You must be signed in to change notification settings - Fork 38
Description
As far as I can tell, if a distribution is capable of reparameterization, it always samples with it. (e.g., Normal).
However, that is not always desirable. For example, if you want to use a REINFORCE style gradient estimator, you need to not differentiate through the samples and only differentiate through the log_prob.
An option to do that is to first call sample, then stop_gradient the sample, and then call log_prob. However, this is not ideal when there are numerical stability issues. (For example, normal dist with a tanh bijector where the mean is heavily saturating the tanh.) Really, we'd like to call sample_and_log_prob for better stability/efficient, except now there is no way to stop grad the sample computation.
I can special case this operation for Transformed Distribution objects by sampling from the underlying dist, stop grading the sample, and then doing the remaining bijector and log_determinant calculations. However, even that is not ideal, because (1) that only handles the Transformed distribution case, and (2) it still requires using a plain sample and log_prob on the underlying distribution which might also benefit from sample_and_log_prob.
Ideally then, it would be nice if we could provide a flag to these methods to specify whether reparameterization is allowed (and it can default to True for consistency with prior version default behavior).
Any thoughts on adding this flag or alternative approaches?