Skip to content

Computing pointwise log-likelihoods without factorizing the likelihood #1038

@sethaxen

Description

@sethaxen

Currently pointwise_loglikelihoods determines the "points" based on a ~ expression. So if we have a fully factorized log-likelihood expression like

for i in eachindex(y, μ)
    y[i] ~ Normal(μ[i], σ)
end

then we end up with one point for each value of i. However, if we write the model in a "non-factorized" (but still technically factorizable form) as

y ~ product_distribution(Normal.(μ, σ))

Then each pointwise log-likelihood is the joint log-likelihood of all entries on y.
I don't think there's any context in which this is the behavior we want (for this, we have likelihood), but I think this is the correct behavior in general (I'll say why later). However, for most users this requires they have two different implementations of the model: one for sampling (using product_distribution or array_dist) and one for pointwise log-likelihoods (using the explicit loop). The primary use case I know of for pointwise log-likelihoods is model comparison, e.g. with LOO or LOGO.

I propose a method like pointwise_conditional_loglikelihoods(dist, y) for array-variate distribution dist and array y that returns an an array log_like with the same shape as y where log_like[i] contains $\log p(y_i | y_{-i}, \theta)$, where $y_{-i}$ is all elements of y except y[i], and $\theta$ are the parameters of dist. This is precisely what we need for LOO, and for factorizable distributions (e.g. DiagNormal and ProductDistribution) is the same as the pointwise log-likelihoods. However, it also generalizes to non-factorizable distributions (the only examples in Distributions that would ever need to be supported are MvNormal, MvNormalCanon, MvLogNormal, GenericMvTDist, and MatrixNormal). This function can be efficiently implemented for each of these distributions (PosteriorStats.pointwise_loglikelihoods already has implementations for these).

I propose further that there's a mode for pointwise log-likelihood accumulation that given an observe directive like y ~ dist computes and accumulates pointwise_conditional_loglikelihoods(dist, y), instead of loglikelihood(dist, y). Perhaps an API like

pointwise_loglikelihoods(dist, chain; factorize=true)

Then a user would never need to rewrite their model to get pointwise log-likelihoods.

An alternative more flexibile syntax would be something like

pointwise_loglikelihoods(dist, chain; groups)

where groups could be something like (; y=1:10) to support LOO or something like (; y=[1:3, 4:5, 6:10]) to support leave-one-group-out (LOGO) CV. But IMO this syntax is too complicated, requires we do things like checking that every index of y appears in exactly one group, and makes it tricky for us to identify if we can compute pointwise log-likelihoods for all i simultaneously. The loop syntax, though it may require a model rewrite, is probably more powerful for these cases.

I don't think the proposal should be the default though, precisely because for non-factorizable distributions like MvNormal, the sum of the pointwise log-likelihoods will not be the same as the log-likelihood of the joint distribution, and this point is subtle enough that it seems liable to be misunderstood, resulting in user error.

Metadata

Metadata

Assignees

No one assigned

    Labels

    enhancementNew feature or request

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions