Skip to content

sample_and_log_prob _without_ reparameterization #323

@jmacglashan

Description

@jmacglashan

As far as I can tell, if a distribution is capable of reparameterization, it always samples with it. (e.g., Normal).

However, that is not always desirable. For example, if you want to use a REINFORCE style gradient estimator, you need to not differentiate through the samples and only differentiate through the log_prob.

An option to do that is to first call sample, then stop_gradient the sample, and then call log_prob. However, this is not ideal when there are numerical stability issues. (For example, normal dist with a tanh bijector where the mean is heavily saturating the tanh.) Really, we'd like to call sample_and_log_prob for better stability/efficient, except now there is no way to stop grad the sample computation.

I can special case this operation for Transformed Distribution objects by sampling from the underlying dist, stop grading the sample, and then doing the remaining bijector and log_determinant calculations. However, even that is not ideal, because (1) that only handles the Transformed distribution case, and (2) it still requires using a plain sample and log_prob on the underlying distribution which might also benefit from sample_and_log_prob.

Ideally then, it would be nice if we could provide a flag to these methods to specify whether reparameterization is allowed (and it can default to True for consistency with prior version default behavior).

Any thoughts on adding this flag or alternative approaches?

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions