Skip to content
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion src/distributions/mv_normal_mean_covariance.jl
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,7 @@ function Distributions.sqmahal!(r, dist::MvNormalMeanCovariance, x::AbstractVect
for i in 1:length(r)
@inbounds r[i] = μ[i] - x[i]
end
return dot(r, invcov(dist), r) # x' * A * x
return xT_A_y(r, invcov(dist), r) # x' * A * x
end

Base.eltype(::MvNormalMeanCovariance{T}) where {T} = T
Expand Down
2 changes: 1 addition & 1 deletion src/distributions/mv_normal_mean_precision.jl
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@ function Distributions.sqmahal!(r, dist::MvNormalMeanPrecision, x::AbstractVecto
for i in 1:length(r)
@inbounds r[i] = μ[i] - x[i]
end
return dot(r, invcov(dist), r)
return xT_A_y(r, invcov(dist), r) # x' * A * x
end

Base.eltype(::MvNormalMeanPrecision{T}) where {T} = T
Expand Down
2 changes: 1 addition & 1 deletion src/distributions/mv_normal_weighted_mean_precision.jl
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,7 @@ function Distributions.sqmahal!(r, dist::MvNormalWeightedMeanPrecision, x::Abstr
for i in 1:length(r)
@inbounds r[i] = μ[i] - x[i]
end
return dot(r, invcov(dist), r)
return xT_A_y(r, invcov(dist), r) # x' * A * x
end

Base.eltype(::MvNormalWeightedMeanPrecision{T}) where {T} = T
Expand Down
7 changes: 6 additions & 1 deletion src/distributions/normal.jl
Original file line number Diff line number Diff line change
Expand Up @@ -191,6 +191,11 @@ promote_variate_type(::Type{Multivariate}, ::Type{<:NormalMeanVariance})
promote_variate_type(::Type{Multivariate}, ::Type{<:NormalMeanPrecision}) = MvNormalMeanPrecision
promote_variate_type(::Type{Multivariate}, ::Type{<:NormalWeightedMeanPrecision}) = MvNormalWeightedMeanPrecision

# Conversion to gaussian distributions from `Distributions.jl`

Base.convert(::Type{Normal}, dist::UnivariateNormalDistributionsFamily) = Normal(mean_std(dist)...)
Base.convert(::Type{MvNormal}, dist::MultivariateNormalDistributionsFamily) = MvNormal(mean_cov(dist)...)

# Conversion to mean - variance parametrisation

function Base.convert(::Type{NormalMeanVariance{T}}, dist::UnivariateNormalDistributionsFamily) where {T <: Real}
Expand Down Expand Up @@ -312,7 +317,7 @@ function Base.prod(
n = length(left)
v_inv, v_logdet = cholinv_logdet(v)
m = m_left - m_right
return -(v_logdet + n * log2π) / 2 - dot(m, v_inv, m) / 2
return -(v_logdet + n * log2π) / 2 - xT_A_y(m, v_inv, m) / 2
end

## Friendly functions
Expand Down
25 changes: 25 additions & 0 deletions src/helpers/algebra/common.jl
Original file line number Diff line number Diff line change
Expand Up @@ -164,6 +164,31 @@ function v_a_vT(v1, a, v2)
return result
end

"""
xT_A_y(x, A, y)

Computes `dot(x, A, y)`. The built-in Julia 3-arg `dot` is not compatible with the auto-differentiation packages,
such as `ForwardDiff`. We use our own implementation in some cases but ultimately fallback to the `dot`.
"""
xT_A_y(x, A, y) = dot(x, A, y)

function xT_A_y(x::AbstractVector, A::AbstractMatrix, y::AbstractVector)
(axes(x)..., axes(y)...) == axes(A) || throw(DimensionMismatch())
T = typeof(dot(first(x), first(A), first(y)))
s = zero(T)
i₁ = first(eachindex(x))
x₁ = first(x)
@inbounds for j in eachindex(y)
yj = y[j]
temp = zero(adjoint(A[i₁, j]) * x₁)
@simd for i in eachindex(x)
temp += adjoint(A[i, j]) * x[i]
end
s += dot(temp, yj)
end
return s
end

"""
mvbeta(x)

Expand Down
6 changes: 3 additions & 3 deletions src/nodes/autoregressive.jl
Original file line number Diff line number Diff line change
Expand Up @@ -50,8 +50,8 @@ default_meta(::Type{AR}) = error("Autoregressive node requires meta flag explici
my1, Vy1 = first(myx), first(Vyx)
Vy1x = ar_slice(getvform(meta), Vyx, 1, (order + 1):(2order))

# Euivalento to AE = (-mean(log, q_γ) + log2π + mγ*(Vy1+my1^2 - 2*mθ'*(Vy1x + mx*my1) + tr(Vθ*Vx) + mx'*Vθ*mx + mθ'*(Vx + mx*mx')*mθ)) / 2
AE = (-mean(log, q_γ) + log2π + mγ * (Vy1 + my1^2 - 2 * mθ' * (Vy1x + mx * my1) + mul_trace(Vθ, Vx) + dot(mx, Vθ, mx) + dot(mθ, Vx, mθ) + abs2(dot(mθ, mx)))) / 2
# Equivalent to AE = (-mean(log, q_γ) + log2π + mγ*(Vy1+my1^2 - 2*mθ'*(Vy1x + mx*my1) + tr(Vθ*Vx) + mx'*Vθ*mx + mθ'*(Vx + mx*mx')*mθ)) / 2
AE = (-mean(log, q_γ) + log2π + mγ * (Vy1 + my1^2 - 2 * mθ' * (Vy1x + mx * my1) + mul_trace(Vθ, Vx) + xT_A_y(mx, Vθ, mx) + xT_A_y(mθ, Vx, mθ) + abs2(dot(mθ, mx)))) / 2

# correction
if is_multivariate(meta)
Expand All @@ -76,7 +76,7 @@ end

my1, Vy1 = first(my), first(Vy)

AE = -0.5mean(log, q_γ) + 0.5log2π + 0.5 * mγ * (Vy1 + my1^2 - 2 * mθ' * mx * my1 + mul_trace(Vθ, Vx) + dot(mx, Vθ, mx) + dot(mθ, Vx, mθ) + abs2(dot(mθ, mx)))
AE = -0.5mean(log, q_γ) + 0.5log2π + 0.5 * mγ * (Vy1 + my1^2 - 2 * mθ' * mx * my1 + mul_trace(Vθ, Vx) + xT_A_y(mx, Vθ, mx) + xT_A_y(mθ, Vx, mθ) + abs2(dot(mθ, mx)))

# correction
if is_multivariate(meta)
Expand Down
2 changes: 1 addition & 1 deletion src/rules/dot_product/out.jl
Original file line number Diff line number Diff line change
Expand Up @@ -6,5 +6,5 @@ end
@rule typeof(dot)(:out, Marginalisation) (m_in1::PointMass, m_in2::NormalDistributionsFamily, meta::AbstractCorrection) = begin
A = mean(m_in1)
in2_mean, in2_cov = mean_cov(m_in2)
return NormalMeanVariance(dot(A, in2_mean), dot(A, in2_cov, A))
return NormalMeanVariance(dot(A, in2_mean), xT_A_y(A, in2_cov, A))
end
2 changes: 1 addition & 1 deletion src/rules/multiplication/A.jl
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ end
@rule typeof(*)(:A, Marginalisation) (m_out::MultivariateNormalDistributionsFamily, m_in::PointMass{<:AbstractVector}, meta::Union{<:AbstractCorrection, Nothing}) = begin
A = mean(m_in)
ξ_out, W_out = weightedmean_precision(m_out)
W = correction!(meta, dot(A, W_out, A))
W = correction!(meta, xT_A_y(A, W_out, A))
return NormalWeightedMeanPrecision(dot(A, ξ_out), W)
end

Expand Down
2 changes: 1 addition & 1 deletion src/rules/multiplication/in.jl
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ end
@rule typeof(*)(:in, Marginalisation) (m_out::MultivariateNormalDistributionsFamily, m_A::PointMass{<:AbstractVector}, meta::Union{<:AbstractCorrection, Nothing}) = begin
A = mean(m_A)
ξ_out, W_out = weightedmean_precision(m_out)
W = correction!(meta, dot(A, W_out, A))
W = correction!(meta, xT_A_y(A, W_out, A))
return NormalWeightedMeanPrecision(dot(A, ξ_out), W)
end

Expand Down
12 changes: 12 additions & 0 deletions test/algebra/test_common.jl
Original file line number Diff line number Diff line change
Expand Up @@ -84,6 +84,18 @@ using LinearAlgebra
@test ReactiveMP.mul_trace(a, b) ≈ a * b
end
end

@testset "xT_A_y" begin
import ReactiveMP: xT_A_y

rng = MersenneTwister(1234)
for size in 2:5, T1 in (Float32, Float64), T2 in (Float32, Float64), T3 in (Float32, Float64)
x = rand(T1, size)
A = rand(T2, size, size)
y = rand(T3, size)
@test dot(x, A, y) ≈ xT_A_y(x, A, y)
end
end
end

end
109 changes: 65 additions & 44 deletions test/distributions/test_normal.jl
Original file line number Diff line number Diff line change
Expand Up @@ -5,29 +5,42 @@ using ReactiveMP
using Random
using LinearAlgebra
using Distributions
using ForwardDiff

import ReactiveMP: convert_eltype

@testset "Normal" begin
@testset "Univariate conversions" begin
check_basic_statistics = (left, right) -> begin
@test mean(left) ≈ mean(right)
@test median(left) ≈ median(right)
@test mode(left) ≈ mode(right)
@test weightedmean(left) ≈ weightedmean(right)
@test var(left) ≈ var(right)
@test std(left) ≈ std(right)
@test cov(left) ≈ cov(right)
@test invcov(left) ≈ invcov(right)
@test precision(left) ≈ precision(right)
@test entropy(left) ≈ entropy(right)
@test pdf(left, 1.0) ≈ pdf(right, 1.0)
@test pdf(left, -1.0) ≈ pdf(right, -1.0)
@test pdf(left, 0.0) ≈ pdf(right, 0.0)
@test logpdf(left, 1.0) ≈ logpdf(right, 1.0)
@test logpdf(left, -1.0) ≈ logpdf(right, -1.0)
@test logpdf(left, 0.0) ≈ logpdf(right, 0.0)
end
check_basic_statistics =
(left, right; include_extended_methods = true) -> begin
@test mean(left) ≈ mean(right)
@test median(left) ≈ median(right)
@test mode(left) ≈ mode(right)
@test var(left) ≈ var(right)
@test std(left) ≈ std(right)
@test entropy(left) ≈ entropy(right)

for value in (1.0, -1.0, 0.0, mean(left), mean(right), rand())
@test pdf(left, value) ≈ pdf(right, value)
@test logpdf(left, value) ≈ logpdf(right, value)
@test all(ForwardDiff.gradient((x) -> logpdf(left, x[1]), [value]) .≈ ForwardDiff.gradient((x) -> logpdf(right, x[1]), [value]))
@test all(ForwardDiff.hessian((x) -> logpdf(left, x[1]), [value]) .≈ ForwardDiff.hessian((x) -> logpdf(right, x[1]), [value]))
end

# These methods are not defined for distributions from `Distributions.jl
if include_extended_methods
@test cov(left) ≈ cov(right)
@test invcov(left) ≈ invcov(right)
@test weightedmean(left) ≈ weightedmean(right)
@test precision(left) ≈ precision(right)
@test all(mean_cov(left) .≈ mean_cov(right))
@test all(mean_invcov(left) .≈ mean_invcov(right))
@test all(mean_precision(left) .≈ mean_precision(right))
@test all(weightedmean_cov(left) .≈ weightedmean_cov(right))
@test all(weightedmean_invcov(left) .≈ weightedmean_invcov(right))
@test all(weightedmean_precision(left) .≈ weightedmean_precision(right))
end
end

types = ReactiveMP.union_types(UnivariateNormalDistributionsFamily{Float64})
etypes = ReactiveMP.union_types(UnivariateNormalDistributionsFamily)
Expand All @@ -36,6 +49,7 @@ import ReactiveMP: convert_eltype

for type in types
left = convert(type, rand(rng, Float64), rand(rng, Float64))
check_basic_statistics(left, convert(Normal, left); include_extended_methods = false)
for type in [types..., etypes...]
right = convert(type, left)
check_basic_statistics(left, right)
Expand All @@ -56,32 +70,38 @@ import ReactiveMP: convert_eltype
end

@testset "Multivariate conversions" begin
check_basic_statistics = (left, right, dims) -> begin
@test mean(left) ≈ mean(right)
@test mode(left) ≈ mode(right)
@test weightedmean(left) ≈ weightedmean(right)
@test var(left) ≈ var(right)
@test cov(left) ≈ cov(right)
@test invcov(left) ≈ invcov(right)
@test logdetcov(left) ≈ logdetcov(right)
@test precision(left) ≈ precision(right)
@test length(left) === length(right)
@test ndims(left) === ndims(right)
@test size(left) === size(right)
@test entropy(left) ≈ entropy(right)
@test all(mean_cov(left) .≈ mean_cov(right))
@test all(mean_invcov(left) .≈ mean_invcov(right))
@test all(mean_precision(left) .≈ mean_precision(right))
@test all(weightedmean_cov(left) .≈ weightedmean_cov(right))
@test all(weightedmean_invcov(left) .≈ weightedmean_invcov(right))
@test all(weightedmean_precision(left) .≈ weightedmean_precision(right))
@test pdf(left, fill(1.0, dims)) ≈ pdf(right, fill(1.0, dims))
@test pdf(left, fill(-1.0, dims)) ≈ pdf(right, fill(-1.0, dims))
@test pdf(left, fill(0.0, dims)) ≈ pdf(right, fill(0.0, dims))
@test logpdf(left, fill(1.0, dims)) ≈ logpdf(right, fill(1.0, dims))
@test logpdf(left, fill(-1.0, dims)) ≈ logpdf(right, fill(-1.0, dims))
@test logpdf(left, fill(0.0, dims)) ≈ logpdf(right, fill(0.0, dims))
end
check_basic_statistics =
(left, right, dims; include_extended_methods = true) -> begin
@test mean(left) ≈ mean(right)
@test mode(left) ≈ mode(right)
@test var(left) ≈ var(right)
@test cov(left) ≈ cov(right)
@test logdetcov(left) ≈ logdetcov(right)
@test length(left) === length(right)
@test size(left) === size(right)
@test entropy(left) ≈ entropy(right)

for value in (fill(1.0, dims), fill(-1.0, dims), fill(0.0, dims), mean(left), mean(right), rand(dims))
@test pdf(left, value) ≈ pdf(right, value)
@test logpdf(left, value) ≈ logpdf(right, value)
@test all(isapprox.(ForwardDiff.gradient((x) -> logpdf(left, x), value), ForwardDiff.gradient((x) -> logpdf(right, x), value), atol = 1e-14))
@test all(isapprox.(ForwardDiff.hessian((x) -> logpdf(left, x), value), ForwardDiff.hessian((x) -> logpdf(right, x), value), atol = 1e-14))
end

# These methods are not defined for distributions from `Distributions.jl
if include_extended_methods
@test ndims(left) === ndims(right)
@test invcov(left) ≈ invcov(right)
@test weightedmean(left) ≈ weightedmean(right)
@test precision(left) ≈ precision(right)
@test all(mean_cov(left) .≈ mean_cov(right))
@test all(mean_invcov(left) .≈ mean_invcov(right))
@test all(mean_precision(left) .≈ mean_precision(right))
@test all(weightedmean_cov(left) .≈ weightedmean_cov(right))
@test all(weightedmean_invcov(left) .≈ weightedmean_invcov(right))
@test all(weightedmean_precision(left) .≈ weightedmean_precision(right))
end
end

types = ReactiveMP.union_types(MultivariateNormalDistributionsFamily{Float64})
etypes = ReactiveMP.union_types(MultivariateNormalDistributionsFamily)
Expand All @@ -92,6 +112,7 @@ import ReactiveMP: convert_eltype
for dim in dims
for type in types
left = convert(type, rand(rng, Float64, dim), Matrix(Diagonal(rand(rng, Float64, dim))))
check_basic_statistics(left, convert(MvNormal, left), dim; include_extended_methods = false)
for type in [types..., etypes...]
right = convert(type, left)
check_basic_statistics(left, right, dim)
Expand Down