-
Notifications
You must be signed in to change notification settings - Fork 30
Open
Description
AbstractGPs.jl has the following default implementation (https://github.com/JuliaGaussianProcesses/AbstractGPs.jl/blob/04e4c53a3a66c62c83cdb4f47aabcfb04543acf4/src/abstract_gp/finite_gp.jl#L298-L309):
function Distributions._logpdf(f::FiniteGP, y::AbstractVector{<:Real})
return first(logpdf(f, reshape(y, :, 1)))
end
...
function Distributions.logpdf(f::FiniteGP, Y::AbstractMatrix{<:Real})
m, C_mat = mean_and_cov(f)
C = cholesky(_symmetric(C_mat))
T = promote_type(eltype(m), eltype(C), eltype(Y))
return -((size(Y, 1) * T(log(2π)) + logdet(C)) .+ diag_Xt_invA_X(C, Y .- m)) ./ 2
endThis means that
DistributionsAD.jl/src/zygote.jl
Lines 70 to 79 in fe20700
| ZygoteRules.@adjoint function Distributions.logpdf( | |
| dist::MultivariateDistribution, | |
| X::AbstractMatrix{<:Real}, | |
| ) | |
| size(X, 1) == length(dist) || | |
| throw(DimensionMismatch("Inconsistent array dimensions.")) | |
| return ZygoteRules.pullback(dist, X) do dist, X | |
| return map(i -> Distributions._logpdf(dist, view(X, :, i)), axes(X, 2)) | |
| end | |
| end |
will recursive indefinitively 😕
Should we maybe remove that the adjoint above? TBH it seems a bit too general.
Metadata
Metadata
Assignees
Labels
No labels