diff --git a/HISTORY.md b/HISTORY.md index a5976f631..4e9bc2d42 100644 --- a/HISTORY.md +++ b/HISTORY.md @@ -1,5 +1,9 @@ # DynamicPPL Changelog +## 0.36.3 + +Moved the `bijector(model)`, where `model` is a `DynamicPPL.Model`, function from the Turing main repo. + ## 0.36.2 Improved docstrings for AD testing utilities. diff --git a/Project.toml b/Project.toml index 860ece8b3..5bef5bcb1 100644 --- a/Project.toml +++ b/Project.toml @@ -1,6 +1,6 @@ name = "DynamicPPL" uuid = "366bfd00-2699-11ea-058f-f148b4cae6d8" -version = "0.36.2" +version = "0.36.3" [deps] ADTypes = "47edcb42-4c32-4615-8424-f2b9edc5f35b" diff --git a/src/DynamicPPL.jl b/src/DynamicPPL.jl index c1c613d08..21f9044cd 100644 --- a/src/DynamicPPL.jl +++ b/src/DynamicPPL.jl @@ -179,6 +179,7 @@ include("logdensityfunction.jl") include("model_utils.jl") include("extract_priors.jl") include("values_as_in_model.jl") +include("bijector.jl") include("debug_utils.jl") using .DebugUtils diff --git a/src/bijector.jl b/src/bijector.jl new file mode 100644 index 000000000..31fe7cd88 --- /dev/null +++ b/src/bijector.jl @@ -0,0 +1,60 @@ + +""" + bijector(model::Model[, sym2ranges = Val(false)]) + +Returns a `Stacked <: Bijector` which maps from the support of the posterior to ℝᵈ with `d` +denoting the dimensionality of the latent variables. +""" +function Bijectors.bijector( + model::DynamicPPL.Model, + (::Val{sym2ranges})=Val(false); + varinfo=DynamicPPL.VarInfo(model), +) where {sym2ranges} + dists = vcat([varinfo.metadata[sym].dists for sym in keys(varinfo.metadata)]...) + + num_ranges = sum([ + length(varinfo.metadata[sym].ranges) for sym in keys(varinfo.metadata) + ]) + ranges = Vector{UnitRange{Int}}(undef, num_ranges) + idx = 0 + range_idx = 1 + + # ranges might be discontinuous => values are vectors of ranges rather than just ranges + sym_lookup = Dict{Symbol,Vector{UnitRange{Int}}}() + for sym in keys(varinfo.metadata) + sym_lookup[sym] = Vector{UnitRange{Int}}() + for r in varinfo.metadata[sym].ranges + ranges[range_idx] = idx .+ r + push!(sym_lookup[sym], ranges[range_idx]) + range_idx += 1 + end + + idx += varinfo.metadata[sym].ranges[end][end] + end + + bs = map(tuple(dists...)) do d + b = Bijectors.bijector(d) + if d isa Distributions.UnivariateDistribution + b + else + # Wrap a bijector `f` such that it operates on vectors of length `prod(in_size)` + # and produces a vector of length `prod(Bijectors.output(f, in_size))`. + in_size = size(d) + vec_in_length = prod(in_size) + reshape_inner = Bijectors.Reshape((vec_in_length,), in_size) + out_size = Bijectors.output_size(b, in_size) + vec_out_length = prod(out_size) + reshape_outer = Bijectors.Reshape(out_size, (vec_out_length,)) + reshape_outer ∘ b ∘ reshape_inner + end + end + + if sym2ranges + return ( + Bijectors.Stacked(bs, ranges), + (; collect(zip(keys(sym_lookup), values(sym_lookup)))...), + ) + else + return Bijectors.Stacked(bs, ranges) + end +end diff --git a/test/bijector.jl b/test/bijector.jl new file mode 100644 index 000000000..08ea0a94d --- /dev/null +++ b/test/bijector.jl @@ -0,0 +1,28 @@ + +@testset "bijector.jl" begin + @testset "bijector" begin + @model function test() + m ~ Normal() + s ~ InverseGamma(3, 3) + return c ~ Dirichlet([1.0, 1.0]) + end + + m = test() + b = bijector(m) + + # m ∈ ℝ, s ∈ ℝ+, c ∈ 2-simplex + # check dimensionalities and ranges + @test b.length_in == 4 + @test b.length_out == 3 + @test b.ranges_in == [1:1, 2:2, 3:4] + @test b.ranges_out == [1:1, 2:2, 3:3] + @test b.ranges_out == [1:1, 2:2, 3:3] + + # check support of mapped variables + binv = inverse(b) + zs = mapslices(binv, randn(b.length_out, 10000); dims=1) + + @test all(zs[2, :] .≥ 0) + @test all(sum(zs[3:4, :]; dims=1) .≈ 1.0) + end +end diff --git a/test/runtests.jl b/test/runtests.jl index 72f33f2d0..06e58738d 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -68,6 +68,7 @@ include("test_util.jl") include("debug_utils.jl") include("deprecated.jl") include("submodels.jl") + include("bijector.jl") end if GROUP == "All" || GROUP == "Group2"