Skip to content
Draft
1 change: 1 addition & 0 deletions docs/src/mixture.md
Original file line number Diff line number Diff line change
Expand Up @@ -99,6 +99,7 @@ var(::UnivariateMixture)
length(::MultivariateMixture)
pdf(::AbstractMixtureModel, ::Any)
logpdf(::AbstractMixtureModel, ::Any)
gradlogpdf(::AbstractMixtureModel, ::Any)
rand(::AbstractMixtureModel)
rand!(::AbstractMixtureModel, ::AbstractArray)
```
Expand Down
1 change: 1 addition & 0 deletions docs/src/truncate.md
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@ are defined for all truncated univariate distributions:
- [`insupport(::UnivariateDistribution, x::Any)`](@ref)
- [`pdf(::UnivariateDistribution, ::Real)`](@ref)
- [`logpdf(::UnivariateDistribution, ::Real)`](@ref)
- [`gradlogpdf(::UnivariateDistribution, ::Real)`](@ref)
- [`cdf(::UnivariateDistribution, ::Real)`](@ref)
- [`logcdf(::UnivariateDistribution, ::Real)`](@ref)
- [`logdiffcdf(::UnivariateDistribution, ::T, ::T) where {T <: Real}`](@ref)
Expand Down
1 change: 1 addition & 0 deletions docs/src/univariate.md
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,7 @@ pdfsquaredL2norm
insupport(::UnivariateDistribution, x::Any)
pdf(::UnivariateDistribution, ::Real)
logpdf(::UnivariateDistribution, ::Real)
gradlogpdf(::UnivariateDistribution, ::Real)
loglikelihood(::UnivariateDistribution, ::AbstractArray)
cdf(::UnivariateDistribution, ::Real)
logcdf(::UnivariateDistribution, ::Real)
Expand Down
70 changes: 70 additions & 0 deletions src/mixtures/mixturemodel.jl
Original file line number Diff line number Diff line change
Expand Up @@ -87,6 +87,13 @@ Here, `x` can be a single sample or an array of multiple samples.
"""
logpdf(d::AbstractMixtureModel, x::Any)

"""
gradlogpdf(d::Union{UnivariateMixture, MultivariateMixture}, x)

Evaluate the gradient of the logarithm of the (mixed) probability density function over a single sample `x`.
"""
gradlogpdf(d::AbstractMixtureModel, x::Any)

"""
rand(d::Union{UnivariateMixture, MultivariateMixture})

Expand Down Expand Up @@ -362,6 +369,38 @@ end
pdf(d::UnivariateMixture, x::Real) = _mixpdf1(d, x)
logpdf(d::UnivariateMixture, x::Real) = _mixlogpdf1(d, x)

function gradlogpdf(d::UnivariateMixture, x::Real)
ps = probs(d)
cs = components(d)

# `d` is expected to have at least one distribution, otherwise this will just error
psi, idxps = iterate(ps)
csi, idxcs = iterate(cs)
pdfx1 = pdf(csi, x)
pdfx = psi * pdfx1
glp = pdfx * gradlogpdf(csi, x)
if iszero(psi) || iszero(pdfx)
glp = zero(glp)
end

while (iterps = iterate(ps, idxps)) !== nothing && (itercs = iterate(cs, idxcs)) !== nothing
psi, idxps = iterps
csi, idxcs = itercs
if !iszero(psi)
pdfxi = pdf(csi, x)
if !iszero(pdfxi)
pipdfxi = psi * pdfxi
pdfx += pipdfxi
glp += pipdfxi * gradlogpdf(csi, x)
end
end
end
if !iszero(pdfx) # else glp is already zero
glp /= pdfx
end
Comment on lines +398 to +400
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is it correct to return a gradlogpdf of zero if x is not in the support of the mixture distribution?

Copy link
Author

@rmsrosa rmsrosa Jan 29, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I wondered about that, but decided to follow the current behavior already implemented. For example:

julia> insupport(Beta(0.5, 0.5), -1)
false

julia> logpdf(Beta(0.5, 0.5), -1)
-Inf

julia> gradlogpdf(Beta(0.5, 0.5), -1)
0.0

I don't know. If it is constant -Inf, then the derivative is zero (except that (-Inf) - (-Inf) is not defined, but what matters is that the rate of change is zero...)

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe we should make this a separate issue and hopefully standardize the behavior.

return glp
end

_pdf!(r::AbstractArray{<:Real}, d::UnivariateMixture{Discrete}, x::UnitRange) = _mixpdf!(r, d, x)
_pdf!(r::AbstractArray{<:Real}, d::UnivariateMixture, x::AbstractArray{<:Real}) = _mixpdf!(r, d, x)
_logpdf!(r::AbstractArray{<:Real}, d::UnivariateMixture, x::AbstractArray{<:Real}) = _mixlogpdf!(r, d, x)
Expand All @@ -371,6 +410,37 @@ _logpdf(d::MultivariateMixture, x::AbstractVector{<:Real}) = _mixlogpdf1(d, x)
_pdf!(r::AbstractArray{<:Real}, d::MultivariateMixture, x::AbstractMatrix{<:Real}) = _mixpdf!(r, d, x)
_logpdf!(r::AbstractArray{<:Real}, d::MultivariateMixture, x::AbstractMatrix{<:Real}) = _mixlogpdf!(r, d, x)

function gradlogpdf(d::MultivariateMixture, x::AbstractVector{<:Real})
ps = probs(d)
cs = components(d)

# `d` is expected to have at least one distribution, otherwise this will just error
psi, idxps = iterate(ps)
csi, idxcs = iterate(cs)
pdfx1 = pdf(csi, x)
pdfx = psi * pdfx1
glp = pdfx * gradlogpdf(csi, x)
if iszero(psi) || iszero(pdfx)
fill!(glp, zero(eltype(glp)))
end

while (iterps = iterate(ps, idxps)) !== nothing && (itercs = iterate(cs, idxcs)) !== nothing
psi, idxps = iterps
csi, idxcs = itercs
if !iszero(psi)
pdfxi = pdf(csi, x)
if !iszero(pdfxi)
pipdfxi = psi * pdfxi
pdfx += pipdfxi
glp .+= pipdfxi .* gradlogpdf(csi, x)
end
end
end
if !iszero(pdfx) # else glp is already zero
glp ./= pdfx
end
return glp
end

## component-wise pdf and logpdf

Expand Down
58 changes: 58 additions & 0 deletions test/gradlogpdf.jl
Original file line number Diff line number Diff line change
Expand Up @@ -25,3 +25,61 @@ using Test
[0.191919191919192, 1.080808080808081] ,atol=1.0e-8)
@test isapprox(gradlogpdf(MvTDist(5., [1., 2.], [1. 0.1; 0.1 1.]), [0.7, 0.9]),
[0.2150711513583442, 1.2111901681759383] ,atol=1.0e-8)

# Test for gradlogpdf on univariate mixture distributions

x = [-0.2, 0.3, 0.8, 1.0, 1.3, 10.5]
delta = 0.0001

for di in (
Normal(-4.5, 2.0),
Exponential(2.0),
Uniform(0.0, 1.0),
Beta(2.0, 3.0),
Beta(0.5, 0.5)
)
d = MixtureModel([di], [1.0])
glp1 = gradlogpdf.(d, x)
glp2 = gradlogpdf.(di, x)
@info "Testing `gradlogpdf` on $d"
@test isapprox(glp1, glp2, atol = 0.01)
end

for d in (
MixtureModel([Normal(1//1, 2//1), Beta(2//1, 3//1), Exponential(3//2)], [3//10, 4//10, 3//10]),
MixtureModel([Normal(-2.0, 3.5), Normal(-4.5, 2.0)], [0.0, 1.0]),
MixtureModel([Beta(1.5, 3.0), Chi(5.0), Chisq(7.0)], [0.4, 0.3, 0.3]),
MixtureModel([Exponential(2.0), Gamma(9.0, 0.5), Gumbel(3.5, 1.0), Laplace(7.0)], [0.3, 0.2, 0.4, 0.1]),
MixtureModel([Logistic(-6.0), LogNormal(5.5), TDist(8.0), Weibull(2.0)], [0.3, 0.2, 0.4, 0.1])
)

# finite differences don't handle when not in the interior of the support
xs = filter(s -> all(insupport.(d, [s - delta, s, s + delta])), x)

glp1 = gradlogpdf.(d, xs)
glp2 = ( logpdf.(d, xs .+ delta) - logpdf.(d, xs .- delta) ) ./ 2delta
@info "Testing `gradlogpdf` on $d"
@test isapprox(glp1, glp2, atol = 0.01)
end

# Test for gradlogpdf on multivariate mixture distributions against centered finite-difference on logpdf

x = [[0.2, 0.3], [0.8, 1.3], [-1.0, 10.5]]
delta = 0.001

for d in (
MixtureModel([MvNormal([1., 2.], [1. 0.1; 0.1 1.])], [1.0]),
MixtureModel([MvNormal([1.0, 2.0], [0.4 0.2; 0.2 0.5]), MvNormal([2.0, 1.0], [0.3 0.1; 0.1 0.4])], [0.4, 0.6]),
MixtureModel([MvNormal([3.0, 2.0], [0.2 0.3; 0.3 0.5]), MvNormal([1.0, 2.0], [0.4 0.2; 0.2 0.5]), MvNormal([2.0, 1.0], [0.3 0.1; 0.1 0.4])], [0.0, 1.0, 0.0]),
MixtureModel([MvTDist(5., [1., 2.], [1. 0.1; 0.1 1.])], [1.0]),
MixtureModel([MvNormal([1.0, 2.0], [0.4 0.2; 0.2 0.5]), MvTDist(5., [1., 2.], [1. 0.1; 0.1 1.])], [0.4, 0.6])
)
xs = filter(s -> insupport(d, s), x)
for xi in xs
glp = gradlogpdf(d, xi)
glpx = ( logpdf(d, xi .+ [delta, 0]) - logpdf(d, xi .- [delta, 0]) ) ./ 2delta
glpy = ( logpdf(d, xi .+ [0, delta]) - logpdf(d, xi .- [0, delta]) ) ./ 2delta
@test isapprox(glp[1], glpx, atol=delta)
@test isapprox(glp[2], glpy, atol=delta)
end
end