Skip to content
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions src/DynamicPPL.jl
Original file line number Diff line number Diff line change
Expand Up @@ -179,6 +179,7 @@ include("logdensityfunction.jl")
include("model_utils.jl")
include("extract_priors.jl")
include("values_as_in_model.jl")
include("bijector.jl")

include("debug_utils.jl")
using .DebugUtils
Expand Down
60 changes: 60 additions & 0 deletions src/bijector.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,60 @@

"""
bijector(model::Model[, sym2ranges = Val(false)])

Returns a `Stacked <: Bijector` which maps from the support of the posterior to ℝᵈ with `d`
denoting the dimensionality of the latent variables.
"""
function Bijectors.bijector(
model::DynamicPPL.Model,
(::Val{sym2ranges})=Val(false);
varinfo=DynamicPPL.VarInfo(model),
) where {sym2ranges}
dists = vcat([varinfo.metadata[sym].dists for sym in keys(varinfo.metadata)]...)

num_ranges = sum([
length(varinfo.metadata[sym].ranges) for sym in keys(varinfo.metadata)
])
ranges = Vector{UnitRange{Int}}(undef, num_ranges)
idx = 0
range_idx = 1

# ranges might be discontinuous => values are vectors of ranges rather than just ranges
sym_lookup = Dict{Symbol,Vector{UnitRange{Int}}}()
for sym in keys(varinfo.metadata)
sym_lookup[sym] = Vector{UnitRange{Int}}()
for r in varinfo.metadata[sym].ranges
ranges[range_idx] = idx .+ r
push!(sym_lookup[sym], ranges[range_idx])
range_idx += 1
end

idx += varinfo.metadata[sym].ranges[end][end]
end

bs = map(tuple(dists...)) do d
b = Bijectors.bijector(d)
if d isa Distributions.UnivariateDistribution
b
else
# Wrap a bijector `f` such that it operates on vectors of length `prod(in_size)`
# and produces a vector of length `prod(Bijectors.output(f, in_size))`.
in_size = size(d)
vec_in_length = prod(in_size)
reshape_inner = Bijectors.Reshape((vec_in_length,), in_size)
out_size = Bijectors.output_size(b, in_size)
vec_out_length = prod(out_size)
reshape_outer = Bijectors.Reshape(out_size, (vec_out_length,))
reshape_outer ∘ b ∘ reshape_inner
end
end

if sym2ranges
return (

Check warning on line 53 in src/bijector.jl

View check run for this annotation

Codecov / codecov/patch

src/bijector.jl#L53

Added line #L53 was not covered by tests
Bijectors.Stacked(bs, ranges),
(; collect(zip(keys(sym_lookup), values(sym_lookup)))...),
)
else
return Bijectors.Stacked(bs, ranges)
end
end
28 changes: 28 additions & 0 deletions test/bijector.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@

@testset "bijector.jl" begin
@testset "bijector" begin
@model function test()
m ~ Normal()
s ~ InverseGamma(3, 3)
return c ~ Dirichlet([1.0, 1.0])
end

m = test()
b = bijector(m)

# m ∈ ℝ, s ∈ ℝ+, c ∈ 2-simplex
# check dimensionalities and ranges
@test b.length_in == 4
@test b.length_out == 3
@test b.ranges_in == [1:1, 2:2, 3:4]
@test b.ranges_out == [1:1, 2:2, 3:3]
@test b.ranges_out == [1:1, 2:2, 3:3]

# check support of mapped variables
binv = inverse(b)
zs = mapslices(binv, randn(b.length_out, 10000); dims=1)

@test all(zs[2, :] .≥ 0)
@test all(sum(zs[3:4, :]; dims=1) .≈ 1.0)
end
end
1 change: 1 addition & 0 deletions test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,7 @@ include("test_util.jl")
include("debug_utils.jl")
include("deprecated.jl")
include("submodels.jl")
include("bijector.jl")
end

if GROUP == "All" || GROUP == "Group2"
Expand Down
Loading