Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions HISTORY.md
Original file line number Diff line number Diff line change
@@ -1,5 +1,10 @@
# DynamicPPL Changelog

## 0.38.1

Added `from_linked_vec_transform` and `from_vec_transform` methods for `ProductNamedTupleDistribution`.
This patch allows sampling from `ProductNamedTupleDistribution` in DynamicPPL models.

## 0.38.0

### Breaking changes
Expand Down
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@ AbstractMCMC = "5"
AbstractPPL = "0.13.1"
Accessors = "0.1"
BangBang = "0.4.1"
Bijectors = "0.13.18, 0.14, 0.15"
Bijectors = "0.15.11"
ChainRulesCore = "1"
Chairmarks = "1.3.1"
Compat = "4"
Expand Down
9 changes: 4 additions & 5 deletions src/contexts/init.jl
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,7 @@ struct InitFromUniform{T<:AbstractFloat} <: AbstractInitStrategy
end
function init(rng::Random.AbstractRNG, ::VarName, dist::Distribution, u::InitFromUniform)
b = Bijectors.bijector(dist)
sz = Bijectors.output_size(b, size(dist))
sz = Bijectors.output_size(b, dist)
y = u.lower .+ ((u.upper - u.lower) .* rand(rng, sz...))
b_inv = Bijectors.inverse(b)
x = b_inv(y)
Expand Down Expand Up @@ -166,12 +166,11 @@ function tilde_assume!!(
# is_transformed(vi) returns true if vi is nonempty and all variables in vi
# are linked.
insert_transformed_value = in_varinfo ? is_transformed(vi, vn) : is_transformed(vi)
f = if insert_transformed_value
link_transform(dist)
y, logjac = if insert_transformed_value
with_logabsdet_jacobian(link_transform(dist), x)
else
identity
x, zero(LogProbType)
end
y, logjac = with_logabsdet_jacobian(f, x)
# Add the new value to the VarInfo. `push!!` errors if the value already
# exists, hence the need for setindex!!.
if in_varinfo
Expand Down
59 changes: 57 additions & 2 deletions src/utils.jl
Original file line number Diff line number Diff line change
Expand Up @@ -354,7 +354,6 @@ Return the transformation from the vector representation of a realization of siz
original representation.
"""
from_vec_transform_for_size(sz::Tuple) = ReshapeTransform(sz)
# TODO(mhauru) Is the below used? If not, this function can be removed.
from_vec_transform_for_size(::Tuple{<:Any}) = identity

"""
Expand All @@ -367,6 +366,60 @@ from_vec_transform(dist::Distribution) = from_vec_transform_for_size(size(dist))
from_vec_transform(::UnivariateDistribution) = UnwrapSingletonTransform()
from_vec_transform(dist::LKJCholesky) = ToChol(dist.uplo) ∘ ReshapeTransform(size(dist))

struct ProductNamedTupleUnvecTransform{names,T<:NamedTuple{names}}
dists::T
# The `i`-th input range corresponds to the segment of the input vector
# that belongs to the `i`-th distribution.
input_ranges::Vector{UnitRange}
Comment on lines +369 to +373
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is a subpar data structure because it assumes that dists and input_ranges have the same length (this is enforced in the inner constructor). I think in an ideal world we would combine dists and input_ranges into a single NamedTuple ... but the issue with that is that to make it type stable I think we'd have to make the constructor also a generated function.

function ProductNamedTupleUnvecTransform(
d::Distributions.ProductNamedTupleDistribution{names}
) where {names}
offset = 1
input_ranges = UnitRange[]
for name in names
this_dist = d.dists[name]
this_name_size = _input_length(from_vec_transform(this_dist))
push!(input_ranges, offset:(offset + this_name_size - 1))
offset += this_name_size
end
return new{names,typeof(d.dists)}(d.dists, input_ranges)
end
end

@generated function (trf::ProductNamedTupleUnvecTransform{names})(
x::AbstractVector
) where {names}
expr = Expr(:tuple)
for (i, name) in enumerate(names)
push!(
expr.args,
:($name = from_vec_transform(trf.dists.$name)(x[trf.input_ranges[$i]])),
)
end
return expr
end

function from_vec_transform(dist::Distributions.ProductNamedTupleDistribution)
return ProductNamedTupleUnvecTransform(dist)
end
function Bijectors.with_logabsdet_jacobian(f::ProductNamedTupleUnvecTransform, x)
return f(x), zero(LogProbType)
end

# This function returns the length of the vector that the function from_vec_transform
# expects. This helps us determine which segment of a concatenated vector belongs to which
# variable.
_input_length(from_vec_trfm::UnwrapSingletonTransform) = 1
_input_length(from_vec_trfm::ReshapeTransform) = prod(from_vec_trfm.output_size)
function _input_length(trfm::ProductNamedTupleUnvecTransform)
return sum(_input_length ∘ from_vec_transform, values(trfm.dists))
end
function _input_length(
c::ComposedFunction{<:DynamicPPL.ToChol,<:DynamicPPL.ReshapeTransform}
)
return _input_length(c.inner)
end

"""
from_vec_transform(f, size::Tuple)

Expand Down Expand Up @@ -405,7 +458,9 @@ function from_linked_vec_transform(dist::UnivariateDistribution)
sz = Bijectors.output_size(f_combined, size(dist))
return UnwrapSingletonTransform(sz) ∘ f_combined
end

function from_linked_vec_transform(dist::Distributions.ProductNamedTupleDistribution)
return invlink_transform(dist)
end
# Specializations that circumvent the `from_vec_transform` machinery.
function from_linked_vec_transform(dist::LKJCholesky)
return inverse(Bijectors.VecCholeskyBijector(dist.uplo))
Expand Down
153 changes: 134 additions & 19 deletions test/utils.jl
Original file line number Diff line number Diff line change
@@ -1,3 +1,21 @@
module DynamicPPLUtilsTests

using Bijectors: Bijectors
using Distributions
using DynamicPPL
using LinearAlgebra: LinearAlgebra
using Test

isapprox_nested(a::Number, b::Number; kwargs...) = isapprox(a, b; kwargs...)
isapprox_nested(a::AbstractArray, b::AbstractArray; kwargs...) = isapprox(a, b; kwargs...)
function isapprox_nested(a::LinearAlgebra.Cholesky, b::LinearAlgebra.Cholesky; kwargs...)
return isapprox(a.U, b.U; kwargs...) && isapprox(a.L, b.L; kwargs...)
end
function isapprox_nested(a::NamedTuple, b::NamedTuple; kwargs...)
keys(a) == keys(b) || return false
return all(k -> isapprox_nested(a[k], b[k]; kwargs...), keys(a))
end

@testset "utils.jl" begin
@testset "addlogprob!" begin
@model function testmodel()
Expand Down Expand Up @@ -31,35 +49,130 @@
end
end

@testset "transformations" begin
function test_transformation(
dist::Distribution; test_bijector_type_stability::Bool=true
)
unlinked = rand(dist)
unlinked_vec = DynamicPPL.tovec(unlinked)
@test unlinked_vec isa AbstractVector

from_vec_trfm = DynamicPPL.from_vec_transform(dist)
unlinked_again, logjac = Bijectors.with_logabsdet_jacobian(
from_vec_trfm, unlinked_vec
)
@test isapprox_nested(unlinked, unlinked_again)
@test iszero(logjac)
# Type stability
@inferred DynamicPPL.from_vec_transform(dist)
@inferred Bijectors.with_logabsdet_jacobian(from_vec_trfm, unlinked_vec)

# Typically the same as `bijector(dist)`, but technically a different
# function
b = DynamicPPL.link_transform(dist)
@test (b(unlinked); true)
linked, logjac = Bijectors.with_logabsdet_jacobian(b, unlinked)
@test logjac isa Real

binv = DynamicPPL.invlink_transform(dist)
unlinked_again, logjac_inv = Bijectors.with_logabsdet_jacobian(binv, linked)
@test isapprox_nested(unlinked, unlinked_again)
@test isapprox(logjac, -logjac_inv)

linked_vec = DynamicPPL.tovec(linked)
@test linked_vec isa AbstractVector
from_linked_vec_trfm = DynamicPPL.from_linked_vec_transform(dist)
unlinked_again_again = from_linked_vec_trfm(linked_vec)
@test isapprox_nested(unlinked, unlinked_again_again)

# Sometimes the bijector itself is not type stable. In this case there is not
# much we can do in DynamicPPL except skip these tests (it has to be fixed
# upstream in Bijectors.)
if test_bijector_type_stability
@inferred DynamicPPL.from_linked_vec_transform(dist)
@inferred Bijectors.with_logabsdet_jacobian(
from_linked_vec_trfm, linked_vec
)
end

# Create a model and check that we can evaluate it with both unlinked and linked
# VarInfo. This relies on the transformations working correctly so is more of an
# 'end to end' test
@model test() = x ~ dist
model = test()
vi_unlinked = VarInfo(model)
vi_linked = DynamicPPL.link!!(VarInfo(model), model)
@test (DynamicPPL.evaluate!!(model, vi_unlinked); true)
@test (DynamicPPL.evaluate!!(model, vi_linked); true)
model_init = DynamicPPL.setleafcontext(model, DynamicPPL.InitContext())
@test (DynamicPPL.evaluate!!(model_init, vi_unlinked); true)
@test (DynamicPPL.evaluate!!(model_init, vi_linked); true)
end

# Unconstrained univariate
test_transformation(Normal())
# Constrained univariate
test_transformation(LogNormal())
test_transformation(truncated(Normal(); lower=0))
test_transformation(Exponential(1.0))
test_transformation(Uniform(-2, 2))
test_transformation(Beta(2, 2))
test_transformation(InverseGamma(2, 3))
# Discrete univariate
test_transformation(Poisson(3))
test_transformation(Binomial(10, 0.5))
# Multivariate
test_transformation(MvNormal(zeros(3), LinearAlgebra.I))
test_transformation(
product_distribution([Normal(), LogNormal()]);
test_bijector_type_stability=false,
)
test_transformation(product_distribution([LogNormal(), LogNormal()]))
# Matrixvariate
test_transformation(LKJ(3, 0.5))
test_transformation(Wishart(7, [1.0 0.0; 0.0 1.0]))
# This is a pathological example: the linked representation is a matrix
test_transformation(product_distribution(fill(Dirichlet(ones(4)), 2, 3)))
# Cholesky
test_transformation(LKJCholesky(3, 0.5))
# ProductNamedTupleDistribution
d = product_distribution((a=Normal(), b=LogNormal()))
test_transformation(d)
d_nested = product_distribution((x=LKJCholesky(2, 0.5), y=d))
test_transformation(d_nested)
end

@testset "getargs_dottilde" begin
# Some things that are not expressions.
@test getargs_dottilde(:x) === nothing
@test getargs_dottilde(1.0) === nothing
@test getargs_dottilde([1.0, 2.0, 4.0]) === nothing
@test DynamicPPL.getargs_dottilde(:x) === nothing
@test DynamicPPL.getargs_dottilde(1.0) === nothing
@test DynamicPPL.getargs_dottilde([1.0, 2.0, 4.0]) === nothing

# Some expressions.
@test getargs_dottilde(:(x ~ Normal(μ, σ))) === nothing
@test getargs_dottilde(:((.~)(x, Normal(μ, σ)))) == (:x, :(Normal(μ, σ)))
@test getargs_dottilde(:((~).(x, Normal(μ, σ)))) == (:x, :(Normal(μ, σ)))
@test getargs_dottilde(:(x .~ Normal(μ, σ))) == (:x, :(Normal(μ, σ)))
@test getargs_dottilde(:(@. x ~ Normal(μ, σ))) === nothing
@test getargs_dottilde(:(@. x ~ Normal(μ, $(Expr(:$, :(sqrt(v))))))) === nothing
@test getargs_dottilde(:(@~ Normal.(μ, σ))) === nothing
@test DynamicPPL.getargs_dottilde(:(x ~ Normal(μ, σ))) === nothing
@test DynamicPPL.getargs_dottilde(:((.~)(x, Normal(μ, σ)))) == (:x, :(Normal(μ, σ)))
@test DynamicPPL.getargs_dottilde(:((~).(x, Normal(μ, σ)))) == (:x, :(Normal(μ, σ)))
@test DynamicPPL.getargs_dottilde(:(x .~ Normal(μ, σ))) == (:x, :(Normal(μ, σ)))
@test DynamicPPL.getargs_dottilde(:(@. x ~ Normal(μ, σ))) === nothing
@test DynamicPPL.getargs_dottilde(:(@. x ~ Normal(μ, $(Expr(:$, :(sqrt(v))))))) ===
nothing
@test DynamicPPL.getargs_dottilde(:(@~ Normal.(μ, σ))) === nothing
end

@testset "getargs_tilde" begin
# Some things that are not expressions.
@test getargs_tilde(:x) === nothing
@test getargs_tilde(1.0) === nothing
@test getargs_tilde([1.0, 2.0, 4.0]) === nothing
@test DynamicPPL.getargs_tilde(:x) === nothing
@test DynamicPPL.getargs_tilde(1.0) === nothing
@test DynamicPPL.getargs_tilde([1.0, 2.0, 4.0]) === nothing

# Some expressions.
@test getargs_tilde(:(x ~ Normal(μ, σ))) == (:x, :(Normal(μ, σ)))
@test getargs_tilde(:((.~)(x, Normal(μ, σ)))) === nothing
@test getargs_tilde(:((~).(x, Normal(μ, σ)))) === nothing
@test getargs_tilde(:(@. x ~ Normal(μ, σ))) === nothing
@test getargs_tilde(:(@. x ~ Normal(μ, $(Expr(:$, :(sqrt(v))))))) === nothing
@test getargs_tilde(:(@~ Normal.(μ, σ))) === nothing
@test DynamicPPL.getargs_tilde(:(x ~ Normal(μ, σ))) == (:x, :(Normal(μ, σ)))
@test DynamicPPL.getargs_tilde(:((.~)(x, Normal(μ, σ)))) === nothing
@test DynamicPPL.getargs_tilde(:((~).(x, Normal(μ, σ)))) === nothing
@test DynamicPPL.getargs_tilde(:(@. x ~ Normal(μ, σ))) === nothing
@test DynamicPPL.getargs_tilde(:(@. x ~ Normal(μ, $(Expr(:$, :(sqrt(v))))))) ===
nothing
@test DynamicPPL.getargs_tilde(:(@~ Normal.(μ, σ))) === nothing
end

@testset "tovec" begin
Expand Down Expand Up @@ -97,3 +210,5 @@
@test DynamicPPL.group_varnames_by_symbol(vns_tuple) == vns_nt
end
end

end
Loading