-
Notifications
You must be signed in to change notification settings - Fork 36
Description
Currently pointwise_loglikelihoods
determines the "points" based on a ~
expression. So if we have a fully factorized log-likelihood expression like
for i in eachindex(y, μ)
y[i] ~ Normal(μ[i], σ)
end
then we end up with one point for each value of i
. However, if we write the model in a "non-factorized" (but still technically factorizable form) as
y ~ product_distribution(Normal.(μ, σ))
Then each pointwise log-likelihood is the joint log-likelihood of all entries on y
.
I don't think there's any context in which this is the behavior we want (for this, we have likelihood
), but I think this is the correct behavior in general (I'll say why later). However, for most users this requires they have two different implementations of the model: one for sampling (using product_distribution
or array_dist
) and one for pointwise log-likelihoods (using the explicit loop). The primary use case I know of for pointwise log-likelihoods is model comparison, e.g. with LOO or LOGO.
I propose a method like pointwise_conditional_loglikelihoods(dist, y)
for array-variate distribution dist
and array y
that returns an an array log_like
with the same shape as y
where log_like[i]
contains y
except y[i]
, and dist
. This is precisely what we need for LOO, and for factorizable distributions (e.g. DiagNormal
and ProductDistribution
) is the same as the pointwise log-likelihoods. However, it also generalizes to non-factorizable distributions (the only examples in Distributions that would ever need to be supported are MvNormal
, MvNormalCanon
, MvLogNormal
, GenericMvTDist
, and MatrixNormal
). This function can be efficiently implemented for each of these distributions (PosteriorStats.pointwise_loglikelihoods
already has implementations for these).
I propose further that there's a mode for pointwise log-likelihood accumulation that given an observe directive like y ~ dist
computes and accumulates pointwise_conditional_loglikelihoods(dist, y)
, instead of loglikelihood(dist, y)
. Perhaps an API like
pointwise_loglikelihoods(dist, chain; factorize=true)
Then a user would never need to rewrite their model to get pointwise log-likelihoods.
An alternative more flexibile syntax would be something like
pointwise_loglikelihoods(dist, chain; groups)
where groups
could be something like (; y=1:10)
to support LOO or something like (; y=[1:3, 4:5, 6:10])
to support leave-one-group-out (LOGO) CV. But IMO this syntax is too complicated, requires we do things like checking that every index of y
appears in exactly one group, and makes it tricky for us to identify if we can compute pointwise log-likelihoods for all i
simultaneously. The loop syntax, though it may require a model rewrite, is probably more powerful for these cases.
I don't think the proposal should be the default though, precisely because for non-factorizable distributions like MvNormal
, the sum of the pointwise log-likelihoods will not be the same as the log-likelihood of the joint distribution, and this point is subtle enough that it seems liable to be misunderstood, resulting in user error.