diff --git a/HISTORY.md b/HISTORY.md index 57ccaecd1..c9822930e 100644 --- a/HISTORY.md +++ b/HISTORY.md @@ -1,5 +1,10 @@ # DynamicPPL Changelog +## 0.38.1 + +Added `from_linked_vec_transform` and `from_vec_transform` methods for `ProductNamedTupleDistribution`. +This patch allows sampling from `ProductNamedTupleDistribution` in DynamicPPL models. + ## 0.38.0 ### Breaking changes diff --git a/Project.toml b/Project.toml index 2fe65fd7b..0d427f40d 100644 --- a/Project.toml +++ b/Project.toml @@ -51,7 +51,7 @@ AbstractMCMC = "5" AbstractPPL = "0.13.1" Accessors = "0.1" BangBang = "0.4.1" -Bijectors = "0.13.18, 0.14, 0.15" +Bijectors = "0.15.11" ChainRulesCore = "1" Chairmarks = "1.3.1" Compat = "4" diff --git a/src/contexts/init.jl b/src/contexts/init.jl index 534c6a7b0..396e1463f 100644 --- a/src/contexts/init.jl +++ b/src/contexts/init.jl @@ -61,7 +61,7 @@ struct InitFromUniform{T<:AbstractFloat} <: AbstractInitStrategy end function init(rng::Random.AbstractRNG, ::VarName, dist::Distribution, u::InitFromUniform) b = Bijectors.bijector(dist) - sz = Bijectors.output_size(b, size(dist)) + sz = Bijectors.output_size(b, dist) y = u.lower .+ ((u.upper - u.lower) .* rand(rng, sz...)) b_inv = Bijectors.inverse(b) x = b_inv(y) @@ -166,12 +166,11 @@ function tilde_assume!!( # is_transformed(vi) returns true if vi is nonempty and all variables in vi # are linked. insert_transformed_value = in_varinfo ? is_transformed(vi, vn) : is_transformed(vi) - f = if insert_transformed_value - link_transform(dist) + y, logjac = if insert_transformed_value + with_logabsdet_jacobian(link_transform(dist), x) else - identity + x, zero(LogProbType) end - y, logjac = with_logabsdet_jacobian(f, x) # Add the new value to the VarInfo. `push!!` errors if the value already # exists, hence the need for setindex!!. if in_varinfo diff --git a/src/utils.jl b/src/utils.jl index b09bfb9fa..b55a2f715 100644 --- a/src/utils.jl +++ b/src/utils.jl @@ -354,7 +354,6 @@ Return the transformation from the vector representation of a realization of siz original representation. """ from_vec_transform_for_size(sz::Tuple) = ReshapeTransform(sz) -# TODO(mhauru) Is the below used? If not, this function can be removed. from_vec_transform_for_size(::Tuple{<:Any}) = identity """ @@ -367,6 +366,60 @@ from_vec_transform(dist::Distribution) = from_vec_transform_for_size(size(dist)) from_vec_transform(::UnivariateDistribution) = UnwrapSingletonTransform() from_vec_transform(dist::LKJCholesky) = ToChol(dist.uplo) ∘ ReshapeTransform(size(dist)) +struct ProductNamedTupleUnvecTransform{names,T<:NamedTuple{names}} + dists::T + # The `i`-th input range corresponds to the segment of the input vector + # that belongs to the `i`-th distribution. + input_ranges::Vector{UnitRange} + function ProductNamedTupleUnvecTransform( + d::Distributions.ProductNamedTupleDistribution{names} + ) where {names} + offset = 1 + input_ranges = UnitRange[] + for name in names + this_dist = d.dists[name] + this_name_size = _input_length(from_vec_transform(this_dist)) + push!(input_ranges, offset:(offset + this_name_size - 1)) + offset += this_name_size + end + return new{names,typeof(d.dists)}(d.dists, input_ranges) + end +end + +@generated function (trf::ProductNamedTupleUnvecTransform{names})( + x::AbstractVector +) where {names} + expr = Expr(:tuple) + for (i, name) in enumerate(names) + push!( + expr.args, + :($name = from_vec_transform(trf.dists.$name)(x[trf.input_ranges[$i]])), + ) + end + return expr +end + +function from_vec_transform(dist::Distributions.ProductNamedTupleDistribution) + return ProductNamedTupleUnvecTransform(dist) +end +function Bijectors.with_logabsdet_jacobian(f::ProductNamedTupleUnvecTransform, x) + return f(x), zero(LogProbType) +end + +# This function returns the length of the vector that the function from_vec_transform +# expects. This helps us determine which segment of a concatenated vector belongs to which +# variable. +_input_length(from_vec_trfm::UnwrapSingletonTransform) = 1 +_input_length(from_vec_trfm::ReshapeTransform) = prod(from_vec_trfm.output_size) +function _input_length(trfm::ProductNamedTupleUnvecTransform) + return sum(_input_length ∘ from_vec_transform, values(trfm.dists)) +end +function _input_length( + c::ComposedFunction{<:DynamicPPL.ToChol,<:DynamicPPL.ReshapeTransform} +) + return _input_length(c.inner) +end + """ from_vec_transform(f, size::Tuple) @@ -405,7 +458,9 @@ function from_linked_vec_transform(dist::UnivariateDistribution) sz = Bijectors.output_size(f_combined, size(dist)) return UnwrapSingletonTransform(sz) ∘ f_combined end - +function from_linked_vec_transform(dist::Distributions.ProductNamedTupleDistribution) + return invlink_transform(dist) +end # Specializations that circumvent the `from_vec_transform` machinery. function from_linked_vec_transform(dist::LKJCholesky) return inverse(Bijectors.VecCholeskyBijector(dist.uplo)) diff --git a/test/utils.jl b/test/utils.jl index 081e58d61..bef1c2ba8 100644 --- a/test/utils.jl +++ b/test/utils.jl @@ -1,3 +1,21 @@ +module DynamicPPLUtilsTests + +using Bijectors: Bijectors +using Distributions +using DynamicPPL +using LinearAlgebra: LinearAlgebra +using Test + +isapprox_nested(a::Number, b::Number; kwargs...) = isapprox(a, b; kwargs...) +isapprox_nested(a::AbstractArray, b::AbstractArray; kwargs...) = isapprox(a, b; kwargs...) +function isapprox_nested(a::LinearAlgebra.Cholesky, b::LinearAlgebra.Cholesky; kwargs...) + return isapprox(a.U, b.U; kwargs...) && isapprox(a.L, b.L; kwargs...) +end +function isapprox_nested(a::NamedTuple, b::NamedTuple; kwargs...) + keys(a) == keys(b) || return false + return all(k -> isapprox_nested(a[k], b[k]; kwargs...), keys(a)) +end + @testset "utils.jl" begin @testset "addlogprob!" begin @model function testmodel() @@ -31,35 +49,130 @@ end end + @testset "transformations" begin + function test_transformation( + dist::Distribution; test_bijector_type_stability::Bool=true + ) + unlinked = rand(dist) + unlinked_vec = DynamicPPL.tovec(unlinked) + @test unlinked_vec isa AbstractVector + + from_vec_trfm = DynamicPPL.from_vec_transform(dist) + unlinked_again, logjac = Bijectors.with_logabsdet_jacobian( + from_vec_trfm, unlinked_vec + ) + @test isapprox_nested(unlinked, unlinked_again) + @test iszero(logjac) + # Type stability + @inferred DynamicPPL.from_vec_transform(dist) + @inferred Bijectors.with_logabsdet_jacobian(from_vec_trfm, unlinked_vec) + + # Typically the same as `bijector(dist)`, but technically a different + # function + b = DynamicPPL.link_transform(dist) + @test (b(unlinked); true) + linked, logjac = Bijectors.with_logabsdet_jacobian(b, unlinked) + @test logjac isa Real + + binv = DynamicPPL.invlink_transform(dist) + unlinked_again, logjac_inv = Bijectors.with_logabsdet_jacobian(binv, linked) + @test isapprox_nested(unlinked, unlinked_again) + @test isapprox(logjac, -logjac_inv) + + linked_vec = DynamicPPL.tovec(linked) + @test linked_vec isa AbstractVector + from_linked_vec_trfm = DynamicPPL.from_linked_vec_transform(dist) + unlinked_again_again = from_linked_vec_trfm(linked_vec) + @test isapprox_nested(unlinked, unlinked_again_again) + + # Sometimes the bijector itself is not type stable. In this case there is not + # much we can do in DynamicPPL except skip these tests (it has to be fixed + # upstream in Bijectors.) + if test_bijector_type_stability + @inferred DynamicPPL.from_linked_vec_transform(dist) + @inferred Bijectors.with_logabsdet_jacobian( + from_linked_vec_trfm, linked_vec + ) + end + + # Create a model and check that we can evaluate it with both unlinked and linked + # VarInfo. This relies on the transformations working correctly so is more of an + # 'end to end' test + @model test() = x ~ dist + model = test() + vi_unlinked = VarInfo(model) + vi_linked = DynamicPPL.link!!(VarInfo(model), model) + @test (DynamicPPL.evaluate!!(model, vi_unlinked); true) + @test (DynamicPPL.evaluate!!(model, vi_linked); true) + model_init = DynamicPPL.setleafcontext(model, DynamicPPL.InitContext()) + @test (DynamicPPL.evaluate!!(model_init, vi_unlinked); true) + @test (DynamicPPL.evaluate!!(model_init, vi_linked); true) + end + + # Unconstrained univariate + test_transformation(Normal()) + # Constrained univariate + test_transformation(LogNormal()) + test_transformation(truncated(Normal(); lower=0)) + test_transformation(Exponential(1.0)) + test_transformation(Uniform(-2, 2)) + test_transformation(Beta(2, 2)) + test_transformation(InverseGamma(2, 3)) + # Discrete univariate + test_transformation(Poisson(3)) + test_transformation(Binomial(10, 0.5)) + # Multivariate + test_transformation(MvNormal(zeros(3), LinearAlgebra.I)) + test_transformation( + product_distribution([Normal(), LogNormal()]); + test_bijector_type_stability=false, + ) + test_transformation(product_distribution([LogNormal(), LogNormal()])) + # Matrixvariate + test_transformation(LKJ(3, 0.5)) + test_transformation(Wishart(7, [1.0 0.0; 0.0 1.0])) + # This is a pathological example: the linked representation is a matrix + test_transformation(product_distribution(fill(Dirichlet(ones(4)), 2, 3))) + # Cholesky + test_transformation(LKJCholesky(3, 0.5)) + # ProductNamedTupleDistribution + d = product_distribution((a=Normal(), b=LogNormal())) + test_transformation(d) + d_nested = product_distribution((x=LKJCholesky(2, 0.5), y=d)) + test_transformation(d_nested) + end + @testset "getargs_dottilde" begin # Some things that are not expressions. - @test getargs_dottilde(:x) === nothing - @test getargs_dottilde(1.0) === nothing - @test getargs_dottilde([1.0, 2.0, 4.0]) === nothing + @test DynamicPPL.getargs_dottilde(:x) === nothing + @test DynamicPPL.getargs_dottilde(1.0) === nothing + @test DynamicPPL.getargs_dottilde([1.0, 2.0, 4.0]) === nothing # Some expressions. - @test getargs_dottilde(:(x ~ Normal(μ, σ))) === nothing - @test getargs_dottilde(:((.~)(x, Normal(μ, σ)))) == (:x, :(Normal(μ, σ))) - @test getargs_dottilde(:((~).(x, Normal(μ, σ)))) == (:x, :(Normal(μ, σ))) - @test getargs_dottilde(:(x .~ Normal(μ, σ))) == (:x, :(Normal(μ, σ))) - @test getargs_dottilde(:(@. x ~ Normal(μ, σ))) === nothing - @test getargs_dottilde(:(@. x ~ Normal(μ, $(Expr(:$, :(sqrt(v))))))) === nothing - @test getargs_dottilde(:(@~ Normal.(μ, σ))) === nothing + @test DynamicPPL.getargs_dottilde(:(x ~ Normal(μ, σ))) === nothing + @test DynamicPPL.getargs_dottilde(:((.~)(x, Normal(μ, σ)))) == (:x, :(Normal(μ, σ))) + @test DynamicPPL.getargs_dottilde(:((~).(x, Normal(μ, σ)))) == (:x, :(Normal(μ, σ))) + @test DynamicPPL.getargs_dottilde(:(x .~ Normal(μ, σ))) == (:x, :(Normal(μ, σ))) + @test DynamicPPL.getargs_dottilde(:(@. x ~ Normal(μ, σ))) === nothing + @test DynamicPPL.getargs_dottilde(:(@. x ~ Normal(μ, $(Expr(:$, :(sqrt(v))))))) === + nothing + @test DynamicPPL.getargs_dottilde(:(@~ Normal.(μ, σ))) === nothing end @testset "getargs_tilde" begin # Some things that are not expressions. - @test getargs_tilde(:x) === nothing - @test getargs_tilde(1.0) === nothing - @test getargs_tilde([1.0, 2.0, 4.0]) === nothing + @test DynamicPPL.getargs_tilde(:x) === nothing + @test DynamicPPL.getargs_tilde(1.0) === nothing + @test DynamicPPL.getargs_tilde([1.0, 2.0, 4.0]) === nothing # Some expressions. - @test getargs_tilde(:(x ~ Normal(μ, σ))) == (:x, :(Normal(μ, σ))) - @test getargs_tilde(:((.~)(x, Normal(μ, σ)))) === nothing - @test getargs_tilde(:((~).(x, Normal(μ, σ)))) === nothing - @test getargs_tilde(:(@. x ~ Normal(μ, σ))) === nothing - @test getargs_tilde(:(@. x ~ Normal(μ, $(Expr(:$, :(sqrt(v))))))) === nothing - @test getargs_tilde(:(@~ Normal.(μ, σ))) === nothing + @test DynamicPPL.getargs_tilde(:(x ~ Normal(μ, σ))) == (:x, :(Normal(μ, σ))) + @test DynamicPPL.getargs_tilde(:((.~)(x, Normal(μ, σ)))) === nothing + @test DynamicPPL.getargs_tilde(:((~).(x, Normal(μ, σ)))) === nothing + @test DynamicPPL.getargs_tilde(:(@. x ~ Normal(μ, σ))) === nothing + @test DynamicPPL.getargs_tilde(:(@. x ~ Normal(μ, $(Expr(:$, :(sqrt(v))))))) === + nothing + @test DynamicPPL.getargs_tilde(:(@~ Normal.(μ, σ))) === nothing end @testset "tovec" begin @@ -97,3 +210,5 @@ @test DynamicPPL.group_varnames_by_symbol(vns_tuple) == vns_nt end end + +end