Skip to content

Commit 4b4b34b

Browse files
committed
Fix and test initialization (#210)
1 parent 1244b98 commit 4b4b34b

File tree

3 files changed

+67
-2
lines changed

3 files changed

+67
-2
lines changed

Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
name = "DynamicPPL"
22
uuid = "366bfd00-2699-11ea-058f-f148b4cae6d8"
3-
version = "0.10.3"
3+
version = "0.10.4"
44

55
[deps]
66
AbstractMCMC = "80f14c24-f653-4e6a-9b94-39d6b0f70001"

src/sampler.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -112,7 +112,7 @@ function initialize_parameters!(vi::AbstractVarInfo, init_params, spl::Sampler)
112112
linked = islinked(vi, spl)
113113
linked && invlink!(vi, spl)
114114
theta = vi[spl]
115-
length(theta) == length(init_theta_flat) ||
115+
length(theta) == length(init_theta) ||
116116
error("Provided initial value doesn't match the dimension of the model")
117117

118118
# Update values that are provided.

test/sampler.jl

Lines changed: 65 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -40,3 +40,68 @@ Random.seed!(100)
4040
@test mean(vi[@varname(s)] for vi in chains) 1.8 atol = 0.1
4141
end
4242

43+
@testset "Initial parameters" begin
44+
# dummy algorithm that just returns initial value and does not perform any sampling
45+
struct OnlyInitAlg end
46+
function DynamicPPL.initialstep(
47+
rng::Random.AbstractRNG,
48+
model::Model,
49+
::Sampler{OnlyInitAlg},
50+
vi::AbstractVarInfo;
51+
kwargs...,
52+
)
53+
return vi, nothing
54+
end
55+
DynamicPPL.getspace(::OnlyInitAlg) = ()
56+
57+
# model with one variable: initialization p = 0.2
58+
@model function coinflip()
59+
p ~ Beta(1, 1)
60+
10 ~ Binomial(25, p)
61+
end
62+
model = coinflip()
63+
sampler = Sampler(OnlyInitAlg())
64+
lptrue = logpdf(Binomial(25, 0.2), 10)
65+
chain = sample(model, sampler, 1; init_params = 0.2)
66+
@test chain[1].metadata.p.vals == [0.2]
67+
@test getlogp(chain[1]) == lptrue
68+
69+
# parallel sampling
70+
chains = sample(model, sampler, MCMCThreads(), 1, 10; init_params = 0.2)
71+
for c in chains
72+
@test c[1].metadata.p.vals == [0.2]
73+
@test getlogp(c[1]) == lptrue
74+
end
75+
76+
# model with two variables: initialization s = 4, m = -1
77+
@model function twovars()
78+
s ~ InverseGamma(2, 3)
79+
m ~ Normal(0, sqrt(s))
80+
end
81+
model = twovars()
82+
lptrue = logpdf(InverseGamma(2, 3), 4) + logpdf(Normal(0, 2), -1)
83+
chain = sample(model, sampler, 1; init_params = [4, -1])
84+
@test chain[1].metadata.s.vals == [4]
85+
@test chain[1].metadata.m.vals == [-1]
86+
@test getlogp(chain[1]) == lptrue
87+
88+
# parallel sampling
89+
chains = sample(model, sampler, MCMCThreads(), 1, 10; init_params = [4, -1])
90+
for c in chains
91+
@test c[1].metadata.s.vals == [4]
92+
@test c[1].metadata.m.vals == [-1]
93+
@test getlogp(c[1]) == lptrue
94+
end
95+
96+
# set only m = -1
97+
chain = sample(model, sampler, 1; init_params = [missing, -1])
98+
@test !ismissing(chain[1].metadata.s.vals[1])
99+
@test chain[1].metadata.m.vals == [-1]
100+
101+
# parallel sampling
102+
chains = sample(model, sampler, MCMCThreads(), 1, 10; init_params = [missing, -1])
103+
for c in chains
104+
@test !ismissing(c[1].metadata.s.vals[1])
105+
@test c[1].metadata.m.vals == [-1]
106+
end
107+
end

0 commit comments

Comments
 (0)