diff --git a/Project.toml b/Project.toml index 583a3cdb0..7aa8ac6c8 100644 --- a/Project.toml +++ b/Project.toml @@ -1,7 +1,7 @@ name = "ReactiveMP" uuid = "a194aa59-28ba-4574-a09c-4a745416d6e3" authors = ["Dmitry Bagaev ", "Albert Podusenko ", "Bart van Erp ", "Ismail Senoz "] -version = "2.3.0" +version = "2.3.1" [deps] DataStructures = "864edb3b-99cc-5e75-8d2d-829cb0a9cfe8" diff --git a/src/constraints/spec/factorisation_spec.jl b/src/constraints/spec/factorisation_spec.jl index c2fe0252b..d88921416 100644 --- a/src/constraints/spec/factorisation_spec.jl +++ b/src/constraints/spec/factorisation_spec.jl @@ -201,7 +201,7 @@ resolve_factorisation(::UnspecifiedConstraints, any, model, fform, variables) = resolve_factorisation(::UnspecifiedConstraints, ::Deterministic, model, fform, variables) = FullFactorisation() # Preoptimised dispatch rules for unspecified constraints and a stochastic node with 2 inputs -resolve_factorisation(::UnspecifiedConstraints, ::Stochastic, model, fform, ::Tuple{V1, V2}) where {V1 <: RandomVariable, V2 <: RandomVariable} = ((1, 2)) +resolve_factorisation(::UnspecifiedConstraints, ::Stochastic, model, fform, ::Tuple{V1, V2}) where {V1 <: RandomVariable, V2 <: RandomVariable} = ((1, 2),) resolve_factorisation(::UnspecifiedConstraints, ::Stochastic, model, fform, ::Tuple{V1, V2}) where {V1 <: Union{<:ConstVariable, <:DataVariable}, V2 <: RandomVariable} = ((1,), (2,)) resolve_factorisation(::UnspecifiedConstraints, ::Stochastic, model, fform, ::Tuple{V1, V2}) where {V1 <: RandomVariable, V2 <: Union{<:ConstVariable, <:DataVariable}} = ((1,), (2,)) diff --git a/src/distributions/categorical.jl b/src/distributions/categorical.jl index 136ec8ab3..93bd97645 100644 --- a/src/distributions/categorical.jl +++ b/src/distributions/categorical.jl @@ -1,4 +1,4 @@ -export Bernoulli +export Categorical import Distributions: Categorical, probs diff --git a/src/distributions/contingency.jl b/src/distributions/contingency.jl index 240b0e508..5deb9ea8a 100644 --- a/src/distributions/contingency.jl +++ b/src/distributions/contingency.jl @@ -2,15 +2,41 @@ export Contingency using LinearAlgebra +""" + Contingency(P, renormalize = Val(true)) + +The contingency distribution is a multivariate generalization of the categorical distribution. As a bivariate distribution, the +contingency distribution defines the joint probability over two unit vectors `v1` and `v2`. The parameter `P` encodes a contingency matrix that specifies the probability of co-occurrence. + + v1 ∈ {0, 1}^d1 where Σ_j v1_j = 1 + v2 ∈ {0, 1}^d2 where Σ_k v2_k = 1 + + P ∈ [0, 1]^{d1 × d2}, where Σ_jk P_jk = 1 + + f(v1, v2, P) = Contingency(out1, out2 | P) = Π_jk P_jk^{v1_j * v2_k} + +A `Contingency` distribution over more than two variables requires higher-order tensors as parameters; these are not implemented in ReactiveMP. + +# Arguments: +- `P`, required, contingency matrix +- `renormalize`, optional, supports either `Val(true)` or `Val(false)`, specifies whether matrix `P` must be automatically renormalized. Does not modify the original `P` and allocates a new one for the renormalized version. If set to `false` the contingency matrix `P` **must** be normalized by hand, otherwise the result of related calculations might be wrong + +""" struct Contingency{T, P <: AbstractMatrix{T}} <: ContinuousMatrixDistribution p::P + + Contingency{T, P}(A::AbstractMatrix) where {T, P <: AbstractMatrix{T}} = new(A) end +Contingency(P::AbstractMatrix) = Contingency(P, Val(true)) +Contingency(P::M, renormalize::Val{true}) where {T, M <: AbstractMatrix{T}} = Contingency{T, M}(P ./ sum(P)) +Contingency(P::M, renormalize::Val{false}) where {T, M <: AbstractMatrix{T}} = Contingency{T, M}(P) + contingency_matrix(distribution::Contingency) = distribution.p vague(::Type{<:Contingency}, dims::Int) = Contingency(ones(dims, dims) ./ abs2(dims)) function entropy(distribution::Contingency) P = contingency_matrix(distribution) - -sum(P .* log.(clamp.(P, tiny, Inf))) + return -mapreduce((p) -> p * clamplog(p), +, P) end diff --git a/src/distributions/dirichlet.jl b/src/distributions/dirichlet.jl index 4cdf840fa..71d966c7a 100644 --- a/src/distributions/dirichlet.jl +++ b/src/distributions/dirichlet.jl @@ -16,7 +16,8 @@ end probvec(dist::Dirichlet) = params(dist)[1] # probvec is not normalised -mean(::typeof(log), dist::Dirichlet) = digamma.(probvec(dist)) .- digamma(sum(probvec(dist))) +mean(::typeof(log), dist::Dirichlet) = digamma.(probvec(dist)) .- digamma(sum(probvec(dist))) +mean(::typeof(clamplog), dist::Dirichlet) = digamma.((clamp(p, tiny, typemax(p)) for p in probvec(dist))) .- digamma(sum(probvec(dist))) # Variate forms promotion diff --git a/src/distributions/normal.jl b/src/distributions/normal.jl index b5170de33..e3a3d0210 100644 --- a/src/distributions/normal.jl +++ b/src/distributions/normal.jl @@ -204,6 +204,11 @@ function Random.rand(rng::AbstractRNG, dist::UnivariateNormalDistributionsFamily return μ + σ * randn(rng, float(T)) end +function Random.rand(rng::AbstractRNG, dist::UnivariateNormalDistributionsFamily{T}, size::Int64) where {T} + container = Vector{T}(undef, size) + return rand!(rng, dist, container) +end + function Random.rand!( rng::AbstractRNG, dist::UnivariateNormalDistributionsFamily, @@ -211,7 +216,7 @@ function Random.rand!( ) where {T <: Real} randn!(rng, container) μ, σ = mean_std(dist) - @turbo for i in 1:length(container) + @turbo for i in eachindex(container) container[i] = μ + σ * container[i] end container @@ -224,6 +229,11 @@ function Random.rand(rng::AbstractRNG, dist::MultivariateNormalDistributionsFami return μ + L * randn(rng, length(μ)) end +function Random.rand(rng::AbstractRNG, dist::MultivariateNormalDistributionsFamily{T}, size::Int64) where {T} + container = Matrix{T}(undef, ndims(dist), size) + return rand!(rng, dist, container) +end + function Random.rand!( rng::AbstractRNG, dist::MultivariateNormalDistributionsFamily, @@ -232,7 +242,7 @@ function Random.rand!( preallocated = similar(container) randn!(rng, reshape(preallocated, length(preallocated))) μ, L = mean_std(dist) - @views for i in 1:size(preallocated)[2] + @views for i in axes(preallocated, 2) copyto!(container[:, i], μ) mul!(container[:, i], L, preallocated[:, i], 1, 1) end diff --git a/src/distributions/pointmass.jl b/src/distributions/pointmass.jl index cc108e930..0e7e751c8 100644 --- a/src/distributions/pointmass.jl +++ b/src/distributions/pointmass.jl @@ -62,6 +62,7 @@ probvec(distribution::PointMass{V}) where {T <: Real, V <: AbstractVector{T}} = mean(::typeof(inv), distribution::PointMass{V}) where {T <: Real, V <: AbstractVector{T}} = error("mean of inverse of `::PointMass{ <: AbstractVector }` is not defined") mean(::typeof(log), distribution::PointMass{V}) where {T <: Real, V <: AbstractVector{T}} = log.(mean(distribution)) +mean(::typeof(clamplog), distribution::PointMass{V}) where {T <: Real, V <: AbstractVector{T}} = clamplog.(mean(distribution)) mean(::typeof(mirrorlog), distribution::PointMass{V}) where {T <: Real, V <: AbstractVector{T}} = error("mean of mirrorlog of `::PointMass{ <: AbstractVector }` is not defined") mean(::typeof(loggamma), distribution::PointMass{V}) where {T <: Real, V <: AbstractVector{T}} = loggamma.(mean(distribution)) mean(::typeof(logdet), distribution::PointMass{V}) where {T <: Real, V <: AbstractVector{T}} = error("mean of logdet of `::PointMass{ <: AbstractVector }` is not defined") @@ -91,6 +92,7 @@ probvec(distribution::PointMass{M}) where {T <: Real, M <: AbstractMatrix{T}} = mean(::typeof(inv), distribution::PointMass{M}) where {T <: Real, M <: AbstractMatrix{T}} = cholinv(mean(distribution)) mean(::typeof(log), distribution::PointMass{M}) where {T <: Real, M <: AbstractMatrix{T}} = log.(mean(distribution)) +mean(::typeof(clamplog), distribution::PointMass{M}) where {T <: Real, M <: AbstractMatrix{T}} = clamplog.(mean(distribution)) mean(::typeof(mirrorlog), distribution::PointMass{M}) where {T <: Real, M <: AbstractMatrix{T}} = error("mean of mirrorlog of `::PointMass{ <: AbstractMatrix }` is not defined") mean(::typeof(loggamma), distribution::PointMass{M}) where {T <: Real, M <: AbstractMatrix{T}} = loggamma.(mean(distribution)) mean(::typeof(logdet), distribution::PointMass{M}) where {T <: Real, M <: AbstractMatrix{T}} = logdet(mean(distribution)) diff --git a/src/helpers.jl b/src/helpers.jl index f41f67045..75e38bed1 100644 --- a/src/helpers.jl +++ b/src/helpers.jl @@ -264,6 +264,11 @@ end ## Other helpers +""" +Same as `log` but clamps the input argument `x` to be in the range `tiny <= x <= typemax(x)` such that `log(0)` does not explode. +""" +clamplog(x) = log(clamp(x, tiny, typemax(x))) + # We override this function for some specific types function is_typeof_equal(left, right) _isequal = typeof(left) === typeof(right) @@ -311,9 +316,9 @@ Float64 """ function deep_eltype end -deep_eltype(::Type{T}) where {T <: Number} = T -deep_eltype(::Type{T}) where {T} = deep_eltype(eltype(T)) -deep_eltype(::T) where {T} = deep_eltype(T) +deep_eltype(::Type{T}) where {T} = T +deep_eltype(::Type{T}) where {T <: AbstractArray} = deep_eltype(eltype(T)) +deep_eltype(any) = deep_eltype(typeof(any)) ## diff --git a/src/inference.jl b/src/inference.jl index 328fa2553..7b3f4a743 100644 --- a/src/inference.jl +++ b/src/inference.jl @@ -392,7 +392,10 @@ function inference(; error("Data is empty. Make sure you used `data` keyword argument with correct value.") else foreach(filter(pair -> isdata(last(pair)), pairs(vardict))) do pair - haskey(data, first(pair)) || error("Data entry $(first(pair)) is missing in `data` dictionary.") + varname = first(pair) + haskey(data, varname) || error( + "Data entry `$(varname)` is missing in `data` argument. Double check `data = ($(varname) = ???, )`" + ) end end diff --git a/src/nodes/categorical.jl b/src/nodes/categorical.jl index 362bf77d5..9ebb865b0 100644 --- a/src/nodes/categorical.jl +++ b/src/nodes/categorical.jl @@ -1,4 +1,4 @@ @node Categorical Stochastic [out, p] -@average_energy Categorical (q_out::Categorical, q_p::Any) = -sum(probvec(q_out) .* mean(log, q_p)) +@average_energy Categorical (q_out::Categorical, q_p::Any) = -sum(probvec(q_out) .* mean(clamplog, q_p)) diff --git a/src/nodes/transition.jl b/src/nodes/transition.jl index 4a198c6ef..343ce812a 100644 --- a/src/nodes/transition.jl +++ b/src/nodes/transition.jl @@ -17,13 +17,13 @@ end end @average_energy Transition (q_out_in::Contingency, q_a::PointMass) = begin - # `map(d -> log(clamp(d, tiny, huge)), mean(q_a))` is an equivalent of `mean(log, q_a)` with an extra `clamp(el, tiny, huge)` operation + # `map(clamplog, mean(q_a))` is an equivalent of `mean(log, q_a)` with an extra `clamp(el, tiny, Inf)` operation # The reason is that we don't want to take log of zeros in the matrix `q_a` (if there are any) # The trick here is that if RHS matrix has zero inputs, than the corresponding entries of the `contingency_matrix` matrix # should also be zeros (see corresponding @marginalrule), so at the end `log(tiny) * 0` should not influence the result. - return -ReactiveMP.mul_trace(ReactiveMP.contingency_matrix(q_out_in)', map(d -> log(clamp(d, tiny, huge)), mean(q_a))) + return -ReactiveMP.mul_trace(ReactiveMP.contingency_matrix(q_out_in)', mean(clamplog, q_a)) end @average_energy Transition (q_out::Any, q_in::Any, q_a::PointMass) = begin - return -probvec(q_out)' * mean(log, q_a) * probvec(q_in) + return -probvec(q_out)' * mean(clamplog, q_a) * probvec(q_in) end diff --git a/src/rules/transition/marginals.jl b/src/rules/transition/marginals.jl index d5153cf51..9bfafc4dd 100644 --- a/src/rules/transition/marginals.jl +++ b/src/rules/transition/marginals.jl @@ -1,17 +1,21 @@ @marginalrule Transition(:out_in) (m_out::Categorical, m_in::Categorical, q_a::MatrixDirichlet) = begin - B = Diagonal(probvec(m_out)) * exp.(mean(log, q_a)) * Diagonal(probvec(m_in)) - return Contingency(B ./ sum(B)) + D = map(e -> clamp(exp(e), tiny, huge), mean(log, q_a)) + B = Diagonal(probvec(m_out)) * D * Diagonal(probvec(m_in)) + P = map!(Base.Fix2(/, sum(B)), B, B) # inplace version of B ./ sum(B) + return Contingency(P, Val(false)) # Matrix `P` has been normalized by hand end @marginalrule Transition(:out_in) (m_out::Categorical, m_in::Categorical, q_a::PointMass) = begin B = Diagonal(probvec(m_out)) * mean(q_a) * Diagonal(probvec(m_in)) - return Contingency(B ./ sum(B)) + P = map!(Base.Fix2(/, sum(B)), B, B) # inplace version of B ./ sum(B) + return Contingency(P, Val(false)) # Matrix `P` has been normalized by hand end @marginalrule Transition(:out_in_a) (m_out::Categorical, m_in::Categorical, m_a::PointMass) = begin B = Diagonal(probvec(m_out)) * mean(m_a) * Diagonal(probvec(m_in)) - return (out_in = Contingency(B ./ sum(B)), a = m_a) + P = map!(Base.Fix2(/, sum(B)), B, B) # inplace version of B ./ sum(B) + return (out_in = Contingency(P, Val(false)), a = m_a) # Matrix `P` has been normalized by hand end @marginalrule Transition(:out_in_a) (m_out::PointMass, m_in::Categorical, m_a::PointMass, meta::Any) = begin diff --git a/src/variable.jl b/src/variable.jl index fbe4f21a0..c6b719913 100644 --- a/src/variable.jl +++ b/src/variable.jl @@ -31,6 +31,10 @@ linear_index(::VariableIndividual) = nothing linear_index(v::VariableVector) = v.index linear_index(v::VariableArray) = LinearIndices(v.size)[v.index] +string_index(::VariableIndividual) = "" +string_index(v::VariableVector) = string("[", v.index, "]") +string_index(v::VariableArray) = string("[", join(v.index.I, ", "), "]") + indexed_name(::VariableIndividual, name::Symbol) = string(name) indexed_name(seq::VariableVector, name::Symbol) = string(name, "_", seq.index) indexed_name(array::VariableArray, name::Symbol) = string(name, "_", join(array.index.I, "_")) diff --git a/src/variables/data.jl b/src/variables/data.jl index 39ea2b2bf..86303981c 100644 --- a/src/variables/data.jl +++ b/src/variables/data.jl @@ -121,7 +121,26 @@ update!(datavar::DataVariable, data::Number) = update!(eltype(datavar), d update!(datavar::DataVariable, data::AbstractArray) = update!(eltype(datavar), datavar, data) update!(::Type{D}, datavar, data::D) where {D} = next!(messageout(datavar, 1), Message(data, false, false)) -update!(::Type{D1}, datavar, data::D2) where {D1, D2} = error("'$(name(datavar)) = datavar($D1, ...)' accepts data of type $D1, but $D2 has been supplied. Check 'update!($(name(datavar)), data::$D2)' and explicitly convert data to type $D1.") +update!(::Type{D1}, datavar, data::D2) where {D1, D2} = __update_wrong_type_error(D1, D2, collection_type(datavar), datavar) + +__datavar_drop_pointmass(::Type{D}) where {D} = D +__datavar_drop_pointmass(::Type{PointMass{D}}) where {D} = D + +__update_wrong_type_error(::Type{D1}, ::Type{D2}, ctype::VariableIndividual, datavar) where {D1, D2} = error( + """ + `$(name(datavar)) = datavar($(__datavar_drop_pointmass(D1)))` accepts data only of type `$(__datavar_drop_pointmass(D1))`, but the value of type `$D2` has been used. + Double check `update!($(name(datavar)), data)` call and explicitly convert data to the type `$(__datavar_drop_pointmass(D1))`, e.g. `convert($(__datavar_drop_pointmass(D1)), data)`. + """ +) + +__update_wrong_type_error(::Type{D1}, ::Type{D2}, ctype::Union{VariableVector, VariableArray}, datavar) where {D1, D2} = + error( + """ + `$(name(datavar)) = datavar($(__datavar_drop_pointmass(D1)), ...)` accepts data only of type `$(__datavar_drop_pointmass(D1))`, but the value of type `$D2` has been used. + Double check `update!($(name(datavar))$(string_index(ctype)), d)` call and explicitly convert data to the type `$(__datavar_drop_pointmass(D1))`, e.g. `update!($(name(datavar))$(string_index(ctype)), convert($(__datavar_drop_pointmass(D1)), d))`. + If you use broadcasted version of the `update!` function, e.g. `update!($(name(datavar)), data)` you may broadcast `convert` function over the whole dataset as well, e.g. `update!($(name(datavar)), convert.($(__datavar_drop_pointmass(D1)), dataset))` + """ + ) update!(::Type{PointMass{D}}, datavar, data::D) where {D} = next!(messageout(datavar, 1), Message(PointMass(data), false, false)) @@ -129,7 +148,9 @@ update!(::Type{PointMass{D}}, datavar, data::D) where {D} = resend!(datavar::DataVariable) = update!(datavar, Rocket.getrecent(messageout(datavar, 1))) function update!(datavars::AbstractArray{<:DataVariable}, data::AbstractArray) - @assert size(datavars) === size(data) "Invalid update! call: size of datavar array and data should match" + @assert size(datavars) === size(data) """ + Invalid `update!` call: size of datavar array and data must match: `$(name(first(datavars)))` has size $(size(datavars)) and data has size $(size(data)). + """ foreach(zip(datavars, data)) do (var, d) update!(var, d) end diff --git a/test/constraints/spec/test_factorisation_spec.jl b/test/constraints/spec/test_factorisation_spec.jl index 196efb82d..e94421bab 100644 --- a/test/constraints/spec/test_factorisation_spec.jl +++ b/test/constraints/spec/test_factorisation_spec.jl @@ -8,7 +8,7 @@ import ReactiveMP: FunctionalIndex import ReactiveMP: CombinedRange, SplittedRange, is_splitted import ReactiveMP: __as_unit_range, __factorisation_specification_resolve_index import ReactiveMP: resolve_factorisation -import ReactiveMP: DefaultConstraints +import ReactiveMP: DefaultConstraints, UnspecifiedConstraints import ReactiveMP: setanonymous!, activate! using GraphPPL # for `@constraints` macro @@ -579,7 +579,8 @@ using GraphPPL # for `@constraints` macro # empty end - for cs in (empty, DefaultConstraints) + # DefaultConstraints are equal to `UnspecifiedConstraints()` for now, but it might change in the future so we test both + for cs in (empty, UnspecifiedConstraints(), DefaultConstraints) let model = FactorGraphModel() d = datavar(model, :d, Float64) c = constvar(model, :c, 1.0) @@ -587,6 +588,8 @@ using GraphPPL # for `@constraints` macro y = randomvar(model, :y) z = randomvar(model, :z) + @test ReactiveMP.resolve_factorisation(cs, model, fform, (x, y)) === ((1, 2),) + @test ReactiveMP.resolve_factorisation(cs, model, fform, (y, x)) === ((1, 2),) @test ReactiveMP.resolve_factorisation(cs, model, fform, (d, d)) === ((1,), (2,)) @test ReactiveMP.resolve_factorisation(cs, model, fform, (c, c)) === ((1,), (2,)) @test ReactiveMP.resolve_factorisation(cs, model, fform, (d, x)) === ((1,), (2,)) diff --git a/test/distributions/test_contingency.jl b/test/distributions/test_contingency.jl index 308b4efdb..799b3cdbf 100644 --- a/test/distributions/test_contingency.jl +++ b/test/distributions/test_contingency.jl @@ -16,8 +16,12 @@ using Random end @testset "contingency_matrix" begin - @test ReactiveMP.contingency_matrix(Contingency(ones(3, 3))) == ones(3, 3) - @test ReactiveMP.contingency_matrix(Contingency(ones(4, 4))) == ones(4, 4) + @test ReactiveMP.contingency_matrix(Contingency(ones(3, 3))) == ones(3, 3) ./ 9 + @test ReactiveMP.contingency_matrix(Contingency(ones(3, 3), Val(true))) == ones(3, 3) ./ 9 + @test ReactiveMP.contingency_matrix(Contingency(ones(3, 3), Val(false))) == ones(3, 3) # Matrix is wrong, but just to test that `false` is working + @test ReactiveMP.contingency_matrix(Contingency(ones(4, 4))) == ones(4, 4) ./ 16 + @test ReactiveMP.contingency_matrix(Contingency(ones(4, 4), Val(true))) == ones(4, 4) ./ 16 + @test ReactiveMP.contingency_matrix(Contingency(ones(4, 4), Val(false))) == ones(4, 4) end @testset "vague" begin @@ -35,9 +39,14 @@ using Random end @testset "entropy" begin - @test entropy(Contingency([0.1 0.9; 0.9 0.1])) ≈ 0.6501659467828964 - @test entropy(Contingency([0.2 0.8; 0.8 0.2])) ≈ 1.0008048470763757 - @test entropy(Contingency([0.45 0.75; 0.55 0.25])) ≈ 1.2504739583323967 + @test entropy(Contingency([0.7 0.1; 0.1 0.1])) ≈ 0.9404479886553263 + @test entropy(Contingency(10.0 * [0.7 0.1; 0.1 0.1])) ≈ 0.9404479886553263 + @test entropy(Contingency([0.07 0.41; 0.31 0.21])) ≈ 1.242506182893139 + @test entropy(Contingency(10.0 * [0.07 0.41; 0.31 0.21])) ≈ 1.242506182893139 + @test entropy(Contingency([0.09 0.00; 0.00 0.91])) ≈ 0.30253782309749805 + @test entropy(Contingency(10.0 * [0.09 0.00; 0.00 0.91])) ≈ 0.30253782309749805 + @test !isnan(entropy(Contingency([0.0 1.0; 1.0 0.0]))) + @test !isinf(entropy(Contingency([0.0 1.0; 1.0 0.0]))) end end diff --git a/test/distributions/test_normal.jl b/test/distributions/test_normal.jl index 66e70cc82..ddbfb3107 100644 --- a/test/distributions/test_normal.jl +++ b/test/distributions/test_normal.jl @@ -129,14 +129,9 @@ using Distributions @test promote_variate_type(Multivariate, MvNormalWeightedMeanPrecision) === MvNormalWeightedMeanPrecision end - @testset "Sampling" begin + @testset "Sampling univariate" begin rng = MersenneTwister(1234) - univariate_types = [ - ReactiveMP.union_types(UnivariateNormalDistributionsFamily{Float64})..., - ReactiveMP.union_types(UnivariateNormalDistributionsFamily{Float32})... - ] - for T in (Float32, Float64) let # NormalMeanVariance μ, v = 10randn(rng), 10rand(rng) @@ -175,6 +170,57 @@ using Distributions end end end + + @testset "Sampling multivariate" begin + rng = MersenneTwister(1234) + + for n in (2, 3), T in (Float64,), nsamples in (10_000,) + let # MvNormalMeanCovariance + μ = randn(rng, n) + L = randn(rng, n, n) + Σ = L * L' + + d = convert(MvNormalMeanCovariance{T}, μ, Σ) + + @test typeof(rand(d)) <: Vector{T} + + samples = SampleList(Val((n,)), rand(rng, d, nsamples), fill(1 / nsamples, nsamples)) + + @test isapprox(mean(samples), mean(d), atol = n * 0.5) + @test isapprox(cov(samples), cov(d), atol = n * 0.5) + end + + let # MvNormalMeanPrecision + μ = randn(rng, n) + L = randn(rng, n, n) + W = L * L' + + d = convert(MvNormalMeanPrecision{T}, μ, W) + + @test typeof(rand(d)) <: Vector{T} + + samples = SampleList(Val((n,)), rand(rng, d, nsamples), fill(T(1 / nsamples), nsamples)) + + @test isapprox(mean(samples), mean(d), atol = n * 0.5) + @test isapprox(cov(samples), cov(d), atol = n * 0.5) + end + + let # MvNormalWeightedMeanPrecision + ξ = randn(rng, n) + L = randn(rng, n, n) + W = L * L' + + d = convert(MvNormalWeightedMeanPrecision{T}, ξ, W) + + @test typeof(rand(d)) <: Vector{T} + + samples = SampleList(Val((n,)), rand(rng, d, nsamples), fill(T(1 / nsamples), nsamples)) + + @test isapprox(mean(samples), mean(d), atol = n * 0.5) + @test isapprox(cov(samples), cov(d), atol = n * 0.5) + end + end + end end end diff --git a/test/distributions/test_wishart_inverse.jl b/test/distributions/test_wishart_inverse.jl index 3df160aa6..d0dd23278 100644 --- a/test/distributions/test_wishart_inverse.jl +++ b/test/distributions/test_wishart_inverse.jl @@ -1,4 +1,4 @@ -module MatrixDirichletTest +module InverseWishartTest using Test using ReactiveMP diff --git a/test/test_helpers.jl b/test/test_helpers.jl index f8e14bcdd..136cb4f2c 100644 --- a/test/test_helpers.jl +++ b/test/test_helpers.jl @@ -4,7 +4,7 @@ using Test using ReactiveMP import ReactiveMP: SkipIndexIterator, skipindex -import ReactiveMP: deep_eltype +import ReactiveMP: clamplog, deep_eltype import ReactiveMP: InfCountingReal, ∞ import ReactiveMP: FunctionalIndex @@ -16,6 +16,11 @@ import ReactiveMP: FunctionalIndex @test collect(skipindex(s, 1)) == [3] end + @testset "clamplog" begin + @test !isnan(clamplog(0.0)) && !isinf(clamplog(0.0)) + @test clamplog(tiny + 1.0) === log(tiny + 1.0) + end + @testset "deep_eltype" begin for type in [Float32, Float64, Complex{Float64}, BigFloat] @test deep_eltype(type) === type