Skip to content

Commit 2935876

Browse files
committed
add missing file
1 parent 0f58523 commit 2935876

File tree

1 file changed

+60
-0
lines changed

1 file changed

+60
-0
lines changed

src/bijector.jl

Lines changed: 60 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,60 @@
1+
2+
"""
3+
bijector(model::Model[, sym2ranges = Val(false)])
4+
5+
Returns a `Stacked <: Bijector` which maps from the support of the posterior to ℝᵈ with `d`
6+
denoting the dimensionality of the latent variables.
7+
"""
8+
function Bijectors.bijector(
9+
model::DynamicPPL.Model,
10+
(::Val{sym2ranges})=Val(false);
11+
varinfo=DynamicPPL.VarInfo(model),
12+
) where {sym2ranges}
13+
dists = vcat([varinfo.metadata[sym].dists for sym in keys(varinfo.metadata)]...)
14+
15+
num_ranges = sum([
16+
length(varinfo.metadata[sym].ranges) for sym in keys(varinfo.metadata)
17+
])
18+
ranges = Vector{UnitRange{Int}}(undef, num_ranges)
19+
idx = 0
20+
range_idx = 1
21+
22+
# ranges might be discontinuous => values are vectors of ranges rather than just ranges
23+
sym_lookup = Dict{Symbol,Vector{UnitRange{Int}}}()
24+
for sym in keys(varinfo.metadata)
25+
sym_lookup[sym] = Vector{UnitRange{Int}}()
26+
for r in varinfo.metadata[sym].ranges
27+
ranges[range_idx] = idx .+ r
28+
push!(sym_lookup[sym], ranges[range_idx])
29+
range_idx += 1
30+
end
31+
32+
idx += varinfo.metadata[sym].ranges[end][end]
33+
end
34+
35+
bs = map(tuple(dists...)) do d
36+
b = Bijectors.bijector(d)
37+
if d isa Distributions.UnivariateDistribution
38+
b
39+
else
40+
# Wrap a bijector `f` such that it operates on vectors of length `prod(in_size)`
41+
# and produces a vector of length `prod(Bijectors.output(f, in_size))`.
42+
in_size = size(d)
43+
vec_in_length = prod(in_size)
44+
reshape_inner = Bijectors.Reshape((vec_in_length,), in_size)
45+
out_size = Bijectors.output_size(b, in_size)
46+
vec_out_length = prod(out_size)
47+
reshape_outer = Bijectors.Reshape(out_size, (vec_out_length,))
48+
reshape_outer b reshape_inner
49+
end
50+
end
51+
52+
if sym2ranges
53+
return (
54+
Bijectors.Stacked(bs, ranges),
55+
(; collect(zip(keys(sym_lookup), values(sym_lookup)))...),
56+
)
57+
else
58+
return Bijectors.Stacked(bs, ranges)
59+
end
60+
end

0 commit comments

Comments
 (0)