Skip to content

Commit 0388e99

Browse files
torfjeldegithub-actions[bot]yebaidevmotion
authored
Updates to init (#489)
* fixed init for transformable distributions * imrpove transformable a bit * fixed tests * added proper testing for init * added coverage for matrix distributions * Apply suggestions from code review Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com> * forgot to commit the rand implementation * Update src/test_utils.jl Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com> * fixed typo * actually fixed it * fixed new testutils model * formatting * bump Bijectors version so we can use output_size and output_length * Apply suggestions from code review Co-authored-by: David Widmann <[email protected]> * Update test/Project.toml Co-authored-by: David Widmann <[email protected]> --------- Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com> Co-authored-by: Hong Ge <[email protected]> Co-authored-by: David Widmann <[email protected]>
1 parent 4c08ddb commit 0388e99

File tree

5 files changed

+87
-15
lines changed

5 files changed

+87
-15
lines changed

Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,7 @@ ZygoteRules = "700de1a5-db45-46bc-99cf-38207098b444"
2424
AbstractMCMC = "2, 3.0, 4"
2525
AbstractPPL = "0.5.3"
2626
BangBang = "0.3"
27-
Bijectors = "0.12.4"
27+
Bijectors = "0.13"
2828
ChainRulesCore = "0.9.7, 0.10, 1"
2929
ConstructionBase = "1"
3030
Distributions = "0.23.8, 0.24, 0.25"

src/test_utils.jl

Lines changed: 49 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -546,6 +546,53 @@ function varnames(model::Model{typeof(demo_dot_assume_matrix_dot_observe_matrix)
546546
return [@varname(s[:, 1]), @varname(s[:, 2]), @varname(m)]
547547
end
548548

549+
@model function demo_assume_matrix_dot_observe_matrix(
550+
x=transpose([1.5 2.0;]), ::Type{TV}=Array{Float64}
551+
) where {TV}
552+
n = length(x)
553+
d = n ÷ 2
554+
s ~ reshape(product_distribution(fill(InverseGamma(2, 3), n)), d, 2)
555+
s_vec = vec(s)
556+
m ~ MvNormal(zeros(n), Diagonal(s_vec))
557+
558+
# Dotted observe for `Matrix`.
559+
x .~ MvNormal(m, Diagonal(s_vec))
560+
561+
return (; s=s, m=m, x=x, logp=getlogp(__varinfo__))
562+
end
563+
function logprior_true(model::Model{typeof(demo_assume_matrix_dot_observe_matrix)}, s, m)
564+
n = length(model.args.x)
565+
s_vec = vec(s)
566+
return loglikelihood(InverseGamma(2, 3), s_vec) +
567+
logpdf(MvNormal(zeros(n), Diagonal(s_vec)), m)
568+
end
569+
function loglikelihood_true(
570+
model::Model{typeof(demo_assume_matrix_dot_observe_matrix)}, s, m
571+
)
572+
return loglikelihood(MvNormal(m, Diagonal(vec(s))), model.args.x)
573+
end
574+
function logprior_true_with_logabsdet_jacobian(
575+
model::Model{typeof(demo_assume_matrix_dot_observe_matrix)}, s, m
576+
)
577+
return _demo_logprior_true_with_logabsdet_jacobian(model, s, m)
578+
end
579+
function varnames(model::Model{typeof(demo_assume_matrix_dot_observe_matrix)})
580+
return [@varname(s), @varname(m)]
581+
end
582+
583+
function Random.rand(
584+
rng::Random.AbstractRNG,
585+
::Type{NamedTuple},
586+
model::Model{typeof(demo_assume_matrix_dot_observe_matrix)},
587+
)
588+
n = length(model.args.x)
589+
s = reshape(rand(rng, InverseGamma(2, 3), n), n ÷ 2, 2)
590+
s_vec = vec(s)
591+
m = rand(rng, MvNormal(zeros(n), Diagonal(s_vec)))
592+
593+
return (s=s, m=m)
594+
end
595+
549596
const DemoModels = Union{
550597
Model{typeof(demo_dot_assume_dot_observe)},
551598
Model{typeof(demo_assume_index_observe)},
@@ -559,6 +606,7 @@ const DemoModels = Union{
559606
Model{typeof(demo_dot_assume_observe_submodel)},
560607
Model{typeof(demo_dot_assume_dot_observe_matrix)},
561608
Model{typeof(demo_dot_assume_matrix_dot_observe_matrix)},
609+
Model{typeof(demo_assume_matrix_dot_observe_matrix)},
562610
}
563611

564612
# We require demo models to have explict impleentations of `rand` since we want
@@ -668,6 +716,7 @@ const DEMO_MODELS = (
668716
demo_dot_assume_observe_submodel(),
669717
demo_dot_assume_dot_observe_matrix(),
670718
demo_dot_assume_matrix_dot_observe_matrix(),
719+
demo_assume_matrix_dot_observe_matrix(),
671720
)
672721

673722
# Model to test `StaticTransformation` with.

src/utils.jl

Lines changed: 9 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -275,26 +275,24 @@ end
275275
randrealuni(rng::Random.AbstractRNG) = 4 * rand(rng) - 2
276276
randrealuni(rng::Random.AbstractRNG, args...) = 4 .* rand(rng, args...) .- 2
277277

278-
const Transformable = Union{
279-
PositiveDistribution,
280-
UnitDistribution,
281-
TransformDistribution,
282-
SimplexDistribution,
283-
PDMatDistribution,
284-
}
285-
istransformable(dist) = false
286-
istransformable(::Transformable) = true
278+
istransformable(dist) = link_transform(dist) !== identity
287279

288280
#################################
289281
# Single-sample initialisations #
290282
#################################
291283

292284
inittrans(rng, dist::UnivariateDistribution) = Bijectors.invlink(dist, randrealuni(rng))
293285
function inittrans(rng, dist::MultivariateDistribution)
294-
return Bijectors.invlink(dist, randrealuni(rng, size(dist)[1]))
286+
# Get the length of the unconstrained vector
287+
b = link_transform(dist)
288+
d = Bijectors.output_length(b, length(dist))
289+
return Bijectors.invlink(dist, randrealuni(rng, d))
295290
end
296291
function inittrans(rng, dist::MatrixDistribution)
297-
return Bijectors.invlink(dist, randrealuni(rng, size(dist)...))
292+
# Get the size of the unconstrained vector
293+
b = link_transform(dist)
294+
sz = Bijectors.output_size(b, size(dist))
295+
return Bijectors.invlink(dist, randrealuni(rng, sz...))
298296
end
299297

300298
################################

test/Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@ Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f"
2323
[compat]
2424
AbstractMCMC = "2.1, 3.0, 4"
2525
AbstractPPL = "0.5"
26-
Bijectors = "0.11, 0.12"
26+
Bijectors = "0.13"
2727
Distributions = "0.25"
2828
DistributionsAD = "0.6.3"
2929
Documenter = "0.26.1, 0.27"

test/sampler.jl

Lines changed: 27 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -24,12 +24,37 @@
2424
@test chains isa Vector{<:VarInfo}
2525
@test length(chains) == N
2626

27-
# Expected value of ``X`` where ``X ~ U[-2, 2]`` is ≈ 0.
28-
@test mean(vi[@varname(m)] for vi in chains) 0 atol = 0.1
27+
# `m` is Gaussian, i.e. no transformation is used, so it
28+
# should have a mean equal to its prior, i.e. 2.
29+
@test mean(vi[@varname(m)] for vi in chains) 2 atol = 0.1
2930

3031
# Expected value of ``exp(X)`` where ``X ~ U[-2, 2]`` is ≈ 1.8.
3132
@test mean(vi[@varname(s)] for vi in chains) 1.8 atol = 0.1
3233
end
34+
35+
@testset "init" begin
36+
@testset "$(model.f)" for model in DynamicPPL.TestUtils.DEMO_MODELS
37+
@testset "$(model.f)" for model in DynamicPPL.TestUtils.DEMO_MODELS
38+
N = 1000
39+
chain_init = sample(model, SampleFromUniform(), N; progress=false)
40+
41+
for vn in keys(first(chain_init))
42+
if AbstractPPL.subsumes(@varname(s), vn)
43+
# `s ~ InverseGamma(2, 3)` and its unconstrained value will be sampled from Unif[-2,2].
44+
dist = InverseGamma(2, 3)
45+
b = DynamicPPL.link_transform(dist)
46+
@test mean(mean(b(vi[vn])) for vi in chain_init) 0 atol = 0.11
47+
elseif AbstractPPL.subsumes(@varname(m), vn)
48+
# `m ~ Normal(0, sqrt(s))` and its constrained value is the same.
49+
@test mean(mean(vi[vn]) for vi in chain_init) 0 atol = 0.11
50+
else
51+
error("Unknown variable name: $vn")
52+
end
53+
end
54+
end
55+
end
56+
end
57+
3358
@testset "Initial parameters" begin
3459
# dummy algorithm that just returns initial value and does not perform any sampling
3560
abstract type OnlyInitAlg end

0 commit comments

Comments
 (0)